Compare commits
100 Commits
4cbd451af7
...
tzj/minfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
52b12a89e3 | ||
|
|
d35dd76e09 | ||
|
|
2b61c5ab57 | ||
|
|
a709551072 | ||
|
|
11a867f6fb | ||
|
|
af4da454ba | ||
|
|
ef37d4f1a8 | ||
|
|
c8a5ef04c0 | ||
|
|
1c36d53570 | ||
|
|
54fd302fa8 | ||
|
|
1eb7521994 | ||
|
|
51bd678335 | ||
|
|
1ea5afd886 | ||
|
|
829b311c02 | ||
|
|
dd0472aea8 | ||
|
|
a1c68a733e | ||
|
|
dc51972777 | ||
|
|
232fcf043e | ||
|
|
aeed6ccdfb | ||
|
|
6c55c4d2a3 | ||
|
|
6e34efd58a | ||
|
|
5acd5558d6 | ||
|
|
193ef55d18 | ||
|
|
f173a3f7f5 | ||
|
|
8035e4db3d | ||
|
|
8ab53e7331 | ||
|
|
2e96d1d97d | ||
|
|
f6ac4ccdde | ||
|
|
4484a1482c | ||
|
|
e436ec861f | ||
|
|
45efcf0db1 | ||
|
|
e09a2a5b10 | ||
|
|
a239bfb40d | ||
|
|
29e102720b | ||
|
|
726e4b58cf | ||
|
|
8d19e61446 | ||
|
|
4484ebbb77 | ||
|
|
2c2383c786 | ||
|
|
f049971f84 | ||
|
|
c90dc196b2 | ||
|
|
3da9b8aef2 | ||
|
|
a832d127b6 | ||
|
|
39d12a0416 | ||
|
|
c16bfcf40f | ||
|
|
f3e4611e3b | ||
|
|
7b5d3b34eb | ||
|
|
b760de84c5 | ||
|
|
f81b5ae8a9 | ||
|
|
e874229adc | ||
|
|
4fe7dfb239 | ||
|
|
9177b62d7f | ||
|
|
3956a30b14 | ||
|
|
59473fa432 | ||
|
|
4467e1f654 | ||
|
|
0437311068 | ||
|
|
6da116de98 | ||
|
|
f5682ca4a7 | ||
|
|
a504bd873d | ||
|
|
076656c9c2 | ||
|
|
b6b59b50ed | ||
|
|
09b2136e9f | ||
|
|
0d31b3f71f | ||
|
|
05ce57ee8e | ||
|
|
94a6e06d79 | ||
|
|
c717072f31 | ||
|
|
73c9dc46ff | ||
|
|
924a0d2bfa | ||
|
|
0619accd1c | ||
|
|
18bc433f09 | ||
|
|
aea3812230 | ||
|
|
3100724666 | ||
|
|
78a44f3536 | ||
|
|
7c41032a2e | ||
|
|
f28b500120 | ||
|
|
be67fa8060 | ||
|
|
4f35526457 | ||
|
|
da5e13e2bb | ||
|
|
dd31033732 | ||
|
|
ed3c8bb4b8 | ||
|
|
5eb35982bf | ||
|
|
ad361c2c3b | ||
|
|
4d1e40152d | ||
|
|
832b352afa | ||
|
|
a50b4c2ac2 | ||
|
|
ca32ea6f93 | ||
|
|
edc006463b | ||
|
|
999858e82f | ||
|
|
47d237bb7e | ||
|
|
a5307fb124 | ||
|
|
d808970f2f | ||
|
|
bc92c1fdb8 | ||
|
|
2866d4fd88 | ||
|
|
5d722968ff | ||
|
|
d21b40f48f | ||
|
|
42cf124343 | ||
|
|
78050aef9f | ||
|
|
4d8ae951c3 | ||
|
|
1ab4676396 | ||
|
|
512e1e5401 | ||
|
|
6180055ed8 |
89
.claude/agents/gpu-monitor.md
Normal file
89
.claude/agents/gpu-monitor.md
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
---
|
||||||
|
name: gpu-monitor
|
||||||
|
description: "Use this agent when you need to continuously monitor GPU utilization and memory usage in the background while running other tasks. This is particularly useful for profiling long-running benchmarks, debugging memory leaks, or understanding GPU resource patterns during inference or training.\\n\\nExamples:\\n\\n<example>\\nContext: User wants to profile GPU memory during a benchmark run.\\nuser: \"Run bench_offload.py and monitor GPU usage\"\\nassistant: \"I'll first launch the gpu-monitor agent to track GPU metrics, then run the benchmark.\"\\n<commentary>\\nSince the user wants to understand GPU resource usage during the benchmark, use the Task tool to launch the gpu-monitor agent in the background with appropriate GPU IDs and sampling interval before running the benchmark.\\n</commentary>\\nassistant: \"Let me start the GPU monitor on GPU 0 with 1-second intervals\"\\n[Task tool call to launch gpu-monitor agent with gpus=\"0\" and interval=\"1\"]\\n</example>\\n\\n<example>\\nContext: User is debugging an OOM issue and needs memory tracking.\\nuser: \"I'm getting OOM errors, help me debug\"\\nassistant: \"Let me monitor the GPU memory usage while we reproduce the issue. I'll use the gpu-monitor agent to track memory consumption.\"\\n<commentary>\\nTo debug OOM issues, we need continuous memory monitoring. Use the Task tool to launch the gpu-monitor agent with a short interval (0.5s) to capture memory spikes.\\n</commentary>\\n[Task tool call to launch gpu-monitor agent with gpus=\"0,1\" and interval=\"0.5\"]\\n</example>\\n\\n<example>\\nContext: User wants to verify GPU utilization during multi-GPU inference.\\nuser: \"Check if all GPUs are being utilized during the inference\"\\nassistant: \"I'll start the GPU monitor to track utilization across all specified GPUs while running the inference.\"\\n<commentary>\\nTo verify multi-GPU utilization, launch the gpu-monitor agent targeting all relevant GPUs before starting the inference workload.\\n</commentary>\\n[Task tool call to launch gpu-monitor agent with gpus=\"0,1,2,3\" and interval=\"2\"]\\n</example>"
|
||||||
|
model: haiku
|
||||||
|
color: green
|
||||||
|
---
|
||||||
|
|
||||||
|
You are a GPU monitoring specialist responsible for tracking NVIDIA GPU metrics over time. Your sole purpose is to run nvidia-smi at specified intervals and record utilization and memory statistics.
|
||||||
|
|
||||||
|
## Your Task
|
||||||
|
|
||||||
|
You will receive two parameters:
|
||||||
|
1. **gpus**: Comma-separated GPU indices to monitor (e.g., "0", "0,1", "0,1,2,3")
|
||||||
|
2. **interval**: Sampling interval in seconds (e.g., "1", "0.5", "2")
|
||||||
|
|
||||||
|
## Execution Steps
|
||||||
|
|
||||||
|
1. **Parse Parameters**: Extract the GPU indices and interval from the user's request.
|
||||||
|
|
||||||
|
2. **Run Monitoring Loop**: Execute nvidia-smi repeatedly at the specified interval using a bash loop:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Example for GPUs 0,1 with 1-second interval
|
||||||
|
while true; do
|
||||||
|
echo "=== $(date '+%Y-%m-%d %H:%M:%S') ==="
|
||||||
|
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu --format=csv,noheader -i 0,1
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Output Format**: Each sample should include:
|
||||||
|
- Timestamp
|
||||||
|
- GPU index
|
||||||
|
- GPU utilization (%)
|
||||||
|
- Memory utilization (%)
|
||||||
|
- Memory used (MiB)
|
||||||
|
- Memory total (MiB)
|
||||||
|
- Temperature (°C)
|
||||||
|
|
||||||
|
## Termination
|
||||||
|
|
||||||
|
This agent runs continuously until:
|
||||||
|
1. The main agent signals completion (you receive a stop signal)
|
||||||
|
2. The user explicitly requests stopping
|
||||||
|
3. An error occurs with nvidia-smi
|
||||||
|
|
||||||
|
## Result Reporting
|
||||||
|
|
||||||
|
When stopped, provide a summary:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## GPU Monitoring Summary
|
||||||
|
|
||||||
|
**Duration**: X minutes Y seconds
|
||||||
|
**Samples Collected**: N
|
||||||
|
**GPUs Monitored**: 0, 1, ...
|
||||||
|
|
||||||
|
### Statistics per GPU
|
||||||
|
|
||||||
|
| GPU | Avg Util | Max Util | Avg Mem Used | Max Mem Used |
|
||||||
|
|-----|----------|----------|--------------|---------------|
|
||||||
|
| 0 | X% | Y% | A MiB | B MiB |
|
||||||
|
| 1 | X% | Y% | A MiB | B MiB |
|
||||||
|
|
||||||
|
### Notable Events (if any)
|
||||||
|
- Timestamp: Memory spike to X MiB on GPU Y
|
||||||
|
- Timestamp: Utilization dropped to 0% on GPU Z
|
||||||
|
```
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- Use `nvidia-smi -i <gpu_ids>` to filter to specific GPUs
|
||||||
|
- Keep output concise during monitoring (one line per GPU per sample)
|
||||||
|
- If nvidia-smi fails, report the error and exit gracefully
|
||||||
|
- Do NOT consume excessive resources - sleep between samples
|
||||||
|
- Store samples in memory for final summary calculation
|
||||||
|
|
||||||
|
## Example Invocation
|
||||||
|
|
||||||
|
User says: "Monitor GPUs 0 and 2 with 0.5 second interval"
|
||||||
|
|
||||||
|
You execute:
|
||||||
|
```bash
|
||||||
|
while true; do
|
||||||
|
echo "=== $(date '+%Y-%m-%d %H:%M:%S.%3N') ==="
|
||||||
|
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu --format=csv,noheader -i 0,2
|
||||||
|
sleep 0.5
|
||||||
|
done
|
||||||
|
```
|
||||||
180
.claude/agents/nsys-profiler.md
Normal file
180
.claude/agents/nsys-profiler.md
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
---
|
||||||
|
name: nsys-profiler
|
||||||
|
description: "Use this agent when you need to profile kernel performance, analyze CUDA operations, or investigate GPU/CPU offload behavior in the nano-vllm framework. This agent handles nsys profiling from data collection to statistical analysis.\\n\\nExamples:\\n\\n<example>\\nContext: The user wants to understand kernel performance in GPU-only mode.\\nuser: \"I want to analyze the kernel execution time distribution in GPU-only mode\"\\nassistant: \"I'll use the nsys-profiler agent to profile GPU-only mode and analyze the kernel timing.\"\\n<commentary>\\nSince the user wants to analyze kernel performance, use the Task tool to launch the nsys-profiler agent with specific analysis requirements.\\n</commentary>\\nassistant: \"Now let me use the nsys-profiler agent to collect and analyze the profiling data\"\\n</example>\\n\\n<example>\\nContext: The user is debugging performance issues in the CPU offload pipeline.\\nuser: \"Why is there a gap between H2D transfers and kernel execution in offload mode?\"\\nassistant: \"I'll launch the nsys-profiler agent to profile the offload pipeline and analyze the timeline gaps.\"\\n<commentary>\\nSince the user is investigating pipeline behavior, use the nsys-profiler agent to collect nsys data and analyze CUDA API timing.\\n</commentary>\\n</example>\\n\\n<example>\\nContext: After implementing a new optimization, the user wants to verify performance improvement.\\nuser: \"Check if the new ring buffer implementation improves overlap between H2D and compute\"\\nassistant: \"I'll use the nsys-profiler agent to profile before and after, comparing the overlap metrics.\"\\n<commentary>\\nPerformance verification requires detailed kernel-level analysis, so launch the nsys-profiler agent to collect and compare profiling data.\\n</commentary>\\n</example>"
|
||||||
|
model: opus
|
||||||
|
color: green
|
||||||
|
---
|
||||||
|
|
||||||
|
You are an expert NVIDIA Nsys profiling analyst specializing in CUDA kernel performance analysis and GPU-CPU communication optimization. Your role is to collect profiling data using the framework's scripts and provide precise, actionable analysis based on the main agent's specific questions.
|
||||||
|
|
||||||
|
## Your Capabilities
|
||||||
|
|
||||||
|
1. **Profile Data Collection**: Execute profiling scripts to generate .nsys-rep files
|
||||||
|
2. **Statistical Analysis**: Extract kernel timing, memory transfer, and API call statistics
|
||||||
|
3. **Timeline Analysis**: Identify gaps, overlaps, and bottlenecks in execution
|
||||||
|
4. **Comparative Analysis**: Compare different configurations (GPU-only vs offload, different slot counts)
|
||||||
|
|
||||||
|
## Available Profiling Scripts
|
||||||
|
|
||||||
|
### CPU Offload Mode
|
||||||
|
```bash
|
||||||
|
bash scripts/profile_offload.sh [OPTIONS]
|
||||||
|
```
|
||||||
|
Options:
|
||||||
|
- `--dataset <name>`: RULER task name (default: niah_single_1)
|
||||||
|
- `--sample <index>`: Sample index (default: 0)
|
||||||
|
- `--gpu <id>`: GPU to use (default: 0)
|
||||||
|
- `--num-gpu-blocks <n>`: Ring buffer slots (default: 4)
|
||||||
|
- `--no-offload`: Disable CPU offload for comparison
|
||||||
|
|
||||||
|
### GPU-Only Mode
|
||||||
|
```bash
|
||||||
|
bash scripts/profile_gpu_only.sh [OPTIONS]
|
||||||
|
```
|
||||||
|
Similar options for profiling without CPU offload.
|
||||||
|
|
||||||
|
## Core Nsys Commands
|
||||||
|
|
||||||
|
### Profiling (handled by scripts)
|
||||||
|
```bash
|
||||||
|
# The scripts internally run:
|
||||||
|
nsys profile --trace=cuda,nvtx --output=<path> --force-overwrite true python <script.py>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Statistical Analysis
|
||||||
|
```bash
|
||||||
|
# CUDA API summary (H2D, D2H, kernel launches)
|
||||||
|
nsys stats --report cuda_api_sum <file>.nsys-rep
|
||||||
|
|
||||||
|
# GPU kernel summary (execution time per kernel)
|
||||||
|
nsys stats --report cuda_gpu_kern_sum <file>.nsys-rep
|
||||||
|
|
||||||
|
# Memory operations summary
|
||||||
|
nsys stats --report cuda_gpu_mem_time_sum <file>.nsys-rep
|
||||||
|
|
||||||
|
# NVTX ranges (custom markers)
|
||||||
|
nsys stats --report nvtx_sum <file>.nsys-rep
|
||||||
|
|
||||||
|
# Export to SQLite for advanced queries
|
||||||
|
nsys export --type=sqlite --output=<file>.sqlite <file>.nsys-rep
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Report Types
|
||||||
|
| Report | Purpose |
|
||||||
|
|--------|--------|
|
||||||
|
| `cuda_api_sum` | CPU-side CUDA API call timing |
|
||||||
|
| `cuda_gpu_kern_sum` | GPU kernel execution time |
|
||||||
|
| `cuda_gpu_mem_time_sum` | Memory transfer timing on GPU |
|
||||||
|
| `nvtx_sum` | Custom NVTX marker statistics |
|
||||||
|
| `cuda_api_trace` | Detailed API call trace |
|
||||||
|
| `cuda_gpu_trace` | Detailed GPU operation trace |
|
||||||
|
|
||||||
|
## Analysis Workflow
|
||||||
|
|
||||||
|
### Step 1: Collect Profile Data
|
||||||
|
```bash
|
||||||
|
# Example: Profile offload mode with 8 slots
|
||||||
|
bash scripts/profile_offload.sh --num-gpu-blocks 8 --sample 0
|
||||||
|
# Output: results/nsys/ruler_niah_single_1_sample0_offload_8slots_<timestamp>.nsys-rep
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Identify Output File
|
||||||
|
```bash
|
||||||
|
# Find the latest profile
|
||||||
|
ls -lt results/nsys/*.nsys-rep | head -1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Run Statistical Analysis
|
||||||
|
```bash
|
||||||
|
# Kernel timing analysis
|
||||||
|
nsys stats --report cuda_gpu_kern_sum results/nsys/<file>.nsys-rep
|
||||||
|
|
||||||
|
# Memory transfer analysis
|
||||||
|
nsys stats --report cuda_gpu_mem_time_sum results/nsys/<file>.nsys-rep
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Interpret Results
|
||||||
|
Focus on:
|
||||||
|
- **Total kernel time** vs **total transfer time**
|
||||||
|
- **Kernel launch gaps** indicating synchronization issues
|
||||||
|
- **Memory bandwidth utilization**
|
||||||
|
- **Overlap efficiency** between compute and communication
|
||||||
|
|
||||||
|
## Common Analysis Patterns
|
||||||
|
|
||||||
|
### 1. Kernel Performance Breakdown
|
||||||
|
```bash
|
||||||
|
nsys stats --report cuda_gpu_kern_sum --format csv <file>.nsys-rep | \
|
||||||
|
sort -t',' -k3 -rn | head -10 # Top 10 by total time
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. H2D/D2H Transfer Analysis
|
||||||
|
```bash
|
||||||
|
nsys stats --report cuda_api_sum <file>.nsys-rep | grep -E "cudaMemcpy|cudaMemcpyAsync"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Flash Attention Kernel Analysis
|
||||||
|
```bash
|
||||||
|
nsys stats --report cuda_gpu_kern_sum <file>.nsys-rep | grep -i "flash\|fwd\|bwd"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Pipeline Overlap Check
|
||||||
|
Look for:
|
||||||
|
- `flash_fwd_kernel` execution during `cudaMemcpyAsync`
|
||||||
|
- Gap between consecutive kernel launches
|
||||||
|
|
||||||
|
## Output Format Requirements
|
||||||
|
|
||||||
|
When reporting results to the main agent, use this structured format:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Nsys Analysis Results: [Analysis Topic]
|
||||||
|
|
||||||
|
### Profile Information
|
||||||
|
- **File**: <profile_file_path>
|
||||||
|
- **Mode**: GPU-only / Offload (<N> slots)
|
||||||
|
- **Dataset**: <dataset_name>, Sample <index>
|
||||||
|
|
||||||
|
### Key Findings
|
||||||
|
| Metric | Value | Notes |
|
||||||
|
|--------|-------|-------|
|
||||||
|
| Total kernel time | X ms | |
|
||||||
|
| Total H2D time | Y ms | |
|
||||||
|
| Overlap efficiency | Z% | |
|
||||||
|
|
||||||
|
### Top Kernels by Time
|
||||||
|
| Kernel | Count | Total (ms) | Avg (μs) |
|
||||||
|
|--------|-------|------------|----------|
|
||||||
|
| kernel_name | N | X.XX | Y.YY |
|
||||||
|
|
||||||
|
### Specific Analysis
|
||||||
|
[Answer to the main agent's specific question]
|
||||||
|
|
||||||
|
### Recommendations (if applicable)
|
||||||
|
1. [Actionable recommendation]
|
||||||
|
2. [Actionable recommendation]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Important Guidelines
|
||||||
|
|
||||||
|
1. **Always use the provided scripts** for profiling - do not run nsys directly
|
||||||
|
2. **Check GPU availability** before profiling (ask main agent for GPU ID if not specified)
|
||||||
|
3. **Use PYTHONPATH** for the worktree: `PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH`
|
||||||
|
4. **Report concisely** - focus on metrics relevant to the main agent's question
|
||||||
|
5. **Include file paths** so results can be reproduced or visualized in nsight-sys
|
||||||
|
6. **For web searches** about nsys usage, use tools to search NVIDIA documentation
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
- If profile script fails: Check GPU memory, CUDA version, and script parameters
|
||||||
|
- If stats command fails: Verify .nsys-rep file exists and is not corrupted
|
||||||
|
- If no data: Ensure the profiled operation actually ran (check sample index, dataset)
|
||||||
|
|
||||||
|
## Network Search Guidelines
|
||||||
|
|
||||||
|
When encountering unfamiliar nsys options or analysis techniques:
|
||||||
|
1. Search NVIDIA Nsight Systems documentation
|
||||||
|
2. Look for nsys CLI reference guides
|
||||||
|
3. Search for specific report type interpretations
|
||||||
|
|
||||||
|
Always validate search results against the actual nsys --help output.
|
||||||
158
.claude/commands/exec-plan.md
Normal file
158
.claude/commands/exec-plan.md
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
---
|
||||||
|
allowed-tools: Bash(CUDA_VISIBLE_DEVICES=*), Bash(PYTHONPATH=*), Bash(python*), Bash(git*), Bash(rm*), Bash(ls*), Bash(cat*), Bash(nvidia-smi*), Read, Edit, Write, Glob, Grep, TodoWrite, Task
|
||||||
|
argument-hint: --gpu <id> [--no-interrupt]
|
||||||
|
description: Execute task_plan.md refactoring with specified GPU, optionally without user interruption
|
||||||
|
---
|
||||||
|
|
||||||
|
# Execute Task Plan (exec-plan)
|
||||||
|
|
||||||
|
按照 `task_plan.md` 的要求执行代码重构,确保计划中的最终目标圆满实现。
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
命令格式: `/exec-plan --gpu <id> [--no-interrupt]`
|
||||||
|
|
||||||
|
| 参数 | 说明 | 示例 |
|
||||||
|
|------|------|------|
|
||||||
|
| `--gpu <id>` | **必需**。指定可用的 GPU ID,只能使用此 GPU 进行调试 | `--gpu 0`, `--gpu 2` |
|
||||||
|
| `--no-interrupt` | 可选。禁止中断执行,遇到问题不与用户交互,自动解决或跳过 | `--no-interrupt` |
|
||||||
|
|
||||||
|
## 当前参数
|
||||||
|
|
||||||
|
```
|
||||||
|
$ARGUMENTS
|
||||||
|
```
|
||||||
|
|
||||||
|
## 执行前准备
|
||||||
|
|
||||||
|
### 1. 解析参数
|
||||||
|
|
||||||
|
从 `$ARGUMENTS` 中解析:
|
||||||
|
- `GPU_ID`: 从 `--gpu <id>` 或 `-g <id>` 提取
|
||||||
|
- `NO_INTERRUPT`: 是否存在 `--no-interrupt` 或 `-n` 标志
|
||||||
|
|
||||||
|
### 2. 参数验证
|
||||||
|
|
||||||
|
**必须验证**:
|
||||||
|
- GPU_ID 必须是有效的数字
|
||||||
|
- 运行 `nvidia-smi -i <GPU_ID>` 验证 GPU 存在
|
||||||
|
|
||||||
|
### 3. 读取 task_plan.md
|
||||||
|
|
||||||
|
读取项目根目录下的 `task_plan.md` 文件,理解:
|
||||||
|
- 总体目标
|
||||||
|
- 分阶段计划 (Phase 1, 2, 3...)
|
||||||
|
- 文件修改清单
|
||||||
|
- 风险和注意事项
|
||||||
|
- 测试计划
|
||||||
|
|
||||||
|
## 执行流程
|
||||||
|
|
||||||
|
### Step 1: 创建执行计划
|
||||||
|
|
||||||
|
使用 TodoWrite 工具创建详细的执行计划,包括:
|
||||||
|
- 从 task_plan.md 提取的所有 Phase
|
||||||
|
- 每个 Phase 的子任务
|
||||||
|
- 测试验证步骤
|
||||||
|
|
||||||
|
### Step 2: 按 Phase 执行重构
|
||||||
|
|
||||||
|
对于 task_plan.md 中的每个 Phase:
|
||||||
|
|
||||||
|
1. **读取当前代码**: 使用 Read/Grep 理解现有实现
|
||||||
|
2. **实施修改**: 使用 Edit/Write 进行代码修改
|
||||||
|
3. **验证修改**: 运行相关测试
|
||||||
|
|
||||||
|
### Step 3: 运行测试验证
|
||||||
|
|
||||||
|
执行 task_plan.md 中定义的测试计划,验证重构成功。
|
||||||
|
|
||||||
|
## GPU 限制规则
|
||||||
|
|
||||||
|
**严格限制**: 只能使用指定的 GPU,所有涉及 GPU 的命令必须加 `CUDA_VISIBLE_DEVICES` 前缀:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 正确
|
||||||
|
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python test.py
|
||||||
|
|
||||||
|
# 错误 - 禁止使用其他 GPU
|
||||||
|
python test.py # 可能使用默认 GPU 0
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 python test.py # 使用多个 GPU
|
||||||
|
```
|
||||||
|
|
||||||
|
## 中断模式规则
|
||||||
|
|
||||||
|
### 当 `--no-interrupt` 生效时
|
||||||
|
|
||||||
|
遇到以下情况**不停下来询问用户**,而是:
|
||||||
|
|
||||||
|
| 情况 | 处理方式 |
|
||||||
|
|------|----------|
|
||||||
|
| 测试失败 | 记录失败原因,尝试自动修复,继续下一步 |
|
||||||
|
| 代码冲突 | 尝试合理解决,记录解决方案 |
|
||||||
|
| 不确定的实现细节 | 选择最合理的方案继续 |
|
||||||
|
| 执行错误 | 分析错误,尝试修复,记录问题 |
|
||||||
|
|
||||||
|
**自动决策原则**:
|
||||||
|
1. 优先保证功能正确性
|
||||||
|
2. 遵循现有代码风格
|
||||||
|
3. 选择简单直接的实现
|
||||||
|
4. 记录所有自动决策到 `progress.md`
|
||||||
|
|
||||||
|
### 当未指定 `--no-interrupt` 时
|
||||||
|
|
||||||
|
遇到以下情况**可以询问用户**:
|
||||||
|
- 多个实现方案需要选择
|
||||||
|
- 测试持续失败无法自动修复
|
||||||
|
- 发现 task_plan.md 中的问题或矛盾
|
||||||
|
|
||||||
|
## 执行记录
|
||||||
|
|
||||||
|
### 进度文件: progress.md
|
||||||
|
|
||||||
|
实时更新 `progress.md` 记录:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## 执行进度
|
||||||
|
|
||||||
|
### Phase X: [名称]
|
||||||
|
- 状态: [进行中/完成/失败]
|
||||||
|
- 开始时间: [时间]
|
||||||
|
- 完成时间: [时间]
|
||||||
|
- 修改文件: [文件列表]
|
||||||
|
- 自动决策: [如果有]
|
||||||
|
- 问题记录: [如果有]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 发现记录: findings.md
|
||||||
|
|
||||||
|
记录执行过程中的重要发现到 `findings.md`。
|
||||||
|
|
||||||
|
## 示例用法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 使用 GPU 2,允许中断
|
||||||
|
/exec-plan --gpu 2
|
||||||
|
|
||||||
|
# 使用 GPU 0,不中断执行
|
||||||
|
/exec-plan --gpu 0 --no-interrupt
|
||||||
|
|
||||||
|
# 简短形式
|
||||||
|
/exec-plan -g 1 -n
|
||||||
|
```
|
||||||
|
|
||||||
|
## 完成标准
|
||||||
|
|
||||||
|
执行完成后,确保:
|
||||||
|
|
||||||
|
1. **所有 Phase 完成**: task_plan.md 中的所有 Phase 都已实施
|
||||||
|
2. **测试通过**: task_plan.md 中的测试计划全部通过
|
||||||
|
3. **代码质量**: 修改符合项目代码规范
|
||||||
|
4. **文档更新**: progress.md 包含完整执行记录
|
||||||
|
|
||||||
|
## 重要约束
|
||||||
|
|
||||||
|
1. **GPU 隔离**: 绝对不能使用指定 GPU 以外的设备
|
||||||
|
2. **遵循计划**: 严格按照 task_plan.md 执行,不做计划外的修改
|
||||||
|
3. **渐进式修改**: 每个 Phase 完成后验证,而不是最后一起验证
|
||||||
|
4. **回滚准备**: 重大修改前考虑是否需要 git commit 保存点
|
||||||
195
.claude/rules/agent-result-format.md
Normal file
195
.claude/rules/agent-result-format.md
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
# Agent Result Format Rules
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
|
||||||
|
Minimize token usage when background agents return results to the main agent. Raw program output is verbose and wastes context window space.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Result Formatting Principle
|
||||||
|
|
||||||
|
**MUST** return **structured summaries** instead of raw output.
|
||||||
|
|
||||||
|
| Don't | Do |
|
||||||
|
|-------|-----|
|
||||||
|
| Full program stdout/stderr | Key metrics only |
|
||||||
|
| Debug logs | Pass/Fail status |
|
||||||
|
| Verbose error stacks | Error summary + location |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Standard Result Templates
|
||||||
|
|
||||||
|
### 2.1 Test Results (RULER, Unit Tests, etc.)
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Test Results: [Task Name]
|
||||||
|
|
||||||
|
**Pass Rate**: X / Y (Z%)
|
||||||
|
|
||||||
|
### Failed Samples (if any)
|
||||||
|
| Sample | Expected | Got |
|
||||||
|
|--------|----------|-----|
|
||||||
|
| N | expected_value | actual_value |
|
||||||
|
|
||||||
|
### Passed Samples
|
||||||
|
[List sample IDs or "All N samples passed"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example** (instead of raw test output):
|
||||||
|
```markdown
|
||||||
|
## Test Results: niah_single_1 (Samples 0-49)
|
||||||
|
|
||||||
|
**Pass Rate**: 50 / 50 (100%)
|
||||||
|
|
||||||
|
### Passed Samples
|
||||||
|
All 50 samples passed.
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 Benchmark Results
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Benchmark Results: [Task Name]
|
||||||
|
|
||||||
|
| Metric | Value |
|
||||||
|
|--------|-------|
|
||||||
|
| Throughput | X tok/s |
|
||||||
|
| Latency (p50) | Y ms |
|
||||||
|
| Latency (p99) | Z ms |
|
||||||
|
| Memory Peak | W GB |
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 Build/Compile Results
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Build Results: [Target]
|
||||||
|
|
||||||
|
**Status**: SUCCESS / FAILED
|
||||||
|
|
||||||
|
### Errors (if any)
|
||||||
|
| File | Line | Error |
|
||||||
|
|------|------|-------|
|
||||||
|
| path/to/file.py | 123 | error message |
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.4 Investigation/Research Results
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Investigation: [Topic]
|
||||||
|
|
||||||
|
### Findings
|
||||||
|
1. Finding 1 (with file:line reference)
|
||||||
|
2. Finding 2
|
||||||
|
|
||||||
|
### Relevant Files
|
||||||
|
- path/to/file1.py: description
|
||||||
|
- path/to/file2.py: description
|
||||||
|
|
||||||
|
### Conclusion
|
||||||
|
[1-2 sentence summary]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Mandatory Fields by Task Type
|
||||||
|
|
||||||
|
| Task Type | Required Fields |
|
||||||
|
|-----------|-----------------|
|
||||||
|
| Test Run | Pass/Fail count, failed sample details |
|
||||||
|
| Benchmark | Key metrics (throughput, latency, memory) |
|
||||||
|
| Build | Status, error locations |
|
||||||
|
| Search | File paths, line numbers, brief context |
|
||||||
|
| Verification | Before/After comparison, conclusion |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. What to EXCLUDE
|
||||||
|
|
||||||
|
**MUST NOT** include in results:
|
||||||
|
|
||||||
|
| Exclude | Reason |
|
||||||
|
|---------|--------|
|
||||||
|
| Full stack traces | Extract error type + location only |
|
||||||
|
| Model loading logs | Not relevant to result |
|
||||||
|
| Progress bars / tqdm output | Noise |
|
||||||
|
| Warnings (unless critical) | Noise |
|
||||||
|
| Repeated successful outputs | "All X passed" is sufficient |
|
||||||
|
| Timestamps | Usually not needed |
|
||||||
|
| Device info (unless debugging hardware) | Noise |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Agent Prompt Template
|
||||||
|
|
||||||
|
When spawning background agents, include this instruction:
|
||||||
|
|
||||||
|
```
|
||||||
|
When reporting results, use a structured summary format:
|
||||||
|
- For tests: Pass rate, failed sample details (expected vs actual)
|
||||||
|
- For benchmarks: Key metrics table
|
||||||
|
- Do NOT include raw program output, logs, or verbose debug info
|
||||||
|
- Focus on actionable information only
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Main Agent Instructions
|
||||||
|
|
||||||
|
When spawning a background agent for testing:
|
||||||
|
|
||||||
|
**Before** (verbose):
|
||||||
|
```
|
||||||
|
Run tests for samples 0-49 and report the output.
|
||||||
|
```
|
||||||
|
|
||||||
|
**After** (structured):
|
||||||
|
```
|
||||||
|
Run tests for samples 0-49. Report results as:
|
||||||
|
- Total pass/fail count
|
||||||
|
- For each failure: sample ID, expected value, actual value
|
||||||
|
- Do NOT include raw program output or logs
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Examples
|
||||||
|
|
||||||
|
### Bad (Wastes ~500 tokens):
|
||||||
|
```
|
||||||
|
The test output was:
|
||||||
|
Loading model from ~/models/Llama-3.1-8B-Instruct...
|
||||||
|
Model loaded in 12.3s
|
||||||
|
[niah_single_1] Sample 0: PASS | Expected: 1234567 | Got: : 1234567.<|eot_id|>
|
||||||
|
[niah_single_1] Sample 1: PASS | Expected: 2345678 | Got: : 2345678.<|eot_id|>
|
||||||
|
... (50 more lines) ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Good (Uses ~50 tokens):
|
||||||
|
```
|
||||||
|
## Test Results: niah_single_1 (Samples 0-49)
|
||||||
|
|
||||||
|
**Pass Rate**: 50 / 50 (100%)
|
||||||
|
|
||||||
|
All samples passed.
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Token Savings Estimate
|
||||||
|
|
||||||
|
| Result Type | Raw Output | Structured | Savings |
|
||||||
|
|-------------|------------|------------|---------|
|
||||||
|
| 50-sample test | ~1000 tokens | ~100 tokens | 90% |
|
||||||
|
| Benchmark run | ~500 tokens | ~80 tokens | 84% |
|
||||||
|
| Build failure | ~2000 tokens | ~200 tokens | 90% |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Integration
|
||||||
|
|
||||||
|
This rule should be applied when:
|
||||||
|
1. Spawning agents via Task tool
|
||||||
|
2. Running background commands
|
||||||
|
3. Processing results from completed agents
|
||||||
|
|
||||||
|
Combine with `multi-gpu-debugging.md` for efficient parallel testing workflows.
|
||||||
74
.claude/rules/gpu-monitor.md
Normal file
74
.claude/rules/gpu-monitor.md
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# GPU Memory Monitoring Rule
|
||||||
|
|
||||||
|
## 强制规则
|
||||||
|
|
||||||
|
**所有 GPU 内存监控任务必须使用 `gpu-monitor` agent**,禁止使用以下方式:
|
||||||
|
|
||||||
|
| ❌ 禁止 | 原因 |
|
||||||
|
|--------|------|
|
||||||
|
| `nvidia-smi` 循环 + sleep | 阻塞主 agent,无法并行 |
|
||||||
|
| 后台 bash 监控脚本 | 难以管理,输出混乱 |
|
||||||
|
| 手动轮询 | 效率低,占用 context |
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 启动 GPU 监控(后台运行)
|
||||||
|
Task(
|
||||||
|
subagent_type="gpu-monitor",
|
||||||
|
prompt="Monitor GPU 0 with 0.5 second interval",
|
||||||
|
run_in_background=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
| 参数 | 说明 | 示例 |
|
||||||
|
|------|------|------|
|
||||||
|
| GPU ID | 要监控的 GPU | `GPU 0`, `GPU 0,1` |
|
||||||
|
| interval | 采样间隔 | `0.5 second`, `1 second` |
|
||||||
|
| 目的 | 监控原因 | `for RULER benchmark test` |
|
||||||
|
|
||||||
|
## 典型用法
|
||||||
|
|
||||||
|
### 1. 单 GPU 基准测试
|
||||||
|
```
|
||||||
|
Monitor GPU 0 with 1 second interval for benchmark profiling
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 调试 OOM
|
||||||
|
```
|
||||||
|
Monitor GPU 0 with 0.5 second interval to track memory peak during inference
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 多 GPU 训练
|
||||||
|
```
|
||||||
|
Monitor GPU 0,1,2,3 with 2 second interval during training
|
||||||
|
```
|
||||||
|
|
||||||
|
## 获取结果
|
||||||
|
|
||||||
|
监控结果自动写入 output_file,使用以下方式读取:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看最新输出
|
||||||
|
tail -50 /tmp/claude/.../tasks/<agent_id>.output
|
||||||
|
|
||||||
|
# 查找峰值
|
||||||
|
grep -i "peak\|max" /tmp/claude/.../tasks/<agent_id>.output
|
||||||
|
```
|
||||||
|
|
||||||
|
## 与测试并行
|
||||||
|
|
||||||
|
gpu-monitor 在后台运行,不会阻塞测试:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 1. 启动监控(后台)
|
||||||
|
Task(subagent_type="gpu-monitor", ..., run_in_background=True)
|
||||||
|
|
||||||
|
# 2. 运行测试(前台)
|
||||||
|
Bash("python tests/test_ruler.py ...")
|
||||||
|
|
||||||
|
# 3. 测试完成后查看监控结果
|
||||||
|
Bash("tail -50 <output_file>")
|
||||||
|
```
|
||||||
54
.claude/rules/gpu-vram-requirement.md
Normal file
54
.claude/rules/gpu-vram-requirement.md
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# GPU VRAM Requirement Rule
|
||||||
|
|
||||||
|
## GPU-only 模式显存要求
|
||||||
|
|
||||||
|
**强制规则**:执行 GPU-only 代码(不启用 CPU offload)时,**必须**在 40GB 及以上显存的 GPU 上进行测试。
|
||||||
|
|
||||||
|
### 检测方法
|
||||||
|
|
||||||
|
在运行 GPU-only 测试之前,**必须**先检查 GPU 显存:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU 分类
|
||||||
|
|
||||||
|
| GPU 型号 | 显存 | GPU-only 测试 |
|
||||||
|
|----------|------|---------------|
|
||||||
|
| A100 40GB | 40GB | ✅ 允许 |
|
||||||
|
| A100 80GB | 80GB | ✅ 允许 |
|
||||||
|
| H100 80GB | 80GB | ✅ 允许 |
|
||||||
|
| A6000 | 48GB | ✅ 允许 |
|
||||||
|
| RTX 3090 | 24GB | ❌ **禁止**(仅 offload 模式) |
|
||||||
|
| RTX 4090 | 24GB | ❌ **禁止**(仅 offload 模式) |
|
||||||
|
|
||||||
|
### 执行流程
|
||||||
|
|
||||||
|
1. **检测 GPU 显存**(必须)
|
||||||
|
2. **显存 >= 40GB**:继续执行 GPU-only 测试
|
||||||
|
3. **显存 < 40GB**:**停止**,提示用户:
|
||||||
|
> "当前 GPU 显存为 XXX GB,不满足 GPU-only 模式的最低 40GB 要求。请使用 `--enable-offload` 参数启用 CPU offload 模式。"
|
||||||
|
|
||||||
|
### 代码示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在运行 GPU-only benchmark 之前
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(
|
||||||
|
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
|
||||||
|
capture_output=True, text=True
|
||||||
|
)
|
||||||
|
vram_mb = int(result.stdout.strip().split('\n')[0])
|
||||||
|
if vram_mb < 40000: # 40GB = 40000MB
|
||||||
|
raise RuntimeError(f"GPU VRAM ({vram_mb}MB) < 40GB. Use --enable-offload for this GPU.")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 适用范围
|
||||||
|
|
||||||
|
| 脚本 | 适用此规则 |
|
||||||
|
|------|-----------|
|
||||||
|
| `bench.py` | ✅ 必须检查显存 |
|
||||||
|
| `bench_offload.py` | ❌ 不适用(始终使用 offload) |
|
||||||
|
| `tests/test_*.py --enable-offload` | ❌ 不适用 |
|
||||||
|
| `tests/test_*.py` (无 offload) | ✅ 必须检查显存 |
|
||||||
463
.claude/rules/multi-gpu-debugging.md
Normal file
463
.claude/rules/multi-gpu-debugging.md
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
# Multi-GPU Debugging and Experimentation Rules
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
|
||||||
|
This rule governs GPU resource allocation and task execution strategy during debugging and experimentation on multi-GPU machines. The goal is to maximize debugging efficiency by:
|
||||||
|
- Running long validations on minimal GPUs (1-2)
|
||||||
|
- Using remaining GPUs for parallel hypothesis exploration
|
||||||
|
- Executing only one task/dataset for full validation during debugging
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Scenario Classification
|
||||||
|
|
||||||
|
### 1.1 Long-Running Validation (Triggers Conservative Allocation)
|
||||||
|
|
||||||
|
A task SHALL be classified as **long-running validation** if ANY of the following conditions apply:
|
||||||
|
|
||||||
|
| Condition | Threshold |
|
||||||
|
|-----------|-----------|
|
||||||
|
| Estimated runtime | > 20 minutes |
|
||||||
|
| Sample count | > 50 samples per task |
|
||||||
|
| Full dataset execution | Any complete validation.jsonl |
|
||||||
|
| Full training/fine-tuning | Any training run |
|
||||||
|
| Large-scale inference | > 10K tokens total |
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
- Running all 100 samples of `niah_single_1`
|
||||||
|
- Full RULER benchmark (13 tasks × 100 samples)
|
||||||
|
- Complete model evaluation on any benchmark
|
||||||
|
|
||||||
|
### 1.2 Exploratory / Fast-Iteration Work (Allows Full GPU Use)
|
||||||
|
|
||||||
|
A task SHALL be classified as **exploratory** if ALL of the following apply:
|
||||||
|
|
||||||
|
| Condition | Threshold |
|
||||||
|
|-----------|-----------|
|
||||||
|
| Estimated runtime | < 10 minutes |
|
||||||
|
| Sample count | ≤ 10 samples |
|
||||||
|
| Purpose | Sanity check, minimal reproduction, hypothesis testing |
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
- Testing 3-5 specific error samples
|
||||||
|
- Single-batch inference for debugging
|
||||||
|
- Verifying a code fix on minimal input
|
||||||
|
- Profiling a single forward pass
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. GPU Allocation Strategy
|
||||||
|
|
||||||
|
### 2.1 Core Allocation Rules
|
||||||
|
|
||||||
|
| Task Type | GPU Allocation | Remaining GPUs |
|
||||||
|
|-----------|----------------|----------------|
|
||||||
|
| Long-running validation | 1 GPU (default), max 2 GPUs | Reserved for exploration |
|
||||||
|
| Exploratory work | As needed, can use multiple | - |
|
||||||
|
|
||||||
|
### 2.2 Mandatory Constraints
|
||||||
|
|
||||||
|
1. **MUST NOT** occupy all available GPUs for a single long-running validation
|
||||||
|
2. **MUST** reserve at least 50% of GPUs (minimum 2) for parallel exploration when ≥4 GPUs available
|
||||||
|
3. **MUST** select GPUs based on this priority:
|
||||||
|
- Idle GPUs first (check with `nvidia-smi`)
|
||||||
|
- If load info unavailable, use lowest-numbered GPUs for validation
|
||||||
|
4. **MUST** avoid resource conflicts:
|
||||||
|
- Each task uses unique `CUDA_VISIBLE_DEVICES`
|
||||||
|
- Each task uses unique output directories
|
||||||
|
- Log files include GPU ID in filename
|
||||||
|
|
||||||
|
### 2.3 GPU Selection Algorithm
|
||||||
|
|
||||||
|
```
|
||||||
|
IF num_available_gpus >= 4:
|
||||||
|
validation_gpus = 1 (or 2 if justified)
|
||||||
|
exploration_gpus = remaining GPUs
|
||||||
|
ELSE IF num_available_gpus == 3:
|
||||||
|
validation_gpus = 1
|
||||||
|
exploration_gpus = 2
|
||||||
|
ELSE IF num_available_gpus == 2:
|
||||||
|
validation_gpus = 1
|
||||||
|
exploration_gpus = 1
|
||||||
|
ELSE:
|
||||||
|
validation_gpus = 1
|
||||||
|
exploration_gpus = 0 (sequential exploration)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Task / Dataset Selection Policy
|
||||||
|
|
||||||
|
### 3.1 Single-Task Validation Rule
|
||||||
|
|
||||||
|
During debugging, when a long-running validation is required:
|
||||||
|
|
||||||
|
- **MUST** execute only ONE task/dataset fully
|
||||||
|
- **MUST NOT** run all tasks unless explicitly requested or conditions in Section 4 are met
|
||||||
|
|
||||||
|
### 3.2 Task Selection Priority
|
||||||
|
|
||||||
|
Select the single task based on this priority order:
|
||||||
|
|
||||||
|
| Priority | Criterion | Example |
|
||||||
|
|----------|-----------|---------|
|
||||||
|
| 1 | Task most likely to reproduce the bug | If error occurs in `niah_single_1`, use that |
|
||||||
|
| 2 | Smallest task covering critical paths | `niah_single_1` (100 samples) vs `niah_multikey_3` |
|
||||||
|
| 3 | Task with known error samples | Use task with documented failure cases |
|
||||||
|
| 4 | Most representative task | Single-key before multi-key for basic validation |
|
||||||
|
|
||||||
|
### 3.3 Other Tasks Handling
|
||||||
|
|
||||||
|
Tasks not selected for full validation:
|
||||||
|
- **MAY** receive lightweight sanity checks (≤5 samples)
|
||||||
|
- **MUST NOT** receive full end-to-end execution by default
|
||||||
|
- **SHOULD** be noted in execution plan for future validation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Scale-Up Conditions
|
||||||
|
|
||||||
|
Expansion to more GPUs or multiple full tasks is **ALLOWED ONLY IF**:
|
||||||
|
|
||||||
|
| Condition | Justification Required |
|
||||||
|
|-----------|------------------------|
|
||||||
|
| Single-task validation completed successfully | Confirm fix works on one task first |
|
||||||
|
| Critical bug identified and fixed | Need cross-task verification |
|
||||||
|
| Cross-dataset consistency required | Clear technical justification needed |
|
||||||
|
| User explicitly requests full-scale | User override |
|
||||||
|
|
||||||
|
### 4.1 Default Behavior
|
||||||
|
|
||||||
|
- **DEFAULT**: Conservative, non-expansive
|
||||||
|
- **MUST** ask for confirmation before scaling up
|
||||||
|
- **MUST** document reason for scale-up in execution plan
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Execution Plan Transparency
|
||||||
|
|
||||||
|
### 5.1 Mandatory Pre-Execution Output
|
||||||
|
|
||||||
|
Before starting any validation, **MUST** output an execution plan containing:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Execution Plan
|
||||||
|
|
||||||
|
### Task Classification
|
||||||
|
- Type: [Long-running validation / Exploratory]
|
||||||
|
- Reason: [Why classified this way]
|
||||||
|
|
||||||
|
### GPU Allocation
|
||||||
|
- Validation GPU(s): [GPU IDs]
|
||||||
|
- Reason: [Why these GPUs selected]
|
||||||
|
- Exploration GPU(s): [GPU IDs]
|
||||||
|
- Exploration tasks: [List of parallel hypotheses to test]
|
||||||
|
|
||||||
|
### Task Selection
|
||||||
|
- Full validation task: [Task name]
|
||||||
|
- Reason: [Why this task selected]
|
||||||
|
- Other tasks: [Skipped / Sanity-check only]
|
||||||
|
|
||||||
|
### Stopping Criteria
|
||||||
|
- Time limit: [X minutes]
|
||||||
|
- Success metric: [e.g., accuracy > 90%]
|
||||||
|
- Error threshold: [e.g., stop if >20 samples fail]
|
||||||
|
|
||||||
|
### Expected Output
|
||||||
|
- [What results will be produced]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 Progress Checkpoints
|
||||||
|
|
||||||
|
For long-running validations, **SHOULD** report progress at:
|
||||||
|
- 25% completion
|
||||||
|
- 50% completion
|
||||||
|
- 75% completion
|
||||||
|
- Final results
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Configuration Defaults
|
||||||
|
|
||||||
|
### 6.1 Default Parameters
|
||||||
|
|
||||||
|
| Parameter | Default Value | Description |
|
||||||
|
|-----------|---------------|-------------|
|
||||||
|
| `LONG_RUNNING_THRESHOLD_MINUTES` | 20 | Runtime threshold for classification |
|
||||||
|
| `LONG_RUNNING_SAMPLE_THRESHOLD` | 50 | Sample count threshold |
|
||||||
|
| `MAX_VALIDATION_GPUS` | 2 | Maximum GPUs for long validation |
|
||||||
|
| `MIN_EXPLORATION_GPUS` | 2 | Minimum GPUs reserved for exploration (when ≥4 available) |
|
||||||
|
| `EXPLORATION_SAMPLE_LIMIT` | 10 | Max samples for exploratory tests |
|
||||||
|
| `SANITY_CHECK_SAMPLES` | 5 | Samples for non-selected tasks |
|
||||||
|
|
||||||
|
### 6.2 User Override
|
||||||
|
|
||||||
|
Users can override defaults by specifying in their request:
|
||||||
|
- "Use all GPUs for validation"
|
||||||
|
- "Run all tasks"
|
||||||
|
- "Increase validation GPUs to N"
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Async Monitoring (CRITICAL)
|
||||||
|
|
||||||
|
### 7.1 Non-Blocking Principle
|
||||||
|
|
||||||
|
**MUST NOT** block the main agent with `sleep` commands waiting for results:
|
||||||
|
- ❌ `sleep 300 && check_results` (blocks main agent)
|
||||||
|
- ✅ Launch background tasks, continue thinking, check periodically
|
||||||
|
|
||||||
|
### 7.2 Continuous GPU Utilization
|
||||||
|
|
||||||
|
**MUST** maximize GPU utilization:
|
||||||
|
- When an agent completes a task, immediately assign new work
|
||||||
|
- Use `run_in_background: true` for all long-running agents
|
||||||
|
- Check agent completion via system notifications, not polling
|
||||||
|
|
||||||
|
### 7.3 Monitoring Strategy
|
||||||
|
|
||||||
|
```
|
||||||
|
CORRECT PATTERN:
|
||||||
|
1. Launch agents in background with run_in_background: true
|
||||||
|
2. Continue analysis, planning, or hypothesis generation
|
||||||
|
3. When agent completion notification arrives, process results
|
||||||
|
4. Immediately assign new tasks to freed GPUs
|
||||||
|
|
||||||
|
WRONG PATTERN:
|
||||||
|
1. Launch agents
|
||||||
|
2. sleep 300 # BLOCKS EVERYTHING!
|
||||||
|
3. Check results
|
||||||
|
4. GPU sits idle during sleep
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.4 Between-Task Work
|
||||||
|
|
||||||
|
While waiting for agents, the main agent SHOULD:
|
||||||
|
- Analyze code for additional hypotheses
|
||||||
|
- Prepare next batch of tests
|
||||||
|
- Update documentation with interim findings
|
||||||
|
- Plan fix implementations based on emerging patterns
|
||||||
|
|
||||||
|
### 7.5 Idle GPU Utilization (CRITICAL)
|
||||||
|
|
||||||
|
**MUST** utilize idle GPUs for exploratory tests while waiting:
|
||||||
|
|
||||||
|
```
|
||||||
|
WRONG PATTERN:
|
||||||
|
1. Launch 2 agents on GPU 0-1
|
||||||
|
2. Wait for completion ← GPU 2-5 sit idle!
|
||||||
|
3. Process results
|
||||||
|
|
||||||
|
CORRECT PATTERN:
|
||||||
|
1. Launch 2 agents on GPU 0-1 for main validation
|
||||||
|
2. IMMEDIATELY launch exploratory tests on GPU 2-5:
|
||||||
|
- Test alternative configurations
|
||||||
|
- Verify edge cases
|
||||||
|
- Run sanity checks on other datasets
|
||||||
|
- Profile performance bottlenecks
|
||||||
|
3. Continue spawning new tasks as GPUs become free
|
||||||
|
4. Process results as they arrive
|
||||||
|
```
|
||||||
|
|
||||||
|
**Idle GPU Detection**:
|
||||||
|
```bash
|
||||||
|
# Check which GPUs are free
|
||||||
|
nvidia-smi --query-gpu=index,utilization.gpu,memory.used --format=csv
|
||||||
|
```
|
||||||
|
|
||||||
|
**Exploratory Test Ideas** (when main validation is running):
|
||||||
|
|
||||||
|
| GPU State | Suggested Task |
|
||||||
|
|-----------|----------------|
|
||||||
|
| Idle during single-task validation | Test same task with different config |
|
||||||
|
| Idle after quick test completes | Run related task (e.g., multikey after single-key) |
|
||||||
|
| Idle during long benchmark | Run profiling or memory analysis |
|
||||||
|
| Multiple GPUs idle | Parallelize hypothesis testing |
|
||||||
|
|
||||||
|
**Anti-Pattern**:
|
||||||
|
- ❌ "I'll wait for the 100-sample test to finish before doing anything else"
|
||||||
|
- ✅ "While GPU 0-1 run the 100-sample test, I'll use GPU 2-5 to test configs X, Y, Z"
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Code Modification Policy (CRITICAL)
|
||||||
|
|
||||||
|
### 8.1 Evidence-Before-Action Principle
|
||||||
|
|
||||||
|
**MUST NOT** modify code until sufficient evidence has been gathered:
|
||||||
|
|
||||||
|
| Phase | Action | Code Modification |
|
||||||
|
|-------|--------|-------------------|
|
||||||
|
| Hypothesis Formation | Identify potential causes | ❌ NO |
|
||||||
|
| Evidence Gathering | Run targeted tests | ❌ NO |
|
||||||
|
| Pattern Analysis | Analyze test results | ❌ NO |
|
||||||
|
| Root Cause Confirmation | Validate with multiple tests | ❌ NO |
|
||||||
|
| Solution Design | Design fix based on evidence | ❌ NO |
|
||||||
|
| **Implementation** | Apply targeted fix | ✅ YES |
|
||||||
|
|
||||||
|
### 8.2 Minimum Evidence Requirements
|
||||||
|
|
||||||
|
Before proposing ANY code modification:
|
||||||
|
|
||||||
|
1. **Reproducibility**: Bug must be reproducible with specific test cases
|
||||||
|
2. **Isolation**: Root cause must be isolated (not symptoms)
|
||||||
|
3. **Multiple Data Points**: At least 3 independent test runs confirming the issue
|
||||||
|
4. **Counter-Evidence**: Attempted to disprove the hypothesis
|
||||||
|
5. **Mechanism Understanding**: Clear understanding of WHY the bug occurs
|
||||||
|
|
||||||
|
### 8.3 Main Agent Behavior
|
||||||
|
|
||||||
|
The main agent **SHOULD**:
|
||||||
|
- Keep thinking and analyzing while background agents run tests
|
||||||
|
- Formulate and refine hypotheses based on incoming results
|
||||||
|
- Document findings in `findings.md` as evidence accumulates
|
||||||
|
- Wait for sufficient test coverage before proposing fixes
|
||||||
|
|
||||||
|
The main agent **MUST NOT**:
|
||||||
|
- Rush to modify code after seeing first failure
|
||||||
|
- Propose fixes based on speculation
|
||||||
|
- Change multiple things at once "just to be safe"
|
||||||
|
- Assume correlation implies causation
|
||||||
|
|
||||||
|
### 8.4 Evidence Documentation Template
|
||||||
|
|
||||||
|
Before any code modification, document in `findings.md`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Proposed Fix: [Brief Description]
|
||||||
|
|
||||||
|
### Evidence Summary
|
||||||
|
- Test A: [Result] - supports/contradicts hypothesis
|
||||||
|
- Test B: [Result] - supports/contradicts hypothesis
|
||||||
|
- Test C: [Result] - supports/contradicts hypothesis
|
||||||
|
|
||||||
|
### Root Cause Analysis
|
||||||
|
- What: [Specific bug behavior]
|
||||||
|
- Where: [File:line or function]
|
||||||
|
- Why: [Mechanism explanation]
|
||||||
|
- Confidence: [High/Medium/Low]
|
||||||
|
|
||||||
|
### Alternative Explanations Ruled Out
|
||||||
|
1. [Alternative A]: Ruled out because [reason]
|
||||||
|
2. [Alternative B]: Ruled out because [reason]
|
||||||
|
|
||||||
|
### Proposed Change
|
||||||
|
- File: [path]
|
||||||
|
- Change: [description]
|
||||||
|
- Expected Impact: [what should improve]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.5 Anti-Patterns
|
||||||
|
|
||||||
|
| Don't | Do Instead |
|
||||||
|
|-------|------------|
|
||||||
|
| See error → immediately edit code | See error → gather more data → analyze → then edit |
|
||||||
|
| Fix based on single test failure | Reproduce failure 3+ times, understand pattern |
|
||||||
|
| Change code "to see what happens" | Form hypothesis first, design targeted experiment |
|
||||||
|
| Modify multiple files simultaneously | Isolate changes, verify each independently |
|
||||||
|
| Skip documentation of findings | Document every significant finding before changing code |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Example Scenario
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
- **Machine**: 8 GPUs (GPU 0-7)
|
||||||
|
- **Task**: Debug RULER chunked attention 20% error rate
|
||||||
|
- **Available tasks**: 6 RULER tasks (niah_single_1/2/3, niah_multikey_1/2/3)
|
||||||
|
- **Estimated full validation time**: ~2 hours for all tasks
|
||||||
|
|
||||||
|
### Execution Plan Output
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Execution Plan
|
||||||
|
|
||||||
|
### Task Classification
|
||||||
|
- Type: Long-running validation
|
||||||
|
- Reason: Full validation of 100 samples × 6 tasks would take ~2 hours
|
||||||
|
|
||||||
|
### GPU Allocation
|
||||||
|
- Validation GPU(s): GPU 0 (1 GPU)
|
||||||
|
- Reason: Single GPU sufficient for sequential 100-sample validation
|
||||||
|
- Exploration GPU(s): GPU 1, 2, 3, 4, 5, 6, 7 (7 GPUs)
|
||||||
|
- Exploration tasks:
|
||||||
|
1. GPU 1: Test 2-slot vs 4-slot ring buffer on error samples
|
||||||
|
2. GPU 2: Test N-way merge implementation
|
||||||
|
3. GPU 3: Test LSE precision fix
|
||||||
|
4. GPU 4: Profile merge accumulation error
|
||||||
|
5. GPU 5: Test with ruler_64k dataset (5 samples)
|
||||||
|
6. GPU 6: Test decode boundary conditions
|
||||||
|
7. GPU 7: Reserved for ad-hoc hypothesis testing
|
||||||
|
|
||||||
|
### Task Selection
|
||||||
|
- Full validation task: niah_single_1
|
||||||
|
- Reason: Has documented error samples (19 known failures), smallest single-key task
|
||||||
|
- Other tasks: Sanity-check only (5 samples each) after fix verified
|
||||||
|
|
||||||
|
### Stopping Criteria
|
||||||
|
- Time limit: 60 minutes for full validation
|
||||||
|
- Success metric: Error rate < 10% (down from 20%)
|
||||||
|
- Error threshold: Pause if new error pattern emerges (>5 consecutive failures)
|
||||||
|
|
||||||
|
### Expected Output
|
||||||
|
- Accuracy comparison: before vs after fix
|
||||||
|
- Error sample analysis: which samples still fail
|
||||||
|
- Hypothesis validation: which exploration branch identified the fix
|
||||||
|
```
|
||||||
|
|
||||||
|
### Execution Flow
|
||||||
|
|
||||||
|
1. **GPU 0**: Runs full `niah_single_1` validation (100 samples, ~40 min)
|
||||||
|
2. **GPU 1-7**: Run parallel exploration tasks (each ~5-15 min)
|
||||||
|
3. **Checkpoint at 50%**: Report GPU 0 progress + any discoveries from exploration
|
||||||
|
4. **On discovery**: If exploration GPU finds fix, pause validation, apply fix, restart
|
||||||
|
5. **Completion**: Report final results, decide if scale-up needed
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Quick Reference Checklist
|
||||||
|
|
||||||
|
Before starting any debugging validation:
|
||||||
|
|
||||||
|
- [ ] Classified task type? (Long-running vs Exploratory)
|
||||||
|
- [ ] If long-running: Limited to 1-2 GPUs?
|
||||||
|
- [ ] If long-running: Selected single task for full validation?
|
||||||
|
- [ ] Remaining GPUs allocated for exploration?
|
||||||
|
- [ ] Execution plan output with all required sections?
|
||||||
|
- [ ] Stopping criteria defined?
|
||||||
|
- [ ] No user override requested? (Default conservative behavior)
|
||||||
|
|
||||||
|
Before proposing any code modification:
|
||||||
|
|
||||||
|
- [ ] Bug reproducible with specific test cases?
|
||||||
|
- [ ] Root cause isolated (not just symptoms)?
|
||||||
|
- [ ] At least 3 independent test runs confirming the issue?
|
||||||
|
- [ ] Alternative explanations ruled out?
|
||||||
|
- [ ] Mechanism of bug clearly understood?
|
||||||
|
- [ ] Evidence documented in findings.md?
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. Rule Violations
|
||||||
|
|
||||||
|
The following actions **VIOLATE** this rule:
|
||||||
|
|
||||||
|
1. Using all 6+ GPUs for a single 100-sample validation
|
||||||
|
2. Running full validation on all tasks without completing single-task first
|
||||||
|
3. Starting long validation without outputting execution plan
|
||||||
|
4. Not reserving GPUs for exploration when ≥4 GPUs available
|
||||||
|
5. Scaling up without meeting conditions in Section 4
|
||||||
|
6. **Modifying code before gathering sufficient evidence** (Section 8)
|
||||||
|
7. Proposing fixes based on single test failure or speculation
|
||||||
|
8. Changing multiple code locations simultaneously without isolation testing
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. Integration with Other Rules
|
||||||
|
|
||||||
|
This rule works alongside:
|
||||||
|
- `gpu-testing.md`: GPU type detection and basic allocation
|
||||||
|
- `planning-with-files.md`: Progress tracking for long validations
|
||||||
|
- `testing.md`: Test script conventions
|
||||||
|
|
||||||
|
When conflicts arise, this rule takes precedence for debugging scenarios.
|
||||||
89
.claude/rules/nsys-profiling.md
Normal file
89
.claude/rules/nsys-profiling.md
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# Nsys Profiling Rule
|
||||||
|
|
||||||
|
## 强制规则
|
||||||
|
|
||||||
|
**所有 nsys profiling 任务必须使用 `scripts/profile_offload.sh` 脚本**,禁止直接运行 nsys 命令。
|
||||||
|
|
||||||
|
| 禁止 | 原因 |
|
||||||
|
|------|------|
|
||||||
|
| `nsys profile python tests/test_ruler.py ...` | 参数不一致,输出路径混乱 |
|
||||||
|
| 手动构造 nsys 命令 | 容易遗漏关键参数 |
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 基本用法(默认 4 slots)
|
||||||
|
bash scripts/profile_offload.sh
|
||||||
|
|
||||||
|
# 指定 GPU slots 数量
|
||||||
|
bash scripts/profile_offload.sh --num-gpu-blocks 8
|
||||||
|
|
||||||
|
# 指定 sample
|
||||||
|
bash scripts/profile_offload.sh --sample 5
|
||||||
|
|
||||||
|
# 指定 dataset
|
||||||
|
bash scripts/profile_offload.sh --dataset niah_single_1
|
||||||
|
|
||||||
|
# 禁用 offload(对比测试)
|
||||||
|
bash scripts/profile_offload.sh --no-offload
|
||||||
|
|
||||||
|
# 组合参数
|
||||||
|
bash scripts/profile_offload.sh --num-gpu-blocks 8 --sample 0 --gpu 1
|
||||||
|
```
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--dataset` | `niah_single_1` | RULER 任务名称 |
|
||||||
|
| `--sample` | `0` | 样本索引 |
|
||||||
|
| `--gpu` | `0` | 使用的 GPU |
|
||||||
|
| `--num-gpu-blocks` | `4` | GPU ring buffer slots 数量 |
|
||||||
|
| `--no-offload` | - | 禁用 CPU offload |
|
||||||
|
|
||||||
|
## 输出文件
|
||||||
|
|
||||||
|
输出文件自动生成到 `results/nsys/` 目录:
|
||||||
|
|
||||||
|
```
|
||||||
|
results/nsys/ruler_<dataset>_sample<index>_offload_<slots>slots_<timestamp>.nsys-rep
|
||||||
|
```
|
||||||
|
|
||||||
|
示例:`ruler_niah_single_1_sample0_offload_8slots_20260127_031500.nsys-rep`
|
||||||
|
|
||||||
|
## 查看结果
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GUI 查看
|
||||||
|
nsight-sys results/nsys/<filename>.nsys-rep
|
||||||
|
|
||||||
|
# 命令行统计
|
||||||
|
nsys stats --report cuda_api_sum results/nsys/<filename>.nsys-rep
|
||||||
|
nsys stats --report cuda_gpu_kern_sum results/nsys/<filename>.nsys-rep
|
||||||
|
```
|
||||||
|
|
||||||
|
## 典型工作流
|
||||||
|
|
||||||
|
### 1. 对比不同 slots 数量
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 测试 4 slots(默认)
|
||||||
|
bash scripts/profile_offload.sh --num-gpu-blocks 4
|
||||||
|
|
||||||
|
# 测试 8 slots
|
||||||
|
bash scripts/profile_offload.sh --num-gpu-blocks 8
|
||||||
|
|
||||||
|
# 对比结果
|
||||||
|
nsys stats --report cuda_gpu_kern_sum results/nsys/*4slots*.nsys-rep
|
||||||
|
nsys stats --report cuda_gpu_kern_sum results/nsys/*8slots*.nsys-rep
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 分析 pipeline overlap
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 生成 profile
|
||||||
|
bash scripts/profile_offload.sh --num-gpu-blocks 8
|
||||||
|
|
||||||
|
# 用 nsight-sys GUI 查看 CUDA HW timeline
|
||||||
|
# 检查 H2D 和 flash_fwd_kernel 是否 overlap
|
||||||
|
```
|
||||||
@@ -1,5 +1,39 @@
|
|||||||
# Sparse Policy 代码规范
|
# Sparse Policy 代码规范
|
||||||
|
|
||||||
|
## Policy 不能为 None (CRITICAL)
|
||||||
|
|
||||||
|
**强制规则**: `sparse_policy` 参数**永远不能为 None**,必须至少为 `FullAttentionPolicy`。
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ❌ 错误:允许 None
|
||||||
|
sparse_policy = getattr(config, 'sparse_policy', None)
|
||||||
|
|
||||||
|
# ✅ 正确:显式处理 None,默认使用 FULL
|
||||||
|
sparse_policy_type = getattr(config, 'sparse_policy', None)
|
||||||
|
if sparse_policy_type is None:
|
||||||
|
sparse_policy_type = SparsePolicyType.FULL
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
1. 统一的 API:所有代码路径都通过 policy 进行 attention 计算
|
||||||
|
2. 避免空指针:消除 `policy.xxx` 调用时的 None 检查
|
||||||
|
3. 简化逻辑:不需要 `if policy is not None` 的分支
|
||||||
|
|
||||||
|
**唯一例外:Warmup 阶段**
|
||||||
|
|
||||||
|
在 `model_runner.warmup_model()` 期间,kvcache_manager 还未分配。此时 `attention.py` 使用 flash_attn fallback:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# attention.py 中的 warmup 处理
|
||||||
|
if context.kvcache_manager is None:
|
||||||
|
# Warmup phase: use flash_attn directly
|
||||||
|
return flash_attn_varlen_func(...) if context.is_prefill else flash_attn_with_kvcache(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
这是唯一允许 kvcache_manager 为 None 的情况。正式推理时,policy 必须存在。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 基类要求 (MANDATORY)
|
## 基类要求 (MANDATORY)
|
||||||
|
|
||||||
每个 `SparsePolicy` 子类 **必须** 遵守以下要求:
|
每个 `SparsePolicy` 子类 **必须** 遵守以下要求:
|
||||||
|
|||||||
90
.claude/rules/test-ruler.md
Normal file
90
.claude/rules/test-ruler.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# test_ruler.py 使用规则
|
||||||
|
|
||||||
|
## 强制规则
|
||||||
|
|
||||||
|
**执行 `test_ruler.py` 前必须查阅文档**,禁止运行 `--help` 或猜测参数。
|
||||||
|
|
||||||
|
| 禁止 | 原因 |
|
||||||
|
|------|------|
|
||||||
|
| `python tests/test_ruler.py --help` | 浪费交互,文档已有完整说明 |
|
||||||
|
| 猜测参数格式 | 容易出错,降低效率 |
|
||||||
|
|
||||||
|
## 必读文档
|
||||||
|
|
||||||
|
**[`docs/test_ruler_usage_guide.md`](../docs/test_ruler_usage_guide.md)** - 包含:
|
||||||
|
- 完整参数说明
|
||||||
|
- 已验证的命令示例
|
||||||
|
- GPU 模式选择指南
|
||||||
|
- max-model-len 设置指南
|
||||||
|
|
||||||
|
## 快速参考
|
||||||
|
|
||||||
|
### 标准命令格式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/<MODEL> \
|
||||||
|
--data-dir tests/data/ruler_<CTX> \
|
||||||
|
--datasets <TASK> \
|
||||||
|
--num-samples <N> \
|
||||||
|
--max-model-len <LEN> \
|
||||||
|
--enable-offload \
|
||||||
|
[--sparse-policy XATTN_BSA] \
|
||||||
|
[--sparse-threshold 0.9]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 常用参数速查
|
||||||
|
|
||||||
|
| 参数 | 用途 | 示例 |
|
||||||
|
|------|------|------|
|
||||||
|
| `--datasets` | 指定任务 | `niah_single_1,qa_1` |
|
||||||
|
| `--num-samples` | 样本数 | `1`, `10`, `0`(全部) |
|
||||||
|
| `--sample-indices` | 指定索引 | `0,5,10` |
|
||||||
|
| `--enable-offload` | CPU offload | RTX 3090 必须 |
|
||||||
|
| `--sparse-policy` | 稀疏策略 | `XATTN_BSA` |
|
||||||
|
| `--json-output` | JSON 输出 | 脚本使用 |
|
||||||
|
| `--quiet` | 安静模式 | 减少输出 |
|
||||||
|
|
||||||
|
### max-model-len 速查
|
||||||
|
|
||||||
|
| 数据目录 | max-model-len |
|
||||||
|
|---------|---------------|
|
||||||
|
| ruler_32k | 40960 |
|
||||||
|
| ruler_64k | 72000 |
|
||||||
|
| ruler_128k | 135000 |
|
||||||
|
|
||||||
|
### 常用命令模板
|
||||||
|
|
||||||
|
**32K Offload + XAttn**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
**64K Offload + XAttn**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
## 执行前检查清单
|
||||||
|
|
||||||
|
- [ ] 用户指定了 GPU?否则询问
|
||||||
|
- [ ] RTX 3090/4090?必须 `--enable-offload`
|
||||||
|
- [ ] data-dir 与 max-model-len 匹配?
|
||||||
|
- [ ] 需要 density 统计?添加 `--sparse-policy XATTN_BSA`
|
||||||
@@ -1,98 +1,108 @@
|
|||||||
# Testing
|
# Testing
|
||||||
|
|
||||||
## Test File Guidelines
|
## Test Code Style
|
||||||
|
|
||||||
### Naming Convention
|
所有测试代码遵循以下风格:
|
||||||
|
|
||||||
- All test files must be named `test_*.py`
|
### 文件结构
|
||||||
- Example: `test_offload_engine.py`, `test_ring_buffer.py`
|
|
||||||
|
|
||||||
### Purpose
|
|
||||||
|
|
||||||
Tests are **educational scripts** for understanding module behavior, NOT traditional unit tests:
|
|
||||||
- Focus on demonstrating how modules work
|
|
||||||
- Show the flow and interaction between components
|
|
||||||
- Help developers understand implementation details
|
|
||||||
|
|
||||||
### Code Style
|
|
||||||
|
|
||||||
1. **Script-based structure**: Write tests as executable scripts, not pytest-style functions
|
|
||||||
2. **Utility functions**: Extract reusable steps as helper functions at the top of the file
|
|
||||||
3. **Main flow as script**: The actual test/demonstration logic runs as top-level script code
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Example structure:
|
"""
|
||||||
|
Test: [模块名称]
|
||||||
|
|
||||||
|
[简要说明测试内容和数据流]
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
from nanovllm.kvcache import SomeModule
|
import sys
|
||||||
|
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||||
|
from nanovllm.xxx import xxx
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Utility Functions
|
# 参数配置
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
def verify(tensor, expected, name):
|
param1 = value1 # 说明约束条件
|
||||||
actual = tensor.mean().item()
|
param2 = value2
|
||||||
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Main Test Script
|
# 构造输入
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
# 1. Initialize
|
input_tensor = ... # 使用结构化数据便于验证
|
||||||
module = SomeModule(param=value)
|
|
||||||
|
|
||||||
# 2. Test feature X
|
# ============================================================
|
||||||
result = module.do_something()
|
# Step N: [操作名称]
|
||||||
assert result == expected_value
|
# ============================================================
|
||||||
|
|
||||||
# 3. Test feature Y
|
output = some_function(input_tensor, ...)
|
||||||
...
|
|
||||||
|
# 验证: [验证逻辑说明]
|
||||||
|
expected = ...
|
||||||
|
actual = output[...].item()
|
||||||
|
assert actual == expected, f"xxx: {actual} != {expected}"
|
||||||
|
|
||||||
print("test_xxx: PASSED")
|
print("test_xxx: PASSED")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Comments
|
### 核心原则
|
||||||
|
|
||||||
- Keep comments concise and clear
|
| 原则 | 说明 |
|
||||||
- Only add comments where the code isn't self-explanatory
|
|------|------|
|
||||||
- Use section headers (`# === Section ===`) to organize logical blocks
|
| **最小化 print** | 只在最后输出 `PASSED`,不打印中间结果 |
|
||||||
|
| **结构化数据** | 使用可预测的输入(全 1、偶奇交替等)便于手算验证 |
|
||||||
|
| **注释说明验证逻辑** | 在 assert 前用注释解释预期值的计算方式 |
|
||||||
|
| **分段用 `====`** | 用 `# ============` 分隔参数、输入、各步骤 |
|
||||||
|
| **assert 验证** | 用 assert 而不是 print 比较结果 |
|
||||||
|
|
||||||
### Output
|
### 输出规范
|
||||||
|
|
||||||
- **Minimize print statements** - the code should be self-explanatory
|
```python
|
||||||
- Only print a final "PASSED" message at the end
|
# ✅ 正确
|
||||||
- Use `assert` for verification instead of printing results
|
assert actual == expected, f"xxx: {actual} != {expected}"
|
||||||
- If the user needs explanation, they will ask
|
print("test_xxx: PASSED")
|
||||||
|
|
||||||
|
# ❌ 错误
|
||||||
|
print(f"输出: {output}")
|
||||||
|
print(f"预期: {expected}, 实际: {actual}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 参数注释
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ✅ 正确: 注释说明约束条件
|
||||||
|
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M
|
||||||
|
segment_size = 128 # 必须 >= block_size
|
||||||
|
|
||||||
|
# ❌ 错误: 无意义的注释
|
||||||
|
seq_len = 512 # 序列长度
|
||||||
|
```
|
||||||
|
|
||||||
|
### 验证逻辑注释
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ✅ 正确: 解释计算过程
|
||||||
|
# 验证: 反对角线求和
|
||||||
|
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4,共 stride/2 对
|
||||||
|
expected = (2*1 + 1*2) * (stride // 2) * head_dim
|
||||||
|
|
||||||
|
# ❌ 错误: 只写公式不解释
|
||||||
|
expected = 4 * 2 * 128
|
||||||
|
```
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run a specific test
|
# 运行单个测试
|
||||||
python tests/test_offload_engine.py
|
PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
|
||||||
|
|
||||||
# Run with specific GPU
|
# 指定 GPU
|
||||||
CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Standard GPU benchmark
|
python bench.py # GPU benchmark
|
||||||
python bench.py
|
python bench_offload.py # CPU offload benchmark
|
||||||
|
python bench_vllm.py # vLLM comparison
|
||||||
# CPU offload benchmark
|
|
||||||
python bench_offload.py
|
|
||||||
|
|
||||||
# vLLM comparison benchmark
|
|
||||||
python bench_vllm.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Quick Verification
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Import test
|
|
||||||
python -c "from nanovllm import LLM"
|
|
||||||
|
|
||||||
# Run offload benchmark (tests CPU-primary ring buffer mode)
|
|
||||||
python bench_offload.py
|
|
||||||
```
|
```
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -239,3 +239,5 @@ task_plan_*.md
|
|||||||
findings_*.md
|
findings_*.md
|
||||||
progress_*.md
|
progress_*.md
|
||||||
notes.md
|
notes.md
|
||||||
|
Snipaste*
|
||||||
|
.ralph-tui/session-meta.json
|
||||||
|
|||||||
12
.ralph-tui/config.toml
Normal file
12
.ralph-tui/config.toml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# Ralph TUI Configuration
|
||||||
|
# Generated by setup wizard
|
||||||
|
# See: ralph-tui config help
|
||||||
|
|
||||||
|
configVersion = "2.1"
|
||||||
|
tracker = "json"
|
||||||
|
agent = "claude"
|
||||||
|
maxIterations = 30
|
||||||
|
autoCommit = false
|
||||||
|
|
||||||
|
[trackerOptions]
|
||||||
|
[agentOptions]
|
||||||
51
CLAUDE.md
51
CLAUDE.md
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code when working with this repository.
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
|
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3, Llama-3, and GLM-4 models with CPU offload for long-context inference.
|
||||||
|
|
||||||
## Documentation Index
|
## Documentation Index
|
||||||
|
|
||||||
@@ -15,12 +15,50 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
||||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||||
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
||||||
|
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
|
||||||
|
| [`docs/xattn_kv_chunking_kernels.md`](docs/xattn_kv_chunking_kernels.md) | XAttention KV Chunking: 三阶段 softmax、存储开销分析 (O(S) vs O(S²))、峰值显存优化 (8x)、Q/KV 独立分块 |
|
||||||
|
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
|
||||||
|
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
|
||||||
|
| [`docs/xattn_density_benchmark.md`](docs/xattn_density_benchmark.md) | 📊 XAttention Density Benchmark: 4K-32K context、stride 参数、per-layer density 分析 |
|
||||||
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
||||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||||
| [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition |
|
| [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition |
|
||||||
| [`docs/ruler_benchmark_results_32k.md`](docs/ruler_benchmark_results_32k.md) | RULER benchmark results (32K context): 13 tasks, 92.3% accuracy, CPU offload performance |
|
| [`docs/ruler_benchmark_results_32k.md`](docs/ruler_benchmark_results_32k.md) | RULER benchmark results (32K context): 13 tasks, 92.3% accuracy, CPU offload performance |
|
||||||
| [`docs/ruler_32k_chunked_offload_issue.md`](docs/ruler_32k_chunked_offload_issue.md) | ⚠️ OPEN ISSUE: 32K chunked offload accuracy problem (35% error rate in RULER) |
|
| [`docs/ruler_32k_chunked_offload_issue.md`](docs/ruler_32k_chunked_offload_issue.md) | ⚠️ OPEN ISSUE: 32K chunked offload accuracy problem (20% error rate in RULER) |
|
||||||
|
| [`docs/chunked_attention_solutions.md`](docs/chunked_attention_solutions.md) | 🔧 SOLUTIONS: Chunked attention 准确性问题的代码分析和解决方案 |
|
||||||
|
| [`docs/nsys_wrong_event_order_bug.md`](docs/nsys_wrong_event_order_bug.md) | 🐛 NSYS BUG: Ring buffer pipeline 触发 nsys 时间戳乱序问题的调试记录 |
|
||||||
|
| [`docs/cpu_scheduling_latency_analysis.md`](docs/cpu_scheduling_latency_analysis.md) | ⚡ PERF: CPU 调度延迟分析,kernel 间隙来源,GPU 利用率优化方向 |
|
||||||
|
| [`docs/bench_offload_results.md`](docs/bench_offload_results.md) | 📊 BENCH: CPU offload 性能测试结果,Full vs XAttention 对比 (32K/128K) |
|
||||||
|
| [`docs/cpu_offload_optimization_strategies.md`](docs/cpu_offload_optimization_strategies.md) | 🚀 OPT: CPU offload 优化策略:chunk size、CUDA Graph、前沿研究(InfiniGen/ShadowKV) |
|
||||||
|
| [`docs/gpu_only_xattn_guide.md`](docs/gpu_only_xattn_guide.md) | 🚀 GPU-Only XAttention: 内存预分配、性能分析 (32K +15%, 64K +41%)、CUDA Graph 限制 |
|
||||||
|
| [`docs/xattn_performance_analysis.md`](docs/xattn_performance_analysis.md) | 📊 XAttention 性能分析: NVTX 标记、block size 影响、estimate vs compute 耗时对比 |
|
||||||
|
| [`docs/observer_architecture.md`](docs/observer_architecture.md) | 📊 Observer 架构: InferenceObserver (TTFT/TPOT)、MemoryObserver (H2D/D2H/D2D) 设计 |
|
||||||
|
| [`docs/memory_communication_benchmark.md`](docs/memory_communication_benchmark.md) | 📊 通信量测试: Full vs XAttention 通信量对比 (32K/64K)、阶段分离统计 |
|
||||||
|
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析,softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
|
||||||
|
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL),≤10B 推荐模型 |
|
||||||
|
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
|
||||||
|
| [`docs/xattn_density_alignment_analysis.md`](docs/xattn_density_alignment_analysis.md) | 📊 ANALYSIS: GPU-only vs Offload 模式 density 对齐分析,chunked softmax 边界效应,5-7% 差异根因 |
|
||||||
|
| [`docs/xattn_kv_chunking_density_test.md`](docs/xattn_kv_chunking_density_test.md) | 🧪 TEST: XAttention KV chunking density 验证,threshold=1.0 对齐,threshold<1.0 差异 10-13% |
|
||||||
|
| [`docs/gpuonly_density_alignment_test.md`](docs/gpuonly_density_alignment_test.md) | ✅ TEST: Density 对齐验证 (GPU-only + Offload, 4K-64K),xattn_estimate vs KV chunking 完全一致 |
|
||||||
|
| [`docs/xattn_memory_benchmark.md`](docs/xattn_memory_benchmark.md) | 📊 BENCH: XAttention 内存基准测试,Qwen3-0.6B 32K 在 24GB 显存可行 (gpu-util=0.28) |
|
||||||
|
| [`docs/xattn_offload_stream_sync_fix.md`](docs/xattn_offload_stream_sync_fix.md) | 🐛 FIX: XAttention Offload stream 同步 bug,Pass1/Pass2 K 数据不一致,compute_stream 包装 |
|
||||||
|
| [`docs/xattn_density_types.md`](docs/xattn_density_types.md) | 📊 Compute vs Comm density: BSA block (128) vs CPU block (4096) 粒度,聚合效应导致 comm=100% |
|
||||||
|
| [`docs/xattn_density_alignment_verification.md`](docs/xattn_density_alignment_verification.md) | ✅ VERIFIED: GPU-only vs Offload density 对齐验证 (32K 差异 0.37%, 64K 差异 0.09%) |
|
||||||
|
| [`docs/test_ruler_usage_guide.md`](docs/test_ruler_usage_guide.md) | 📖 GUIDE: test_ruler.py 使用指南,RULER benchmark 测试命令,已验证的命令示例 |
|
||||||
|
| [`docs/xattn_offload_profiling_32k.md`](docs/xattn_offload_profiling_32k.md) | 📊 PROFILE: XAttn vs Full 32K nsys 分析,estimate 占 41%,find_blocks 占 37%,compute 仅 21% |
|
||||||
|
| [`docs/changelog_2026-02-05.md`](docs/changelog_2026-02-05.md) | 📋 CHANGELOG: GQA buffer OOM 修复 (节省 16GB),tests 目录清理 (-4306 行) |
|
||||||
|
|
||||||
|
## Rules Index
|
||||||
|
|
||||||
|
| Rule | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| [`.claude/rules/multi-gpu-debugging.md`](.claude/rules/multi-gpu-debugging.md) | **Multi-GPU debugging**: GPU allocation (1-2 for validation, rest for exploration), single-task validation policy |
|
||||||
|
| [`.claude/rules/gpu-testing.md`](.claude/rules/gpu-testing.md) | GPU type detection, card assignment, needle test requirements |
|
||||||
|
| [`.claude/rules/sparse-policy.md`](.claude/rules/sparse-policy.md) | SparsePolicy implementation requirements |
|
||||||
|
| [`.claude/rules/planning-with-files.md`](.claude/rules/planning-with-files.md) | Planning file management for complex tasks |
|
||||||
|
| [`.claude/rules/gpu-monitor.md`](.claude/rules/gpu-monitor.md) | **GPU memory monitoring**: 必须使用 gpu-monitor agent,禁止手动 nvidia-smi 循环 |
|
||||||
|
| [`.claude/rules/test-ruler.md`](.claude/rules/test-ruler.md) | **test_ruler.py 规则**: 禁止 --help,必须查阅文档,含快速参考和命令模板 |
|
||||||
|
|
||||||
## GPU Mutex for Multi-Instance Debugging
|
## GPU Mutex for Multi-Instance Debugging
|
||||||
|
|
||||||
@@ -75,6 +113,15 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
|||||||
|
|
||||||
**Files**: `bench.py` (GPU), `bench_offload.py` (CPU offload), `bench_vllm.py` (comparison)
|
**Files**: `bench.py` (GPU), `bench_offload.py` (CPU offload), `bench_vllm.py` (comparison)
|
||||||
|
|
||||||
|
**GPU-only 测试模型选择**:
|
||||||
|
|
||||||
|
| GPU | 显存 | GPU-only 测试模型 |
|
||||||
|
|-----|------|------------------|
|
||||||
|
| RTX 3090 | 24GB | **Qwen3-0.6B** (必须,7B+ 模型会 OOM) |
|
||||||
|
| A100 | 40GB+ | Qwen3-0.6B / 4B / 7B 均可 |
|
||||||
|
|
||||||
|
**Offload Mode Constraint**: When using `enable_cpu_offload=True`, only test with context length ≥ 32K. Shorter contexts don't exercise the chunked offload pipeline properly.
|
||||||
|
|
||||||
**Common Issues**:
|
**Common Issues**:
|
||||||
1. `max_num_batched_tokens < max_model_len`: Set equal for long context
|
1. `max_num_batched_tokens < max_model_len`: Set equal for long context
|
||||||
2. CUDA graph dimension mismatch: Ensure `input_len + output_len <= max_model_len`
|
2. CUDA graph dimension mismatch: Ensure `input_len + output_len <= max_model_len`
|
||||||
|
|||||||
59
bench.py
59
bench.py
@@ -2,6 +2,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.utils.observer import InferenceObserver
|
||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||||
@@ -14,13 +15,17 @@ def bench_decode(llm, num_seqs, input_len, output_len):
|
|||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
|
|
||||||
# Calculate metrics
|
# Get metrics from InferenceObserver
|
||||||
prefill_tokens = num_seqs * input_len
|
ttft_ms = InferenceObserver.ttft / 1e6
|
||||||
|
tpot_ms = InferenceObserver.tpot / 1e6
|
||||||
|
|
||||||
|
# Calculate throughput from observer metrics
|
||||||
decode_tokens = num_seqs * output_len
|
decode_tokens = num_seqs * output_len
|
||||||
decode_throughput = decode_tokens / t
|
decode_throughput = 1000.0 / tpot_ms if tpot_ms > 0 else 0 # tokens/s per sequence
|
||||||
|
|
||||||
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
|
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
|
||||||
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
print(f" TTFT: {ttft_ms:.2f}ms, TPOT: {tpot_ms:.2f}ms")
|
||||||
|
print(f" Decode Throughput: {decode_throughput:.2f} tok/s (from observer)")
|
||||||
|
|
||||||
|
|
||||||
def bench_prefill(llm, num_seqs, input_len):
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
@@ -33,31 +38,69 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
t = time.time()
|
t = time.time()
|
||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
|
|
||||||
|
# Get TTFT from InferenceObserver
|
||||||
|
ttft_ms = InferenceObserver.ttft / 1e6
|
||||||
|
ttft_s = ttft_ms / 1000.0
|
||||||
|
|
||||||
total_input_tokens = num_seqs * input_len
|
total_input_tokens = num_seqs * input_len
|
||||||
throughput = total_input_tokens / t
|
# Use observer TTFT for accurate prefill throughput
|
||||||
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
throughput_observer = total_input_tokens / ttft_s if ttft_s > 0 else 0
|
||||||
|
throughput_external = total_input_tokens / t
|
||||||
|
|
||||||
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len})")
|
||||||
|
print(f" External Time: {t:.2f}s, Throughput: {throughput_external:.2f}tok/s")
|
||||||
|
print(f" Observer TTFT: {ttft_ms:.2f}ms, Throughput: {throughput_observer:.2f}tok/s")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance")
|
parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance")
|
||||||
|
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
|
||||||
|
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
|
||||||
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||||
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||||
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
|
# Sparse policy option (GPU-only mode now supports policy routing)
|
||||||
|
parser.add_argument("--policy", type=str, default=None,
|
||||||
|
choices=["full", "xattn"],
|
||||||
|
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
|
||||||
|
parser.add_argument("--enable-policy", action="store_true",
|
||||||
|
help="Enable sparse policy routing (FullAttentionPolicy by default)")
|
||||||
|
parser.add_argument("--gpu-util", type=float, default=0.9,
|
||||||
|
help="GPU memory utilization (default: 0.9)")
|
||||||
|
parser.add_argument("--block-size", type=int, default=1024,
|
||||||
|
help="KV cache block size (default: 1024)")
|
||||||
|
parser.add_argument("--enforce-eager", action="store_true",
|
||||||
|
help="Disable CUDA graphs (default: False)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser(args.model)
|
||||||
max_len = args.max_len
|
max_len = args.max_len
|
||||||
|
|
||||||
|
# Configure sparse policy
|
||||||
|
if args.policy == "xattn":
|
||||||
|
sparse_policy = SparsePolicyType.XATTN_BSA
|
||||||
|
print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}")
|
||||||
|
elif args.policy == "full" or args.enable_policy:
|
||||||
|
sparse_policy = SparsePolicyType.FULL
|
||||||
|
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
|
||||||
|
else:
|
||||||
|
sparse_policy = None
|
||||||
print(f"\n[nanovllm GPU] max_len={max_len}")
|
print(f"\n[nanovllm GPU] max_len={max_len}")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=args.enforce_eager,
|
||||||
max_model_len=max_len,
|
max_model_len=max_len,
|
||||||
max_num_batched_tokens=max_len,
|
max_num_batched_tokens=max_len,
|
||||||
|
sparse_policy=sparse_policy,
|
||||||
|
gpu_memory_utilization=args.gpu_util,
|
||||||
|
kvcache_block_size=args.block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
|
|||||||
@@ -2,6 +2,15 @@ import os
|
|||||||
import time
|
import time
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.utils.observer import InferenceObserver
|
||||||
|
from nanovllm.utils.memory_observer import MemoryObserver
|
||||||
|
|
||||||
|
|
||||||
|
def print_memory_stats():
|
||||||
|
"""Print MemoryObserver communication statistics"""
|
||||||
|
fmt = MemoryObserver._fmt_bytes
|
||||||
|
print(f"[Memory] Prefill H2D: {fmt(MemoryObserver.prefill_h2d_bytes)}, D2H: {fmt(MemoryObserver.prefill_d2h_bytes)}")
|
||||||
|
print(f" Decode H2D: {fmt(MemoryObserver.decode_h2d_bytes)}, D2H: {fmt(MemoryObserver.decode_d2h_bytes)}")
|
||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||||
@@ -14,16 +23,18 @@ def bench_decode(llm, num_seqs, input_len, output_len):
|
|||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
|
|
||||||
# Calculate metrics
|
# Get metrics from InferenceObserver
|
||||||
prefill_tokens = num_seqs * input_len
|
ttft_ms = InferenceObserver.ttft / 1e6
|
||||||
decode_tokens = num_seqs * output_len
|
tpot_ms = InferenceObserver.tpot / 1e6
|
||||||
|
|
||||||
# Approximate: assume prefill takes ~input_len/prefill_speed, rest is decode
|
# Calculate throughput from observer metrics
|
||||||
# For more accurate measurement, we'd need internal timing
|
decode_tokens = num_seqs * output_len
|
||||||
decode_throughput = decode_tokens / t # This includes prefill time, so it's a lower bound
|
decode_throughput = 1000.0 / tpot_ms if tpot_ms > 0 else 0 # tokens/s per sequence
|
||||||
|
|
||||||
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
|
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
|
||||||
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
print(f" TTFT: {ttft_ms:.2f}ms, TPOT: {tpot_ms:.2f}ms")
|
||||||
|
print(f" Decode Throughput: {decode_throughput:.2f} tok/s (from observer)")
|
||||||
|
print_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
def bench_prefill(llm, num_seqs, input_len):
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
@@ -36,9 +47,20 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
t = time.time()
|
t = time.time()
|
||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
|
|
||||||
|
# Get TTFT from InferenceObserver
|
||||||
|
ttft_ms = InferenceObserver.ttft / 1e6
|
||||||
|
ttft_s = ttft_ms / 1000.0
|
||||||
|
|
||||||
total_input_tokens = num_seqs * input_len
|
total_input_tokens = num_seqs * input_len
|
||||||
throughput = total_input_tokens / t
|
# Use observer TTFT for accurate prefill throughput
|
||||||
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
throughput_observer = total_input_tokens / ttft_s if ttft_s > 0 else 0
|
||||||
|
throughput_external = total_input_tokens / t
|
||||||
|
|
||||||
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len})")
|
||||||
|
print(f" External Time: {t:.2f}s, Throughput: {throughput_external:.2f}tok/s")
|
||||||
|
print(f" Observer TTFT: {ttft_ms:.2f}ms, Throughput: {throughput_observer:.2f}tok/s")
|
||||||
|
print_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -46,40 +68,67 @@ def main():
|
|||||||
from nanovllm.config import SparsePolicyType
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Benchmark CPU offload performance")
|
parser = argparse.ArgumentParser(description="Benchmark CPU offload performance")
|
||||||
parser.add_argument("--enable-quest", action="store_true", help="Enable Quest sparse attention for decode")
|
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
|
||||||
|
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
|
||||||
|
# Sparse policy selection (mutually exclusive)
|
||||||
|
sparse_group = parser.add_mutually_exclusive_group()
|
||||||
|
sparse_group.add_argument("--enable-quest", action="store_true",
|
||||||
|
help="Enable Quest sparse attention (decode only, prefill uses full)")
|
||||||
|
sparse_group.add_argument("--enable-xattn", action="store_true",
|
||||||
|
help="Enable XAttention BSA (prefill only, decode uses full)")
|
||||||
|
# Quest parameters
|
||||||
parser.add_argument("--topk", type=int, default=16, help="Top-K blocks for Quest (default: 16)")
|
parser.add_argument("--topk", type=int, default=16, help="Top-K blocks for Quest (default: 16)")
|
||||||
parser.add_argument("--threshold", type=int, default=4, help="Apply sparse only when blocks > threshold (default: 4)")
|
parser.add_argument("--threshold", type=int, default=4, help="Apply sparse only when blocks > threshold (default: 4)")
|
||||||
|
# XAttention parameters
|
||||||
|
parser.add_argument("--xattn-threshold", type=float, default=0.95,
|
||||||
|
help="XAttention cumulative attention threshold (default: 0.95)")
|
||||||
|
parser.add_argument("--xattn-stride", type=int, default=8,
|
||||||
|
help="XAttention Q/K downsampling stride (default: 8)")
|
||||||
|
# General parameters
|
||||||
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||||
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||||
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (default: 6)")
|
parser.add_argument("--num-gpu-blocks", type=int, default=4, help="Number of GPU blocks (default: 4)")
|
||||||
|
parser.add_argument("--block-size", type=int, default=1024, help="KV cache block size (default: 1024)")
|
||||||
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
|
parser.add_argument("--enforce-eager", action="store_true", help="Disable CUDA Graphs (use eager mode)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser(args.model)
|
||||||
max_len = args.max_len
|
max_len = args.max_len
|
||||||
|
|
||||||
|
# Enable MemoryObserver for communication stats
|
||||||
|
MemoryObserver._enabled = True
|
||||||
|
|
||||||
# Setup policy configuration
|
# Setup policy configuration
|
||||||
if args.enable_quest:
|
if args.enable_quest:
|
||||||
sparse_policy = SparsePolicyType.QUEST
|
sparse_policy = SparsePolicyType.QUEST
|
||||||
print(f"\n[Quest Sparse Attention] topk={args.topk}, threshold={args.threshold}")
|
print(f"\n[Quest Sparse Attention] decode: Quest (topk={args.topk}, threshold={args.threshold}), prefill: Full")
|
||||||
|
elif args.enable_xattn:
|
||||||
|
sparse_policy = SparsePolicyType.XATTN_BSA
|
||||||
|
print(f"\n[XAttention BSA] prefill: XAttn (tau={args.xattn_threshold}, stride={args.xattn_stride}), decode: Full")
|
||||||
else:
|
else:
|
||||||
sparse_policy = SparsePolicyType.FULL
|
sparse_policy = SparsePolicyType.FULL
|
||||||
print("\n[Full Attention] baseline (no sparse)")
|
print("\n[Full Attention] baseline (no sparse)")
|
||||||
|
|
||||||
print(f"[Config] max_len={max_len}, num_gpu_blocks={args.num_gpu_blocks}")
|
print(f"[Config] max_len={max_len}, num_gpu_blocks={args.num_gpu_blocks}, block_size={args.block_size}")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=args.enforce_eager,
|
||||||
max_model_len=max_len,
|
max_model_len=max_len,
|
||||||
max_num_batched_tokens=max_len,
|
max_num_batched_tokens=max_len,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
num_gpu_blocks=args.num_gpu_blocks,
|
||||||
|
kvcache_block_size=args.block_size,
|
||||||
sparse_policy=sparse_policy,
|
sparse_policy=sparse_policy,
|
||||||
|
# Quest parameters
|
||||||
sparse_topk_blocks=args.topk,
|
sparse_topk_blocks=args.topk,
|
||||||
sparse_threshold_blocks=args.threshold,
|
sparse_threshold_blocks=args.threshold,
|
||||||
|
# XAttention parameters
|
||||||
|
sparse_threshold=args.xattn_threshold,
|
||||||
|
sparse_stride=args.xattn_stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
|
|||||||
@@ -1,5 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
os.environ["VLLM_USE_V1"] = "1"
|
import sys
|
||||||
|
|
||||||
|
# Parse --use-v1 flag before importing vllm
|
||||||
|
use_v1 = "--use-v1" in sys.argv
|
||||||
|
if use_v1:
|
||||||
|
os.environ["VLLM_USE_V1"] = "1"
|
||||||
|
sys.argv.remove("--use-v1")
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_USE_V1"] = "0"
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
@@ -44,24 +53,28 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser(description="Benchmark vLLM performance (for comparison)")
|
parser = argparse.ArgumentParser(description="Benchmark vLLM performance (for comparison)")
|
||||||
|
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
|
||||||
|
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
|
||||||
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||||
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||||
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
|
parser.add_argument("--gpu-util", type=float, default=0.9, help="GPU memory utilization (default: 0.9)")
|
||||||
|
parser.add_argument("--enforce-eager", action="store_true", help="Disable CUDA Graphs (use eager mode)")
|
||||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser(args.model)
|
||||||
max_len = args.max_len
|
max_len = args.max_len
|
||||||
|
|
||||||
print(f"\n[vLLM] max_len={max_len}")
|
print(f"\n[vLLM] max_len={max_len}, gpu_util={args.gpu_util}, enforce_eager={args.enforce_eager}")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=args.enforce_eager,
|
||||||
max_model_len=max_len,
|
max_model_len=max_len,
|
||||||
max_num_seqs=128,
|
max_num_seqs=128,
|
||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=args.gpu_util,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
|
|||||||
199
docs/bench_offload_results.md
Normal file
199
docs/bench_offload_results.md
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# CPU Offload Benchmark Results
|
||||||
|
|
||||||
|
本文档记录 `bench_offload.py` 在不同配置下的性能测试结果。
|
||||||
|
|
||||||
|
## 测试环境
|
||||||
|
|
||||||
|
| 参数 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| GPU | NVIDIA A100-SXM4-80GB |
|
||||||
|
| 模型 | Llama-3.1-8B-Instruct |
|
||||||
|
| GPU slots | 4 |
|
||||||
|
|
||||||
|
## Sparse Policy 配置
|
||||||
|
|
||||||
|
| 策略 | Prefill | Decode | 说明 |
|
||||||
|
|------|---------|--------|------|
|
||||||
|
| FULL | Full Attention | Full Attention | 基线,加载所有 blocks |
|
||||||
|
| XATTN_BSA | XAttention (tau=0.95, stride=8) | Full Attention (fallback) | 稀疏 prefill |
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### Block Size 4096 (推荐)
|
||||||
|
|
||||||
|
#### GPU-only 模式
|
||||||
|
|
||||||
|
| 上下文 | Full Attention | XAttention | 相对性能 |
|
||||||
|
|--------|----------------|------------|----------|
|
||||||
|
| 32K | 4863 tok/s | 5587 tok/s | **+14.9%** ✅ |
|
||||||
|
| 64K | 3373 tok/s | 4766 tok/s | **+41.3%** ✅ |
|
||||||
|
|
||||||
|
#### CPU Offload 模式 (优化后, 2026-01-28)
|
||||||
|
|
||||||
|
| 上下文 | Full Attention | XAttention | 相对性能 |
|
||||||
|
|--------|----------------|------------|----------|
|
||||||
|
| 32K | 4678 tok/s | 4398 tok/s | **-6.0%** |
|
||||||
|
| 64K | 3331 tok/s | 3203 tok/s | **-3.8%** |
|
||||||
|
| 128K | 2144 tok/s | 2196 tok/s | **+2.4%** ✅ |
|
||||||
|
|
||||||
|
#### CPU Offload 模式 (优化前, 2026-01-27)
|
||||||
|
|
||||||
|
| 上下文 | Full Attention | XAttention | 相对性能 |
|
||||||
|
|--------|----------------|------------|----------|
|
||||||
|
| 32K | 4648 tok/s | 4002 tok/s | **-13.9%** ❌ |
|
||||||
|
| 64K | 3329 tok/s | 2642 tok/s | **-20.6%** ❌ |
|
||||||
|
| 128K | 2122 tok/s | 867 tok/s | **-59.1%** ❌ |
|
||||||
|
|
||||||
|
### Block Size 256 (小 block 测试)
|
||||||
|
|
||||||
|
#### CPU Offload 模式 (64K)
|
||||||
|
|
||||||
|
| 策略 | 耗时 | 吞吐量 | 相对性能 |
|
||||||
|
|------|------|--------|----------|
|
||||||
|
| Full Attention | 401.04s | 163.41 tok/s | baseline |
|
||||||
|
| XAttention BSA | 390.35s | 167.89 tok/s | **+2.7%** ✅ |
|
||||||
|
|
||||||
|
### Block Size 1024 (历史测试)
|
||||||
|
|
||||||
|
#### CPU Offload 模式
|
||||||
|
|
||||||
|
| 上下文 | Full Attention | XAttention | 相对性能 |
|
||||||
|
|--------|----------------|------------|----------|
|
||||||
|
| 32K | 1587.74 tok/s | 1172.33 tok/s | -26% |
|
||||||
|
| 128K | 552.63 tok/s | 466.17 tok/s | -16% |
|
||||||
|
|
||||||
|
## 关键发现
|
||||||
|
|
||||||
|
### 1. GPU-only vs CPU Offload 模式差异
|
||||||
|
|
||||||
|
| 模式 | XAttention 效果 | 原因 |
|
||||||
|
|------|-----------------|------|
|
||||||
|
| **GPU-only** | ✅ 显著加速 (+15% ~ +41%) | 计算是瓶颈,稀疏注意力减少 FLOPs |
|
||||||
|
| **CPU Offload (优化后)** | ✅ 长上下文略有收益 | estimate_block_size 优化减少估计开销 |
|
||||||
|
| **CPU Offload (优化前)** | ❌ 性能下降 (-14% ~ -59%) | 传输是瓶颈,稀疏估计增加额外开销 |
|
||||||
|
|
||||||
|
### 2. Block Size 对性能的影响
|
||||||
|
|
||||||
|
| Block Size | 64K Full (Offload) | 特点 |
|
||||||
|
|------------|-------------------|------|
|
||||||
|
| 4096 | 3329 tok/s | ⭐ 最佳性能 |
|
||||||
|
| 1024 | ~1500 tok/s | 中等 |
|
||||||
|
| 256 | 163 tok/s | 极慢(20x 下降) |
|
||||||
|
|
||||||
|
**原因**: 更小的 block = 更多的 blocks = 更多 H2D 传输开销
|
||||||
|
|
||||||
|
### 3. XAttention 在小 Block Size 下反转
|
||||||
|
|
||||||
|
当 block size = 256 时,XAttention 反而略有优势 (+2.7%):
|
||||||
|
- 256 个 blocks (vs 16 个 @ 4096)
|
||||||
|
- 稀疏跳过的 blocks 比例更明显
|
||||||
|
- 但绝对性能极差,不推荐使用
|
||||||
|
|
||||||
|
### 4. estimate_block_size 优化效果 (2026-01-28)
|
||||||
|
|
||||||
|
```
|
||||||
|
Offload 模式 XAttention 相对性能变化:
|
||||||
|
优化前 优化后 改进
|
||||||
|
32K: -13.9% -6.0% +7.9pp
|
||||||
|
64K: -20.6% -3.8% +16.8pp
|
||||||
|
128K: -59.1% +2.4% +61.5pp ✅
|
||||||
|
```
|
||||||
|
|
||||||
|
优化内容:
|
||||||
|
- `estimate_block_size` 从 4096 改为 1024
|
||||||
|
- `softmax_fuse_block_sum` kernel 时间从 48% 降到 1% (44x 加速)
|
||||||
|
- 选择策略从 mask + voting 改为 score + threshold
|
||||||
|
|
||||||
|
优化后结论:
|
||||||
|
- **128K 长上下文 XAttention 反超 Full Attention**
|
||||||
|
- 短上下文仍有少量开销,但已显著减少
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
### 推荐配置 (优化后, 2026-01-28)
|
||||||
|
|
||||||
|
| 场景 | 推荐策略 | Block Size |
|
||||||
|
|------|----------|------------|
|
||||||
|
| GPU-only (VRAM 充足) | XAttention | 4096 |
|
||||||
|
| CPU Offload (128K+) | XAttention | 4096 |
|
||||||
|
| CPU Offload (32K-64K) | Full Attention 或 XAttention | 4096 |
|
||||||
|
|
||||||
|
### XAttention 适用条件 (优化后)
|
||||||
|
|
||||||
|
✅ **适合**:
|
||||||
|
- GPU-only 模式(计算密集)
|
||||||
|
- CPU Offload + 长上下文(128K+)有正向收益
|
||||||
|
- 长上下文(64K+)收益更大
|
||||||
|
|
||||||
|
⚠️ **中性**:
|
||||||
|
- CPU Offload + 中等上下文(32K-64K):略慢 3-6%,可接受
|
||||||
|
|
||||||
|
❌ **不推荐**:
|
||||||
|
- 短上下文(<32K)收益不明显
|
||||||
|
|
||||||
|
## 运行命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GPU-only 模式
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench.py --max-len 65536 --block-size 4096 --gpu-util 0.7
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench.py --max-len 65536 --block-size 4096 --gpu-util 0.7 --policy xattn
|
||||||
|
|
||||||
|
# CPU Offload 模式 (推荐 block-size 4096)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 4096
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 4096 --enable-xattn
|
||||||
|
|
||||||
|
# CPU Offload 模式 (小 block size 测试)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 256
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 256 --enable-xattn
|
||||||
|
|
||||||
|
# 调整 XAttention 参数
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --enable-xattn --xattn-threshold 0.8 --xattn-stride 16
|
||||||
|
```
|
||||||
|
|
||||||
|
## FlashInfer Merge 优化 (2026-01-28)
|
||||||
|
|
||||||
|
将 Triton 实现的 `merge_attention_outputs` 替换为 FlashInfer 的 `cascade.merge_state`。
|
||||||
|
|
||||||
|
### 性能对比 (Full Attention, block-size 4096)
|
||||||
|
|
||||||
|
| 上下文 | Triton merge | FlashInfer merge | 提升 |
|
||||||
|
|--------|--------------|------------------|------|
|
||||||
|
| 32K | 4678 tok/s | 4717 tok/s | **+0.8%** |
|
||||||
|
| 64K | 3331 tok/s | 3411 tok/s | **+2.4%** |
|
||||||
|
| 128K | 2144 tok/s | 2178 tok/s | **+1.6%** |
|
||||||
|
|
||||||
|
### 关键发现
|
||||||
|
|
||||||
|
1. **端到端提升有限**(0.8% ~ 2.4%):merge 操作不是主要瓶颈
|
||||||
|
- H2D 传输占主导(64K 传输 64GB)
|
||||||
|
- Attention 计算是另一主要耗时
|
||||||
|
- Merge 在总耗时中占比很小
|
||||||
|
|
||||||
|
2. **Merge kernel 单独对比**(长序列时 FlashInfer 优势明显):
|
||||||
|
|
||||||
|
| seq_len | heads | Triton (ms) | FlashInfer (ms) | Speedup |
|
||||||
|
|---------|-------|-------------|-----------------|---------|
|
||||||
|
| 4096 | 32 | 0.129 | 0.087 | **1.49x** |
|
||||||
|
| 8192 | 32 | 0.251 | 0.147 | **1.70x** |
|
||||||
|
| 16384 | 32 | 0.499 | 0.274 | **1.82x** |
|
||||||
|
|
||||||
|
3. **短序列 FlashInfer 反而慢**:格式转换开销(squeeze, transpose, contiguous)
|
||||||
|
|
||||||
|
### 技术细节
|
||||||
|
|
||||||
|
- **LSE 格式差异**:FlashInfer 使用 log2,flash_attn 使用 ln
|
||||||
|
- **转换系数**:`LOG2_E = 1.4427`(ln → log2),`LN_2 = 0.6931`(log2 → ln)
|
||||||
|
- **FlashInfer attention JIT 问题**:CUDA 版本兼容性问题,仅使用 merge_state
|
||||||
|
|
||||||
|
### 代码位置
|
||||||
|
|
||||||
|
- `nanovllm/ops/chunked_attention.py`: `merge_attention_outputs_flashinfer()`
|
||||||
|
- `nanovllm/kvcache/sparse/full_policy.py`: 3 处 import 更新
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: 1 处 import 更新
|
||||||
|
|
||||||
|
## 更新记录
|
||||||
|
|
||||||
|
- 2026-01-28: **FlashInfer merge 替换 Triton merge**,端到端提升 0.8% ~ 2.4%
|
||||||
|
- 2026-01-28: **estimate_block_size 优化后重新测试**,128K XAttention 反超 Full (+2.4%)
|
||||||
|
- 2026-01-27: 添加 GPU-only vs Offload 对比,block size 影响分析
|
||||||
|
- 2026-01-27: 初始测试,Llama-3.1-8B-Instruct, A100 80GB
|
||||||
94
docs/changelog_2026-02-05.md
Normal file
94
docs/changelog_2026-02-05.md
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# Changelog 2026-02-05
|
||||||
|
|
||||||
|
## Bug Fixes
|
||||||
|
|
||||||
|
### XAttention Offload GQA Buffer OOM Fix
|
||||||
|
|
||||||
|
**Issue**: `docs/issue_xattn_offload_gqa_buffer_oom.md`
|
||||||
|
|
||||||
|
**Problem**: 在 XAttention BSA + CPU Offload 模式下,`alloc_policy_metadata()` 分配了只有 GPU-only 模式才需要的 GQA expansion buffers (`_k_expanded`, `_v_expanded`),导致 24GB GPU (RTX 3090) 上 OOM。
|
||||||
|
|
||||||
|
**Root Cause**:
|
||||||
|
- GQA buffer 大小: `2 × num_heads × max_seq_len × head_dim × dtype_size`
|
||||||
|
- 对于 1M max_seq_len: 2 × 32 × 1048576 × 128 × 2 = **16 GB**
|
||||||
|
- Offload 模式的 `compute_chunked_prefill()` 不需要这些 buffer
|
||||||
|
|
||||||
|
**Fix** (commit `11a867f`):
|
||||||
|
1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数
|
||||||
|
2. `nanovllm/kvcache/sparse/xattn_bsa.py`: offload 模式跳过 GQA buffer 分配
|
||||||
|
3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数
|
||||||
|
|
||||||
|
**Memory Savings**:
|
||||||
|
| max_model_len | 修复前 | 修复后 |
|
||||||
|
|---------------|--------|--------|
|
||||||
|
| 72K | +1.1 GB | 0 GB |
|
||||||
|
| 1M | +16 GB | 0 GB |
|
||||||
|
|
||||||
|
**Verification**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
- 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers`
|
||||||
|
- 测试结果: 100% 准确率
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Code Cleanup
|
||||||
|
|
||||||
|
### Tests Directory Cleanup
|
||||||
|
|
||||||
|
**Commits**: `a709551`, `2b61c5a`, `d35dd76`
|
||||||
|
|
||||||
|
删除了 16 个冗余/过时的测试文件,保留核心测试:
|
||||||
|
|
||||||
|
**保留的文件** (4 个):
|
||||||
|
| 文件 | 用途 |
|
||||||
|
|------|------|
|
||||||
|
| `test_ruler.py` | 核心 RULER benchmark (13 tasks, 100 samples) |
|
||||||
|
| `test_xattn_estimate_alignment.py` | XAttn kernel 一致性验证 |
|
||||||
|
| `utils.py` | 共享工具函数 |
|
||||||
|
| `__init__.py` | 包标记 |
|
||||||
|
|
||||||
|
**删除的文件** (16 个, -4306 行):
|
||||||
|
|
||||||
|
| 类别 | 文件 | 删除原因 |
|
||||||
|
|------|------|----------|
|
||||||
|
| XAttn 测试 | `test_xattn_bsa.py` | 功能被 test_ruler 覆盖 |
|
||||||
|
| | `test_xattn_chunked.py` | 与 estimate_chunked 重复 |
|
||||||
|
| | `test_xattn_estimate_chunked.py` | chunked prefill 验证 |
|
||||||
|
| | `test_xattn_kernels.py` | Triton kernel 单元测试 |
|
||||||
|
| | `test_xattn_kv_chunking_batch.py` | batch 验证 |
|
||||||
|
| Needle 测试 | `test_needle.py` | 被 test_ruler NIAH 任务覆盖 |
|
||||||
|
| | `test_needle_ref.py` | HF 参考实现 |
|
||||||
|
| CUDA Graph | `test_chunk_attention_graph.py` | 被 graph_reuse 取代 |
|
||||||
|
| | `test_chunk_attention_graph_reuse.py` | 实验性功能 |
|
||||||
|
| | `test_cudagraph_memory.py` | 内存分析工具 |
|
||||||
|
| 其他 | `test_gpuonly_density_alignment.py` | GPU-only 密度测试 |
|
||||||
|
| | `test_hierarchical_estimate.py` | 分层估计测试 |
|
||||||
|
| | `test_quest_policy.py` | Quest 策略测试 |
|
||||||
|
| | `test_sequential.py` | 状态隔离测试 |
|
||||||
|
| | `bench_estimate_block_size.py` | 性能 benchmark |
|
||||||
|
| | `modeling_qwen3.py` | Qwen3 参考模型 |
|
||||||
|
|
||||||
|
**Note**: 所有删除的文件可从 git 历史恢复:
|
||||||
|
```bash
|
||||||
|
git checkout <commit-hash>^ -- tests/<filename>
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
| 类型 | 数量 | 影响 |
|
||||||
|
|------|------|------|
|
||||||
|
| Bug Fix | 1 | 节省 16GB 显存 (1M seq) |
|
||||||
|
| 文件删除 | 16 | -4306 行代码 |
|
||||||
|
| 新增文档 | 1 | 本文件 |
|
||||||
1078
docs/chunked_attention_solutions.md
Normal file
1078
docs/chunked_attention_solutions.md
Normal file
File diff suppressed because it is too large
Load Diff
300
docs/cpu_offload_optimization_strategies.md
Normal file
300
docs/cpu_offload_optimization_strategies.md
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
# CPU Offload 优化策略
|
||||||
|
|
||||||
|
本文档记录 CPU Offload 场景下的性能优化策略分析,包括实际可行的方案和前沿研究方向。
|
||||||
|
|
||||||
|
## 问题回顾
|
||||||
|
|
||||||
|
根据 [CPU 调度延迟分析](cpu_scheduling_latency_analysis.md),当前 chunked attention pipeline 的主要问题:
|
||||||
|
|
||||||
|
| 指标 | 当前值 | 理论值 |
|
||||||
|
|------|--------|--------|
|
||||||
|
| Flash kernel 执行时间 | ~138 μs | - |
|
||||||
|
| Flash kernel 间隔 | ~942 μs | ~211 μs (仅 H2D + merge) |
|
||||||
|
| GPU 利用率 | **12.8%** | **39.5%** (理论上限) |
|
||||||
|
| CPU 调度空闲占比 | **77-81%** | 0% |
|
||||||
|
|
||||||
|
**瓶颈根源**:每个 block 都经过完整的 Python 循环,导致大量 CPU 调度延迟。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 优化方案一:调大 Chunk Size(推荐)
|
||||||
|
|
||||||
|
### 核心洞察
|
||||||
|
|
||||||
|
**Merge 多个小 chunk 和直接使用大 chunk 是等效的**:
|
||||||
|
|
||||||
|
```
|
||||||
|
方案 A: Merge 4 个小 chunks
|
||||||
|
[H2D 2K][H2D 2K][H2D 2K][H2D 2K] → concat → [Flash 8K] → merge
|
||||||
|
|
||||||
|
方案 B: 直接用大 chunk
|
||||||
|
[H2D 8K] → [Flash 8K] → merge
|
||||||
|
|
||||||
|
计算结果完全等效!
|
||||||
|
```
|
||||||
|
|
||||||
|
### 收益分析
|
||||||
|
|
||||||
|
| 指标 | 小 chunk (2K) × 4 | 大 chunk (8K) × 1 |
|
||||||
|
|------|-------------------|-------------------|
|
||||||
|
| H2D 次数 | 4 | 1 |
|
||||||
|
| Flash kernel 调用 | 4 | 1 |
|
||||||
|
| Merge 调用 | 4 | 1 |
|
||||||
|
| Python 循环次数 | 4 | 1 |
|
||||||
|
| CPU 调度开销 | 4 × ~300μs = 1200μs | 1 × ~300μs = 300μs |
|
||||||
|
|
||||||
|
**本质**:CPU 调度延迟问题的根源是循环次数太多,调大 chunk size 直接减少循环次数。
|
||||||
|
|
||||||
|
### Trade-off
|
||||||
|
|
||||||
|
1. **GPU 内存增加**
|
||||||
|
- 2K chunk: 每 slot ~4MB (K+V)
|
||||||
|
- 8K chunk: 每 slot ~16MB (K+V)
|
||||||
|
- 4 slots = 64MB,对 80GB A100 影响很小
|
||||||
|
|
||||||
|
2. **单次 H2D 时间变长**
|
||||||
|
- H2D 8K ≈ 350μs
|
||||||
|
- Flash 8K ≈ 550μs
|
||||||
|
- 因为 Flash > H2D,pipeline 仍然有效
|
||||||
|
|
||||||
|
### 配置方法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 测试不同 block size
|
||||||
|
python bench_offload.py --kvcache-block-size 2048 # 基准
|
||||||
|
python bench_offload.py --kvcache-block-size 4096 # 2x
|
||||||
|
python bench_offload.py --kvcache-block-size 8192 # 4x
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 优化方案二:CUDA Graph(适用于非 Attention 部分)
|
||||||
|
|
||||||
|
### CUDA Graph 在 Offload 场景的局限性
|
||||||
|
|
||||||
|
CUDA Graph 的前提:所有操作在 capture 时确定,数据地址固定。
|
||||||
|
|
||||||
|
**Offload 场景的现实**:
|
||||||
|
1. **H2D 源地址动态** - 每次从不同的 CPU block 加载
|
||||||
|
2. **加载决策在运行时** - 哪些 block 需要加载是动态的
|
||||||
|
3. **CPU 必须协调** - H2D 和 Compute 的同步需要 CPU 参与
|
||||||
|
|
||||||
|
```
|
||||||
|
Offload 场景:
|
||||||
|
┌─────────────────────────────────────────┐
|
||||||
|
│ 数据在 CPU,需要动态加载 │
|
||||||
|
│ [H2D_i] → [Compute] → [H2D_{i+n}] → ...│
|
||||||
|
│ ↑ 动态、CPU 必须参与调度 │
|
||||||
|
└─────────────────────────────────────────┘
|
||||||
|
|
||||||
|
即使用 Graph:
|
||||||
|
Python: [wait_h2d] [replay] [launch_h2d] [wait_h2d] [replay] ...
|
||||||
|
↑ CPU 参与 ↑ CPU 参与 ↑ CPU 参与
|
||||||
|
|
||||||
|
CPU 调度开销仍然存在,Graph 只优化了中间的 compute 部分。
|
||||||
|
```
|
||||||
|
|
||||||
|
**结论**:CUDA Graph 不是 Offload 场景的银弹。
|
||||||
|
|
||||||
|
### 适用场景:MLP 和 Projection 层
|
||||||
|
|
||||||
|
LLM 每层的计算流程:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ [LayerNorm] → [QKV Proj] → [Attention] → [O Proj] → [Add] │
|
||||||
|
│ ↑ │
|
||||||
|
│ KV Offload │
|
||||||
|
│ [LayerNorm] → [MLP: gate + up + down] → [Add] │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
| 组件 | 涉及 Offload | 能用 CUDA Graph |
|
||||||
|
|------|-------------|-----------------|
|
||||||
|
| LayerNorm | ❌ | ✅ |
|
||||||
|
| QKV Projection | ❌ | ✅ |
|
||||||
|
| **Attention** | ✅ | ❌ |
|
||||||
|
| Output Projection | ❌ | ✅ |
|
||||||
|
| MLP (FFN) | ❌ | ✅ |
|
||||||
|
|
||||||
|
**只有 Attention 涉及动态 KV Cache 加载,其余都是"纯计算",可以用 CUDA Graph。**
|
||||||
|
|
||||||
|
### 实现方案
|
||||||
|
|
||||||
|
```python
|
||||||
|
class OptimizedLayer:
|
||||||
|
def __init__(self, layer):
|
||||||
|
# Graph 1: Attention 之前
|
||||||
|
self.graph_pre_attn = capture([
|
||||||
|
layer.input_layernorm,
|
||||||
|
layer.self_attn.q_proj,
|
||||||
|
layer.self_attn.k_proj,
|
||||||
|
layer.self_attn.v_proj,
|
||||||
|
])
|
||||||
|
|
||||||
|
# Graph 2: Attention 之后 + MLP
|
||||||
|
self.graph_post_attn = capture([
|
||||||
|
layer.self_attn.o_proj,
|
||||||
|
# residual add
|
||||||
|
layer.post_attention_layernorm,
|
||||||
|
layer.mlp.gate_proj,
|
||||||
|
layer.mlp.up_proj,
|
||||||
|
layer.mlp.down_proj,
|
||||||
|
# residual add
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, kv_cache):
|
||||||
|
# Pre-attention (CUDA Graph)
|
||||||
|
self.graph_pre_attn.replay()
|
||||||
|
|
||||||
|
# Attention with offload (动态,不能用 graph)
|
||||||
|
attn_output = chunked_attention_with_offload(q, kv_cache)
|
||||||
|
|
||||||
|
# Post-attention + MLP (CUDA Graph)
|
||||||
|
self.graph_post_attn.replay()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 收益估算
|
||||||
|
|
||||||
|
MLP 每层典型操作 launch 开销:
|
||||||
|
- `gate_proj`, `up_proj`, `act_fn`, `gate * up`, `down_proj`, `residual add`
|
||||||
|
- 每个操作 ~30-50μs launch 开销,总计 ~200μs/层
|
||||||
|
- 用 CUDA Graph:~30μs/层
|
||||||
|
|
||||||
|
**32 层 × 170μs 节省 ≈ 5.4ms**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 优化方案三:前沿研究方向
|
||||||
|
|
||||||
|
### 1. InfiniGen - 投机预取 (OSDI'24)
|
||||||
|
|
||||||
|
**核心思想**:不需要加载所有 KV,只预取"重要"的 token。
|
||||||
|
|
||||||
|
```
|
||||||
|
关键洞察:相邻层的 attention pattern 高度相似
|
||||||
|
↓
|
||||||
|
用第 L 层的 attention score 预测第 L+1 层需要哪些 token
|
||||||
|
↓
|
||||||
|
只预取 top-k 重要的 KV entries(而不是全部)
|
||||||
|
```
|
||||||
|
|
||||||
|
**技术实现**:
|
||||||
|
- 用当前层的 Q 和下一层的部分 K 做"预演"
|
||||||
|
- 预测下一层的 attention 分布
|
||||||
|
- 异步预取预测的重要 token
|
||||||
|
- **减少 PCIe 带宽浪费,而不是加速传输**
|
||||||
|
|
||||||
|
**效果**:最高 **3x 加速**
|
||||||
|
|
||||||
|
**参考**:[InfiniGen (OSDI'24)](https://www.usenix.org/conference/osdi24/presentation/lee)
|
||||||
|
|
||||||
|
### 2. ShadowKV - 低秩压缩 + Sparse Offload (ICML'25 Spotlight)
|
||||||
|
|
||||||
|
**核心思想**:Key 压缩存 GPU,Value offload 到 CPU,只加载 1.56% 的 KV。
|
||||||
|
|
||||||
|
```
|
||||||
|
Pre-filling:
|
||||||
|
┌─────────────────────────────────────────────────┐
|
||||||
|
│ Key Cache → SVD 低秩压缩 → 保留在 GPU │
|
||||||
|
│ Value Cache → Offload 到 CPU │
|
||||||
|
│ 计算每个 chunk 的 landmark (均值) │
|
||||||
|
│ 识别 outlier tokens → 保留在 GPU │
|
||||||
|
└─────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Decoding:
|
||||||
|
┌─────────────────────────────────────────────────┐
|
||||||
|
│ 用 landmarks 快速估计 attention score │
|
||||||
|
│ 只加载 top-k 重要的 Value (1.56% sparse) │
|
||||||
|
│ 结合 GPU 上的 outliers 计算最终结果 │
|
||||||
|
└─────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**效果**:6x 更大 batch size,**3.04x 吞吐提升**
|
||||||
|
|
||||||
|
**参考**:[ShadowKV (ByteDance)](https://github.com/ByteDance-Seed/ShadowKV)
|
||||||
|
|
||||||
|
### 3. L2 Cache 异步预取 (2025)
|
||||||
|
|
||||||
|
**核心思想**:利用 GPU L2 Cache 做预取,在计算时预取下一批 KV。
|
||||||
|
|
||||||
|
```
|
||||||
|
传统:
|
||||||
|
Compute: [Flash_i] [Flash_{i+1}]
|
||||||
|
H2D: [H2D_{i+1}]
|
||||||
|
↑ 等待
|
||||||
|
|
||||||
|
L2 Prefetch:
|
||||||
|
Compute: [Flash_i + Prefetch_{i+1} to L2] [Flash_{i+1} L2 hit]
|
||||||
|
↑ 计算时利用空闲 memory bandwidth 预取
|
||||||
|
```
|
||||||
|
|
||||||
|
**技术**:
|
||||||
|
- 在 Flash Attention kernel 内部发起预取指令
|
||||||
|
- 利用计算时的空闲 memory bandwidth
|
||||||
|
- 下一次访问直接 L2 hit
|
||||||
|
|
||||||
|
**效果**:**2.15x attention kernel 效率**,1.97x 端到端吞吐
|
||||||
|
|
||||||
|
**参考**:[Asynchronous KV Cache Prefetching (2025)](https://arxiv.org/abs/2504.06319)
|
||||||
|
|
||||||
|
### 4. KVPR - I/O-Aware 调度 (ACL'25)
|
||||||
|
|
||||||
|
**核心思想**:计算最优的 recompute vs offload 比例。
|
||||||
|
|
||||||
|
```
|
||||||
|
权衡:
|
||||||
|
- Recompute: 重新计算 KV(用 GPU 算力换内存)
|
||||||
|
- Offload: 从 CPU 加载(用 PCIe 带宽换算力)
|
||||||
|
|
||||||
|
KVPR: 根据当前负载动态决定最优比例
|
||||||
|
+ 预取技术重叠数据传输和计算
|
||||||
|
```
|
||||||
|
|
||||||
|
**参考**:[KVPR (ACL'25)](https://aclanthology.org/2025.findings-acl.997.pdf)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 优化策略总结
|
||||||
|
|
||||||
|
### 推荐优先级
|
||||||
|
|
||||||
|
| 优先级 | 方案 | 核心优化 | 实现复杂度 | 预期收益 |
|
||||||
|
|--------|------|---------|-----------|---------|
|
||||||
|
| **P0** | 调大 chunk size | 减少循环次数 | 极低(改配置) | 2-4x |
|
||||||
|
| **P1** | MLP CUDA Graph | 减少 launch 开销 | 中 | ~5ms/request |
|
||||||
|
| **P2** | InfiniGen 式预取 | 只加载重要 token | 中高 | 2-3x |
|
||||||
|
| **P3** | ShadowKV 式压缩 | Key 压缩 + Sparse | 高 | 3x |
|
||||||
|
| **P3** | C++ Extension | 消除 Python 开销 | 高 | 2-3x |
|
||||||
|
|
||||||
|
### 策略分离原则
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Attention + Offload 部分: │
|
||||||
|
│ - 瓶颈:H2D 传输 + CPU 调度 │
|
||||||
|
│ - 优化:调大 chunk size / 投机预取 / Sparse │
|
||||||
|
│ │
|
||||||
|
│ MLP + Proj + Norm 部分: │
|
||||||
|
│ - 瓶颈:Kernel launch 开销 │
|
||||||
|
│ - 优化:CUDA Graph │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
两部分优化完全正交,可以组合使用。
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/kvcache/sparse/full_policy.py`: Chunked attention pipeline
|
||||||
|
- `nanovllm/kvcache/offload_engine.py`: H2D/D2H 传输管理
|
||||||
|
- `docs/cpu_scheduling_latency_analysis.md`: 问题分析
|
||||||
|
|
||||||
|
## 参考文献
|
||||||
|
|
||||||
|
1. [InfiniGen: Efficient Generative Inference of Large Language Models with Dynamic KV Cache Management](https://www.usenix.org/conference/osdi24/presentation/lee) - OSDI'24
|
||||||
|
2. [ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference](https://github.com/ByteDance-Seed/ShadowKV) - ICML'25 Spotlight
|
||||||
|
3. [Accelerating LLM Inference Throughput via Asynchronous KV Cache Prefetching](https://arxiv.org/abs/2504.06319) - 2025
|
||||||
|
4. [KVPR: Efficient LLM Inference with I/O-Aware KV Cache](https://aclanthology.org/2025.findings-acl.997.pdf) - ACL'25
|
||||||
|
5. [LMCache: An Efficient KV Cache Layer for Enterprise-Scale LLM Inference](https://lmcache.ai/tech_report.pdf) - 2025
|
||||||
177
docs/cpu_scheduling_latency_analysis.md
Normal file
177
docs/cpu_scheduling_latency_analysis.md
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
# CPU 调度延迟分析
|
||||||
|
|
||||||
|
## 问题概述
|
||||||
|
|
||||||
|
在分析 nsys profile 时发现,chunked attention pipeline 中存在大量的 **CPU 调度延迟**,导致 GPU 利用率显著下降。
|
||||||
|
|
||||||
|
## 观察数据
|
||||||
|
|
||||||
|
### 测试环境
|
||||||
|
- GPU: NVIDIA A100-SXM4-80GB
|
||||||
|
- 模型: Llama-3.1-8B-Instruct
|
||||||
|
- 测试: RULER niah_single_1, 64K context
|
||||||
|
- Profile 文件: `ruler_8slots_test.nsys-rep`
|
||||||
|
- 时间段: 92.982s - 93.038s
|
||||||
|
|
||||||
|
### Kernel 执行时间
|
||||||
|
|
||||||
|
| Kernel | 典型执行时间 |
|
||||||
|
|--------|-------------|
|
||||||
|
| flash_fwd_kernel | ~138 μs |
|
||||||
|
| H2D memcpy (2MB) | ~87 μs |
|
||||||
|
| merge_lse_kernel | ~3.5 μs |
|
||||||
|
| merge_output_kernel | ~34 μs |
|
||||||
|
|
||||||
|
### 操作间隙分析
|
||||||
|
|
||||||
|
从 cuda_gpu_trace 观察到的间隙:
|
||||||
|
|
||||||
|
```
|
||||||
|
Start (ms) Dur (μs) Gap (μs) Type
|
||||||
|
------------------------------------------------------------
|
||||||
|
92984.680 138.3 378.3 flash_fwd_kernel ← GAP!
|
||||||
|
92985.051 86.8 232.9 H2D memcpy ← GAP!
|
||||||
|
92985.141 86.8 2.8 H2D memcpy
|
||||||
|
92985.587 135.9 360.0 flash_fwd_kernel ← GAP!
|
||||||
|
92986.026 3.4 302.4 merge_lse ← GAP!
|
||||||
|
92986.164 33.5 135.0 merge_output ← GAP!
|
||||||
|
92986.371 86.9 173.4 H2D memcpy ← GAP!
|
||||||
|
92986.461 86.8 2.7 H2D memcpy
|
||||||
|
92986.816 137.9 268.2 flash_fwd_kernel ← GAP!
|
||||||
|
```
|
||||||
|
|
||||||
|
### Flash Kernel 间隙分解
|
||||||
|
|
||||||
|
| 间隙 | 总时间 | 有效工作时间 | 空闲时间 |
|
||||||
|
|------|--------|-------------|---------|
|
||||||
|
| Flash 1 → Flash 2 | 769 μs | ~174 μs (2x H2D) | ~595 μs (77%) |
|
||||||
|
| Flash 2 → Flash 3 | 1092 μs | ~211 μs (merge + H2D) | ~881 μs (81%) |
|
||||||
|
| Flash 3 → Flash 4 | 965 μs | ~211 μs (merge + H2D) | ~754 μs (78%) |
|
||||||
|
|
||||||
|
**关键发现**: 每个 flash kernel 之间约 **77-81% 的时间是 CPU 调度空闲**。
|
||||||
|
|
||||||
|
## 间隙来源分析
|
||||||
|
|
||||||
|
### 1. CPU 调度延迟类型
|
||||||
|
|
||||||
|
| 转换 | 典型延迟 | 原因 |
|
||||||
|
|------|---------|------|
|
||||||
|
| Kernel 结束 → 下一个 Kernel 开始 | 100-400 μs | CPU 准备参数、调用 CUDA driver |
|
||||||
|
| Flash 结束 → H2D 开始 | ~233 μs | Python 代码执行 + CUDA launch |
|
||||||
|
| H2D 结束 → Flash 开始 | ~360 μs | 同步等待 + kernel launch |
|
||||||
|
| Flash 结束 → merge 开始 | ~302 μs | Python 代码执行 |
|
||||||
|
|
||||||
|
### 2. 延迟产生的代码位置
|
||||||
|
|
||||||
|
```python
|
||||||
|
# full_policy.py: compute_chunked_prefill
|
||||||
|
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
# 1. 等待 H2D 完成 (同步点)
|
||||||
|
offload_engine.wait_slot_layer(current_slot) # ← 可能引入延迟
|
||||||
|
|
||||||
|
# 2. 获取 KV 数据
|
||||||
|
k_block, v_block = offload_engine.get_kv_for_slot(current_slot)
|
||||||
|
|
||||||
|
# 3. 调用 flash attention (kernel launch)
|
||||||
|
block_out, block_lse = flash_attn_with_kvcache(...) # ← CPU 调度延迟
|
||||||
|
|
||||||
|
# 4. merge 操作
|
||||||
|
merge_output(...) # ← CPU 调度延迟
|
||||||
|
merge_lse(...) # ← CPU 调度延迟
|
||||||
|
|
||||||
|
# 5. 发起下一个 H2D (异步)
|
||||||
|
offload_engine.load_to_slot_layer(next_slot, ...) # ← CPU 调度延迟
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 为什么 H2D 之间间隙小
|
||||||
|
|
||||||
|
注意到连续的 H2D memcpy 之间间隙只有 ~2.7 μs,这是因为:
|
||||||
|
- 它们在同一个 stream 上连续发起
|
||||||
|
- CUDA driver 可以批量处理
|
||||||
|
- 没有 Python 代码介入
|
||||||
|
|
||||||
|
## GPU 利用率计算
|
||||||
|
|
||||||
|
基于观察数据:
|
||||||
|
|
||||||
|
| 指标 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| Flash kernel 平均执行时间 | 138 μs |
|
||||||
|
| Flash kernel 平均间隔 | 942 μs |
|
||||||
|
| Flash kernel GPU 利用率 | 138 / (138 + 942) = **12.8%** |
|
||||||
|
|
||||||
|
如果消除 CPU 调度延迟(仅保留必要的 H2D + merge):
|
||||||
|
|
||||||
|
| 指标 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| 必要间隔 (2x H2D + merge) | ~211 μs |
|
||||||
|
| 理论 GPU 利用率 | 138 / (138 + 211) = **39.5%** |
|
||||||
|
|
||||||
|
**潜在提升**: 3x GPU 利用率
|
||||||
|
|
||||||
|
## 优化方向
|
||||||
|
|
||||||
|
### 1. CUDA Graph
|
||||||
|
将整个 block 处理流程编译为 CUDA Graph,消除重复的 kernel launch 开销。
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 伪代码
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
# 预录制 flash + merge 操作
|
||||||
|
block_out, block_lse = flash_attn_with_kvcache(...)
|
||||||
|
merge_output(...)
|
||||||
|
merge_lse(...)
|
||||||
|
|
||||||
|
# 运行时只需 replay
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
graph.replay() # 单次 launch,无 Python 介入
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 自定义 Triton Kernel
|
||||||
|
将 flash + merge 融合为单个 kernel,减少 kernel launch 次数。
|
||||||
|
|
||||||
|
### 3. C++ Extension
|
||||||
|
将 Python 循环移到 C++ 层,减少 Python 解释器开销。
|
||||||
|
|
||||||
|
### 4. 流水线重叠优化
|
||||||
|
确保 H2D 传输与前一个 block 的计算完全重叠:
|
||||||
|
|
||||||
|
```
|
||||||
|
Block 0: [H2D slot0] [Flash slot0] [merge]
|
||||||
|
Block 1: [H2D slot1] [Flash slot1] [merge]
|
||||||
|
Block 2: [H2D slot2] [Flash slot2] [merge]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 验证方法
|
||||||
|
|
||||||
|
### 1. 使用 nsys 分析间隙
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 生成 profile
|
||||||
|
bash scripts/profile_offload.sh --num-gpu-blocks 8
|
||||||
|
|
||||||
|
# 查看 kernel trace
|
||||||
|
nsys stats --report cuda_gpu_trace --format csv <file>.nsys-rep | \
|
||||||
|
awk -F',' 'NR>1 && $1 >= START && $1 <= END'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 计算间隙
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 从 trace 数据计算
|
||||||
|
prev_end = start + duration
|
||||||
|
gap = next_start - prev_end
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/kvcache/sparse/full_policy.py`: Pipeline 实现
|
||||||
|
- `nanovllm/kvcache/offload_engine.py`: H2D/D2H 传输
|
||||||
|
- `scripts/profile_offload.sh`: Profiling 脚本
|
||||||
|
|
||||||
|
## 参考
|
||||||
|
|
||||||
|
- [CUDA Graph 文档](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs)
|
||||||
|
- [nsys 用户指南](https://docs.nvidia.com/nsight-systems/UserGuide/index.html)
|
||||||
152
docs/cuda_graph_memory_guide.md
Normal file
152
docs/cuda_graph_memory_guide.md
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
# CUDA Graph 内存机制指南
|
||||||
|
|
||||||
|
本文档基于对 Qwen3-4B 模型的实际测试,详细分析 CUDA Graph 在 LLM 推理中的内存行为。
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
CUDA Graph 通过捕获 GPU kernel 执行序列并重放来减少 CPU 开销,从而提升推理性能。本指南重点分析其内存特性。
|
||||||
|
|
||||||
|
## 性能提升
|
||||||
|
|
||||||
|
| 模式 | Decode 吞吐量 | 说明 |
|
||||||
|
|------|--------------|------|
|
||||||
|
| Eager | ~25 tok/s | 每次推理重新调度 kernel |
|
||||||
|
| CUDA Graph | ~70 tok/s | 重放预录制的 kernel 序列 |
|
||||||
|
| **加速比** | **2.80x** | |
|
||||||
|
|
||||||
|
## 内存阶段分析
|
||||||
|
|
||||||
|
基于 Qwen3-4B (bf16) 在 RTX 3090 上的测试结果:
|
||||||
|
|
||||||
|
### 各阶段内存变化
|
||||||
|
|
||||||
|
| 阶段 | 内存 (MB) | 增量 | 说明 |
|
||||||
|
|------|-----------|------|------|
|
||||||
|
| 模型加载 | 7672 | +7672 | 模型权重 |
|
||||||
|
| StaticCache 分配 | 7816 | +144 | **主要开销** |
|
||||||
|
| Warmup (3次) | 7825 | +8 | 激活值缓存 |
|
||||||
|
| Graph 捕获 | 7833 | +8 | 存储 kernel 序列 |
|
||||||
|
| Graph Replay | 7833 | **0** | 零额外分配 |
|
||||||
|
|
||||||
|
### 关键发现
|
||||||
|
|
||||||
|
1. **Graph 捕获开销很小**:仅约 8 MB,用于存储 kernel 调用序列
|
||||||
|
|
||||||
|
2. **StaticCache 是主要开销**:
|
||||||
|
```
|
||||||
|
size = num_layers × 2 × batch_size × num_kv_heads × max_cache_len × head_dim × dtype_size
|
||||||
|
```
|
||||||
|
- Qwen3-4B (1024 tokens): 36 × 2 × 1 × 8 × 1024 × 128 × 2 = **144 MB**
|
||||||
|
|
||||||
|
3. **Graph Replay 零分配**:所有张量地址在 capture 时已固定,replay 只重放 kernel
|
||||||
|
|
||||||
|
## Cache 长度与内存关系
|
||||||
|
|
||||||
|
| Cache 长度 | 总开销 | 每 1K tokens |
|
||||||
|
|------------|--------|--------------|
|
||||||
|
| 256 | 53 MB | 206 MB |
|
||||||
|
| 512 | 89 MB | 174 MB |
|
||||||
|
| 1024 | 161 MB | 157 MB |
|
||||||
|
| 2048 | 305 MB | 149 MB |
|
||||||
|
| 4096 | 593 MB | 145 MB |
|
||||||
|
|
||||||
|
内存开销与 cache 长度近似线性关系,每 1K tokens 约需 145-160 MB。
|
||||||
|
|
||||||
|
## CUDA Graph 工作原理
|
||||||
|
|
||||||
|
### 核心要求:固定内存地址
|
||||||
|
|
||||||
|
CUDA Graph 要求所有张量在 capture 时地址固定,之后只能通过 `copy_()` 更新值:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 分配固定地址的张量
|
||||||
|
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
|
||||||
|
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# Capture 时使用这些张量
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
outputs = model(input_ids=static_input_ids, ...)
|
||||||
|
|
||||||
|
# Replay 时通过 copy_() 更新值(地址不变)
|
||||||
|
static_input_ids.copy_(new_token) # 更新输入
|
||||||
|
static_cache_position.fill_(position) # 更新位置
|
||||||
|
graph.replay() # 重放
|
||||||
|
```
|
||||||
|
|
||||||
|
### StaticCache vs DynamicCache
|
||||||
|
|
||||||
|
| 特性 | DynamicCache | StaticCache |
|
||||||
|
|------|--------------|-------------|
|
||||||
|
| 内存分配 | 按需增长 | 预分配固定大小 |
|
||||||
|
| 地址稳定性 | 不稳定 | 稳定 |
|
||||||
|
| CUDA Graph 兼容 | ❌ | ✅ |
|
||||||
|
| 内存效率 | 高(按需) | 低(预分配) |
|
||||||
|
|
||||||
|
### 典型工作流程
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Prefill (Eager)
|
||||||
|
└── 使用 DynamicCache 处理变长输入
|
||||||
|
|
||||||
|
2. 创建 StaticCache
|
||||||
|
└── 预分配 max_cache_len 大小的缓存
|
||||||
|
|
||||||
|
3. 复制 Prefill KV 到 StaticCache
|
||||||
|
└── 将 DynamicCache 内容拷贝到固定地址
|
||||||
|
|
||||||
|
4. Warmup (3次)
|
||||||
|
└── 确保所有 lazy initialization 完成
|
||||||
|
|
||||||
|
5. Capture Graph
|
||||||
|
└── 录制 decode 的 kernel 序列
|
||||||
|
|
||||||
|
6. Decode Loop
|
||||||
|
└── 更新输入 → graph.replay() → 读取输出
|
||||||
|
```
|
||||||
|
|
||||||
|
## 多 Batch Size Graph 的内存问题
|
||||||
|
|
||||||
|
如果为多个 batch size 分别捕获 graph(如 nanovllm 的设计),内存会快速增长:
|
||||||
|
|
||||||
|
| Batch Size | StaticCache (1024 tokens) | 累计 |
|
||||||
|
|------------|---------------------------|------|
|
||||||
|
| 1 | 144 MB | 144 MB |
|
||||||
|
| 2 | 288 MB | 432 MB |
|
||||||
|
| 4 | 576 MB | 1,008 MB |
|
||||||
|
| 8 | 1,152 MB | 2,160 MB |
|
||||||
|
| 16 | 2,304 MB | 4,464 MB |
|
||||||
|
| ... | ... | ... |
|
||||||
|
|
||||||
|
这是因为每个 batch size 需要独立的 StaticCache。实际系统(如 nanovllm)使用 PagedAttention 共享 KV cache 来避免此问题。
|
||||||
|
|
||||||
|
## 测试脚本
|
||||||
|
|
||||||
|
提供了测试脚本用于验证以上结论:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 基本内存分析
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tests/test_cudagraph_memory.py
|
||||||
|
|
||||||
|
# 指定 cache 长度
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tests/test_cudagraph_memory.py --max-cache-len 2048
|
||||||
|
|
||||||
|
# 测试 cache 长度缩放
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tests/test_cudagraph_memory.py --test-scaling
|
||||||
|
```
|
||||||
|
|
||||||
|
性能对比演示:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Eager vs CUDA Graph 性能对比
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tests/data/test_cudagraph_demo.py --mode both
|
||||||
|
```
|
||||||
|
|
||||||
|
## 总结
|
||||||
|
|
||||||
|
| 项目 | 结论 |
|
||||||
|
|------|------|
|
||||||
|
| 性能提升 | ~2.8x decode 吞吐量 |
|
||||||
|
| Graph 捕获开销 | ~8 MB(很小) |
|
||||||
|
| 主要内存开销 | StaticCache(与 cache_len 成正比) |
|
||||||
|
| Replay 内存 | 零额外分配 |
|
||||||
|
| 核心要求 | 固定张量地址 |
|
||||||
196
docs/cuda_graph_offload_guide.md
Normal file
196
docs/cuda_graph_offload_guide.md
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
# CUDA Graph Support for CPU Offload Mode
|
||||||
|
|
||||||
|
This document describes the CUDA graph implementation for the CPU offload decode path, which provides significant performance improvements for decode throughput.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
CUDA graphs capture a sequence of GPU operations and replay them with minimal CPU overhead. In offload mode, we capture per-layer graphs for the decode path, achieving **4x decode throughput improvement**.
|
||||||
|
|
||||||
|
## Performance Results
|
||||||
|
|
||||||
|
| Metric | Eager Mode | CUDA Graph | Improvement |
|
||||||
|
|--------|------------|------------|-------------|
|
||||||
|
| Decode Throughput | ~12 tok/s | ~50 tok/s | **4.2x** |
|
||||||
|
| TPOT (Time per output token) | ~80ms | ~19ms | **4.2x** |
|
||||||
|
| Prefill Throughput | ~8000 tok/s | ~8000 tok/s | Same |
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Why Standard CUDA Graph Capture Doesn't Work
|
||||||
|
|
||||||
|
The standard `capture_cudagraph()` captures the PagedAttention decode path:
|
||||||
|
- Uses block tables for scattered KV cache access
|
||||||
|
- `Attention.k_cache/v_cache` point to PagedAttention buffers
|
||||||
|
|
||||||
|
In offload mode, the decode path is different:
|
||||||
|
- Uses contiguous ring buffers for KV cache
|
||||||
|
- `Attention.k_cache/v_cache` dynamically point to ring buffer slices
|
||||||
|
- H2D transfers interleaved with compute
|
||||||
|
|
||||||
|
### Per-Layer Graph Design
|
||||||
|
|
||||||
|
We capture one CUDA graph per transformer layer:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Offload Decode with CUDA Graphs │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ Initialization: │
|
||||||
|
│ capture_offload_cudagraph() captures 36 layer graphs │
|
||||||
|
│ Each graph: layer.forward() with ring buffer as cache │
|
||||||
|
│ │
|
||||||
|
│ Decode Step: │
|
||||||
|
│ 1. Embedding (eager, outside graph) │
|
||||||
|
│ 2. For each layer: │
|
||||||
|
│ a. Wait for H2D load (outside graph) │
|
||||||
|
│ b. Copy decode KV to ring buffer (outside graph) │
|
||||||
|
│ c. Set Attention.k_cache = ring_buffer[buffer_idx] │
|
||||||
|
│ d. Set context (slot_mapping, context_lens) │
|
||||||
|
│ e. graph.replay() - layer forward │
|
||||||
|
│ f. synchronize() │
|
||||||
|
│ g. Copy layer_outputs -> hidden_states │
|
||||||
|
│ h. Copy new KV to decode buffer (outside graph) │
|
||||||
|
│ i. Start next layer H2D load │
|
||||||
|
│ 3. Final norm and logits (eager) │
|
||||||
|
│ │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Ring Buffer Mapping
|
||||||
|
|
||||||
|
Each layer maps to a ring buffer slot:
|
||||||
|
```python
|
||||||
|
buffer_idx = layer_id % num_kv_buffers
|
||||||
|
```
|
||||||
|
|
||||||
|
With 4 buffers and 36 layers:
|
||||||
|
- Layer 0, 4, 8, ... use buffer 0
|
||||||
|
- Layer 1, 5, 9, ... use buffer 1
|
||||||
|
- Layer 2, 6, 10, ... use buffer 2
|
||||||
|
- Layer 3, 7, 11, ... use buffer 3
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### Graph Capture (`capture_offload_cudagraph`)
|
||||||
|
|
||||||
|
Location: `model_runner.py:1075-1164`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def capture_offload_cudagraph(self):
|
||||||
|
# Fixed-address tensors for graph I/O
|
||||||
|
hidden_states = torch.randn(1, hidden_size, ...)
|
||||||
|
residual = torch.randn(1, hidden_size, ...)
|
||||||
|
layer_outputs = torch.zeros(1, hidden_size, ...)
|
||||||
|
layer_residual = torch.zeros(1, hidden_size, ...)
|
||||||
|
|
||||||
|
for layer_id in range(num_layers):
|
||||||
|
buffer_idx = layer_id % num_buffers
|
||||||
|
|
||||||
|
# Set Attention cache to ring buffer slice
|
||||||
|
attn_module.k_cache = ring_buffer[buffer_idx:buffer_idx+1]
|
||||||
|
attn_module.v_cache = ring_buffer[buffer_idx:buffer_idx+1]
|
||||||
|
|
||||||
|
# Set context for contiguous mode
|
||||||
|
set_context(is_prefill=False, slot_mapping=...,
|
||||||
|
context_lens=..., block_tables=None)
|
||||||
|
|
||||||
|
# Warmup and capture
|
||||||
|
with torch.cuda.graph(graph, pool):
|
||||||
|
out_h, out_r = layer(positions, hidden_states, residual)
|
||||||
|
layer_outputs.copy_(out_h)
|
||||||
|
layer_residual.copy_(out_r)
|
||||||
|
|
||||||
|
# Propagate state for next layer's capture
|
||||||
|
hidden_states.copy_(layer_outputs)
|
||||||
|
residual.copy_(layer_residual)
|
||||||
|
```
|
||||||
|
|
||||||
|
Key design decisions:
|
||||||
|
1. **Fixed-address tensors**: Graph inputs/outputs use pre-allocated tensors
|
||||||
|
2. **Include copy in graph**: `layer_outputs.copy_(out_h)` is captured
|
||||||
|
3. **State propagation**: Update hidden_states between layer captures
|
||||||
|
4. **Random initialization**: Use `randn` instead of zeros for realistic distributions
|
||||||
|
|
||||||
|
### Graph Replay (`run_layerwise_offload_decode`)
|
||||||
|
|
||||||
|
Location: `model_runner.py:844-1031`
|
||||||
|
|
||||||
|
```python
|
||||||
|
use_cuda_graph = not self.enforce_eager and hasattr(self, 'offload_graphs')
|
||||||
|
|
||||||
|
if use_cuda_graph:
|
||||||
|
# Use fixed-address tensors
|
||||||
|
graph_vars["positions"][0] = len(seq) - 1
|
||||||
|
graph_vars["slot_mapping"][0] = context_len
|
||||||
|
graph_vars["context_lens"][0] = context_len + 1
|
||||||
|
graph_vars["hidden_states"].copy_(embedding)
|
||||||
|
graph_vars["residual"].zero_()
|
||||||
|
|
||||||
|
for layer_id in range(num_layers):
|
||||||
|
# H2D and buffer setup (outside graph)
|
||||||
|
offload_engine.wait_buffer_load(current_buffer)
|
||||||
|
attn_module.k_cache = ring_buffer[current_buffer:current_buffer+1]
|
||||||
|
set_context(...)
|
||||||
|
|
||||||
|
if use_cuda_graph:
|
||||||
|
# Replay graph
|
||||||
|
self.offload_graphs[layer_id].replay()
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
|
# Copy outputs to inputs for next layer
|
||||||
|
if layer_id < num_layers - 1:
|
||||||
|
graph_vars["hidden_states"].copy_(graph_vars["layer_outputs"])
|
||||||
|
graph_vars["residual"].copy_(graph_vars["layer_residual"])
|
||||||
|
else:
|
||||||
|
# Eager execution
|
||||||
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
```
|
||||||
|
|
||||||
|
Key points:
|
||||||
|
1. **Synchronization required**: `synchronize()` after each graph replay
|
||||||
|
2. **Manual state propagation**: Copy layer_outputs to hidden_states between replays
|
||||||
|
3. **H2D outside graph**: Ring buffer loads happen before graph replay
|
||||||
|
|
||||||
|
## Limitations and Future Work
|
||||||
|
|
||||||
|
### Current Limitations
|
||||||
|
|
||||||
|
1. **Per-layer sync overhead**: Each layer requires synchronization
|
||||||
|
2. **No kernel fusion across layers**: Each layer is a separate graph
|
||||||
|
3. **Fixed batch size**: Only supports batch_size=1 for offload
|
||||||
|
|
||||||
|
### Future Optimization: Full-Decode Graph
|
||||||
|
|
||||||
|
Potential improvement: Capture entire decode step as single graph
|
||||||
|
- Complete all H2D loads before graph
|
||||||
|
- Single graph covers all 36 layers
|
||||||
|
- Better kernel fusion, less CPU overhead
|
||||||
|
- More complex to implement (handle buffer rotation inside graph)
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Run needle test with CUDA graph:
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py \
|
||||||
|
--input-len 32768 \
|
||||||
|
--enable-offload \
|
||||||
|
--use-cuda-graph
|
||||||
|
```
|
||||||
|
|
||||||
|
Run benchmark:
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python bench_offload.py \
|
||||||
|
--input-len 16384 \
|
||||||
|
--bench-all
|
||||||
|
```
|
||||||
|
|
||||||
|
## Files Modified
|
||||||
|
|
||||||
|
| File | Changes |
|
||||||
|
|------|---------|
|
||||||
|
| `model_runner.py:46-50` | Call `capture_offload_cudagraph()` for offload mode |
|
||||||
|
| `model_runner.py:69-73` | Clean up offload graph resources in `exit()` |
|
||||||
|
| `model_runner.py:844-1031` | Add CUDA graph support to `run_layerwise_offload_decode()` |
|
||||||
|
| `model_runner.py:1075-1164` | New `capture_offload_cudagraph()` method |
|
||||||
|
| `tests/test_needle.py` | Add `--use-cuda-graph` flag |
|
||||||
258
docs/estimate_block_size_performance.md
Normal file
258
docs/estimate_block_size_performance.md
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
# Estimate Block Size 性能分析
|
||||||
|
|
||||||
|
本文档记录 XAttention estimate 阶段中 `block_size` 参数对 `softmax_fuse_block_sum` kernel 性能的影响。
|
||||||
|
|
||||||
|
## 问题背景
|
||||||
|
|
||||||
|
当前 `select_blocks` 中的 estimate 过程使用全局的 `kvcache_block_size`(通常为 4096):
|
||||||
|
|
||||||
|
```python
|
||||||
|
# xattn_bsa.py: select_blocks()
|
||||||
|
block_size = ctx.block_size # 来自 kvcache_manager.block_size (4096)
|
||||||
|
reshaped_block_size = block_size // self.stride # 4096/8 = 512
|
||||||
|
|
||||||
|
block_sums = softmax_fuse_block_sum(
|
||||||
|
attn_scores,
|
||||||
|
reshaped_block_size, # 512 - 性能最差点!
|
||||||
|
...
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
这导致 `softmax_fuse_block_sum` kernel 使用 `reshaped_block_size=512`,而这正是性能曲线的最差点。
|
||||||
|
|
||||||
|
## Benchmark 结果
|
||||||
|
|
||||||
|
### 测试配置
|
||||||
|
|
||||||
|
- GPU: NVIDIA A100-SXM4-80GB
|
||||||
|
- NUM_HEADS: 32
|
||||||
|
- HEAD_DIM: 128
|
||||||
|
- STRIDE: 8
|
||||||
|
- 测试脚本: `tests/bench_estimate_block_size.py`
|
||||||
|
|
||||||
|
### softmax_fuse_block_sum 性能数据
|
||||||
|
|
||||||
|
| block_size | reshaped | 16K context | 32K context | 64K context |
|
||||||
|
|------------|----------|-------------|-------------|-------------|
|
||||||
|
| 64 | 8 | 4.86ms | 18.36ms | 70.83ms |
|
||||||
|
| 128 | 16 | 0.83ms | 3.12ms | 16.83ms |
|
||||||
|
| 256 | 32 | 0.63ms | 2.41ms | 11.24ms |
|
||||||
|
| 512 | 64 | **0.38ms** | **1.52ms** | 9.54ms |
|
||||||
|
| 1024 | 128 | 0.42ms | 1.54ms | **6.01ms** |
|
||||||
|
| 2048 | 256 | 1.08ms | 3.24ms | 12.81ms |
|
||||||
|
| **4096** | **512** | 9.66ms | 25.36ms | **95.32ms** |
|
||||||
|
|
||||||
|
### 性能曲线
|
||||||
|
|
||||||
|
```
|
||||||
|
softmax_fuse_block_sum 耗时 (64K context):
|
||||||
|
|
||||||
|
block_size=64 ████████████████████████████████████ 70.83ms
|
||||||
|
block_size=128 ████████ 16.83ms
|
||||||
|
block_size=256 █████ 11.24ms
|
||||||
|
block_size=512 ████ 9.54ms
|
||||||
|
block_size=1024 ███ 6.01ms ◀── 最优点
|
||||||
|
block_size=2048 ██████ 12.81ms
|
||||||
|
block_size=4096 ████████████████████████████████████████████████ 95.32ms ◀── 当前使用
|
||||||
|
```
|
||||||
|
|
||||||
|
### 关键发现
|
||||||
|
|
||||||
|
1. **性能呈 U 型曲线**:太小和太大的 block_size 都会导致性能下降
|
||||||
|
2. **最优点在 512-1024**:对应 `reshaped_block_size` 64-128
|
||||||
|
3. **当前配置 (4096) 是最差点**:95.32ms vs 最优 6.01ms,**慢 15.85x**
|
||||||
|
|
||||||
|
## 性能曲线解释
|
||||||
|
|
||||||
|
```
|
||||||
|
Performance (耗时)
|
||||||
|
│
|
||||||
|
│ ▲ 太小:
|
||||||
|
│ / - output blocks 数量多 (q_len / block_size)
|
||||||
|
│/ - grid 调度开销大
|
||||||
|
│ - 每个 thread block 工作量小
|
||||||
|
│ ┌─────────┐
|
||||||
|
│ / 最优 \
|
||||||
|
│ / 区域 \ ▲ 太大:
|
||||||
|
│/ \ - block_size 作为 tl.constexpr
|
||||||
|
│ \ - 寄存器压力增大 (可能 spill)
|
||||||
|
│ \ - shared memory 不足
|
||||||
|
│ \- L1 cache 效率下降
|
||||||
|
└──────────────────────────────────→ block_size
|
||||||
|
64 128 256 512 1024 2048 4096
|
||||||
|
↑
|
||||||
|
最优点 (512-1024)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Triton Kernel 内部分析
|
||||||
|
|
||||||
|
`softmax_fuse_block_sum_kernel` 中的关键约束:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 每个 thread block 处理的数据
|
||||||
|
offs_q = tl.arange(0, block_size) # block_size 个元素
|
||||||
|
m_i = tl.zeros([block_size], dtype=tl.float32) # 寄存器分配
|
||||||
|
|
||||||
|
# reshape 操作
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
# 当 block_size=512, segment_size=512 时 → (512, 1, 512) 的 3D tensor
|
||||||
|
```
|
||||||
|
|
||||||
|
当 `block_size` 过大时:
|
||||||
|
- 每个 thread block 需要更多寄存器
|
||||||
|
- `tl.arange(0, block_size)` 生成更大的向量
|
||||||
|
- reshape 操作的内存访问模式变差
|
||||||
|
|
||||||
|
## 优化建议
|
||||||
|
|
||||||
|
### 方案 1: 固定 estimate block_size
|
||||||
|
|
||||||
|
在 `select_blocks` 中使用固定的小 block_size 进行估计:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 建议修改
|
||||||
|
ESTIMATE_BLOCK_SIZE = 1024 # 或 512,而非 ctx.block_size
|
||||||
|
|
||||||
|
reshaped_block_size = ESTIMATE_BLOCK_SIZE // self.stride # 128
|
||||||
|
```
|
||||||
|
|
||||||
|
**优点**:简单直接,预期提升 15x
|
||||||
|
**缺点**:estimate 的 block 粒度与 CPU block 不一致,需要映射
|
||||||
|
|
||||||
|
### 方案 2: 两级 block 结构
|
||||||
|
|
||||||
|
- 外层使用 `kvcache_block_size` (4096) 管理 CPU blocks
|
||||||
|
- 内层使用 `estimate_block_size` (1024) 进行估计
|
||||||
|
- 估计结果聚合回 CPU block 粒度
|
||||||
|
|
||||||
|
### 方案 3: 自适应 block_size
|
||||||
|
|
||||||
|
根据 context length 动态选择 estimate block_size:
|
||||||
|
|
||||||
|
| Context Length | Recommended block_size |
|
||||||
|
|----------------|------------------------|
|
||||||
|
| < 16K | 512 |
|
||||||
|
| 16K - 64K | 1024 |
|
||||||
|
| > 64K | 1024 |
|
||||||
|
|
||||||
|
## 与实际 Profiling 的对比
|
||||||
|
|
||||||
|
### Nsys Profiling 数据 (64K context, block_size=4096)
|
||||||
|
|
||||||
|
| 阶段 | 时间占比 | 说明 |
|
||||||
|
|------|----------|------|
|
||||||
|
| softmax_fuse_block_sum | **48.1%** | 最后一个 chunk |
|
||||||
|
| flash_fwd_kernel | 30.7% | 实际 attention 计算 |
|
||||||
|
| flat_group_gemm | 3.5% | estimate GEMM |
|
||||||
|
|
||||||
|
### 预期优化效果
|
||||||
|
|
||||||
|
如果将 estimate block_size 从 4096 改为 1024:
|
||||||
|
|
||||||
|
| 指标 | 当前 (4096) | 优化后 (1024) | 提升 |
|
||||||
|
|------|-------------|---------------|------|
|
||||||
|
| softmax kernel | 95.32ms | 6.01ms | **15.85x** |
|
||||||
|
| estimate 阶段占比 | 48.1% | ~5% | 显著降低 |
|
||||||
|
| 总体 prefill 时间 | ~2s (最后chunk) | ~1.1s | ~1.8x |
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行 benchmark
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/bench_estimate_block_size.py --gpu 0
|
||||||
|
|
||||||
|
# 指定单个 context length
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/bench_estimate_block_size.py --gpu 0 --ctx-len 65536
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
| 文件 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `nanovllm/kvcache/sparse/xattn_bsa.py` | XAttention BSA Policy 实现 |
|
||||||
|
| `nanovllm/ops/xattn.py` | Triton kernels |
|
||||||
|
| `tests/bench_estimate_block_size.py` | 性能测试脚本 |
|
||||||
|
| `docs/xattn_performance_analysis.md` | XAttention 整体性能分析 |
|
||||||
|
|
||||||
|
## 分级求和方案 (Hierarchical Block Sum)
|
||||||
|
|
||||||
|
使用小的 `estimate_block_size=1024` 计算细粒度 block_sums,然后聚合到 CPU block 级别 (4096)。
|
||||||
|
|
||||||
|
### 数学等价性
|
||||||
|
|
||||||
|
```
|
||||||
|
方案1 (block_size=4096): softmax_fuse_block_sum → [1, heads, 1, 1]
|
||||||
|
方案2 (block_size=1024): softmax_fuse_block_sum → [1, heads, 4, 4] → sum → [1, heads]
|
||||||
|
|
||||||
|
验证结果: Max difference = 0.0 ✅ 完全等价
|
||||||
|
```
|
||||||
|
|
||||||
|
### 验证代码
|
||||||
|
|
||||||
|
`tests/test_hierarchical_estimate.py` - 纯 torch + xattn kernels 实现
|
||||||
|
|
||||||
|
### 性能提升
|
||||||
|
|
||||||
|
| 指标 | 当前 (4096) | 优化后 (1024) | 提升 |
|
||||||
|
|------|-------------|---------------|------|
|
||||||
|
| softmax kernel | 12.07 ms | 0.29 ms | **41x** |
|
||||||
|
| 端到端 estimate | 95 ms | ~6 ms | **15x** |
|
||||||
|
|
||||||
|
## ⚠️ 选择策略变更
|
||||||
|
|
||||||
|
**重要**: 分级求和方案使用新的选择策略:
|
||||||
|
|
||||||
|
| 特性 | 原策略 (mask + voting) | 新策略 (score + threshold) |
|
||||||
|
|------|------------------------|----------------------------|
|
||||||
|
| 输入 | `[batch, heads, q_blocks, k_blocks]` | `[batch, heads, num_cpu_blocks]` |
|
||||||
|
| 选择粒度 | Per-q-block | Per-chunk |
|
||||||
|
| 聚合方式 | majority voting | threshold on scores |
|
||||||
|
|
||||||
|
新策略更简洁,直接利用分级求和产生的 score,避免了 mask 生成和 voting 的复杂逻辑。
|
||||||
|
|
||||||
|
## 实现状态 ✅ (2026-01-28)
|
||||||
|
|
||||||
|
### 已实现
|
||||||
|
|
||||||
|
分级求和方案已在 `xattn_bsa.py` 中实现:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class XAttentionBSAPolicy:
|
||||||
|
def __init__(self, ..., estimate_block_size: int = 1024):
|
||||||
|
self.estimate_block_size = estimate_block_size # 新参数
|
||||||
|
|
||||||
|
def select_blocks(self, ...):
|
||||||
|
# Step 2: Hierarchical softmax_fuse_block_sum
|
||||||
|
reshaped_est_bs = estimate_bs // self.stride # 1024/8 = 128
|
||||||
|
block_sums_fine = softmax_fuse_block_sum(attn_scores, reshaped_est_bs, ...)
|
||||||
|
|
||||||
|
# Step 3: Aggregate to CPU block level
|
||||||
|
block_sums_coarse = block_sums_fine.view(..., num_cpu_blocks, ratio).sum(dim=-1)
|
||||||
|
cpu_block_scores = block_sums_coarse.sum(dim=2)
|
||||||
|
|
||||||
|
# Step 4: Score + threshold selection (replaces mask + voting)
|
||||||
|
scores_per_block = cpu_block_scores.mean(dim=(0, 1))
|
||||||
|
# ... cumulative threshold selection
|
||||||
|
```
|
||||||
|
|
||||||
|
### 实测结果 (Nsys Profiling)
|
||||||
|
|
||||||
|
| Kernel | 优化前 | 优化后 | 改进 |
|
||||||
|
|--------|--------|--------|------|
|
||||||
|
| softmax_fuse_block_sum 占比 | 48.1% | **1.1%** | **44x** |
|
||||||
|
| softmax_fuse_block_sum 平均时间 | ~2ms | 489us | **4x** |
|
||||||
|
|
||||||
|
### 端到端性能 (32K context)
|
||||||
|
|
||||||
|
| 指标 | FULL Policy | XATTN Policy | 改进 |
|
||||||
|
|------|-------------|--------------|------|
|
||||||
|
| Prefill throughput | 3511 tok/s | 3695 tok/s | +5% |
|
||||||
|
| TTFT | 9327 ms | 8863 ms | -5% |
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
当前 estimate 阶段使用全局 `kvcache_block_size=4096` 导致 `softmax_fuse_block_sum` kernel 性能处于最差点。通过将 estimate block_size 改为 512-1024,可以获得 **15x** 的性能提升,显著降低 estimate 阶段的开销。
|
||||||
|
|
||||||
|
**⚠️ 重要变更**: 选择策略从 `mask + majority voting` 改为 `score + threshold`,更简洁且更直接。
|
||||||
77
docs/gpu_only_sparse_integration.md
Normal file
77
docs/gpu_only_sparse_integration.md
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# GPU-only Sparse Policy 整合
|
||||||
|
|
||||||
|
本文档记录将 sparse attention 策略整合到 GPU-only 模式的过程和性能对比。
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
当前 sparse policy(Quest、XAttention)仅在 CPU offload 路径中实现。目标是将其扩展到 GPU-only 模式,以提升长上下文场景下的性能。
|
||||||
|
|
||||||
|
## 基准性能(优化前)
|
||||||
|
|
||||||
|
**测试环境**:
|
||||||
|
- GPU: NVIDIA A100-SXM4-80GB
|
||||||
|
- 模型: Llama-3.1-8B-Instruct
|
||||||
|
- 上下文长度: 32K tokens
|
||||||
|
- 日期: 2026-01-27
|
||||||
|
|
||||||
|
### Prefill Benchmark (32K context)
|
||||||
|
|
||||||
|
| 模式 | Throughput | Time | KV Cache 分配 |
|
||||||
|
|------|------------|------|---------------|
|
||||||
|
| **GPU-only (Full Attention)** | 4869.67 tok/s | 6.73s | 438 blocks (56GB GPU) |
|
||||||
|
| CPU Offload (Full Attention) | 1500.29 tok/s | 21.84s | 4 blocks GPU + 32 blocks CPU |
|
||||||
|
|
||||||
|
**性能比**: GPU-only 比 CPU Offload 快 **3.2x**
|
||||||
|
|
||||||
|
### 配置详情
|
||||||
|
|
||||||
|
**GPU-only 模式**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--max-len 32768
|
||||||
|
```
|
||||||
|
|
||||||
|
**CPU Offload 模式**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench_offload.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--max-len 32768
|
||||||
|
```
|
||||||
|
|
||||||
|
### KV Cache 配置
|
||||||
|
|
||||||
|
| 参数 | GPU-only | CPU Offload |
|
||||||
|
|------|----------|-------------|
|
||||||
|
| block_size | 1024 tokens | 1024 tokens |
|
||||||
|
| per-token KV | 128 KB | 128 KB |
|
||||||
|
| per-block KV | 128 MB | 128 MB |
|
||||||
|
| GPU blocks | 438 | 4 |
|
||||||
|
| CPU blocks | 0 | 32 |
|
||||||
|
| Total memory | 56 GB | 4.6 GB |
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
|
||||||
|
将以下 sparse policy 整合到 GPU-only 模式:
|
||||||
|
|
||||||
|
| Policy | 阶段 | 描述 |
|
||||||
|
|--------|------|------|
|
||||||
|
| Quest | Decode | Top-K block selection based on query-key scores |
|
||||||
|
| XAttention BSA | Prefill | Block sparse attention with cumulative threshold |
|
||||||
|
|
||||||
|
## 实现进度
|
||||||
|
|
||||||
|
- [ ] 分析现有 sparse policy 代码结构
|
||||||
|
- [ ] 设计 GPU-only sparse policy 接口
|
||||||
|
- [ ] 实现 GPU-only Quest decode
|
||||||
|
- [ ] 实现 GPU-only XAttention prefill
|
||||||
|
- [ ] 性能测试和对比
|
||||||
|
|
||||||
|
## 优化后性能
|
||||||
|
|
||||||
|
*待测试*
|
||||||
|
|
||||||
|
| 模式 | Throughput | Speedup vs Full |
|
||||||
|
|------|------------|-----------------|
|
||||||
|
| GPU-only + Quest (decode) | TBD | TBD |
|
||||||
|
| GPU-only + XAttn (prefill) | TBD | TBD |
|
||||||
296
docs/gpu_only_xattn_guide.md
Normal file
296
docs/gpu_only_xattn_guide.md
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
# GPU-Only XAttention 指南
|
||||||
|
|
||||||
|
本文档介绍 GPU-only 模式下 XAttention BSA 的实现、内存优化和性能分析。
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
GPU-only 模式下,所有 KV cache 存储在 GPU 上,无需 CPU offload。XAttention 通过稀疏注意力加速 prefill 阶段。
|
||||||
|
|
||||||
|
### 执行路径对比
|
||||||
|
|
||||||
|
| 模式 | Prefill 方法 | Decode 方法 | KV 存储 |
|
||||||
|
|------|-------------|-------------|---------|
|
||||||
|
| GPU-only Full | `compute_prefill()` | `compute_decode()` | GPU |
|
||||||
|
| GPU-only XAttn | `compute_prefill()` | `compute_decode()` | GPU |
|
||||||
|
| CPU Offload | `compute_chunked_prefill()` | `compute_chunked_decode()` | CPU + GPU |
|
||||||
|
|
||||||
|
## 架构设计
|
||||||
|
|
||||||
|
### SparsePolicy 接口
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy:
|
||||||
|
# GPU-only 方法
|
||||||
|
def compute_prefill(self, q, k, v, ...) -> Tensor
|
||||||
|
def compute_decode(self, q, k_cache, v_cache, ...) -> Tensor
|
||||||
|
|
||||||
|
# CPU Offload 方法
|
||||||
|
def compute_chunked_prefill(self, q, k, v, ...) -> Tensor
|
||||||
|
def compute_chunked_decode(self, q, ...) -> Tensor
|
||||||
|
|
||||||
|
# 初始化方法
|
||||||
|
def initialize(self, num_layers, ...) -> None # CPU offload metadata
|
||||||
|
def alloc_policy_metadata(self, num_heads, ...) -> None # GPU-only buffers
|
||||||
|
```
|
||||||
|
|
||||||
|
### XAttentionBSAPolicy 实现
|
||||||
|
|
||||||
|
```
|
||||||
|
GPU-only Prefill 流程:
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ 1. GQA 扩展 (使用预分配 buffer) │
|
||||||
|
│ K: [seq, kv_heads, dim] → K_exp: [1, heads, seq, dim] │
|
||||||
|
│ │
|
||||||
|
│ 2. XAttention 估计 │
|
||||||
|
│ flat_group_gemm_fuse_reshape_kernel (Q@K^T) │
|
||||||
|
│ softmax_fuse_block_sum_kernel (block 重要性) │
|
||||||
|
│ → sparse mask │
|
||||||
|
│ │
|
||||||
|
│ 3. BSA 稀疏注意力 │
|
||||||
|
│ flash_fwd_block_kernel (只计算选中的 blocks) │
|
||||||
|
│ → output │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## 内存预分配
|
||||||
|
|
||||||
|
### 问题背景
|
||||||
|
|
||||||
|
XAttention 的 `compute_prefill()` 需要 GQA 扩展:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 之前: 动态分配 (~2GB for 64K)
|
||||||
|
K_exp = K.repeat_interleave(num_groups, dim=1) # 分配 1
|
||||||
|
k_bsa = k.repeat_interleave(num_groups, dim=1) # 分配 2 (重复!)
|
||||||
|
```
|
||||||
|
|
||||||
|
每次 prefill 都动态分配,导致:
|
||||||
|
- 内存碎片
|
||||||
|
- 分配延迟
|
||||||
|
- 可能 OOM
|
||||||
|
|
||||||
|
### 解决方案: alloc_policy_metadata()
|
||||||
|
|
||||||
|
在框架初始化时预分配 buffer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class XAttentionBSAPolicy(SparsePolicy):
|
||||||
|
def alloc_policy_metadata(self, num_heads, num_kv_heads, head_dim,
|
||||||
|
max_seq_len, dtype, device):
|
||||||
|
# 预分配 GQA 扩展 buffer
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def compute_prefill(self, q, k, v, ...):
|
||||||
|
seq_len = k.shape[0]
|
||||||
|
# 使用预分配 buffer 的 slice
|
||||||
|
K_exp = self._k_expanded[:, :, :seq_len, :]
|
||||||
|
# 原地 GQA 扩展
|
||||||
|
K_exp.view(...).copy_(K.unsqueeze(2).expand(...))
|
||||||
|
# 复用同一 buffer 给 BSA
|
||||||
|
k_bsa = K_exp.squeeze(0).transpose(0, 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 内存使用
|
||||||
|
|
||||||
|
| 序列长度 | 预分配大小 | 说明 |
|
||||||
|
|---------|-----------|------|
|
||||||
|
| 32K | 512 MB | `2 * 32 * 32768 * 128 * 2 bytes` |
|
||||||
|
| 64K | 1024 MB | `2 * 32 * 65536 * 128 * 2 bytes` |
|
||||||
|
|
||||||
|
优化效果:
|
||||||
|
- 之前: ~2GB 动态分配 (xattn_estimate + BSA 各一次)
|
||||||
|
- 之后: ~1GB 预分配 (复用同一 buffer)
|
||||||
|
|
||||||
|
### 框架集成
|
||||||
|
|
||||||
|
```python
|
||||||
|
# model_runner.py - allocate_kv_cache()
|
||||||
|
def allocate_kv_cache(self):
|
||||||
|
# ... KV cache 分配 ...
|
||||||
|
|
||||||
|
# GPU-only 模式: 预分配 policy buffers
|
||||||
|
if not config.enable_cpu_offload:
|
||||||
|
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_seq_len=config.max_model_len,
|
||||||
|
dtype=dtype,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能分析
|
||||||
|
|
||||||
|
### 32K Prefill 性能
|
||||||
|
|
||||||
|
| Policy | Throughput | 相对提升 |
|
||||||
|
|--------|------------|----------|
|
||||||
|
| Baseline | 4880 tok/s | - |
|
||||||
|
| Full | 4892 tok/s | +0.2% |
|
||||||
|
| **XAttention** | **5602 tok/s** | **+15%** |
|
||||||
|
|
||||||
|
### 64K Prefill 性能
|
||||||
|
|
||||||
|
| Policy | Throughput | 相对提升 |
|
||||||
|
|--------|------------|----------|
|
||||||
|
| Baseline | 3386 tok/s | - |
|
||||||
|
| Full | 3355 tok/s | -0.9% |
|
||||||
|
| **XAttention** | **4775 tok/s** | **+41%** |
|
||||||
|
|
||||||
|
### Kernel 时间分解 (32K)
|
||||||
|
|
||||||
|
**XAttention:**
|
||||||
|
```
|
||||||
|
FFN GEMM: 3219 ms (54%)
|
||||||
|
BSA Attention: 1231 ms (21%)
|
||||||
|
XAttn Estimation: 415 ms (7%)
|
||||||
|
Other: 1020 ms (18%)
|
||||||
|
─────────────────────────────
|
||||||
|
Total: 5885 ms
|
||||||
|
```
|
||||||
|
|
||||||
|
**Full:**
|
||||||
|
```
|
||||||
|
FFN GEMM: 3244 ms (48%)
|
||||||
|
Dense Attention: 2861 ms (43%)
|
||||||
|
Other: 595 ms (9%)
|
||||||
|
─────────────────────────────
|
||||||
|
Total: 6700 ms
|
||||||
|
```
|
||||||
|
|
||||||
|
### 加速来源
|
||||||
|
|
||||||
|
```
|
||||||
|
Dense Attention: 2861 ms
|
||||||
|
BSA Attention: 1231 ms (节省 1630 ms, -57%)
|
||||||
|
XAttn Estimation: 415 ms (额外开销)
|
||||||
|
─────────────────────────────
|
||||||
|
净节省: 1215 ms (42% attention 时间)
|
||||||
|
```
|
||||||
|
|
||||||
|
## CUDA Graph 限制
|
||||||
|
|
||||||
|
### 为什么 Prefill 不能用 CUDA Graph
|
||||||
|
|
||||||
|
CUDA Graph 要求所有操作在 capture 时确定:
|
||||||
|
|
||||||
|
| 必须固定 | Prefill 的情况 |
|
||||||
|
|---------|---------------|
|
||||||
|
| Tensor 形状 | seq_len 可变 (1 ~ max_model_len) |
|
||||||
|
| Kernel grid | 依赖 seq_len |
|
||||||
|
| 内存地址 | 中间 tensor 大小变化 |
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 不同请求的 seq_len 不同
|
||||||
|
request_1: prefill(seq_len=1024) # grid=(8, 32, 1)
|
||||||
|
request_2: prefill(seq_len=32768) # grid=(256, 32, 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Decode 可以用 CUDA Graph
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Decode 每次只处理 1 token
|
||||||
|
q: [batch_size, 1, heads, dim] # 形状固定
|
||||||
|
```
|
||||||
|
|
||||||
|
nanovllm 为每个 batch_size 预先 capture 一个 graph:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def capture_cudagraph(self):
|
||||||
|
for batch_size in [1, 2, 4, 8, ...]:
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
self.run_model(dummy_input, is_prefill=False)
|
||||||
|
self.graphs[batch_size] = g
|
||||||
|
```
|
||||||
|
|
||||||
|
### Nsys Profile 结果
|
||||||
|
|
||||||
|
```
|
||||||
|
XAttention 32K Prefill:
|
||||||
|
Total kernels: 41,904
|
||||||
|
Non-graph: 41,904 (100%)
|
||||||
|
Graph: 0
|
||||||
|
|
||||||
|
Full 32K Prefill:
|
||||||
|
Total kernels: 35,308
|
||||||
|
Non-graph: 35,308 (100%)
|
||||||
|
Graph: 0
|
||||||
|
```
|
||||||
|
|
||||||
|
**两者都是 100% NON-GRAPH**,这是 prefill 的本质特性。
|
||||||
|
|
||||||
|
## Profiling 工具
|
||||||
|
|
||||||
|
### 使用 profile.sh
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# XAttention 32K
|
||||||
|
bash scripts/profile.sh --max-len 32768 --policy xattn
|
||||||
|
|
||||||
|
# Full 32K
|
||||||
|
bash scripts/profile.sh --max-len 32768 --policy full
|
||||||
|
|
||||||
|
# 64K (需要降低 gpu-util)
|
||||||
|
bash scripts/profile.sh --max-len 65536 --policy xattn --gpu-util 0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
### 分析 nsys 结果
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看 kernel 统计
|
||||||
|
nsys stats --report cuda_gpu_kern_sum results/nsys/<file>.nsys-rep
|
||||||
|
|
||||||
|
# 用 sqlite 查询详细数据
|
||||||
|
sqlite3 results/nsys/<file>.sqlite "
|
||||||
|
SELECT
|
||||||
|
(SELECT value FROM StringIds WHERE id = shortName) as kernel,
|
||||||
|
COUNT(*) as count,
|
||||||
|
SUM(end-start)/1e6 as total_ms
|
||||||
|
FROM CUPTI_ACTIVITY_KIND_KERNEL
|
||||||
|
GROUP BY shortName
|
||||||
|
ORDER BY total_ms DESC
|
||||||
|
LIMIT 10
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用指南
|
||||||
|
|
||||||
|
### 启用 XAttention GPU-only
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm import LLM
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model_path,
|
||||||
|
max_model_len=32768,
|
||||||
|
sparse_policy=SparsePolicyType.XATTN_BSA,
|
||||||
|
gpu_memory_utilization=0.9, # 64K 时可能需要降低
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 命令行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# bench.py
|
||||||
|
python bench.py --max-len 32768 --policy xattn
|
||||||
|
|
||||||
|
# 64K 需要降低 gpu-util
|
||||||
|
python bench.py --max-len 65536 --policy xattn --gpu-util 0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
### 最佳实践
|
||||||
|
|
||||||
|
1. **32K 及以下**: 使用默认 `gpu_memory_utilization=0.9`
|
||||||
|
2. **64K**: 降低到 `gpu_memory_utilization=0.7`
|
||||||
|
3. **Decode**: XAttention 自动 fallback 到 FullAttentionPolicy
|
||||||
|
4. **Paged KV Cache**: 当 `block_tables` 存在时自动 fallback 到 flash_attn
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [Sparse Policy 架构](sparse_policy_architecture.md)
|
||||||
|
- [XAttention 算法详解](xattention_algorithm_guide.md)
|
||||||
|
- [BSA 接口文档](block_sparse_attn_interface.md)
|
||||||
246
docs/gpuonly_density_alignment_test.md
Normal file
246
docs/gpuonly_density_alignment_test.md
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
# Density Alignment Test Results
|
||||||
|
|
||||||
|
验证 GPU-only 和 Offload 模式下三阶段 KV chunking 流程的正确性。
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
### GPU-only 模式
|
||||||
|
- **模型**: Qwen3-0.6B (28 layers, 16 heads, 8 KV heads, head_dim=128)
|
||||||
|
- **Threshold**: 0.9
|
||||||
|
- **Block Size**: 128 tokens (BSA block)
|
||||||
|
- **Stride**: 8
|
||||||
|
- **Chunk Size**: 16384 tokens
|
||||||
|
|
||||||
|
### Offload 模式
|
||||||
|
- **模型**: Llama-3.1-8B-Instruct (32 layers, 32 heads, 8 KV heads, head_dim=128)
|
||||||
|
- **Threshold**: 0.9
|
||||||
|
- **Block Size**: 128 tokens (BSA block)
|
||||||
|
- **Stride**: 4
|
||||||
|
- **Chunk Size**: 4096 tokens
|
||||||
|
|
||||||
|
## 三阶段 KV Chunking 对齐测试 (2026-02-02)
|
||||||
|
|
||||||
|
### 测试目的
|
||||||
|
|
||||||
|
验证 `xattn_estimate` 高层 API 与手动实现的三阶段 KV chunking 流程是否完全一致。
|
||||||
|
|
||||||
|
### 三阶段流程
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Stage 1: softmax_compute_partial_stats │
|
||||||
|
│ └── 每个 KV chunk 独立计算 partial stats (m_i, l_i) │
|
||||||
|
│ │
|
||||||
|
│ Stage 2: merge_softmax_stats │
|
||||||
|
│ └── Host 端合并所有 chunks: (m_global, l_global) │
|
||||||
|
│ │
|
||||||
|
│ Stage 3: softmax_normalize_and_block_sum │
|
||||||
|
│ └── 使用全局 stats 归一化并计算 block sums │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试结果
|
||||||
|
|
||||||
|
#### CHUNK_SIZE = 16384 (默认)
|
||||||
|
|
||||||
|
| Context | Tokens | Q Chunks | KV Chunks | Density | Mask 差异 | attn_sums 差异 | 结果 |
|
||||||
|
|---------|--------|----------|-----------|---------|-----------|----------------|------|
|
||||||
|
| 4K | 3,692 | 1 | 1 | 63.84% | 0 | 0.0 | ✅ |
|
||||||
|
| 8K | 7,892 | 1 | 1 | 64.98% | 0 | 0.0 | ✅ |
|
||||||
|
| 16K | 15,689 | 1 | 1 | 61.63% | 0 | 0.0 | ✅ |
|
||||||
|
| 32K | 32,485 | 2 | 2 | 50.21% | 0 | 0.0 | ✅ |
|
||||||
|
| **64K** | **64,891** | **4** | **4** | **37.00%** | **0** | **0.0** | ✅ |
|
||||||
|
|
||||||
|
#### CHUNK_SIZE = 4096 (更多 chunks)
|
||||||
|
|
||||||
|
| Context | Tokens | Q Chunks | KV Chunks | Density | xattn_estimate vs KV chunking | 结果 |
|
||||||
|
|---------|--------|----------|-----------|---------|-------------------------------|------|
|
||||||
|
| 4K | 3,692 | 1 | 1 | 63.84% | 0.000000 | ✅ |
|
||||||
|
| 8K | 7,892 | 2 | 2 | 63.02% | 0.000000 | ✅ |
|
||||||
|
| 16K | 15,689 | 4 | 4 | 60.08% | 0.000000 | ✅ |
|
||||||
|
| 32K | 32,485 | 8 | 8 | 49.84% | 0.000000 | ✅ |
|
||||||
|
| **64K** | **64,891** | **16** | **16** | **36.91%** | **0.000000** | ✅ |
|
||||||
|
|
||||||
|
### 64K 详细验证 (CHUNK_SIZE=4096)
|
||||||
|
|
||||||
|
64K 序列使用 chunk_size=4096 时产生 16×16 的 chunk 矩阵:
|
||||||
|
|
||||||
|
```
|
||||||
|
seq_len: 64891, q_chunk_num: 16, kv_chunk_num: 16
|
||||||
|
|
||||||
|
Q chunk 0: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
|
||||||
|
Q chunk 1: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
|
||||||
|
...
|
||||||
|
Q chunk 15: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
|
||||||
|
```
|
||||||
|
|
||||||
|
每个 Q chunk 需要合并 16 个 KV chunks 的 softmax stats,充分验证了 `merge_softmax_stats` 在大规模 chunk 合并场景下的正确性。
|
||||||
|
|
||||||
|
### 验证指标
|
||||||
|
|
||||||
|
| 指标 | 预期 | 所有长度实际结果 |
|
||||||
|
|------|------|------------------|
|
||||||
|
| attn_sums max diff | 0 | 0.000000e+00 |
|
||||||
|
| attn_sums mean diff | 0 | 0.000000e+00 |
|
||||||
|
| mask exact match | True | True |
|
||||||
|
| density diff | 0% | 0.000000% |
|
||||||
|
|
||||||
|
### 结论
|
||||||
|
|
||||||
|
✅ **三阶段 KV chunking 与一次性处理完全等价,无任何精度损失。**
|
||||||
|
|
||||||
|
- 当 seq_len < CHUNK_SIZE (16384):单 chunk 处理
|
||||||
|
- 当 seq_len >= CHUNK_SIZE:多 chunk 分段处理后合并,结果与一次性处理完全一致
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Offload 模式测试 (2026-02-02)
|
||||||
|
|
||||||
|
使用 Offload 模式保存的真实 KV cache 数据进行测试。
|
||||||
|
|
||||||
|
### 测试结果
|
||||||
|
|
||||||
|
| 文件 | Tokens | Layer | Saved Density | Computed Density | Q/KV Chunks | 结果 |
|
||||||
|
|------|--------|-------|---------------|------------------|-------------|------|
|
||||||
|
| `qkv_3688.pt` | 3.7K | 3 | 38.34% | 38.34% | 1/1 | ✅ PASSED |
|
||||||
|
| `qkv_7888.pt` | 7.9K | 3 | 29.06% | 27.56% | 2/2 | ✅ PASSED |
|
||||||
|
| `qkv_15685.pt` | 15.7K | 3 | 19.77% | 18.60% | 4/4 | ✅ PASSED |
|
||||||
|
| `qkv_32485.pt` | 32.5K | 5 | 15.71% | 15.62% | 8/8 | ✅ PASSED |
|
||||||
|
| `qkv_64891.pt` | 64.9K | 3 | 11.09% | 11.09% | 16/16 | ✅ PASSED |
|
||||||
|
|
||||||
|
### Layer 5 GPU-only 测试 (threshold=0.9)
|
||||||
|
|
||||||
|
| 指标 | 结果 |
|
||||||
|
|------|------|
|
||||||
|
| Q/K shape | `[1, 16, 21001, 128]` (21K tokens) |
|
||||||
|
| Density | 6.24% |
|
||||||
|
| xattn_estimate vs KV chunking | 完全一致 (0.0000%) |
|
||||||
|
| mask 差异 | 0 / 435600 blocks |
|
||||||
|
| attn_sums 差异 | max=0.0, mean=0.0 |
|
||||||
|
|
||||||
|
### 观察
|
||||||
|
|
||||||
|
1. **Density 随 context 增长而降低**: 3.7K (38%) → 64.9K (11%)
|
||||||
|
2. **xattn_estimate API 与三阶段 KV chunking 完全一致**: 所有长度差异均为 0.0000%
|
||||||
|
3. **Saved density vs Computed density 略有差异**: 这是因为 saved density 可能在不同 chunk 下记录,累积计算方式略有不同
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录:xattn_bsa vs xattn_estimate 对齐
|
||||||
|
|
||||||
|
| Context | Tokens | Layer 0 Density | Compute Density | Min Layer | 验证结果 |
|
||||||
|
|---------|--------|-----------------|-----------------|-----------|----------|
|
||||||
|
| 4k | 3,692 | 63.8% | 52.9% | Layer 3 (31.3%) | ✅ PASSED |
|
||||||
|
| 8k | 7,892 | 65.0% | 52.5% | Layer 5 (27.3%) | ✅ PASSED |
|
||||||
|
| 16k | 15,689 | 61.6% | 47.8% | Layer 5 (23.5%) | ✅ PASSED |
|
||||||
|
| 32k | 32,485 | 50.2% | 40.1% | Layer 5 (18.5%) | ✅ PASSED |
|
||||||
|
| 64k | 64,891 | 37.0% | 29.6% | Layer 5 (12.4%) | ✅ PASSED |
|
||||||
|
|
||||||
|
## Density 计算公式
|
||||||
|
|
||||||
|
### Total (分母)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Causal mask: Q block i 只能看到 K block 0 到 i
|
||||||
|
causal_mask[i, j] = (j <= i + q_offset_blocks)
|
||||||
|
|
||||||
|
# Total = causal 区域内的 block 数 × batch × heads
|
||||||
|
total = causal_mask.sum() × batch × heads
|
||||||
|
= (n × (n+1) / 2) × 1 × 32 # n = valid_q_blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
### Selected (分子)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 causal 区域内,被选中 (mask=True) 的 block 数量
|
||||||
|
selected = (mask & causal_mask).sum()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Density
|
||||||
|
|
||||||
|
```python
|
||||||
|
density = selected / total
|
||||||
|
```
|
||||||
|
|
||||||
|
## 观察
|
||||||
|
|
||||||
|
1. **Density 随 context 增长而降低**: 4k (63.8%) → 64k (37.0%),这是因为长序列中 attention 更加分散
|
||||||
|
|
||||||
|
2. **Layer 5 通常是最稀疏的层**: 在所有长度测试中,Layer 5 的 density 最低
|
||||||
|
|
||||||
|
3. **Layer 0 density 最高**: 第一层的 attention pattern 最密集,可能与 sink token 效应有关
|
||||||
|
|
||||||
|
4. **Threshold=0.9 对应 ~50% density**: 在 32k context 下,threshold=0.9 意味着选择覆盖 90% attention 的 blocks,实际 density 约 50%
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### Step 1: 启用 debug 保存
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/xattn_bsa.py
|
||||||
|
_DEBUG_SAVE_MASK = True # 改为 True
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: 运行 GPU-only 推理
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: 运行 KV chunking 对齐验证
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 使用 GPU-only 保存的数据
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_alignment.py --gpuonly
|
||||||
|
|
||||||
|
# 使用 Offload 模式保存的数据 (默认)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_alignment.py
|
||||||
|
|
||||||
|
# 指定自定义数据文件
|
||||||
|
python tests/test_xattn_estimate_alignment.py --data-file /path/to/data.pt
|
||||||
|
|
||||||
|
# 批量测试所有 Offload 数据
|
||||||
|
for f in results/kvcache/qkv_*.pt; do
|
||||||
|
echo "Testing: $(basename $f)"
|
||||||
|
python tests/test_xattn_estimate_alignment.py --data-file "$f"
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
### 批量测试所有长度
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for ctx in 4k 8k 16k 32k 64k; do
|
||||||
|
case $ctx in
|
||||||
|
4k) max_len=5000 ;;
|
||||||
|
8k) max_len=9000 ;;
|
||||||
|
16k) max_len=17000 ;;
|
||||||
|
32k) max_len=34000 ;;
|
||||||
|
64k) max_len=65664 ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo "Testing $ctx..."
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--data-dir tests/data/ruler_$ctx \
|
||||||
|
--max-model-len $max_len \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--num-samples 1 --quiet
|
||||||
|
|
||||||
|
python tests/test_xattn_estimate_alignment.py --gpuonly
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy 实现
|
||||||
|
- `nanovllm/ops/xattn.py`: xattn_estimate 函数及三阶段 KV chunking kernels
|
||||||
|
- `tests/test_xattn_estimate_alignment.py`: KV chunking 对齐验证脚本
|
||||||
209
docs/issue_xattn_offload_gqa_buffer_oom.md
Normal file
209
docs/issue_xattn_offload_gqa_buffer_oom.md
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
# Issue: XAttention Offload Mode GQA Buffer OOM
|
||||||
|
|
||||||
|
## 问题描述
|
||||||
|
|
||||||
|
在使用 XAttention BSA (Block Sparse Attention) + CPU Offload 模式运行 GLM-4-9B 等大模型时,出现 CUDA OOM 错误。
|
||||||
|
|
||||||
|
### 错误信息
|
||||||
|
|
||||||
|
```
|
||||||
|
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB.
|
||||||
|
GPU 0 has a total capacity of 23.57 GiB of which 4.19 GiB is free.
|
||||||
|
```
|
||||||
|
|
||||||
|
### 复现环境
|
||||||
|
|
||||||
|
| 项目 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| 模型 | GLM-4-9B-Chat-1M |
|
||||||
|
| GPU | RTX 3090 (24GB) |
|
||||||
|
| Context Length | 32K |
|
||||||
|
| sparse_policy | XATTN_BSA |
|
||||||
|
| enable_cpu_offload | true |
|
||||||
|
| max_model_len | 1048576 (1M) |
|
||||||
|
|
||||||
|
### 错误位置
|
||||||
|
|
||||||
|
```
|
||||||
|
File "nanovllm/kvcache/sparse/xattn_bsa.py", line 246, in alloc_policy_metadata
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 问题分析
|
||||||
|
|
||||||
|
### 内存分配分析
|
||||||
|
|
||||||
|
`alloc_policy_metadata()` 在 KV cache 初始化时分配以下 buffer:
|
||||||
|
|
||||||
|
| Buffer | 用途 | 大小 (GLM-4, 1M seq) |
|
||||||
|
|--------|------|----------------------|
|
||||||
|
| `_prefill_mask_buffer` | BSA mask | ~32 MB |
|
||||||
|
| `_m_partial_buffer` | KV chunking m stats | ~32 MB |
|
||||||
|
| `_l_partial_buffer` | KV chunking l stats | ~32 MB |
|
||||||
|
| `_block_sums_buffer` | Block sums | ~64 MB |
|
||||||
|
| **`_k_expanded`** | GQA K 扩展 | **~8 GB** |
|
||||||
|
| **`_v_expanded`** | GQA V 扩展 | **~8 GB** |
|
||||||
|
|
||||||
|
### GQA Buffer 计算
|
||||||
|
|
||||||
|
```python
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
= (1, 32, 1048576, 128)
|
||||||
|
|
||||||
|
size = 1 × 32 × 1048576 × 128 × 2 bytes (fp16)
|
||||||
|
= 8,589,934,592 bytes
|
||||||
|
= 8 GB per buffer
|
||||||
|
```
|
||||||
|
|
||||||
|
### 根本原因
|
||||||
|
|
||||||
|
1. **设计意图冲突**:`_k_expanded` 和 `_v_expanded` 的文档注释明确说是 "for GPU-only mode"
|
||||||
|
2. **条件检查不完整**:代码只检查了 `num_heads == num_kv_heads` 来跳过分配,没有检查 offload 模式
|
||||||
|
3. **Offload 模式不需要这些 buffer**:`compute_chunked_prefill()` 使用不同的计算路径,不依赖预分配的 GQA buffer
|
||||||
|
|
||||||
|
### 相关代码
|
||||||
|
|
||||||
|
```python
|
||||||
|
# xattn_bsa.py:238-247
|
||||||
|
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
|
||||||
|
if num_heads == num_kv_heads:
|
||||||
|
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
||||||
|
return # <-- 只检查了 GQA,没检查 offload 模式
|
||||||
|
|
||||||
|
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device) # <-- OOM here
|
||||||
|
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 解决思路
|
||||||
|
|
||||||
|
### 方案 1: 在 Offload 模式下跳过 GQA Buffer 分配 (推荐)
|
||||||
|
|
||||||
|
在 `alloc_policy_metadata()` 中添加 offload 模式检查:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def alloc_policy_metadata(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
enable_cpu_offload: bool = False, # <-- 新增参数
|
||||||
|
) -> None:
|
||||||
|
# ... 分配 mask buffer 和 KV chunking buffers (offload 模式需要)
|
||||||
|
|
||||||
|
# Skip GQA buffers in offload mode
|
||||||
|
# Chunked prefill uses compute_chunked_prefill() which doesn't need these
|
||||||
|
if enable_cpu_offload:
|
||||||
|
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers")
|
||||||
|
return
|
||||||
|
|
||||||
|
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
|
||||||
|
if num_heads == num_kv_heads:
|
||||||
|
logger.info(f"[XAttn] No GQA expansion needed")
|
||||||
|
return
|
||||||
|
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
```
|
||||||
|
|
||||||
|
**需要修改的文件**:
|
||||||
|
1. `nanovllm/kvcache/sparse/xattn_bsa.py` - `alloc_policy_metadata()` 方法
|
||||||
|
2. `nanovllm/engine/model_runner.py` - 调用 `alloc_policy_metadata()` 时传入 `enable_cpu_offload`
|
||||||
|
|
||||||
|
### 方案 2: 延迟分配 (Lazy Allocation)
|
||||||
|
|
||||||
|
只在 `compute_prefill()` 首次调用时分配 GQA buffer,offload 模式走 `compute_chunked_prefill()` 不会触发分配。
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compute_prefill(self, ...):
|
||||||
|
# Lazy allocation on first use
|
||||||
|
if self._k_expanded is None and num_heads != num_kv_heads:
|
||||||
|
self._allocate_gqa_buffers(...)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方案 3: 基于 chunk_size 限制 buffer 大小
|
||||||
|
|
||||||
|
不预分配 max_seq_len 大小,而是只分配 chunk_size 大小:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 原来: max_seq_len (1M tokens) -> 8 GB
|
||||||
|
# 修改后: chunk_size (16K tokens) -> ~130 MB
|
||||||
|
buffer_len = self.chunk_size if enable_cpu_offload else max_seq_len
|
||||||
|
shape = (1, num_heads, buffer_len, head_dim)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 验证方法
|
||||||
|
|
||||||
|
修复后运行以下命令验证:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /home/zijie/Code/COMPASS
|
||||||
|
GPULIST=0 ./scripts/run_ruler.sh glm4-9b-xattn-nanovllm synthetic xattn --task niah_single_1
|
||||||
|
```
|
||||||
|
|
||||||
|
预期结果:
|
||||||
|
- 不再出现 8GB allocation 的 OOM 错误
|
||||||
|
- 模型正常加载并完成推理
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- `docs/xattn_bsa_policy_design.md` - XAttention BSA Policy 设计文档
|
||||||
|
- `docs/gpu_only_xattn_guide.md` - GPU-Only XAttention 指南
|
||||||
|
|
||||||
|
## 优先级
|
||||||
|
|
||||||
|
**High** - 阻塞 9B+ 模型在 24GB 显存 GPU 上使用 XAttention + Offload 模式
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 修复状态
|
||||||
|
|
||||||
|
**✅ 已修复** (2026-02-05)
|
||||||
|
|
||||||
|
### 修复内容
|
||||||
|
|
||||||
|
采用方案 1,在 offload 模式下跳过 GQA buffer 分配:
|
||||||
|
|
||||||
|
1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数
|
||||||
|
2. `nanovllm/kvcache/sparse/xattn_bsa.py`: 实现 offload 模式检查,跳过 GQA buffer
|
||||||
|
3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数
|
||||||
|
|
||||||
|
### 验证结果
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 64K offload 测试
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
- ✅ 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers`
|
||||||
|
- ✅ 测试通过: 100% 准确率
|
||||||
|
- ✅ 内存节省: ~16 GB (for 1M max_seq_len)
|
||||||
|
|
||||||
|
### 内存对比
|
||||||
|
|
||||||
|
| 配置 | 修复前 | 修复后 |
|
||||||
|
|------|--------|--------|
|
||||||
|
| max_model_len=72K | +1.1 GB | 0 GB |
|
||||||
|
| max_model_len=1M | +16 GB | 0 GB |
|
||||||
184
docs/long_context_models_1m.md
Normal file
184
docs/long_context_models_1m.md
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
# 1M+ 上下文长度模型列表
|
||||||
|
|
||||||
|
本文档收集了 Hugging Face 上支持 1M (1,048,576) 及以上上下文长度的开源模型。
|
||||||
|
|
||||||
|
> 更新时间: 2026-01-28
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、纯语言模型 (≤10B 参数)
|
||||||
|
|
||||||
|
### 1. 官方原版模型
|
||||||
|
|
||||||
|
| 厂商 | 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||||
|
|------|------|--------|------|--------|------|
|
||||||
|
| **Qwen** | Qwen2.5-7B-Instruct-1M | 1M | 7B | 69.3K | [HF](https://hf.co/Qwen/Qwen2.5-7B-Instruct-1M) |
|
||||||
|
| **THUDM** | GLM-4-9B-Chat-1M | 1M | 9B | 5.0K | [HF](https://hf.co/zai-org/glm-4-9b-chat-1m) |
|
||||||
|
| **InternLM** | InternLM2.5-7B-Chat-1M | 1M | 7B | 322 | [HF](https://hf.co/internlm/internlm2_5-7b-chat-1m) |
|
||||||
|
| **NVIDIA** | Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct | 1M | 8B | 2.9K | [HF](https://hf.co/nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct) |
|
||||||
|
| **LWM** | LWM-Text-1M | 1M | 7B | 75 | [HF](https://hf.co/LargeWorldModel/LWM-Text-1M) |
|
||||||
|
| **LWM** | LWM-Text-Chat-1M | 1M | 7B | 3.0K | [HF](https://hf.co/LargeWorldModel/LWM-Text-Chat-1M) |
|
||||||
|
|
||||||
|
### 2. Gradient AI 扩展系列 (基于 Llama 3)
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||||
|
|------|--------|------|--------|------|
|
||||||
|
| Llama-3-8B-Instruct-Gradient-1048k | **1M** | 8B | 44.8K | [HF](https://hf.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k) |
|
||||||
|
| Llama-3-8B-Instruct-Gradient-4194k | **4M** | 8B | 9 | [HF](https://hf.co/gradientai/Llama-3-8B-Instruct-Gradient-4194k) |
|
||||||
|
|
||||||
|
### 3. 社区衍生版本 (Abliterated)
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 基础模型 | 下载量 | 链接 |
|
||||||
|
|------|--------|----------|--------|------|
|
||||||
|
| Qwen2.5-7B-Instruct-1M-abliterated | 1M | Qwen2.5-7B | 375 | [HF](https://hf.co/huihui-ai/Qwen2.5-7B-Instruct-1M-abliterated) |
|
||||||
|
| Nemotron-8B-UltraLong-1M-Abliterated | 1M | Nemotron-8B | 46 | [HF](https://hf.co/SicariusSicariiStuff/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct_Abliterated) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、视觉-语言模型 (≤10B 参数)
|
||||||
|
|
||||||
|
### Qwen3 VL 系列
|
||||||
|
|
||||||
|
#### Instruct 版本
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||||
|
|------|--------|------|--------|------|
|
||||||
|
| Qwen3-VL-2B-Instruct-1M-GGUF | 1M | 2B | 824 | [HF](https://hf.co/unsloth/Qwen3-VL-2B-Instruct-1M-GGUF) |
|
||||||
|
| Qwen3-VL-4B-Instruct-1M-GGUF | 1M | 4B | 936 | [HF](https://hf.co/unsloth/Qwen3-VL-4B-Instruct-1M-GGUF) |
|
||||||
|
| Qwen3-VL-8B-Instruct-1M-GGUF | 1M | 8B | 962 | [HF](https://hf.co/unsloth/Qwen3-VL-8B-Instruct-1M-GGUF) |
|
||||||
|
|
||||||
|
#### Thinking 推理版本
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||||
|
|------|--------|------|--------|------|
|
||||||
|
| Qwen3-VL-2B-Thinking-1M-GGUF | 1M | 2B | 808 | [HF](https://hf.co/unsloth/Qwen3-VL-2B-Thinking-1M-GGUF) |
|
||||||
|
| Qwen3-VL-4B-Thinking-1M-GGUF | 1M | 4B | 666 | [HF](https://hf.co/unsloth/Qwen3-VL-4B-Thinking-1M-GGUF) |
|
||||||
|
| Qwen3-VL-8B-Thinking-1M-GGUF | 1M | 8B | 4.6K | [HF](https://hf.co/unsloth/Qwen3-VL-8B-Thinking-1M-GGUF) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、推荐模型 (≤10B)
|
||||||
|
|
||||||
|
| 用途 | 推荐模型 | 理由 |
|
||||||
|
|------|----------|------|
|
||||||
|
| **通用对话** | Qwen2.5-7B-Instruct-1M | 官方支持,RULER 93.1分,Apache 2.0 |
|
||||||
|
| **中英双语** | GLM-4-9B-Chat-1M | 清华出品,中文优化 |
|
||||||
|
| **最长上下文** | Llama-3-8B-Gradient-4194k | 支持 4M 上下文 |
|
||||||
|
| **多模态** | Qwen3-VL-8B-Thinking-1M | 视觉理解 + 推理能力 |
|
||||||
|
| **无审查** | Qwen2.5-7B-Instruct-1M-abliterated | 移除安全限制 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、VRAM 需求参考
|
||||||
|
|
||||||
|
| 模型规模 | 1M 上下文 VRAM | 备注 |
|
||||||
|
|----------|----------------|------|
|
||||||
|
| 7B (FP16) | ~120GB | 需多卡 |
|
||||||
|
| 7B (INT4) | ~40GB | 单卡 A100 可行 |
|
||||||
|
| 8B (FP16) | ~130GB | 需多卡 |
|
||||||
|
| 9B (FP16) | ~140GB | 需多卡 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、技术对比
|
||||||
|
|
||||||
|
| 模型系列 | 扩展技术 | RULER 得分 | 许可证 |
|
||||||
|
|---------|---------|------------|--------|
|
||||||
|
| Qwen2.5-1M | Dual Chunk Attention | 93.1 | Apache 2.0 |
|
||||||
|
| GLM-4-1M | - | 89.9 | 自定义 |
|
||||||
|
| Gradient-Llama | 渐进式扩展 | - | Llama 3 |
|
||||||
|
| Nemotron-1M | NVIDIA 训练 | - | CC-BY-NC-4.0 |
|
||||||
|
| LWM-1M | RingAttention | - | 开源 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 附录:大参数模型 (>10B)
|
||||||
|
|
||||||
|
> 以下模型参数量超过 10B,需要更多计算资源。
|
||||||
|
|
||||||
|
## A. 纯语言模型 (>10B)
|
||||||
|
|
||||||
|
### 官方模型
|
||||||
|
|
||||||
|
| 厂商 | 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||||
|
|------|------|--------|------|--------|------|
|
||||||
|
| **Qwen** | Qwen2.5-14B-Instruct-1M | 1M | 14B | 4.7K | [HF](https://hf.co/Qwen/Qwen2.5-14B-Instruct-1M) |
|
||||||
|
| **MiniMax** | MiniMax-Text-01 | 1M | 456B MoE | 721 | [HF](https://hf.co/MiniMaxAI/MiniMax-Text-01) |
|
||||||
|
| **Gradient** | Llama-3-70B-Instruct-Gradient-1048k | 1M | 70B | 9 | [HF](https://hf.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k) |
|
||||||
|
|
||||||
|
### Qwen3 Coder 系列 (MoE)
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|
||||||
|
|------|--------|-----------------|--------|------|
|
||||||
|
| Qwen3-Coder-30B-A3B-Instruct-1M-GGUF | 1M | 30B / 3B | 13.1K | [HF](https://hf.co/unsloth/Qwen3-Coder-30B-A3B-Instruct-1M-GGUF) |
|
||||||
|
| Qwen3-Coder-480B-A35B-Instruct-1M | 1M | 480B / 35B | 50 | [HF](https://hf.co/unsloth/Qwen3-Coder-480B-A35B-Instruct-1M) |
|
||||||
|
| Qwen3-Coder-480B-A35B-Instruct-1M-GGUF | 1M | 480B / 35B | 1.7K | [HF](https://hf.co/unsloth/Qwen3-Coder-480B-A35B-Instruct-1M-GGUF) |
|
||||||
|
| Qwen3-Coder-42B-A3B-TOTAL-RECALL-1M | 1M | 42B / 3B | - | [HF](https://hf.co/DavidAU/Qwen3-Coder-42B-A3B-Instruct-TOTAL-RECALL-MASTER-CODER-M-1million-ctx) |
|
||||||
|
|
||||||
|
### 社区衍生版本
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||||
|
|------|--------|------|--------|------|
|
||||||
|
| Qwen2.5-14B-Instruct-1M-abliterated | 1M | 14B | 147 | [HF](https://hf.co/huihui-ai/Qwen2.5-14B-Instruct-1M-abliterated) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## B. 视觉-语言模型 (>10B)
|
||||||
|
|
||||||
|
### Meta Llama 4 系列 (MoE 多模态)
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|
||||||
|
|------|--------|-----------------|--------|------|
|
||||||
|
| Llama-4-Scout-17B-16E-Instruct | **10M** | 109B / 17B | 180K | [HF](https://hf.co/meta-llama/Llama-4-Scout-17B-16E-Instruct) |
|
||||||
|
| Llama-4-Maverick-17B-128E-Instruct | **1M** | 400B / 17B | 32.6K | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E-Instruct) |
|
||||||
|
| Llama-4-Scout-17B-16E | 10M | 109B / 17B | 8.4K | [HF](https://hf.co/meta-llama/Llama-4-Scout-17B-16E) |
|
||||||
|
| Llama-4-Maverick-17B-128E | 1M | 400B / 17B | 368 | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E) |
|
||||||
|
| Llama-4-Maverick-17B-128E-Instruct-FP8 | 1M | 400B / 17B | 29.6K | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8) |
|
||||||
|
|
||||||
|
### Qwen3 VL 大模型系列
|
||||||
|
|
||||||
|
#### Dense 模型
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||||
|
|------|--------|------|--------|------|
|
||||||
|
| Qwen3-VL-32B-Instruct-1M-GGUF | 1M | 32B | 1.2K | [HF](https://hf.co/unsloth/Qwen3-VL-32B-Instruct-1M-GGUF) |
|
||||||
|
| Qwen3-VL-32B-Thinking-1M-GGUF | 1M | 32B | 452 | [HF](https://hf.co/unsloth/Qwen3-VL-32B-Thinking-1M-GGUF) |
|
||||||
|
|
||||||
|
#### MoE 模型
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|
||||||
|
|------|--------|-----------------|--------|------|
|
||||||
|
| Qwen3-VL-30B-A3B-Instruct-1M-GGUF | 1M | 30B / 3B | 821 | [HF](https://hf.co/unsloth/Qwen3-VL-30B-A3B-Instruct-1M-GGUF) |
|
||||||
|
| Qwen3-VL-30B-A3B-Thinking-1M-GGUF | 1M | 30B / 3B | 944 | [HF](https://hf.co/unsloth/Qwen3-VL-30B-A3B-Thinking-1M-GGUF) |
|
||||||
|
| Qwen3-VL-235B-A22B-Instruct-1M-GGUF | 1M | 235B / 22B | 581 | [HF](https://hf.co/unsloth/Qwen3-VL-235B-A22B-Instruct-1M-GGUF) |
|
||||||
|
| Qwen3-VL-235B-A22B-Thinking-1M-GGUF | 1M | 235B / 22B | 733 | [HF](https://hf.co/unsloth/Qwen3-VL-235B-A22B-Thinking-1M-GGUF) |
|
||||||
|
|
||||||
|
#### MXFP4 量化版本
|
||||||
|
|
||||||
|
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|
||||||
|
|------|--------|------|--------|------|
|
||||||
|
| Qwen3-VL-30B-A3B-Instruct-1M-MXFP4_MOE-GGUF | 1M | 30B MoE | 689 | [HF](https://hf.co/noctrex/Qwen3-VL-30B-A3B-Instruct-1M-MXFP4_MOE-GGUF) |
|
||||||
|
| Qwen3-VL-30B-A3B-Thinking-1M-MXFP4_MOE-GGUF | 1M | 30B MoE | 565 | [HF](https://hf.co/noctrex/Qwen3-VL-30B-A3B-Thinking-1M-MXFP4_MOE-GGUF) |
|
||||||
|
| Qwen3-VL-235B-A22B-Instruct-1M-MXFP4_MOE-GGUF | 1M | 235B MoE | 136 | [HF](https://hf.co/noctrex/Qwen3-VL-235B-A22B-Instruct-1M-MXFP4_MOE-GGUF) |
|
||||||
|
| Qwen3-VL-235B-A22B-Thinking-1M-MXFP4_MOE-GGUF | 1M | 235B MoE | 244 | [HF](https://hf.co/noctrex/Qwen3-VL-235B-A22B-Thinking-1M-MXFP4_MOE-GGUF) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 统计汇总
|
||||||
|
|
||||||
|
| 类别 | ≤10B 模型数 | >10B 模型数 | 最大上下文 |
|
||||||
|
|------|-------------|-------------|-----------|
|
||||||
|
| 纯语言模型 | 10 | 8 | 4M |
|
||||||
|
| 视觉-语言模型 | 6 | 14 | 10M |
|
||||||
|
| **合计** | **16** | **22** | **10M** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 参考资源
|
||||||
|
|
||||||
|
- [Qwen2.5-1M 官方博客](https://qwenlm.github.io/blog/qwen2.5-1m/)
|
||||||
|
- [LongRoPE 论文](https://huggingface.co/papers/2402.13753)
|
||||||
|
- [InfiniteHiP 论文](https://huggingface.co/papers/2502.08910)
|
||||||
|
- [Top LLMs for Long Context Windows](https://www.siliconflow.com/articles/en/top-LLMs-for-long-context-windows)
|
||||||
120
docs/memory_communication_benchmark.md
Normal file
120
docs/memory_communication_benchmark.md
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# Memory Communication Benchmark
|
||||||
|
|
||||||
|
GPU-CPU 通信量测试结果,对比 Full Policy 和 XAttention BSA Policy。
|
||||||
|
|
||||||
|
## 测试环境
|
||||||
|
|
||||||
|
- **模型**: Llama-3.1-8B-Instruct
|
||||||
|
- **GPU**: RTX 3090 (24GB)
|
||||||
|
- **配置**: `num_gpu_blocks=4`, `block_size=1024`, `enable_cpu_offload=True`
|
||||||
|
- **XAttention 参数**: `threshold=0.95`, `stride=8`
|
||||||
|
|
||||||
|
## 32K 上下文测试结果
|
||||||
|
|
||||||
|
| 指标 | Full Policy | XAttention | 比率 |
|
||||||
|
|------|-------------|------------|------|
|
||||||
|
| **Prefill H2D** | 66.57 GB | 111.12 GB | **1.67x** |
|
||||||
|
| Prefill D2H | 4.29 GB | 4.29 GB | 1.00x |
|
||||||
|
| TTFT | 8473 ms | 10367 ms | 1.22x |
|
||||||
|
|
||||||
|
### XAttention Block Selection (32K)
|
||||||
|
|
||||||
|
| 指标 | 数值 |
|
||||||
|
|------|------|
|
||||||
|
| 可用 blocks | 465 |
|
||||||
|
| 选中 blocks | 374 |
|
||||||
|
| 选择密度 | 80.4% |
|
||||||
|
|
||||||
|
## 64K 上下文测试结果
|
||||||
|
|
||||||
|
| 指标 | Full Policy | XAttention | 比率 |
|
||||||
|
|------|-------------|------------|------|
|
||||||
|
| **Prefill H2D** | 262.13 GB | 386.62 GB | **1.48x** |
|
||||||
|
| Prefill D2H | 8.46 GB | 8.46 GB | 1.00x |
|
||||||
|
| Decode H2D (32 tokens) | 262.13 GB | 262.13 GB | 1.00x |
|
||||||
|
| TTFT | 27081 ms | 33634 ms | 1.24x |
|
||||||
|
|
||||||
|
## 通信量比率对比 (K-only 优化前)
|
||||||
|
|
||||||
|
| 上下文长度 | XAttn/Full Prefill H2D 比率 |
|
||||||
|
|------------|----------------------------|
|
||||||
|
| 32K | 1.67x |
|
||||||
|
| 64K | 1.48x |
|
||||||
|
|
||||||
|
### 分析 (优化前)
|
||||||
|
|
||||||
|
1. **XAttention 通信量增加原因**:
|
||||||
|
- Estimate 阶段:加载 **100%** 历史 blocks 的 **K+V**(用于 attention score 估计)
|
||||||
|
- Compute 阶段:加载 **选中的** blocks(约 70-80%)
|
||||||
|
- 理论比率:`1 + selection_density`
|
||||||
|
|
||||||
|
2. **64K 比率更低的原因**:
|
||||||
|
- 更长上下文时,attention 分布更稀疏
|
||||||
|
- XAttention 的 block 选择更有效(选中比例更低)
|
||||||
|
- First/last block 强制包含的影响相对减小
|
||||||
|
|
||||||
|
3. **Decode 阶段通信量相同**:
|
||||||
|
- XAttention 仅支持 prefill 阶段
|
||||||
|
- Decode 阶段 fallback 到 Full Policy
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## K-only 优化 (2026-01-28)
|
||||||
|
|
||||||
|
### 优化原理
|
||||||
|
|
||||||
|
XAttention 的 `select_blocks` 估计阶段只需要 K 来计算 attention scores:
|
||||||
|
```python
|
||||||
|
# flat_group_gemm_fuse_reshape 只使用 Q 和 K
|
||||||
|
attn_scores = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
V 在估计阶段完全不使用,但之前代码会同时加载 K 和 V,造成 50% 通信量浪费。
|
||||||
|
|
||||||
|
### 优化实现
|
||||||
|
|
||||||
|
1. **新增方法**: `OffloadEngine.load_k_only_to_slot_layer()` - 只加载 K
|
||||||
|
2. **修改 select_blocks**: 使用只加载 K 的新方法
|
||||||
|
|
||||||
|
### 优化后测试结果
|
||||||
|
|
||||||
|
| 上下文 | Full Policy | XAttn (优化前) | XAttn (优化后) | 优化节省 |
|
||||||
|
|--------|-------------|---------------|---------------|---------|
|
||||||
|
| 32K | 66.57 GB | 111.12 GB | **79.76 GB** | **28.2%** |
|
||||||
|
| 64K | 262.13 GB | 386.62 GB | **258.78 GB** | **33.1%** |
|
||||||
|
|
||||||
|
### XAttn/Full 比率变化
|
||||||
|
|
||||||
|
| 上下文 | 优化前比率 | 优化后比率 |
|
||||||
|
|--------|-----------|-----------|
|
||||||
|
| 32K | 1.67x | **1.20x** |
|
||||||
|
| 64K | 1.48x | **0.99x** |
|
||||||
|
|
||||||
|
### 结论
|
||||||
|
|
||||||
|
优化后,64K 上下文的 XAttention 通信量与 Full Policy 基本持平 (0.99x),
|
||||||
|
而 32K 也从 1.67x 降到 1.20x。这说明估计阶段的 K-only 优化非常有效
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 32K Full Policy
|
||||||
|
python bench_offload.py --max-len 32768 --input-len 32000
|
||||||
|
|
||||||
|
# 32K XAttention
|
||||||
|
python bench_offload.py --max-len 32768 --input-len 32000 --enable-xattn
|
||||||
|
|
||||||
|
# 64K Full Policy
|
||||||
|
python bench_offload.py --max-len 65536 --input-len 64000
|
||||||
|
|
||||||
|
# 64K XAttention
|
||||||
|
python bench_offload.py --max-len 65536 --input-len 64000 --enable-xattn
|
||||||
|
|
||||||
|
# 包含 decode 测试
|
||||||
|
python bench_offload.py --max-len 65536 --input-len 64000 --bench-decode --output-len 32
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`observer_architecture.md`](observer_architecture.md) - Observer 架构设计
|
||||||
|
- [`xattn_bsa_policy_design.md`](xattn_bsa_policy_design.md) - XAttention BSA 算法设计
|
||||||
323
docs/new_model_integration_guide.md
Normal file
323
docs/new_model_integration_guide.md
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
# 新模型整合指南
|
||||||
|
|
||||||
|
本文档总结了将新模型(如GLM-4)整合到nanovllm的经验和常见问题。
|
||||||
|
|
||||||
|
## 整合流程概览
|
||||||
|
|
||||||
|
```
|
||||||
|
1. 分析模型配置 (config.json)
|
||||||
|
↓
|
||||||
|
2. 创建模型文件 (nanovllm/models/<model>.py)
|
||||||
|
↓
|
||||||
|
3. 实现权重加载 (nanovllm/utils/loader.py)
|
||||||
|
↓
|
||||||
|
4. 处理特殊组件 (RoPE, Attention, etc.)
|
||||||
|
↓
|
||||||
|
5. 处理tokenizer差异 (EOS tokens, chat template)
|
||||||
|
↓
|
||||||
|
6. 验证输出正确性
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 配置字段映射
|
||||||
|
|
||||||
|
不同模型使用不同的配置字段名称,需要建立映射关系:
|
||||||
|
|
||||||
|
| 标准字段 | GLM-4 | Qwen | Llama | 说明 |
|
||||||
|
|----------|-------|------|-------|------|
|
||||||
|
| `num_key_value_heads` | `multi_query_group_num` | `num_key_value_heads` | `num_key_value_heads` | KV heads数量 |
|
||||||
|
| `head_dim` | `kv_channels` | 计算得出 | 计算得出 | 每个head的维度 |
|
||||||
|
| `intermediate_size` | `ffn_hidden_size` | `intermediate_size` | `intermediate_size` | FFN隐藏层大小 |
|
||||||
|
| `max_position_embeddings` | `seq_length` | `max_position_embeddings` | `max_position_embeddings` | 最大位置 |
|
||||||
|
| `rope_theta` | `10000 * rope_ratio` | `rope_theta` | `rope_theta` | RoPE基础频率 |
|
||||||
|
|
||||||
|
### 代码示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在模型 __init__ 中处理配置差异
|
||||||
|
num_kv_heads = getattr(config, 'num_key_value_heads',
|
||||||
|
getattr(config, 'multi_query_group_num', num_heads))
|
||||||
|
|
||||||
|
head_dim = getattr(config, 'head_dim',
|
||||||
|
getattr(config, 'kv_channels', hidden_size // num_heads))
|
||||||
|
|
||||||
|
intermediate_size = getattr(config, 'intermediate_size',
|
||||||
|
getattr(config, 'ffn_hidden_size', None))
|
||||||
|
|
||||||
|
max_position = getattr(config, 'max_position_embeddings',
|
||||||
|
getattr(config, 'seq_length', 4096))
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. RoPE实现差异
|
||||||
|
|
||||||
|
RoPE是模型整合中**最容易出错**的部分。不同模型可能使用不同的RoPE变体:
|
||||||
|
|
||||||
|
### 2.1 旋转方式
|
||||||
|
|
||||||
|
| 类型 | 描述 | 使用模型 |
|
||||||
|
|------|------|----------|
|
||||||
|
| **Half rotation** | 前半和后半分别旋转 `[x0,x1,...] → [x0*cos-x_{d/2}*sin, ...]` | Llama, Qwen |
|
||||||
|
| **Interleaved rotation** | 相邻元素配对旋转 `[x0,x1,...] → [x0*cos-x1*sin, x1*cos+x0*sin, ...]` | GLM-4 |
|
||||||
|
|
||||||
|
### 2.2 旋转维度
|
||||||
|
|
||||||
|
| 类型 | 描述 | 使用模型 |
|
||||||
|
|------|------|----------|
|
||||||
|
| **Full rotation** | 旋转整个head_dim | Llama, Qwen |
|
||||||
|
| **Partial rotation** | 只旋转head_dim的一部分,其余pass-through | GLM-4 (rotary_dim = head_dim // 2) |
|
||||||
|
|
||||||
|
### 2.3 GLM-4 RoPE实现
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GLM4RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, head_dim, rotary_dim, ...):
|
||||||
|
# GLM-4只旋转一半维度
|
||||||
|
self.rotary_dim = rotary_dim # = head_dim // 2
|
||||||
|
|
||||||
|
def forward(self, positions, query, key):
|
||||||
|
# 分离旋转部分和pass-through部分
|
||||||
|
q_rot = query[..., :self.rotary_dim]
|
||||||
|
q_pass = query[..., self.rotary_dim:]
|
||||||
|
|
||||||
|
# 只对旋转部分应用interleaved RoPE
|
||||||
|
q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin)
|
||||||
|
|
||||||
|
# 拼接回去
|
||||||
|
return torch.cat([q_rot, q_pass], dim=-1), ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.4 调试RoPE问题
|
||||||
|
|
||||||
|
**症状**:模型输出乱码或重复无意义的内容(如 "The. The. The...")
|
||||||
|
|
||||||
|
**调试方法**:
|
||||||
|
```python
|
||||||
|
# 对比HuggingFace参考实现的输出
|
||||||
|
hf_q, hf_k = hf_model.apply_rotary_pos_emb(query, key, cos, sin)
|
||||||
|
my_q, my_k = my_rotary_emb(positions, query, key)
|
||||||
|
|
||||||
|
print(f"Q max diff: {(hf_q - my_q).abs().max()}") # 应该 < 1e-5
|
||||||
|
print(f"K max diff: {(hf_k - my_k).abs().max()}") # 应该 < 1e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 权重名称映射
|
||||||
|
|
||||||
|
不同模型的权重命名规范不同:
|
||||||
|
|
||||||
|
### 3.1 常见映射
|
||||||
|
|
||||||
|
| 组件 | Llama/Qwen | GLM-4 |
|
||||||
|
|------|------------|-------|
|
||||||
|
| Attention QKV | `q_proj`, `k_proj`, `v_proj` | `query_key_value` (合并) |
|
||||||
|
| Attention Output | `o_proj` | `dense` |
|
||||||
|
| MLP Gate | `gate_proj` | `dense_h_to_4h` (部分) |
|
||||||
|
| MLP Up | `up_proj` | `dense_h_to_4h` (部分) |
|
||||||
|
| MLP Down | `down_proj` | `dense_4h_to_h` |
|
||||||
|
| LayerNorm | `input_layernorm` | `input_layernorm` |
|
||||||
|
| Post-Attention LN | `post_attention_layernorm` | `post_attention_layernorm` |
|
||||||
|
|
||||||
|
### 3.2 实现权重转换
|
||||||
|
|
||||||
|
```python
|
||||||
|
def convert_glm4_weights(name, param):
|
||||||
|
"""将GLM-4权重名称转换为nanovllm格式"""
|
||||||
|
# 处理合并的QKV权重
|
||||||
|
if "query_key_value" in name:
|
||||||
|
# 拆分为q, k, v
|
||||||
|
q, k, v = param.split([q_size, kv_size, kv_size], dim=0)
|
||||||
|
return {"q_proj": q, "k_proj": k, "v_proj": v}
|
||||||
|
|
||||||
|
# 处理合并的gate+up权重
|
||||||
|
if "dense_h_to_4h" in name:
|
||||||
|
gate, up = param.chunk(2, dim=0)
|
||||||
|
return {"gate_proj": gate, "up_proj": up}
|
||||||
|
|
||||||
|
return {name: param}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. EOS Token处理
|
||||||
|
|
||||||
|
### 4.1 问题
|
||||||
|
|
||||||
|
某些模型使用**多个EOS tokens**:
|
||||||
|
|
||||||
|
| 模型 | EOS Token(s) | 说明 |
|
||||||
|
|------|--------------|------|
|
||||||
|
| Llama | `128001` | 单一EOS |
|
||||||
|
| Qwen | `151643` | 单一EOS |
|
||||||
|
| GLM-4 | `[151329, 151336, 151338]` | 多个:endoftext, user, observation |
|
||||||
|
|
||||||
|
**问题**:`tokenizer.eos_token_id` 只返回第一个,导致模型不会在其他EOS token处停止。
|
||||||
|
|
||||||
|
### 4.2 解决方案
|
||||||
|
|
||||||
|
```python
|
||||||
|
# config.py - 支持多个EOS
|
||||||
|
eos: int | list[int] = -1
|
||||||
|
|
||||||
|
# llm_engine.py - 从hf_config读取完整EOS列表
|
||||||
|
eos_from_config = getattr(config.hf_config, 'eos_token_id', None)
|
||||||
|
if eos_from_config is not None:
|
||||||
|
config.eos = eos_from_config
|
||||||
|
else:
|
||||||
|
config.eos = self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
# scheduler.py - 使用set进行高效查找
|
||||||
|
self.eos_set = set(eos) if isinstance(eos, list) else {eos}
|
||||||
|
|
||||||
|
# 检查时使用 in 而不是 ==
|
||||||
|
if token_id in self.eos_set:
|
||||||
|
# 停止生成
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 调试EOS问题
|
||||||
|
|
||||||
|
**症状**:模型总是生成到max_tokens才停止
|
||||||
|
|
||||||
|
**调试方法**:
|
||||||
|
```python
|
||||||
|
# 检查EOS配置
|
||||||
|
print(f"tokenizer.eos_token_id: {tokenizer.eos_token_id}")
|
||||||
|
print(f"hf_config.eos_token_id: {config.hf_config.eos_token_id}")
|
||||||
|
|
||||||
|
# 检查输出中的EOS tokens
|
||||||
|
output = llm.generate([prompt], params)
|
||||||
|
for eos_id in [151329, 151336, 151338]:
|
||||||
|
if eos_id in output[0]['token_ids']:
|
||||||
|
print(f"Found EOS {eos_id} at position {output[0]['token_ids'].index(eos_id)}")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Chat Template
|
||||||
|
|
||||||
|
不同模型使用不同的对话模板:
|
||||||
|
|
||||||
|
| 模型 | 模板格式 |
|
||||||
|
|------|----------|
|
||||||
|
| Llama-3 | `<\|begin_of_text\|><\|start_header_id\|>user<\|end_header_id\|>\n{content}<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>\n` |
|
||||||
|
| Qwen | `<\|im_start\|>user\n{content}<\|im_end\|>\n<\|im_start\|>assistant\n` |
|
||||||
|
| GLM-4 | `[gMASK]<sop><\|user\|>\n{content}<\|assistant\|>\n` |
|
||||||
|
|
||||||
|
### 实现模板转换
|
||||||
|
|
||||||
|
```python
|
||||||
|
def convert_to_model_prompt(prompt: str, model_type: str) -> str:
|
||||||
|
"""将标准prompt转换为模型特定格式"""
|
||||||
|
if model_type == "glm4":
|
||||||
|
return f"[gMASK]<sop><|user|>\n{prompt}<|assistant|>\n"
|
||||||
|
elif model_type == "llama3":
|
||||||
|
return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
|
||||||
|
# ...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 验证清单
|
||||||
|
|
||||||
|
整合新模型后,按以下顺序验证:
|
||||||
|
|
||||||
|
### 6.1 权重加载验证
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 检查所有权重是否正确加载
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.abs().sum() == 0:
|
||||||
|
print(f"WARNING: {name} is all zeros!")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 单层输出验证
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 对比embedding层输出
|
||||||
|
my_emb = my_model.embed_tokens(input_ids)
|
||||||
|
hf_emb = hf_model.model.embed_tokens(input_ids)
|
||||||
|
print(f"Embedding diff: {(my_emb - hf_emb).abs().max()}") # < 1e-5
|
||||||
|
|
||||||
|
# 对比第一层输出
|
||||||
|
my_out = my_model.layers[0](my_emb, ...)
|
||||||
|
hf_out = hf_model.model.layers[0](hf_emb, ...)
|
||||||
|
print(f"Layer 0 diff: {(my_out - hf_out).abs().max()}") # < 1e-4
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.3 生成质量验证
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 简单问答测试
|
||||||
|
prompt = "Hello, how are you?"
|
||||||
|
output = llm.generate([prompt], SamplingParams(max_tokens=50))
|
||||||
|
print(output[0]['text']) # 应该是连贯的回答
|
||||||
|
|
||||||
|
# 检查是否正确停止
|
||||||
|
print(f"Generated {len(output[0]['token_ids'])} tokens (max=50)")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.4 RULER基准测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行1个sample快速验证
|
||||||
|
python tests/test_ruler.py --model <path> --num-samples 1
|
||||||
|
|
||||||
|
# 验证通过后运行完整测试
|
||||||
|
python tests/test_ruler.py --model <path> --num-samples 100
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 常见问题速查
|
||||||
|
|
||||||
|
| 症状 | 可能原因 | 解决方案 |
|
||||||
|
|------|----------|----------|
|
||||||
|
| 输出乱码/重复 | RoPE实现错误 | 检查旋转方式(interleaved vs half)和旋转维度(full vs partial) |
|
||||||
|
| 数值爆炸(NaN/Inf) | 权重加载错误或dtype不匹配 | 检查权重映射,确保dtype一致 |
|
||||||
|
| 不停止生成 | EOS token处理错误 | 从hf_config读取完整EOS列表 |
|
||||||
|
| 输出质量差 | LayerNorm或bias缺失 | 检查add_qkv_bias等配置 |
|
||||||
|
| 位置编码错误 | max_position_embeddings读取错误 | 检查配置字段名称(seq_length等) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 文件结构
|
||||||
|
|
||||||
|
新模型整合需要修改/创建的文件:
|
||||||
|
|
||||||
|
```
|
||||||
|
nanovllm/
|
||||||
|
├── models/
|
||||||
|
│ └── <model>.py # 新建:模型定义
|
||||||
|
├── layers/
|
||||||
|
│ └── rotary_embedding.py # 修改:如需特殊RoPE
|
||||||
|
├── utils/
|
||||||
|
│ └── loader.py # 修改:权重加载
|
||||||
|
├── config.py # 可能修改:新配置字段
|
||||||
|
└── engine/
|
||||||
|
├── llm_engine.py # 可能修改:EOS处理
|
||||||
|
└── scheduler.py # 可能修改:EOS检查
|
||||||
|
tests/
|
||||||
|
└── test_ruler.py # 修改:chat template
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录:GLM-4整合案例
|
||||||
|
|
||||||
|
### 遇到的问题及解决
|
||||||
|
|
||||||
|
1. **配置字段差异** → 添加getattr fallback链
|
||||||
|
2. **Interleaved RoPE** → 实现`apply_rotary_emb_interleaved`
|
||||||
|
3. **Partial rotation (head_dim//2)** → 实现`GLM4RotaryEmbedding`
|
||||||
|
4. **多EOS tokens** → 修改config/llm_engine/scheduler支持list
|
||||||
|
5. **合并的QKV权重** → 在loader中拆分
|
||||||
|
|
||||||
|
### 关键代码位置
|
||||||
|
|
||||||
|
- RoPE实现: `nanovllm/layers/rotary_embedding.py:GLM4RotaryEmbedding`
|
||||||
|
- 模型定义: `nanovllm/models/glm4.py`
|
||||||
|
- 权重加载: `nanovllm/utils/loader.py:load_glm4_weights`
|
||||||
|
- EOS处理: `nanovllm/engine/scheduler.py:eos_set`
|
||||||
210
docs/nsys_wrong_event_order_bug.md
Normal file
210
docs/nsys_wrong_event_order_bug.md
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
# Nsys "Wrong Event Order" Bug 调试记录
|
||||||
|
|
||||||
|
## 问题描述
|
||||||
|
|
||||||
|
使用 `nsys profile` 对 nanovllm 的 CPU offload 模式进行性能分析时,无法生成 `.nsys-rep` 文件,报错:
|
||||||
|
|
||||||
|
```
|
||||||
|
Importer error status: Importation failed.
|
||||||
|
Wrong event order has been detected when adding events to the collection:
|
||||||
|
new event ={ StartNs=21569539222 StopNs=21569672388 ... Type=48 }
|
||||||
|
last event ={ StartNs=22046804077 StopNs=22046805343 ... Type=48 }
|
||||||
|
```
|
||||||
|
|
||||||
|
## 环境信息
|
||||||
|
|
||||||
|
- **nsys 版本**: 2023.4.4.54-234433681190v0
|
||||||
|
- **CUDA**: 12.4
|
||||||
|
- **问题状态**: nsys 已知 bug,2024.2+ 版本已修复
|
||||||
|
|
||||||
|
## 调试过程
|
||||||
|
|
||||||
|
### 阶段 1:确定触发条件
|
||||||
|
|
||||||
|
使用 bisect 脚本 (`tests/test_nsys_bisect.py`) 逐步测试:
|
||||||
|
|
||||||
|
| Stage | 描述 | 结果 |
|
||||||
|
|-------|------|------|
|
||||||
|
| 1 | CUDA init | ✅ |
|
||||||
|
| 2 | Import nanovllm | ✅ |
|
||||||
|
| 3 | Create LLM (offload) | ✅ |
|
||||||
|
| 4 | 短 prompt 生成 | ✅ |
|
||||||
|
| **5** | **长 prompt (~64K) prefill** | ❌ |
|
||||||
|
|
||||||
|
**结论**:问题出在长 prompt 的 chunked prefill 流程。
|
||||||
|
|
||||||
|
### 阶段 2:定位具体组件
|
||||||
|
|
||||||
|
在 `_chunked_prefill_attention` 方法中逐步注释代码:
|
||||||
|
|
||||||
|
| 组件 | 文件位置 | 结果 |
|
||||||
|
|------|----------|------|
|
||||||
|
| 整个方法 (return zeros) | `attention.py:167` | ✅ |
|
||||||
|
| `select_blocks()` | `attention.py:217` | ✅ |
|
||||||
|
| `offload_prefill_buffer_async()` | `attention.py:241-248` | ✅ |
|
||||||
|
| `compute_chunked_prefill()` | `attention.py:225-235` | ❌ |
|
||||||
|
|
||||||
|
**结论**:问题出在 `compute_chunked_prefill` 内部。
|
||||||
|
|
||||||
|
### 阶段 3:定位 Ring Buffer Pipeline
|
||||||
|
|
||||||
|
在 `full_policy.py` 中进一步定位:
|
||||||
|
|
||||||
|
| 组件 | 代码行 | 结果 |
|
||||||
|
|------|--------|------|
|
||||||
|
| Current chunk attention | 191-198 | ✅ |
|
||||||
|
| **Historical block loading (ring buffer)** | 133-189 | ❌ |
|
||||||
|
|
||||||
|
**根因确认**:Ring buffer pipeline 的多 stream 操作触发了 nsys bug。
|
||||||
|
|
||||||
|
## 根本原因
|
||||||
|
|
||||||
|
### 触发 Bug 的代码
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/full_policy.py:133-189
|
||||||
|
|
||||||
|
# 多 slot pipeline 模式
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
current_slot = load_slots[block_idx % num_slots]
|
||||||
|
|
||||||
|
# 等待 slot 的 transfer stream 完成
|
||||||
|
offload_engine.wait_slot_layer(current_slot)
|
||||||
|
|
||||||
|
# 在 compute_stream 上执行 attention
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(...)
|
||||||
|
offload_engine.record_slot_compute_done(current_slot)
|
||||||
|
|
||||||
|
# 异步发起下一个 block 的加载
|
||||||
|
if next_block_idx < num_blocks:
|
||||||
|
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stream 结构
|
||||||
|
|
||||||
|
```
|
||||||
|
slot_transfer_streams[0] ─┐
|
||||||
|
slot_transfer_streams[1] ─┼─ 4 个 transfer streams
|
||||||
|
slot_transfer_streams[2] ─┤
|
||||||
|
slot_transfer_streams[3] ─┘
|
||||||
|
│
|
||||||
|
▼ wait/record 同步
|
||||||
|
│
|
||||||
|
compute_stream ───────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
这种 4+1 stream 的复杂同步模式导致 nsys 2023.4.4 版本的事件时间戳排序算法出错。
|
||||||
|
|
||||||
|
### 为什么简单多 stream 测试无法复现
|
||||||
|
|
||||||
|
我们尝试用简单的测试代码 (`tests/test_multistream_nsys.py`) 复现问题:
|
||||||
|
|
||||||
|
- 4-8 streams, 2000+ iterations: ✅ 成功
|
||||||
|
- 32 threads + multi-stream: ✅ 成功
|
||||||
|
- >64k CUDA operations: ✅ 成功
|
||||||
|
|
||||||
|
但都无法触发 bug。原因是实际代码中的 stream 同步模式更复杂:
|
||||||
|
1. 跨 stream 的 event wait/record
|
||||||
|
2. 与 FlashAttention kernel 的交互
|
||||||
|
3. 长时间运行(~50 秒)累积大量事件
|
||||||
|
|
||||||
|
## 解决方案
|
||||||
|
|
||||||
|
### 方案 1:升级 nsys(推荐)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 下载 nsys 2024.2+ 版本
|
||||||
|
# https://developer.nvidia.com/nsight-systems
|
||||||
|
```
|
||||||
|
|
||||||
|
根据 [NVIDIA 论坛](https://forums.developer.nvidia.com/t/nsys-profiler-wrong-event-order/264881),此 bug 在 2024.2 版本已修复。
|
||||||
|
|
||||||
|
### 方案 2:使用 .qdstrm 文件
|
||||||
|
|
||||||
|
即使导入失败,`.qdstrm` 文件仍然生成:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 生成的文件
|
||||||
|
results/nsys/ruler_niah_single_1_sample0_offload_*.qdstrm
|
||||||
|
|
||||||
|
# 尝试用 GUI 直接打开
|
||||||
|
nsight-sys <file>.qdstrm
|
||||||
|
```
|
||||||
|
|
||||||
|
GUI 可能有更好的容错能力。
|
||||||
|
|
||||||
|
### 方案 3:使用 PyTorch Profiler
|
||||||
|
|
||||||
|
```python
|
||||||
|
from torch.profiler import profile, ProfilerActivity
|
||||||
|
|
||||||
|
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
|
||||||
|
# your code
|
||||||
|
|
||||||
|
prof.export_chrome_trace("trace.json") # chrome://tracing 查看
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方案 4:临时禁用 ring buffer pipeline
|
||||||
|
|
||||||
|
在 `full_policy.py` 中临时使用单 slot 同步模式(仅用于调试):
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 强制使用单 slot 模式
|
||||||
|
if len(load_slots) == 1 or True: # 添加 "or True"
|
||||||
|
# 同步模式,不会触发 nsys bug
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## 复现步骤
|
||||||
|
|
||||||
|
### 环境准备
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /home/zijie/Code/nano-vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
### 运行 Bisect 脚本
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Stage 5 会触发 bug
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$PWD:$PYTHONPATH \
|
||||||
|
nsys profile --trace=cuda,nvtx,osrt --force-overwrite=true \
|
||||||
|
-o /tmp/bisect python tests/test_nsys_bisect.py --stage 5
|
||||||
|
```
|
||||||
|
|
||||||
|
### 验证修复
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 临时在 full_policy.py 中跳过 historical block loading
|
||||||
|
# 将第 133 行改为: if False and cpu_block_table:
|
||||||
|
|
||||||
|
# 重新运行,应该成功
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$PWD:$PYTHONPATH \
|
||||||
|
nsys profile --trace=cuda,nvtx,osrt --force-overwrite=true \
|
||||||
|
-o /tmp/bisect_fixed python tests/test_nsys_bisect.py --stage 5
|
||||||
|
|
||||||
|
# 检查是否生成 .nsys-rep
|
||||||
|
ls -la /tmp/bisect_fixed.nsys-rep
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
| 文件 | 用途 |
|
||||||
|
|------|------|
|
||||||
|
| `tests/test_nsys_bisect.py` | Bisect 调试脚本 |
|
||||||
|
| `tests/test_multistream_nsys.py` | 简单多 stream 测试 |
|
||||||
|
| `scripts/profile_offload.sh` | nsys profile 脚本 |
|
||||||
|
| `nanovllm/layers/attention.py` | Attention 层 |
|
||||||
|
| `nanovllm/kvcache/sparse/full_policy.py` | Ring buffer pipeline |
|
||||||
|
|
||||||
|
## 参考资料
|
||||||
|
|
||||||
|
- [Nsys Profiler- Wrong event order - NVIDIA Forums](https://forums.developer.nvidia.com/t/nsys-profiler-wrong-event-order/264881)
|
||||||
|
- [Nsight Systems 2025.3 Release Notes](https://docs.nvidia.com/nsight-systems/2025.3/ReleaseNotes/index.html)
|
||||||
|
- [Nsight Systems User Guide](https://docs.nvidia.com/nsight-systems/UserGuide/index.html)
|
||||||
|
|
||||||
|
## 调试日期
|
||||||
|
|
||||||
|
2026-01-24
|
||||||
194
docs/observer_architecture.md
Normal file
194
docs/observer_architecture.md
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# Observer Architecture
|
||||||
|
|
||||||
|
nanovllm 的 Observer 架构用于统计推理过程中的关键指标,采用类变量(class variable)模式实现全局状态管理。
|
||||||
|
|
||||||
|
## 架构概览
|
||||||
|
|
||||||
|
```
|
||||||
|
Observer (基类)
|
||||||
|
├── InferenceObserver - 推理时间指标 (TTFT, TPOT)
|
||||||
|
└── MemoryObserver - 内存传输统计 (H2D, D2H, D2D)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 设计原则
|
||||||
|
|
||||||
|
### 1. 类变量模式
|
||||||
|
|
||||||
|
所有 Observer 使用类变量(而非实例变量)存储状态:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Observer:
|
||||||
|
"""Observer 基类"""
|
||||||
|
_enabled: bool = True # 类变量,控制是否启用
|
||||||
|
|
||||||
|
class InferenceObserver(Observer):
|
||||||
|
ttft: int = 0 # 类变量,全局共享
|
||||||
|
tpot: int = 0
|
||||||
|
ttft_start: int = 0
|
||||||
|
tpot_start: int = 0
|
||||||
|
```
|
||||||
|
|
||||||
|
**优点**:
|
||||||
|
- 无需实例化,任何地方都可以直接访问
|
||||||
|
- 避免跨模块传递 observer 实例
|
||||||
|
- 适合全局统计场景
|
||||||
|
|
||||||
|
### 2. 启用/禁用控制
|
||||||
|
|
||||||
|
每个 Observer 可独立启用/禁用:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 启用 MemoryObserver
|
||||||
|
MemoryObserver._enabled = True
|
||||||
|
|
||||||
|
# 禁用后,record_* 方法不会记录
|
||||||
|
MemoryObserver._enabled = False
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 阶段分离
|
||||||
|
|
||||||
|
MemoryObserver 支持 prefill/decode 阶段分离统计:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@classmethod
|
||||||
|
def record_h2d(cls, num_bytes: int, is_prefill: bool = True) -> None:
|
||||||
|
if not cls._enabled:
|
||||||
|
return
|
||||||
|
cls.h2d_bytes += num_bytes
|
||||||
|
cls.h2d_count += 1
|
||||||
|
if is_prefill:
|
||||||
|
cls.prefill_h2d_bytes += num_bytes
|
||||||
|
else:
|
||||||
|
cls.decode_h2d_bytes += num_bytes
|
||||||
|
```
|
||||||
|
|
||||||
|
## Observer 实现
|
||||||
|
|
||||||
|
### InferenceObserver
|
||||||
|
|
||||||
|
**位置**: `nanovllm/utils/observer.py`
|
||||||
|
|
||||||
|
**统计指标**:
|
||||||
|
| 指标 | 说明 | 单位 |
|
||||||
|
|------|------|------|
|
||||||
|
| `ttft` | Time To First Token | 纳秒 |
|
||||||
|
| `tpot` | Time Per Output Token | 纳秒 |
|
||||||
|
| `ttft_start` | TTFT 计时开始点 | 纳秒 |
|
||||||
|
| `tpot_start` | TPOT 计时开始点 | 纳秒 |
|
||||||
|
|
||||||
|
**统计位置**:
|
||||||
|
| 位置 | 代码 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `scheduler.py:add()` | `InferenceObserver.ttft_start = perf_counter_ns()` | 开始计时 |
|
||||||
|
| `llm_engine.py:step()` | `InferenceObserver.ttft = ... - ttft_start` | Prefill 完成后计算 TTFT |
|
||||||
|
| `llm_engine.py:step()` | `InferenceObserver.tpot = ... - tpot_start` | Decode 时计算 TPOT |
|
||||||
|
|
||||||
|
### MemoryObserver
|
||||||
|
|
||||||
|
**位置**: `nanovllm/utils/memory_observer.py`
|
||||||
|
|
||||||
|
**统计指标**:
|
||||||
|
| 指标 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `h2d_bytes` / `h2d_count` | Host to Device 传输量/次数 |
|
||||||
|
| `d2h_bytes` / `d2h_count` | Device to Host 传输量/次数 |
|
||||||
|
| `d2d_bytes` / `d2d_count` | Device to Device 复制量/次数 |
|
||||||
|
| `prefill_h2d_bytes` / `prefill_d2h_bytes` | Prefill 阶段 H2D/D2H |
|
||||||
|
| `decode_h2d_bytes` / `decode_d2h_bytes` | Decode 阶段 H2D/D2H |
|
||||||
|
|
||||||
|
**统计位置** (均在 `offload_engine.py`):
|
||||||
|
|
||||||
|
| 方法 | 传输类型 | 说明 |
|
||||||
|
|------|----------|------|
|
||||||
|
| `load_to_slot_layer()` | H2D | 从 CPU 加载 block 到 GPU slot |
|
||||||
|
| `load_block_sample_from_cpu()` | H2D | 采样加载(Quest) |
|
||||||
|
| `load_block_full_from_cpu()` | H2D | 完整加载 block |
|
||||||
|
| `offload_slot_layer_to_cpu()` | D2H | GPU slot 卸载到 CPU |
|
||||||
|
| `offload_prefill_buffer_async()` | D2H | Prefill buffer 异步卸载 |
|
||||||
|
| `write_to_prefill_buffer()` | D2D | 写入 prefill buffer |
|
||||||
|
| `write_to_decode_buffer()` | D2D | 写入 decode buffer |
|
||||||
|
|
||||||
|
**重置位置**:
|
||||||
|
| 位置 | 代码 |
|
||||||
|
|------|------|
|
||||||
|
| `llm_engine.py:generate()` | `MemoryObserver.complete_reset()` |
|
||||||
|
| `llm_engine.py:generate()` | `InferenceObserver.complete_reset()` |
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
### 1. 启用并统计
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.utils.memory_observer import MemoryObserver
|
||||||
|
|
||||||
|
# 启用统计
|
||||||
|
MemoryObserver._enabled = True
|
||||||
|
|
||||||
|
# 运行推理
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# 获取结果
|
||||||
|
print(f"Prefill H2D: {MemoryObserver.prefill_h2d_bytes / 1e9:.2f} GB")
|
||||||
|
print(f"Decode H2D: {MemoryObserver.decode_h2d_bytes / 1e9:.2f} GB")
|
||||||
|
|
||||||
|
# 或使用 print_summary
|
||||||
|
MemoryObserver.print_summary()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 在 bench_offload.py 中
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.utils.memory_observer import MemoryObserver
|
||||||
|
|
||||||
|
# 启用
|
||||||
|
MemoryObserver._enabled = True
|
||||||
|
|
||||||
|
# benchmark 结束后打印
|
||||||
|
def print_memory_stats():
|
||||||
|
fmt = MemoryObserver._fmt_bytes
|
||||||
|
print(f"[Memory] Prefill H2D: {fmt(MemoryObserver.prefill_h2d_bytes)}")
|
||||||
|
print(f" Decode H2D: {fmt(MemoryObserver.decode_h2d_bytes)}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 获取结构化数据
|
||||||
|
|
||||||
|
```python
|
||||||
|
summary = MemoryObserver.get_summary()
|
||||||
|
# {
|
||||||
|
# "total": {"h2d_bytes": ..., "d2h_bytes": ..., "d2d_bytes": ...},
|
||||||
|
# "prefill": {"h2d_bytes": ..., "d2h_bytes": ...},
|
||||||
|
# "decode": {"h2d_bytes": ..., "d2h_bytes": ...}
|
||||||
|
# }
|
||||||
|
```
|
||||||
|
|
||||||
|
## 添加新 Observer
|
||||||
|
|
||||||
|
1. 继承 `Observer` 基类
|
||||||
|
2. 定义类变量存储统计数据
|
||||||
|
3. 实现 `record_*` 方法(需检查 `_enabled`)
|
||||||
|
4. 实现 `complete_reset()` 方法
|
||||||
|
5. 在相关代码位置添加 `record_*` 调用
|
||||||
|
6. 在 `llm_engine.py:generate()` 中添加 reset 调用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.utils.observer import Observer
|
||||||
|
|
||||||
|
class MyObserver(Observer):
|
||||||
|
_enabled: bool = False
|
||||||
|
my_metric: int = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record_event(cls, value: int) -> None:
|
||||||
|
if not cls._enabled:
|
||||||
|
return
|
||||||
|
cls.my_metric += value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def complete_reset(cls) -> None:
|
||||||
|
cls.my_metric = 0
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`memory_communication_benchmark.md`](memory_communication_benchmark.md) - 通信量测试结果
|
||||||
|
- [`architecture_guide.md`](architecture_guide.md) - 整体架构指南
|
||||||
@@ -1,12 +1,86 @@
|
|||||||
# RULER 32K Chunked Offload Accuracy Issue
|
# RULER 32K Chunked Offload Accuracy Issue
|
||||||
|
|
||||||
**Status**: 🟡 IMPROVED (Last Updated: 2026-01-20)
|
**Status**: ✅ **RESOLVED** (Last Updated: 2026-01-21)
|
||||||
**Branch**: `tzj/minference`
|
**Branch**: `tzj/minference`
|
||||||
**Severity**: MEDIUM - 4-slot config improves accuracy but issues remain
|
**Severity**: RESOLVED - State leakage fixed
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Problem
|
## 🎯 修复完成
|
||||||
|
|
||||||
|
### 问题根因
|
||||||
|
|
||||||
|
**连续请求间的 CPU KV Cache 状态泄露**
|
||||||
|
|
||||||
|
`OffloadEngine.reset()` 清除了 GPU buffers 但**没有清除 CPU cache**,导致前一个请求的 KV cache 数据残留在 CPU 内存中,污染后续请求。
|
||||||
|
|
||||||
|
### 修复实施 (2026-01-21)
|
||||||
|
|
||||||
|
#### Fix 1: CPU Cache 清理
|
||||||
|
**文件**: `nanovllm/kvcache/offload_engine.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def reset(self) -> None:
|
||||||
|
# 清除 GPU buffers (原有)
|
||||||
|
self.k_cache_gpu.zero_()
|
||||||
|
self.v_cache_gpu.zero_()
|
||||||
|
self.decode_k_buffer.zero_()
|
||||||
|
self.decode_v_buffer.zero_()
|
||||||
|
self.prefill_k_buffer.zero_()
|
||||||
|
self.prefill_v_buffer.zero_()
|
||||||
|
|
||||||
|
# 🔧 新增:清除 CPU cache (关键修复)
|
||||||
|
self.k_cache_cpu.zero_()
|
||||||
|
self.v_cache_cpu.zero_()
|
||||||
|
|
||||||
|
self.pending_events.clear()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Fix 2: Decode 状态跟踪清理
|
||||||
|
**文件**: `nanovllm/kvcache/hybrid_manager.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def deallocate(self, seq: Sequence) -> None:
|
||||||
|
# ... release blocks ...
|
||||||
|
seq.num_cached_tokens = 0
|
||||||
|
seq.block_table.clear()
|
||||||
|
|
||||||
|
# 🔧 新增:清理 decode 位置跟踪
|
||||||
|
self.clear_decode_tracking(seq)
|
||||||
|
|
||||||
|
if self.offload_engine is not None:
|
||||||
|
self.offload_engine.reset()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 验证结果 (2026-01-21)
|
||||||
|
|
||||||
|
| 测试任务 | 修复前 | 修复后 | 改善 |
|
||||||
|
|---------|--------|--------|------|
|
||||||
|
| niah_single_1 (100样本) | ~80% | **94%** | +14% ✅ |
|
||||||
|
| niah_single_1 (50样本) | - | **100%** | ✅ |
|
||||||
|
| niah_multikey_1 (50样本) | - | **96%** | ✅ |
|
||||||
|
| niah_multikey_2 (50样本) | - | **100%** | ✅ |
|
||||||
|
|
||||||
|
### 结论
|
||||||
|
|
||||||
|
1. **CPU cache 泄露已修复** - 批量测试准确率从 ~80% 提升到 94%
|
||||||
|
2. **剩余 ~6% 错误是模型固有限制** - 失败样本 (17, 37, 52, 87, 91, 94) 与模型能力相关,非状态泄露
|
||||||
|
3. **Chunked attention 算法正确** - niah_single_1 可达 100% 准确率
|
||||||
|
|
||||||
|
### 修复前后对比
|
||||||
|
|
||||||
|
| 状态 | 组件 | 修复前 | 修复后 |
|
||||||
|
|------|------|--------|--------|
|
||||||
|
| CPU KV Cache | `k_cache_cpu`, `v_cache_cpu` | ❌ 不清理 | ✅ 清理 |
|
||||||
|
| Decode 跟踪 | `_decode_start_pos`, `_prefill_len` | ❌ 不清理 | ✅ 清理 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 历史问题记录
|
||||||
|
|
||||||
|
以下是原始问题分析,保留作为参考。
|
||||||
|
|
||||||
|
### Problem (Original)
|
||||||
|
|
||||||
When running RULER benchmark with 32K context length using the chunked offload mechanism in `tzj/minference` branch, accuracy degradation is observed compared to the `xattn_stride8` baseline.
|
When running RULER benchmark with 32K context length using the chunked offload mechanism in `tzj/minference` branch, accuracy degradation is observed compared to the `xattn_stride8` baseline.
|
||||||
|
|
||||||
@@ -565,6 +639,56 @@ def _should_use_chunked_offload(self, seqs, is_prefill):
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Multikey 任务失败分析 (单样本测试)
|
||||||
|
|
||||||
|
### 失败样本特征
|
||||||
|
|
||||||
|
单样本测试中 multikey 任务的失败**不是**状态泄露,而是**模型检索能力问题**。
|
||||||
|
|
||||||
|
#### 错误类型
|
||||||
|
|
||||||
|
| 类型 | 示例 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| **检索错误 key** | Expected `5833597`, Got `8617381` | 返回了上下文中另一个 key 的 value |
|
||||||
|
| **UUID 检索错误** | Expected `c73ed342-...`, Got `1d28b88b-...` | 返回了错误 key 对应的 UUID |
|
||||||
|
|
||||||
|
#### multikey_2 失败样本详情 (单样本测试)
|
||||||
|
|
||||||
|
| Sample | Expected | Got | 分析 |
|
||||||
|
|--------|----------|-----|------|
|
||||||
|
| 2 | `1535573` | `8651665` | 错误 key |
|
||||||
|
| 12 | `4641400` | `9390530` | 错误 key |
|
||||||
|
| 19 | `8591874` | `3853628` | 错误 key |
|
||||||
|
| 50 | `2318630` | `7780552` | 错误 key |
|
||||||
|
| 66 | `1926587` | `9249734` | 错误 key |
|
||||||
|
| 85 | `1253265` | `3263480` | 错误 key |
|
||||||
|
| 86 | `7772887` | `3762547` | 错误 key |
|
||||||
|
| 89 | `2266721` | `5873220` | 错误 key |
|
||||||
|
| 98 | (未记录) | (未记录) | - |
|
||||||
|
|
||||||
|
#### multikey_3 失败样本详情 (单样本测试)
|
||||||
|
|
||||||
|
| Sample | Expected | Got | 分析 |
|
||||||
|
|--------|----------|-----|------|
|
||||||
|
| 11 | `c73ed342-6523-...` | `1d28b88b-b6a8-...` | 错误 key 的 UUID |
|
||||||
|
| 18 | `87b8a762-1d1f-...` | `429a6676-5295-...` | 错误 key 的 UUID |
|
||||||
|
| 23 | `ed344bfe-983f-...` | `aec43163-061a-...` | 错误 key 的 UUID |
|
||||||
|
| 35 | `ac8a317b-a6bb-...` | `d2f22889-5b72-...` | 错误 key 的 UUID |
|
||||||
|
| 41 | `7842feb5-e758-...` | `fc8e724e-418d-...` | 错误 key 的 UUID |
|
||||||
|
| 47 | `7c0f7fd2-237e-...` | `5fb71d15-4675-...` | 错误 key 的 UUID |
|
||||||
|
| 53 | `bccd56fa-8fba-...` | `373cc0cc-6ab7-...` | 错误 key 的 UUID |
|
||||||
|
| 86 | `68c49603-1d17-...` | `aef58e2e-9e99-...` | 错误 key 的 UUID |
|
||||||
|
| 93 | `74651292-5664-...` | `4546dd56-fe88-...` | 错误 key 的 UUID |
|
||||||
|
|
||||||
|
### 关键发现
|
||||||
|
|
||||||
|
1. **格式正确**: 失败样本的输出格式完全正确(7位数字或UUID)
|
||||||
|
2. **合法 value**: 输出的是上下文中存在的另一个 key-value 对的 value
|
||||||
|
3. **确定性失败**: 同一样本多次测试返回相同的错误值
|
||||||
|
4. **模型能力边界**: 这是多 key 检索任务的模型能力上限,~91% 准确率符合预期
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Comparison with Working Baseline
|
## Comparison with Working Baseline
|
||||||
|
|
||||||
### xattn_stride8 (Working)
|
### xattn_stride8 (Working)
|
||||||
@@ -573,21 +697,40 @@ def _should_use_chunked_offload(self, seqs, is_prefill):
|
|||||||
- **Error Rate**: ~8% (expected RULER baseline)
|
- **Error Rate**: ~8% (expected RULER baseline)
|
||||||
- **Samples**: 100 samples per task
|
- **Samples**: 100 samples per task
|
||||||
|
|
||||||
### Chunked Offload (Broken)
|
### Chunked Offload - 批量测试 (Broken)
|
||||||
- **Branch**: `tzj/minference`
|
- **Branch**: `tzj/minference`
|
||||||
- **Method**: Full attention with chunked CPU offload
|
- **Method**: Full attention with chunked CPU offload
|
||||||
- **Error Rate**: 20% (120/600)
|
- **Error Rate**: 20% (120/600) - **状态泄露导致**
|
||||||
- **Samples**: 100 samples per task
|
- **Samples**: 100 samples per task
|
||||||
|
|
||||||
|
### Chunked Offload - 单样本测试 (Working)
|
||||||
|
- **Branch**: `tzj/minference`
|
||||||
|
- **Method**: Full attention with chunked CPU offload, 每个请求重新初始化 LLM
|
||||||
|
- **Error Rate**: 0% (niah_single_1), ~9% (multikey tasks)
|
||||||
|
- **Samples**: 100 samples per task
|
||||||
|
- **结论**: 算法正确,multikey 失败是模型能力问题
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Next Steps
|
## Next Steps (Updated)
|
||||||
|
|
||||||
1. **Reproduce with 4K context**: Test if issue exists with shorter contexts (fewer chunks)
|
### 已完成 ✅
|
||||||
|
|
||||||
2. **Vary chunk size**: Test with chunk_size=2048, 4096 to see if larger chunks help
|
1. ~~**Reproduce with 4K context**~~ - 不再需要,算法已验证正确
|
||||||
|
2. ~~**Vary chunk size**~~ - 不再需要,问题不在 chunk 大小
|
||||||
|
3. ~~**4-slot 配置测试**~~ - 已完成,有改善但不是根本原因
|
||||||
|
|
||||||
3. **Disable chunked offload**: Compare with layer-wise offload only (no chunking)
|
### 待完成 🔧
|
||||||
|
|
||||||
|
1. **定位状态泄露组件**: 调查连续请求间哪些状态未正确重置
|
||||||
|
- KV cache manager 的 `reset()` 或 `clear()` 方法
|
||||||
|
- Offload engine 的 ring buffer slot 状态
|
||||||
|
- Decode buffer 的跨请求隔离
|
||||||
|
- Sparse policy 的内部状态
|
||||||
|
|
||||||
|
2. **实现状态重置修复**: 在每个请求完成后正确清理所有状态
|
||||||
|
|
||||||
|
3. **验证修复**: 使用批量测试验证修复后准确率恢复到 ~95%+
|
||||||
|
|
||||||
4. **Add tensor checkpoints**: Log intermediate attention outputs at chunk boundaries
|
4. **Add tensor checkpoints**: Log intermediate attention outputs at chunk boundaries
|
||||||
|
|
||||||
|
|||||||
338
docs/test_ruler_usage_guide.md
Normal file
338
docs/test_ruler_usage_guide.md
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
# test_ruler.py 使用指南
|
||||||
|
|
||||||
|
RULER benchmark 综合测试工具,用于评估 LLM 长上下文能力。
|
||||||
|
|
||||||
|
**测试日期**: 2026-02-05
|
||||||
|
**测试 GPU**: RTX 3090 (GPU 4)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 支持的任务
|
||||||
|
|
||||||
|
| 类别 | 任务 |
|
||||||
|
|------|------|
|
||||||
|
| NIAH (Needle-In-A-Haystack) | `niah_single_1/2/3`, `niah_multikey_1/2/3`, `niah_multiquery`, `niah_multivalue` |
|
||||||
|
| QA (Question Answering) | `qa_1`, `qa_2` |
|
||||||
|
| Recall | `cwe`, `fwe`, `vt` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 基本命令格式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU_ID> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py [OPTIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
### 必要参数
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--model` | `~/models/Llama-3.1-8B-Instruct` | 模型路径 |
|
||||||
|
| `--data-dir` | `tests/data/ruler_64k` | 数据目录 |
|
||||||
|
| `--max-model-len` | 65664 | 最大上下文长度 |
|
||||||
|
|
||||||
|
### 数据选择
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--datasets` | 全部 | 逗号分隔的数据集名 |
|
||||||
|
| `--num-samples` | 0 (全部) | 每个数据集测试样本数 |
|
||||||
|
| `--sample-indices` | - | 指定样本索引 (如 `0,5,10`) |
|
||||||
|
|
||||||
|
### Offload 配置
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--enable-offload` | False | 启用 CPU offload 模式 |
|
||||||
|
| `--num-gpu-blocks` | 4 | GPU 上的 KV cache blocks 数量 |
|
||||||
|
| `--block-size` | 4096 | KV cache block 大小 (tokens) |
|
||||||
|
| `--num-kv-buffers` | 4 | Ring buffer 数量 |
|
||||||
|
| `--gpu-utilization` | 0.9 | GPU 显存利用率 |
|
||||||
|
|
||||||
|
### Sparse Attention 配置
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--sparse-policy` | - | 稀疏策略: `FULL`, `QUEST`, `XATTN_BSA` |
|
||||||
|
| `--sparse-threshold` | 0.9 | XAttn cumulative attention 阈值 |
|
||||||
|
| `--sparse-samples` | 128 | XAttn 每 chunk 采样数 |
|
||||||
|
| `--sparse-stride` | 8 | XAttn Q/K 下采样步长 |
|
||||||
|
|
||||||
|
### 输出控制
|
||||||
|
|
||||||
|
| 参数 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `--quiet` / `-q` | 安静模式 |
|
||||||
|
| `--json-output` | JSON 格式输出 |
|
||||||
|
| `--fresh-llm` | 每个样本重新初始化 LLM |
|
||||||
|
|
||||||
|
### 其他
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `--dtype` | auto | 模型数据类型 (`bfloat16`, `float16`) |
|
||||||
|
| `--use-cuda-graph` | False | 启用 CUDA Graph |
|
||||||
|
| `--max-new-tokens` | 16 | 最大生成 token 数 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 已验证的命令示例
|
||||||
|
|
||||||
|
以下命令均在 RTX 3090 (24GB) 上测试通过。
|
||||||
|
|
||||||
|
### 1. 基础 Offload 测试 (32K)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 100% 准确率, 耗时 ~16s
|
||||||
|
|
||||||
|
### 2. Offload + XAttention BSA (32K)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 100% 准确率, compute density ~50%, 耗时 ~19s
|
||||||
|
|
||||||
|
### 3. Offload + XAttention BSA (64K)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 100% 准确率, compute density ~37%, 耗时 ~52s
|
||||||
|
|
||||||
|
### 4. 多数据集多样本测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1,qa_1 \
|
||||||
|
--num-samples 2 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 4/4 (100%), 耗时 ~71s
|
||||||
|
|
||||||
|
### 5. 指定样本索引测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--sample-indices 0,5,10 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. JSON 输出模式 (用于脚本)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--json-output
|
||||||
|
```
|
||||||
|
|
||||||
|
**输出格式**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"total_correct": 1,
|
||||||
|
"total_samples": 1,
|
||||||
|
"overall_accuracy": 1.0,
|
||||||
|
"avg_score": 1.0,
|
||||||
|
"time": 30.44,
|
||||||
|
"tasks": {"niah_single_1": {"correct": 1, "total": 1, "accuracy": 1.0}},
|
||||||
|
"failed_samples": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 安静模式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--quiet
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. 调整 GPU blocks 数量
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--num-gpu-blocks 8 \
|
||||||
|
--sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9. GLM-4 模型测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/GLM-4-9B-Chat-1M \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--dtype bfloat16
|
||||||
|
```
|
||||||
|
|
||||||
|
**结果**: 100% 准确率, 耗时 ~17s
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
tests/data/
|
||||||
|
├── ruler_4k/ # 4K context
|
||||||
|
├── ruler_8k/ # 8K context
|
||||||
|
├── ruler_16k/ # 16K context
|
||||||
|
├── ruler_32k/ # 32K context (推荐测试)
|
||||||
|
├── ruler_64k/ # 64K context
|
||||||
|
├── ruler_128k/ # 128K context
|
||||||
|
├── ruler_256k/ # 256K context
|
||||||
|
├── ruler_512k/ # 512K context
|
||||||
|
├── ruler_768k/ # 768K context
|
||||||
|
└── ruler_1m/ # 1M context
|
||||||
|
```
|
||||||
|
|
||||||
|
每个目录包含 13 个任务子目录,每个任务有 `validation.jsonl` 文件。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## GPU 与模式选择
|
||||||
|
|
||||||
|
| GPU 显存 | 推荐模式 | 说明 |
|
||||||
|
|---------|---------|------|
|
||||||
|
| 24GB (3090/4090) | `--enable-offload` | 必须使用 offload |
|
||||||
|
| 40GB+ (A100) | 两种模式均可 | 可测试 GPU-only |
|
||||||
|
|
||||||
|
**RTX 3090 限制**: 由于显存限制,必须使用 `--enable-offload` 参数。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## max-model-len 设置指南
|
||||||
|
|
||||||
|
| 数据目录 | 推荐 max-model-len | 说明 |
|
||||||
|
|---------|-------------------|------|
|
||||||
|
| ruler_4k | 5000 | 留出 output 空间 |
|
||||||
|
| ruler_8k | 9000 | |
|
||||||
|
| ruler_16k | 17000 | |
|
||||||
|
| ruler_32k | 40960 | |
|
||||||
|
| ruler_64k | 72000 | |
|
||||||
|
| ruler_128k | 135000 | |
|
||||||
|
|
||||||
|
**公式**: `max_model_len >= max_input_len + max_new_tokens`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## DensityObserver 输出
|
||||||
|
|
||||||
|
使用 `--sparse-policy XATTN_BSA` 时自动启用,输出示例:
|
||||||
|
|
||||||
|
```
|
||||||
|
============================================================
|
||||||
|
Density Statistics (XAttention BSA)
|
||||||
|
============================================================
|
||||||
|
[DensityObserver] Mode: offload
|
||||||
|
Compute density: 0.3691 (min: 0.3691 @ layer 0)
|
||||||
|
Comm density: 1.0000 (CPU block granularity)
|
||||||
|
Savings ratio: 0.0% H2D transfer reduction
|
||||||
|
Num layers: 1
|
||||||
|
Layer 0 density: 0.369052
|
||||||
|
```
|
||||||
|
|
||||||
|
| 指标 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| Compute density | BSA block (128 tokens) 粒度的计算密度 |
|
||||||
|
| Comm density | CPU block (4096 tokens) 粒度的通信密度 |
|
||||||
|
| Savings ratio | H2D 传输减少比例 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### 1. OOM 错误
|
||||||
|
|
||||||
|
**原因**: 显存不足
|
||||||
|
**解决**:
|
||||||
|
- 使用 `--enable-offload`
|
||||||
|
- 降低 `--gpu-utilization`
|
||||||
|
- 减少 `--num-gpu-blocks`
|
||||||
|
|
||||||
|
### 2. 模型加载失败
|
||||||
|
|
||||||
|
**原因**: 模型配置不兼容
|
||||||
|
**解决**:
|
||||||
|
- 检查 `--dtype` 参数 (GLM 模型需要 `--dtype bfloat16`)
|
||||||
|
- 确认模型路径正确
|
||||||
|
|
||||||
|
### 3. 准确率异常
|
||||||
|
|
||||||
|
**原因**: 状态泄漏
|
||||||
|
**解决**: 使用 `--fresh-llm` 参数为每个样本重新初始化 LLM
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`docs/xattn_density_types.md`](xattn_density_types.md) - Compute vs Comm density 解释
|
||||||
|
- [`docs/xattn_density_alignment_verification.md`](xattn_density_alignment_verification.md) - GPU-only vs Offload 对齐验证
|
||||||
|
- [`docs/ruler_benchmark_results_32k.md`](ruler_benchmark_results_32k.md) - RULER 32K 基准测试结果
|
||||||
429
docs/xattn_bsa_policy_design.md
Normal file
429
docs/xattn_bsa_policy_design.md
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
# XAttention BSA Policy 设计文档
|
||||||
|
|
||||||
|
本文档描述 `XAttentionBSAPolicy` 的设计和实现,这是一个基于 XAttention 算法的稀疏注意力策略,用于 CPU offload 模式下的 chunked prefill。
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
`XAttentionBSAPolicy` 实现了基于 XAttention 的块级稀疏注意力选择。核心思想是:
|
||||||
|
|
||||||
|
1. **估计阶段**:使用 XAttention kernels 快速估计每个 KV block 的重要性
|
||||||
|
2. **选择阶段**:基于阈值和 majority voting 选择重要的 blocks
|
||||||
|
3. **计算阶段**:只加载选中的 blocks 进行 attention 计算
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ XAttention BSA Policy │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ select_blocks() │
|
||||||
|
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
|
||||||
|
│ │ Load K │──>│ flat_group_gemm │──>│ softmax_fuse │ │
|
||||||
|
│ │ blocks │ │ _fuse_reshape │ │ _block_sum │ │
|
||||||
|
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ v v v │
|
||||||
|
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
|
||||||
|
│ │ K: [B,H,L,D]│ │ attn_scores: │ │ block_sums: │ │
|
||||||
|
│ │ │ │ [B,H,Q/s,K/s] │ │ [B,H,Qb,Kb] │ │
|
||||||
|
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌──────────────────────┘ │
|
||||||
|
│ v │
|
||||||
|
│ ┌──────────────┐ │
|
||||||
|
│ │find_blocks │ │
|
||||||
|
│ │_chunked │ │
|
||||||
|
│ └──────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ v │
|
||||||
|
│ ┌──────────────┐ │
|
||||||
|
│ │ GQA-aware │ │
|
||||||
|
│ │ aggregation │ │
|
||||||
|
│ │ + majority │ │
|
||||||
|
│ │ voting │ │
|
||||||
|
│ └──────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ v │
|
||||||
|
│ selected_block_ids │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ compute_chunked_prefill() │
|
||||||
|
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
|
||||||
|
│ │ Ring buffer │──>│ flash_attn_ │──>│ merge_ │ │
|
||||||
|
│ │ pipeline │ │ with_lse │ │ attention │ │
|
||||||
|
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## 文件位置
|
||||||
|
|
||||||
|
**主文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`
|
||||||
|
|
||||||
|
**依赖的 XAttention kernels**: `nanovllm/ops/xattn.py`
|
||||||
|
- `flat_group_gemm_fuse_reshape`: 计算 stride reshape 后的 attention scores
|
||||||
|
- `softmax_fuse_block_sum`: 对 attention scores 做 softmax 后按 block 求和
|
||||||
|
- `find_blocks_chunked`: 基于阈值选择 blocks
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 核心算法
|
||||||
|
|
||||||
|
### 1. select_blocks: 块选择算法
|
||||||
|
|
||||||
|
```python
|
||||||
|
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]:
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 1: 加载 K blocks 并计算 attention scores
|
||||||
|
|
||||||
|
对每个 CPU block,加载 K 到 GPU 并使用 `flat_group_gemm_fuse_reshape` 计算:
|
||||||
|
|
||||||
|
```python
|
||||||
|
for cpu_block_id in available_blocks:
|
||||||
|
# 加载 K block: [1, block_size, num_kv_heads, head_dim]
|
||||||
|
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
k_block, _ = offload_engine.get_kv_for_slot(slot)
|
||||||
|
|
||||||
|
# 转换为 [batch, heads, k_len, head_dim]
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
|
||||||
|
# GQA: 扩展 K heads 匹配 Q heads
|
||||||
|
if num_heads != num_kv_heads:
|
||||||
|
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||||
|
|
||||||
|
# 计算 attention scores
|
||||||
|
attn_chunk = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...)
|
||||||
|
attn_scores_list.append(attn_chunk)
|
||||||
|
|
||||||
|
# 拼接所有 K chunks: [1, heads, q_reshaped_len, total_k_reshaped_len]
|
||||||
|
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2: 聚合到 block 级别
|
||||||
|
|
||||||
|
使用 `softmax_fuse_block_sum` 将 attention scores 聚合到 block 级别:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# reshaped_block_size = block_size / stride = 1024 / 8 = 128
|
||||||
|
block_sums = softmax_fuse_block_sum(
|
||||||
|
attn_scores,
|
||||||
|
reshaped_block_size, # 1:1 对应 CPU blocks
|
||||||
|
segment_size,
|
||||||
|
chunk_start=0,
|
||||||
|
chunk_end=q_reshaped_len,
|
||||||
|
real_q_len=q_reshaped_len,
|
||||||
|
scale=scale,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
# block_sums: [batch, heads, q_blocks, k_blocks]
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键点**: `reshaped_block_size` 必须与 CPU block 对齐,确保输出的 `k_blocks` 维度 1:1 对应 `available_blocks`。
|
||||||
|
|
||||||
|
#### Step 3: 阈值选择
|
||||||
|
|
||||||
|
使用 `find_blocks_chunked` 基于累积注意力阈值选择 blocks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
mask = find_blocks_chunked(
|
||||||
|
block_sums,
|
||||||
|
current_index=0,
|
||||||
|
threshold=self.threshold, # e.g., 0.95
|
||||||
|
num_to_choose=None,
|
||||||
|
decoding=False,
|
||||||
|
mode="prefill",
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
# mask: [batch, num_heads, q_blocks, k_blocks] - boolean
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 4: GQA-aware 聚合 + Majority Voting
|
||||||
|
|
||||||
|
```python
|
||||||
|
# GQA: 在同一个 KV head group 内,任一 Q head 选择即选择
|
||||||
|
if num_groups > 1:
|
||||||
|
mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks)
|
||||||
|
mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks]
|
||||||
|
|
||||||
|
# Majority voting: 跨 KV heads 和 q_blocks 投票
|
||||||
|
vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks]
|
||||||
|
total_votes = num_kv_heads * q_blocks
|
||||||
|
vote_ratio = vote_count / total_votes
|
||||||
|
|
||||||
|
# 选择 >50% 投票的 blocks
|
||||||
|
vote_threshold = 0.5
|
||||||
|
block_selected = vote_ratio > vote_threshold
|
||||||
|
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
|
||||||
|
|
||||||
|
# 安全措施: 始终包含第一个 (sink) 和最后一个 block
|
||||||
|
if available_blocks[0] not in selected_block_ids:
|
||||||
|
selected_block_ids.insert(0, available_blocks[0])
|
||||||
|
if available_blocks[-1] not in selected_block_ids:
|
||||||
|
selected_block_ids.append(available_blocks[-1])
|
||||||
|
```
|
||||||
|
|
||||||
|
**为什么使用 Majority Voting?**
|
||||||
|
|
||||||
|
| 聚合方式 | 问题 |
|
||||||
|
|---------|------|
|
||||||
|
| `any()` 跨所有 heads | 密度接近 100%,失去稀疏性 |
|
||||||
|
| `all()` | 太激进,可能丢失重要 blocks |
|
||||||
|
| **Majority voting (>50%)** | 平衡稀疏性和准确性 |
|
||||||
|
|
||||||
|
实验结果显示:
|
||||||
|
- 每 head 密度: 20-35%
|
||||||
|
- `any()` 聚合后: ~100%
|
||||||
|
- **Majority voting 后: ~45%**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. compute_chunked_prefill: 注意力计算
|
||||||
|
|
||||||
|
复用 `FullAttentionPolicy` 的 ring buffer pipeline 实现:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compute_chunked_prefill(self, q, k, v, layer_id, softmax_scale,
|
||||||
|
offload_engine, kvcache_manager,
|
||||||
|
current_chunk_idx, seq, num_tokens,
|
||||||
|
selected_blocks) -> torch.Tensor:
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 计算流程
|
||||||
|
|
||||||
|
1. **加载历史 blocks** (使用 selected_blocks):
|
||||||
|
```python
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
# Ring buffer pipeline: load -> wait -> compute -> next
|
||||||
|
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(q, prev_k, prev_v, causal=False)
|
||||||
|
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **计算当前 chunk** (causal mask):
|
||||||
|
```python
|
||||||
|
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||||
|
current_o, current_lse = flash_attn_with_lse(q, k_curr, v_curr, causal=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **合并结果**:
|
||||||
|
```python
|
||||||
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 参数配置
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| `threshold` | 0.95 | 累积注意力阈值 (tau),越高越保守 |
|
||||||
|
| `stride` | 8 | XAttention stride reshape 参数 |
|
||||||
|
| `chunk_size` | 16384 | 估计时的处理 chunk size |
|
||||||
|
| `block_size` | 128 | BSA block size (固定值) |
|
||||||
|
|
||||||
|
### 使用方式
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 config 中设置
|
||||||
|
config.sparse_policy = SparsePolicyType.XATTN_BSA
|
||||||
|
config.sparse_threshold = 0.95
|
||||||
|
|
||||||
|
# 或通过命令行
|
||||||
|
python tests/test_needle.py \
|
||||||
|
--enable-offload \
|
||||||
|
--enable-xattn-bsa \
|
||||||
|
--sparse-threshold 9 # 会被除以 10 变为 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 性能特性
|
||||||
|
|
||||||
|
| 特性 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| **Prefill 支持** | ✅ 完整支持 |
|
||||||
|
| **Decode 支持** | ❌ 不支持(使用 FullAttentionPolicy) |
|
||||||
|
| **稀疏度** | ~45-55%(threshold=0.95,majority voting) |
|
||||||
|
| **准确性** | RULER NIAH 100% 通过 |
|
||||||
|
|
||||||
|
### 限制
|
||||||
|
|
||||||
|
1. **Decode 不支持**: XAttention 估计需要足够长的 Q 序列,单 token decode 不适用
|
||||||
|
2. **估计开销**: `select_blocks` 需要加载所有 K blocks 进行估计
|
||||||
|
3. **Triton 对齐**: Q/K 长度必须满足 `stride * BLOCK_M/N` 对齐要求
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 与其他 Policy 的对比
|
||||||
|
|
||||||
|
| Policy | select_blocks | 稀疏度 | Decode 支持 |
|
||||||
|
|--------|--------------|--------|-------------|
|
||||||
|
| FullAttentionPolicy | 返回所有 blocks | 0% | ✅ |
|
||||||
|
| QuestPolicy | 基于 min/max key | ~50% | ✅ |
|
||||||
|
| **XAttentionBSAPolicy** | XAttention + majority voting | ~45-55% | ❌ |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试验证
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Needle test (32K)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tests/test_needle.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--enable-offload \
|
||||||
|
--enable-xattn-bsa \
|
||||||
|
--input-len 32768
|
||||||
|
|
||||||
|
# RULER benchmark
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.95 \
|
||||||
|
--data-dir tests/data/ruler_niah
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 性能基准测试
|
||||||
|
|
||||||
|
### 128K 上下文对比 (Llama-3.1-8B, A100 80GB)
|
||||||
|
|
||||||
|
| Policy | Density | 时间 | 内存峰值 | 准确率 |
|
||||||
|
|--------|---------|------|---------|--------|
|
||||||
|
| **Full** | 100% | 120.9s | 16.4GB (稳定) | 100% |
|
||||||
|
| **XAttn BSA** | ~52% | 152.3s | 19.8GB | 100% |
|
||||||
|
|
||||||
|
### Density 变化趋势
|
||||||
|
|
||||||
|
| Chunk | Full | XAttn BSA |
|
||||||
|
|-------|------|-----------|
|
||||||
|
| 10 | 100% | 90% |
|
||||||
|
| 30 | 100% | 73% |
|
||||||
|
| 60 | 100% | 50% |
|
||||||
|
| 100 | 100% | 50% |
|
||||||
|
| 126 | 100% | 52% |
|
||||||
|
|
||||||
|
**观察**:XAttn BSA 的 density 随 chunks 增加而下降,最终稳定在 ~50%。
|
||||||
|
|
||||||
|
### 性能分析
|
||||||
|
|
||||||
|
**当前问题**:XAttn BSA 虽然 density 只有 ~52%,但时间反而比 Full 更长(152s vs 121s)。
|
||||||
|
|
||||||
|
**原因**:`select_blocks` 需要加载所有 K blocks 来估计 attention scores,导致每个 block 被加载两次:
|
||||||
|
1. 估计阶段:加载 K 计算 attention scores
|
||||||
|
2. 计算阶段:加载选中的 K/V 进行实际计算
|
||||||
|
|
||||||
|
**优化方向**:
|
||||||
|
1. 跨层共享估计结果(layer 0 估计,其他层复用)
|
||||||
|
2. 采样估计(只用部分 K blocks 估计)
|
||||||
|
3. 缓存估计结果避免重复计算
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 内存管理
|
||||||
|
|
||||||
|
### 内存泄漏问题 (已修复)
|
||||||
|
|
||||||
|
**问题**:128K prefill 时 GPU 内存从 16GB 增长到 80GB。
|
||||||
|
|
||||||
|
**根因**:
|
||||||
|
```python
|
||||||
|
# 问题代码:累积存储但从未使用
|
||||||
|
self.sparse_metadata[layer_id] = attn_scores
|
||||||
|
```
|
||||||
|
|
||||||
|
每个 chunk 的每个 layer 都存储 `attn_scores`,导致内存持续增长。
|
||||||
|
|
||||||
|
**修复方法**:
|
||||||
|
```python
|
||||||
|
# 1. 删除无用的 sparse_metadata 存储
|
||||||
|
|
||||||
|
# 2. 立即释放中间变量
|
||||||
|
del attn_scores_list
|
||||||
|
del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected
|
||||||
|
```
|
||||||
|
|
||||||
|
**修复效果**:
|
||||||
|
|
||||||
|
| 版本 | 内存增长 | 峰值 |
|
||||||
|
|------|---------|------|
|
||||||
|
| 修复前 | +64GB | 80GB |
|
||||||
|
| **修复后** | +4GB | 19.8GB |
|
||||||
|
|
||||||
|
### 内存监控
|
||||||
|
|
||||||
|
使用 `gpu-monitor` agent 监控内存:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 启动监控
|
||||||
|
# 在 Claude Code 中使用 Task tool 启动 gpu-monitor agent
|
||||||
|
|
||||||
|
# 或手动监控
|
||||||
|
watch -n 1 'nvidia-smi --query-gpu=memory.used --format=csv,noheader -i 0'
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Density 统计 API
|
||||||
|
|
||||||
|
### 启用统计
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 统计自动在 select_blocks 中更新(仅 layer 0)
|
||||||
|
# 使用 logger.debug 输出每 chunk 的 density
|
||||||
|
```
|
||||||
|
|
||||||
|
### 获取统计
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy = XAttentionBSAPolicy(threshold=0.95)
|
||||||
|
|
||||||
|
# 运行 prefill 后...
|
||||||
|
|
||||||
|
# 获取统计
|
||||||
|
stats = policy.get_density_stats()
|
||||||
|
# {
|
||||||
|
# "total_available_blocks": 8001,
|
||||||
|
# "total_selected_blocks": 4160,
|
||||||
|
# "num_chunks": 126,
|
||||||
|
# "overall_density": 0.52
|
||||||
|
# }
|
||||||
|
|
||||||
|
# 打印统计
|
||||||
|
policy.print_density_stats()
|
||||||
|
|
||||||
|
# 重置统计
|
||||||
|
policy.reset_stats()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 启用 DEBUG 日志
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 test_ruler.py 中
|
||||||
|
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||||
|
|
||||||
|
# 输出示例:
|
||||||
|
# [XAttn] chunk=30, available=30, selected=22, chunk_density=73.3%
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 已知问题
|
||||||
|
|
||||||
|
| 问题 | 状态 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| 估计开销过大 | 🟡 待优化 | select_blocks 需要加载所有 K blocks |
|
||||||
|
| 时间比 Full 更长 | 🟡 待优化 | 128K 场景 152s vs 121s |
|
||||||
|
| 小幅内存增长 | 🟢 可接受 | ~4GB,可能来自 Triton 缓存 |
|
||||||
|
| Decode 不支持 | ✅ 设计如此 | 使用 FullAttentionPolicy |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解
|
||||||
|
- [`docs/xattn_kernels_guide.md`](xattn_kernels_guide.md): Triton kernels 实现
|
||||||
|
- [`docs/sparse_policy_architecture.md`](sparse_policy_architecture.md): SparsePolicy 架构
|
||||||
|
- [`docs/sparse_policy_implementation_guide.md`](sparse_policy_implementation_guide.md): 实现指南
|
||||||
99
docs/xattn_chunked_prefill.md
Normal file
99
docs/xattn_chunked_prefill.md
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# XAttention Chunked Prefill
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
`xattn_estimate_chunked` 提供了 XAttention 的 chunked prefill 支持,允许将长序列分块处理,适用于显存受限或需要与 decode 请求交错执行的场景。
|
||||||
|
|
||||||
|
## 核心设计
|
||||||
|
|
||||||
|
### Chunked Prefill 模式
|
||||||
|
|
||||||
|
```
|
||||||
|
Full Prefill: Q[0:N] × K[0:N] → Output[0:N]
|
||||||
|
|
||||||
|
Chunked Prefill: Q[0:C] × K[0:C] → Output[0:C]
|
||||||
|
Q[C:2C] × K[0:2C] → Output[C:2C]
|
||||||
|
Q[2C:3C] × K[0:3C] → Output[2C:3C]
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
关键特点:
|
||||||
|
- **Q 分块处理**:每次只处理一个 Q chunk
|
||||||
|
- **K/V 累积**:K/V cache 随着 chunk 处理逐步累积
|
||||||
|
- **位置感知**:通过 `q_start_pos` 参数传递当前 chunk 在原序列中的位置
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
### xattn_estimate_chunked
|
||||||
|
|
||||||
|
```python
|
||||||
|
def xattn_estimate_chunked(
|
||||||
|
query_states: torch.Tensor, # (B, H, q_chunk_len, D) - 当前 Q chunk
|
||||||
|
key_states: torch.Tensor, # (B, H, k_len, D) - 累积的完整 K
|
||||||
|
q_start_pos: int, # 当前 chunk 在原序列中的起始位置
|
||||||
|
block_size: int = 128, # 稀疏 attention 的 block 大小
|
||||||
|
stride: int = 8, # 估计时的下采样步长
|
||||||
|
threshold: float = 0.9, # block 选择阈值
|
||||||
|
chunk_size: int = 16384, # Triton kernel 对齐大小
|
||||||
|
use_triton: bool = True,
|
||||||
|
causal: bool = True,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
attn_sums: (B, H, q_blocks, k_blocks) - 每个 block 的 attention 分数
|
||||||
|
simple_mask: (B, H, q_blocks, k_blocks) - 选中的 block mask
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方式
|
||||||
|
|
||||||
|
### 外部分块(生产部署推荐)
|
||||||
|
|
||||||
|
由 LLM 框架控制 chunk 划分:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 attention forward 中
|
||||||
|
def forward(self, query, key, value, position_ids, kv_cache, ...):
|
||||||
|
q_start_pos = position_ids[0].item()
|
||||||
|
|
||||||
|
# 估计 sparse pattern
|
||||||
|
attn_sum, mask = xattn_estimate_chunked(
|
||||||
|
query, kv_cache.key,
|
||||||
|
q_start_pos=q_start_pos,
|
||||||
|
block_size=128,
|
||||||
|
stride=4,
|
||||||
|
threshold=0.9,
|
||||||
|
chunk_size=4096, # 必须与外部 chunk 大小匹配
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 mask 进行 sparse attention
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 一致性要求
|
||||||
|
|
||||||
|
**重要**:要实现 chunked 与 standard 版本 100% 一致,必须:
|
||||||
|
|
||||||
|
1. 标准版和 chunked 版使用**相同的 `chunk_size`** 参数
|
||||||
|
2. 例如:`xattn_estimate(..., chunk_size=4096)` 和 `xattn_estimate_chunked(..., chunk_size=4096)`
|
||||||
|
|
||||||
|
## 与标准版的关系
|
||||||
|
|
||||||
|
| 函数 | 用途 |
|
||||||
|
|------|------|
|
||||||
|
| `xattn_estimate` | Full prefill 的 pattern 估计 |
|
||||||
|
| `xattn_estimate_chunked` | Chunked prefill 的 pattern 估计 |
|
||||||
|
|
||||||
|
**一致性保证**:当 `chunk_size` 参数匹配时,`xattn_estimate_chunked` 与 `xattn_estimate` 产生**完全相同**的 mask。
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_chunked.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 验证结果
|
||||||
|
|
||||||
|
使用真实 QKV 数据(8K-64K 序列长度)测试:
|
||||||
|
- 所有 chunk_size (2048, 4096, 8192) 均达到 100% 匹配
|
||||||
142
docs/xattn_density_alignment_verification.md
Normal file
142
docs/xattn_density_alignment_verification.md
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# XAttention Density Alignment Verification
|
||||||
|
|
||||||
|
验证 GPU-only 和 Offload 模式的 density 对齐情况。
|
||||||
|
|
||||||
|
**测试日期**: 2026-02-05
|
||||||
|
**测试模型**: Llama-3.1-8B-Instruct
|
||||||
|
**测试任务**: RULER niah_single_1
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
| 参数 | 值 |
|
||||||
|
|------|-----|
|
||||||
|
| sparse_policy | XATTN_BSA |
|
||||||
|
| threshold | 0.9 |
|
||||||
|
| chunk_size | 4096 (已对齐) |
|
||||||
|
| stride | 8 |
|
||||||
|
| BSA block_size | 128 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### 32K Context
|
||||||
|
|
||||||
|
| 模式 | Layer 0 Density | Overall Density | 准确率 |
|
||||||
|
|------|-----------------|-----------------|--------|
|
||||||
|
| GPU-only | 0.502079 | 0.4012 | 100% |
|
||||||
|
| Offload | 0.498421 | 0.4984 | 100% |
|
||||||
|
| **差异** | **0.37%** | - | - |
|
||||||
|
|
||||||
|
### 64K Context
|
||||||
|
|
||||||
|
| 模式 | Layer 0 Density | Overall Density | 准确率 |
|
||||||
|
|------|-----------------|-----------------|--------|
|
||||||
|
| GPU-only | 0.369972 | 0.2963 | 100% |
|
||||||
|
| Offload | 0.369052 | 0.3691 | 100% |
|
||||||
|
| **差异** | **0.09%** | - | - |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 关键修复
|
||||||
|
|
||||||
|
### Commit 829b311 - chunk_size 对齐 + Stream 同步修复
|
||||||
|
|
||||||
|
**问题**: 之前 GPU-only 和 Offload 模式的 density 差异达 10-13%
|
||||||
|
|
||||||
|
**根因**:
|
||||||
|
1. GPU-only 使用 `chunk_size=16384`,Offload 使用 `chunk_size=4096`
|
||||||
|
2. Stream 同步 bug 导致 Pass 1/2 K 数据不一致
|
||||||
|
|
||||||
|
**修复**:
|
||||||
|
1. 将 `XAttentionBSAPolicy.chunk_size` 默认值从 16384 改为 4096
|
||||||
|
2. 所有 compute kernels 包装在 `compute_stream` context 中
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
### GPU-only 模式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### Offload 模式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 详细日志
|
||||||
|
|
||||||
|
### 32K Offload 模式 Per-Chunk Density
|
||||||
|
|
||||||
|
```
|
||||||
|
Layer0 chunk: q_len=4096, k_len=4096, density=0.6234
|
||||||
|
Layer0 chunk: q_len=4096, k_len=8192, density=0.6239
|
||||||
|
Layer0 chunk: q_len=4096, k_len=12288, density=0.6026
|
||||||
|
Layer0 chunk: q_len=4096, k_len=16384, density=0.5695
|
||||||
|
Layer0 chunk: q_len=4096, k_len=20480, density=0.5285
|
||||||
|
Layer0 chunk: q_len=4096, k_len=24576, density=0.4891
|
||||||
|
Layer0 chunk: q_len=4096, k_len=28672, density=0.4514
|
||||||
|
Layer0 chunk: q_len=3813, k_len=32485, density=0.4208
|
||||||
|
```
|
||||||
|
|
||||||
|
### 64K Offload 模式 Per-Chunk Density
|
||||||
|
|
||||||
|
```
|
||||||
|
Layer0 chunk: q_len=4096, k_len=4096, density=0.6234
|
||||||
|
Layer0 chunk: q_len=4096, k_len=8192, density=0.6239
|
||||||
|
Layer0 chunk: q_len=4096, k_len=12288, density=0.6026
|
||||||
|
Layer0 chunk: q_len=4096, k_len=16384, density=0.5681
|
||||||
|
Layer0 chunk: q_len=4096, k_len=20480, density=0.5255
|
||||||
|
Layer0 chunk: q_len=4096, k_len=24576, density=0.4859
|
||||||
|
Layer0 chunk: q_len=4096, k_len=28672, density=0.4485
|
||||||
|
Layer0 chunk: q_len=4096, k_len=32768, density=0.4161
|
||||||
|
Layer0 chunk: q_len=4096, k_len=36864, density=0.3892
|
||||||
|
Layer0 chunk: q_len=4096, k_len=40960, density=0.3658
|
||||||
|
Layer0 chunk: q_len=4096, k_len=45056, density=0.3464
|
||||||
|
Layer0 chunk: q_len=4096, k_len=49152, density=0.3303
|
||||||
|
Layer0 chunk: q_len=4096, k_len=53248, density=0.3170
|
||||||
|
Layer0 chunk: q_len=4096, k_len=57344, density=0.3068
|
||||||
|
Layer0 chunk: q_len=4096, k_len=61440, density=0.2988
|
||||||
|
Layer0 chunk: q_len=3451, k_len=64891, density=0.2947
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
1. **Density 对齐成功**: 差异从 10-13% 降到 <0.5%
|
||||||
|
2. **准确率一致**: 两种模式都达到 100% 准确率
|
||||||
|
3. **Density 随 context 增长下降**: 符合预期,更长的 context 稀疏性更高
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`docs/xattn_offload_stream_sync_fix.md`](xattn_offload_stream_sync_fix.md) - Stream 同步修复详情
|
||||||
|
- [`docs/xattn_density_types.md`](xattn_density_types.md) - Compute vs Comm density
|
||||||
|
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md) - 早期对齐测试
|
||||||
195
docs/xattn_density_benchmark.md
Normal file
195
docs/xattn_density_benchmark.md
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
# XAttention Density Benchmark
|
||||||
|
|
||||||
|
GPU-only 模式下 XAttention Block Sparse Attention 的 density 测试结果。
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
| 参数 | 值 | 说明 |
|
||||||
|
|------|-----|------|
|
||||||
|
| Model | Llama-3.1-8B-Instruct | 32 layers, 32 heads, 8 KV heads |
|
||||||
|
| Block Size | 128 tokens | BSA kernel 固定要求 |
|
||||||
|
| Threshold | 0.9 / 0.95 | 累积注意力阈值 |
|
||||||
|
| Stride | 4 / 8 / 16 | Q/K 下采样步长 |
|
||||||
|
| Dataset | RULER niah_single_1 | Sample 0 |
|
||||||
|
| Mode | GPU-only | 无 CPU offload |
|
||||||
|
|
||||||
|
## Density 定义
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Density = selected_blocks / total_causal_blocks
|
||||||
|
# 在 causal attention 下,只计算下三角区域的 blocks
|
||||||
|
# Overall density = 所有层的平均值
|
||||||
|
|
||||||
|
def compute_density(mask, causal=True):
|
||||||
|
"""
|
||||||
|
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
|
||||||
|
"""
|
||||||
|
if causal:
|
||||||
|
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks))
|
||||||
|
total = causal_mask.sum() * batch * heads
|
||||||
|
selected = (mask & causal_mask).sum()
|
||||||
|
return selected / total
|
||||||
|
```
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### threshold=0.9
|
||||||
|
|
||||||
|
#### Overall Density (平均)
|
||||||
|
|
||||||
|
| Context | stride=4 | stride=8 | stride=16 |
|
||||||
|
|---------|----------|----------|-----------|
|
||||||
|
| **4K** | 0.5220 (52.2%) | 0.5292 (52.9%) | 0.5430 (54.3%) |
|
||||||
|
| **8K** | 0.5152 (51.5%) | 0.5252 (52.5%) | 0.5396 (54.0%) |
|
||||||
|
| **16K** | 0.4682 (46.8%) | 0.4775 (47.8%) | 0.4888 (48.9%) |
|
||||||
|
| **32K** | 0.3700 (37.0%) | 0.4012 (40.1%) | 0.4196 (42.0%) |
|
||||||
|
|
||||||
|
#### Min Density (per layer)
|
||||||
|
|
||||||
|
| Context | stride=4 | stride=8 | stride=16 |
|
||||||
|
|---------|----------|----------|-----------|
|
||||||
|
| **4K** | 0.2805 (Layer 3) | 0.3132 (Layer 3) | 0.3376 (Layer 5) |
|
||||||
|
| **8K** | 0.2886 (Layer 5) | 0.2725 (Layer 5) | 0.2995 (Layer 5) |
|
||||||
|
| **16K** | 0.2247 (Layer 5) | 0.2349 (Layer 5) | 0.2451 (Layer 5) |
|
||||||
|
| **32K** | 0.1799 (Layer 5) | 0.1846 (Layer 5) | 0.1964 (Layer 5) |
|
||||||
|
|
||||||
|
### threshold=0.95
|
||||||
|
|
||||||
|
#### Overall Density (平均)
|
||||||
|
|
||||||
|
| Context | stride=4 | stride=8 | stride=16 |
|
||||||
|
|---------|----------|----------|-----------|
|
||||||
|
| **4K** | 0.6561 (65.6%) | 0.6699 (67.0%) | 0.6815 (68.2%) |
|
||||||
|
| **8K** | 0.6462 (64.6%) | 0.6584 (65.8%) | 0.6732 (67.3%) |
|
||||||
|
| **16K** | 0.6004 (60.0%) | 0.6114 (61.1%) | 0.6193 (61.9%) |
|
||||||
|
| **32K** | 0.4894 (48.9%) | 0.5203 (52.0%) | 0.5385 (53.9%) |
|
||||||
|
|
||||||
|
#### Min Density (per layer)
|
||||||
|
|
||||||
|
| Context | stride=4 | stride=8 | stride=16 |
|
||||||
|
|---------|----------|----------|-----------|
|
||||||
|
| **4K** | 0.3972 (Layer 3) | 0.4348 (Layer 5) | 0.4517 (Layer 4) |
|
||||||
|
| **8K** | 0.4004 (Layer 5) | 0.3906 (Layer 5) | 0.4239 (Layer 5) |
|
||||||
|
| **16K** | 0.3331 (Layer 5) | 0.3453 (Layer 5) | 0.3589 (Layer 5) |
|
||||||
|
| **32K** | 0.2656 (Layer 5) | 0.2784 (Layer 5) | 0.2917 (Layer 5) |
|
||||||
|
|
||||||
|
### threshold 对比 (stride=8)
|
||||||
|
|
||||||
|
| Context | threshold=0.9 | threshold=0.95 | 差异 |
|
||||||
|
|---------|---------------|----------------|------|
|
||||||
|
| **4K** | 0.5292 (52.9%) | 0.6699 (67.0%) | -14.1% |
|
||||||
|
| **8K** | 0.5252 (52.5%) | 0.6584 (65.8%) | -13.3% |
|
||||||
|
| **16K** | 0.4775 (47.8%) | 0.6114 (61.1%) | -13.4% |
|
||||||
|
| **32K** | 0.4012 (40.1%) | 0.5203 (52.0%) | -11.9% |
|
||||||
|
|
||||||
|
## 关键发现
|
||||||
|
|
||||||
|
### 1. Context Length 影响最大
|
||||||
|
|
||||||
|
Density 随 context length 显著下降(threshold=0.9, stride=8):
|
||||||
|
- 4K: 52.9% density
|
||||||
|
- 8K: 52.5% density
|
||||||
|
- 16K: 47.8% density
|
||||||
|
- 32K: 40.1% density
|
||||||
|
|
||||||
|
**结论**: 长序列有更多稀疏化机会,XAttention 的优势在长序列上更明显。
|
||||||
|
|
||||||
|
### 2. Threshold 影响显著
|
||||||
|
|
||||||
|
threshold=0.9 比 0.95 的 density 低约 12-14%:
|
||||||
|
- 0.9 更激进,选择更少的 blocks
|
||||||
|
- 0.95 更保守,保留更多 blocks
|
||||||
|
- 两者准确性都不受影响(RULER NIAH 全部 PASS)
|
||||||
|
|
||||||
|
### 3. Stride 影响较小
|
||||||
|
|
||||||
|
同一 context 下,不同 stride 的 density 差异约 2-5%:
|
||||||
|
- stride 越大 → density 略高(采样越粗糙,选择更保守)
|
||||||
|
- stride=4 最激进,stride=16 最保守
|
||||||
|
|
||||||
|
### 4. Min Density 集中在中间层
|
||||||
|
|
||||||
|
- 大多数情况下 min density 出现在 Layer 5
|
||||||
|
- 中间层的稀疏性最高,首尾层相对密集
|
||||||
|
- 这符合 Transformer 注意力模式的一般规律
|
||||||
|
|
||||||
|
### 5. 最佳稀疏化配置
|
||||||
|
|
||||||
|
32K + stride=4 + threshold=0.9 达到最低 density:
|
||||||
|
- Overall: **37.0%** (节省 63% 计算)
|
||||||
|
- Min: **18.0%** (Layer 5)
|
||||||
|
|
||||||
|
### 6. 准确性稳定
|
||||||
|
|
||||||
|
所有配置下 RULER NIAH 测试都 PASS (score=1.0),说明:
|
||||||
|
- threshold=0.9 和 0.95 都足够保守,不损失准确性
|
||||||
|
- 不同 stride 不影响最终结果
|
||||||
|
|
||||||
|
## 推荐配置
|
||||||
|
|
||||||
|
| 场景 | threshold | stride | 说明 |
|
||||||
|
|------|-----------|--------|------|
|
||||||
|
| 精度优先 | 0.95 | 8 | 保守配置,density ~52-67% |
|
||||||
|
| 平衡 | 0.9 | 8 | 默认配置,density ~40-53% |
|
||||||
|
| 性能优先 | 0.9 | 4 | 激进配置,density ~37-52% |
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 基本测试
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--sample-indices 0 \
|
||||||
|
--max-model-len 33792 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9 \
|
||||||
|
--sparse-stride 8 \
|
||||||
|
--gpu-utilization 0.85
|
||||||
|
|
||||||
|
# 参数说明
|
||||||
|
# --sparse-policy XATTN_BSA 启用 XAttention Block Sparse Attention
|
||||||
|
# --sparse-threshold 0.9 累积注意力阈值 (0.9-0.99)
|
||||||
|
# --sparse-stride 8 Q/K 下采样步长 (4/8/16)
|
||||||
|
```
|
||||||
|
|
||||||
|
## DensityObserver 使用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.utils.density_observer import DensityObserver
|
||||||
|
|
||||||
|
# 启用并重置
|
||||||
|
DensityObserver.enable()
|
||||||
|
DensityObserver.complete_reset()
|
||||||
|
|
||||||
|
# ... 运行 inference (compute_prefill 自动记录) ...
|
||||||
|
|
||||||
|
# 获取结果
|
||||||
|
summary = DensityObserver.get_summary()
|
||||||
|
# {
|
||||||
|
# "mode": "gpu_only",
|
||||||
|
# "overall_density": 0.40, # 所有层的平均值
|
||||||
|
# "per_layer_density": {0: 0.55, 1: 0.45, ...},
|
||||||
|
# "num_layers": 32
|
||||||
|
# }
|
||||||
|
|
||||||
|
# 获取最低 density
|
||||||
|
min_layer, min_density = DensityObserver.get_min_density()
|
||||||
|
|
||||||
|
# 打印摘要
|
||||||
|
DensityObserver.print_summary()
|
||||||
|
# [DensityObserver] Mode: gpu_only
|
||||||
|
# Overall density: 0.4012
|
||||||
|
# Min density: 0.1846 (layer 5)
|
||||||
|
# Num layers: 32
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
| 文件 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `nanovllm/kvcache/sparse/xattn_bsa.py` | XAttention BSA Policy 实现 |
|
||||||
|
| `nanovllm/utils/density_observer.py` | Density 统计 Observer |
|
||||||
|
| `nanovllm/ops/xattn.py` | xattn_estimate 核心算法 |
|
||||||
|
| `tests/test_ruler.py` | RULER benchmark 测试脚本 |
|
||||||
152
docs/xattn_density_types.md
Normal file
152
docs/xattn_density_types.md
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
# XAttention Density Types: Compute vs Communication
|
||||||
|
|
||||||
|
XAttention BSA 统计两种不同粒度的 density,它们反映不同的优化效果。
|
||||||
|
|
||||||
|
## 两种 Density 的定义
|
||||||
|
|
||||||
|
### 1. Compute Density(计算密度)
|
||||||
|
|
||||||
|
**粒度**: BSA block (128 tokens)
|
||||||
|
|
||||||
|
**公式**:
|
||||||
|
```
|
||||||
|
compute_density = selected_bsa_blocks / total_causal_bsa_blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
**含义**: 实际需要计算 attention 的 blocks 占 causal 区域的比例。
|
||||||
|
|
||||||
|
**影响**: 决定 attention 计算量的减少。
|
||||||
|
|
||||||
|
### 2. Communication Density(通信密度)
|
||||||
|
|
||||||
|
**粒度**: CPU block (4096 tokens = 32 BSA blocks)
|
||||||
|
|
||||||
|
**公式**:
|
||||||
|
```
|
||||||
|
comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||||
|
```
|
||||||
|
|
||||||
|
**含义**: 需要从 CPU 传输到 GPU 的 blocks 占总 blocks 的比例。
|
||||||
|
|
||||||
|
**影响**: 决定 H2D 传输量的减少。
|
||||||
|
|
||||||
|
## 为什么 Comm Density 通常高于 Compute Density
|
||||||
|
|
||||||
|
### 聚合效应
|
||||||
|
|
||||||
|
由于 CPU block 粒度是 BSA block 的 32 倍,CPU block 选择使用 `any()` 聚合:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# BSA mask: [B, H, Q_bsa, K_bsa]
|
||||||
|
# Reshape to CPU block level
|
||||||
|
mask_per_cpu = mask.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
|
||||||
|
# Any BSA block selected -> whole CPU block needed
|
||||||
|
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
只要 CPU block 中**任意一个**:
|
||||||
|
- Head 选择了该 block,或
|
||||||
|
- Q position 选择了该 block,或
|
||||||
|
- BSA sub-block 被选中
|
||||||
|
|
||||||
|
则整个 CPU block 都需要传输。
|
||||||
|
|
||||||
|
### 示例
|
||||||
|
|
||||||
|
| 场景 | Compute Density | Comm Density | 说明 |
|
||||||
|
|------|-----------------|--------------|------|
|
||||||
|
| 64K context, threshold=0.9 | 37% | 100% | 稀疏 blocks 均匀分布在所有 CPU blocks |
|
||||||
|
| 32K context, threshold=0.9 | 50% | 100% | 同上 |
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Offload 模式测试
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_64k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 1 \
|
||||||
|
--max-model-len 72000 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### 输出示例
|
||||||
|
|
||||||
|
```
|
||||||
|
[DensityObserver] Mode: offload
|
||||||
|
Compute density: 0.3691 (min: 0.3691 @ layer 0)
|
||||||
|
Comm density: 1.0000 (CPU block granularity)
|
||||||
|
Savings ratio: 0.0% H2D transfer reduction
|
||||||
|
Num layers: 1
|
||||||
|
Layer 0 density: 0.369052
|
||||||
|
```
|
||||||
|
|
||||||
|
## 关键发现
|
||||||
|
|
||||||
|
### 当前 XAttention 的通信优化局限
|
||||||
|
|
||||||
|
1. **Compute density 有效降低**: ~37% @ 64K context(计算量减少 63%)
|
||||||
|
2. **Comm density 没有降低**: 100%(通信量没有减少)
|
||||||
|
|
||||||
|
### 原因分析
|
||||||
|
|
||||||
|
Attention pattern 的特点:
|
||||||
|
- 不同 heads 关注不同位置
|
||||||
|
- 不同 Q positions 关注不同 K positions
|
||||||
|
- 稀疏选择分布在整个 sequence 上
|
||||||
|
|
||||||
|
这导致虽然每个 (head, Q, K) 组合只选择少量 blocks,但聚合后覆盖了所有 CPU blocks。
|
||||||
|
|
||||||
|
### 潜在优化方向
|
||||||
|
|
||||||
|
1. **Per-head block selection**: 每个 head 独立选择 CPU blocks
|
||||||
|
2. **Block clustering**: 将相关 blocks 聚合到同一 CPU block
|
||||||
|
3. **Dynamic block size**: 根据 attention pattern 动态调整 CPU block 大小
|
||||||
|
|
||||||
|
## DensityObserver API
|
||||||
|
|
||||||
|
### 启用和重置
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.utils.density_observer import DensityObserver
|
||||||
|
|
||||||
|
DensityObserver.enable()
|
||||||
|
DensityObserver.complete_reset()
|
||||||
|
DensityObserver.set_mode("offload") # or "gpu_only"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 记录
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Compute density (GPU-only 模式自动记录)
|
||||||
|
DensityObserver.record(layer_id, mask, causal=True)
|
||||||
|
|
||||||
|
# Comm density (Offload 模式在 select_blocks 中记录)
|
||||||
|
DensityObserver.record_comm_density(layer_id, selected_cpu_blocks, total_cpu_blocks)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 获取结果
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 总体 density
|
||||||
|
overall_compute = DensityObserver.get_overall_density()
|
||||||
|
overall_comm = DensityObserver.get_overall_comm_density()
|
||||||
|
|
||||||
|
# Per-layer density
|
||||||
|
per_layer_compute = DensityObserver.get_per_layer_density()
|
||||||
|
per_layer_comm = DensityObserver.get_per_layer_comm_density()
|
||||||
|
|
||||||
|
# 打印摘要
|
||||||
|
DensityObserver.print_summary()
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/utils/density_observer.py`: DensityObserver 实现
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy(select_blocks 中记录 comm density)
|
||||||
|
- `tests/test_ruler.py`: RULER benchmark 测试脚本
|
||||||
198
docs/xattn_kernels_guide.md
Normal file
198
docs/xattn_kernels_guide.md
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# XAttention Kernels Guide
|
||||||
|
|
||||||
|
本文档详细说明 XAttention 的两个核心 Triton kernel 的工作原理。
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
XAttention 使用 stride 采样来快速估计 attention 分布,用于稀疏 attention 的 block 选择。
|
||||||
|
|
||||||
|
**数据流**:
|
||||||
|
```
|
||||||
|
Q [batch, heads, q_len, head_dim]
|
||||||
|
K [batch, heads, kv_len, head_dim]
|
||||||
|
↓ flat_group_gemm_fuse_reshape (stride 采样 + GEMM)
|
||||||
|
attn_scores [batch, heads, q_len/stride, kv_len/stride]
|
||||||
|
↓ softmax_fuse_block_sum (softmax + block 求和)
|
||||||
|
block_sums [batch, heads, q_blocks, k_blocks]
|
||||||
|
↓ threshold 选择
|
||||||
|
sparse_mask [batch, heads, q_blocks, k_blocks]
|
||||||
|
```
|
||||||
|
|
||||||
|
**注意**:Q 和 K 可以有不同的长度(q_len ≠ kv_len),这在 chunked prefill 场景中很常见。
|
||||||
|
|
||||||
|
## Kernel 1: flat_group_gemm_fuse_reshape
|
||||||
|
|
||||||
|
### 功能
|
||||||
|
|
||||||
|
计算 stride reshape 后的 attention scores,本质是计算原始 attention 矩阵中每个 stride×stride 块的**反对角线求和**。
|
||||||
|
|
||||||
|
### 函数签名
|
||||||
|
|
||||||
|
```python
|
||||||
|
def flat_group_gemm_fuse_reshape(
|
||||||
|
query_states: torch.Tensor, # [batch, heads, q_len, head_dim]
|
||||||
|
key_states: torch.Tensor, # [batch, heads, kv_len, head_dim]
|
||||||
|
stride: int,
|
||||||
|
chunk_start: int,
|
||||||
|
chunk_end: int,
|
||||||
|
is_causal: bool = True,
|
||||||
|
) -> torch.Tensor: # [batch, heads, q_len/stride, kv_len/stride]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 采样方式
|
||||||
|
|
||||||
|
```
|
||||||
|
Q 采样: (stride-1-s)::stride (逆向)
|
||||||
|
K 采样: s::stride (正向)
|
||||||
|
|
||||||
|
例如 stride=4:
|
||||||
|
Q 采样位置: 3, 7, 11, 15, ... (从位置 3 开始,每隔 4)
|
||||||
|
K 采样位置: 0, 4, 8, 12, ... (从位置 0 开始,每隔 4)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 反对角线原理
|
||||||
|
|
||||||
|
对于原始 attention 矩阵的每个 stride×stride 块:
|
||||||
|
|
||||||
|
```
|
||||||
|
stride=4 的块:
|
||||||
|
K[0] K[1] K[2] K[3]
|
||||||
|
Q[0] · · · X ← 反对角线
|
||||||
|
Q[1] · · X ·
|
||||||
|
Q[2] · X · ·
|
||||||
|
Q[3] X · · ·
|
||||||
|
```
|
||||||
|
|
||||||
|
**输出值 = 反对角线元素之和**
|
||||||
|
|
||||||
|
因为:
|
||||||
|
- `Q[i]` 采样自原始位置 `(stride-1-i)`
|
||||||
|
- `K[j]` 采样自原始位置 `j`
|
||||||
|
- 当 `i + j = stride - 1` 时,恰好在反对角线上
|
||||||
|
|
||||||
|
### Triton 约束
|
||||||
|
|
||||||
|
**GPU 相关的 BLOCK 大小**:
|
||||||
|
|
||||||
|
| GPU 类型 | 显存 | BLOCK_M/N | 最小 q_len/kv_len |
|
||||||
|
|----------|------|-----------|-------------------|
|
||||||
|
| RTX 3090 | 24GB | 64 | stride × 64 = 256 |
|
||||||
|
| A100/H100 | ≥40GB | 128 | stride × 128 = 512 |
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 代码中的判断逻辑
|
||||||
|
if props.total_memory < 30 * 1024**3: # < 30GB
|
||||||
|
BLOCK_M = BLOCK_N = 64
|
||||||
|
else:
|
||||||
|
BLOCK_M = BLOCK_N = 128
|
||||||
|
|
||||||
|
assert q_len % (stride * BLOCK_M) == 0
|
||||||
|
assert kv_len % (stride * BLOCK_N) == 0
|
||||||
|
```
|
||||||
|
|
||||||
|
### 验证示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 输入: 偶数位置=1, 奇数位置=2
|
||||||
|
# q_len=512, kv_len=2048, stride=4, head_dim=128
|
||||||
|
|
||||||
|
# 反对角线元素 (stride=4):
|
||||||
|
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4 (每对)
|
||||||
|
# stride=4 有 2 对
|
||||||
|
# 乘以 head_dim=128
|
||||||
|
# 预期值: 4 * 2 * 128 = 1024
|
||||||
|
|
||||||
|
# 输出 shape: [1, 1, 128, 512] (512/4=128, 2048/4=512)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Kernel 2: softmax_fuse_block_sum
|
||||||
|
|
||||||
|
### 功能
|
||||||
|
|
||||||
|
对 `flat_group_gemm_fuse_reshape` 的输出做 softmax,然后按 block 求和,得到每个 block 的 attention 权重总和。
|
||||||
|
|
||||||
|
### 参数说明
|
||||||
|
|
||||||
|
| 参数 | 含义 |
|
||||||
|
|------|------|
|
||||||
|
| `attn_weights_slice` | 输入 attention scores `[batch, heads, q_reshaped, k_reshaped]` |
|
||||||
|
| `reshaped_block_size` | Block 大小(在 reshaped 空间,= block_size / stride) |
|
||||||
|
| `segment_size` | 每次迭代处理的 K 维度大小(tiling) |
|
||||||
|
| `chunk_start` | Q 的起始位置(用于 causal mask) |
|
||||||
|
| `chunk_end` | Q 的结束位置 |
|
||||||
|
| `real_q_len` | 有效 Q 长度(用于 padding mask) |
|
||||||
|
| `scale` | 缩放因子(融合多个因素) |
|
||||||
|
| `is_causal` | 是否应用 causal mask |
|
||||||
|
|
||||||
|
### Scale 因子
|
||||||
|
|
||||||
|
```python
|
||||||
|
scale = log2(e) / sqrt(head_dim) / stride / norm
|
||||||
|
= 1.4426950408889634 / sqrt(head_dim) / stride / norm
|
||||||
|
```
|
||||||
|
|
||||||
|
| 因子 | 值 | 作用 |
|
||||||
|
|------|-----|------|
|
||||||
|
| `log2(e)` | 1.4426950408889634 | Triton 用 `exp2` 而非 `exp`,需转换底数 |
|
||||||
|
| `1/sqrt(head_dim)` | 1/√128 | 标准 attention 缩放 |
|
||||||
|
| `1/stride` | 1/4 | stride 采样的归一化 |
|
||||||
|
| `1/norm` | 变化 | 额外归一化因子 |
|
||||||
|
|
||||||
|
**为什么用 exp2**:Triton 的 `exp2` 比 `exp` 更快(硬件原生支持),所以把 log₂(e) 融合到 scale 里。
|
||||||
|
|
||||||
|
### Segment Size 约束
|
||||||
|
|
||||||
|
```python
|
||||||
|
assert segment_size >= reshaped_block_size
|
||||||
|
```
|
||||||
|
|
||||||
|
原因:kernel 内部使用 `segment_size // block_size` 做 reshape:
|
||||||
|
|
||||||
|
```python
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
```
|
||||||
|
|
||||||
|
如果 `segment_size < block_size`,则 `segment_size // block_size = 0`,导致无效维度。
|
||||||
|
|
||||||
|
### 验证示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 输入: attn_scores [1, 1, 128, 512] (所有值相同)
|
||||||
|
# block_size=128
|
||||||
|
|
||||||
|
# softmax 后每行均匀分布 (所有值相同 → 均匀)
|
||||||
|
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len = 128/512 = 0.25
|
||||||
|
# 每个 Q block 有 block_size=128 行
|
||||||
|
# block_sum = 128 * 0.25 = 32
|
||||||
|
|
||||||
|
# 输出 shape: [1, 1, 1, 4] (128/128=1, 512/128=4)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 完整示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 参数
|
||||||
|
q_len = 512 # Q 长度
|
||||||
|
kv_len = 2048 # K/V 长度 (可以不同于 q_len)
|
||||||
|
stride = 4
|
||||||
|
block_size = 128
|
||||||
|
|
||||||
|
# Step 1: flat_group_gemm_fuse_reshape
|
||||||
|
# 输入: Q [1,1,512,128], K [1,1,2048,128]
|
||||||
|
# 输出: attn_scores [1,1,128,512]
|
||||||
|
|
||||||
|
# Step 2: softmax_fuse_block_sum
|
||||||
|
# 输入: attn_scores [1,1,128,512]
|
||||||
|
# 输出: block_sums [1,1,1,4]
|
||||||
|
# q_blocks = 128/128 = 1
|
||||||
|
# k_blocks = 512/128 = 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## 测试代码
|
||||||
|
|
||||||
|
参考 `tests/test_xattn_kernels.py`,使用结构化数据验证两个 kernel 的正确性。
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解
|
||||||
|
- [`docs/sparse_attention_guide.md`](sparse_attention_guide.md): 稀疏 attention 方法概述
|
||||||
122
docs/xattn_kv_chunking_density_test.md
Normal file
122
docs/xattn_kv_chunking_density_test.md
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# XAttention KV Chunking Density 验证测试
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
验证 XAttention Triton kernel 是否只能沿 Q 轴分 chunk,不能沿 KV 轴分 chunk。
|
||||||
|
|
||||||
|
**假设**:`softmax_fuse_block_sum` 需要完整的 K 来计算正确的归一化分母,分 chunk 后的 attention 分布与完整序列不同。
|
||||||
|
|
||||||
|
## 测试方法
|
||||||
|
|
||||||
|
1. **GPU-only 模式**:一次性对完整序列调用 `xattn_estimate`,记录 Layer 0 的 density
|
||||||
|
2. **Offload DEBUG 模式**:分 chunk 调用 `xattn_estimate`,累积 selected/total counts,计算最终 density
|
||||||
|
3. 使用相同的 `_debug_k_full` buffer 收集完整 K cache,确保输入数据一致
|
||||||
|
|
||||||
|
### 关键代码逻辑
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Offload DEBUG: 每个 chunk 累积 selected/total
|
||||||
|
for each chunk:
|
||||||
|
K_full = _debug_k_full[:, :, :total_k_len, :] # 累积的 K
|
||||||
|
_, mask_chunk = xattn_estimate(Q_chunk, K_full, threshold=threshold, causal=True)
|
||||||
|
|
||||||
|
# 裁剪到有效区域,计算正确的 causal mask (考虑 Q 偏移量)
|
||||||
|
q_offset_blocks = k_blocks - q_blocks
|
||||||
|
causal_mask = indices <= (q_indices + q_offset_blocks)
|
||||||
|
|
||||||
|
selected += (mask_valid & causal_mask).sum()
|
||||||
|
total += causal_mask.sum()
|
||||||
|
|
||||||
|
density = selected / total
|
||||||
|
```
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
### 64K 序列 (niah_single_1, 序列长度 64891)
|
||||||
|
|
||||||
|
| threshold | GPU-only selected | Offload selected | GPU-only density | Offload density | 差异 (selected) |
|
||||||
|
|-----------|------------------|------------------|------------------|-----------------|-----------------|
|
||||||
|
| **0.90** | 1,524,617 | 1,330,506 | **0.3700** | **0.3229** | 194,111 (12.7%) |
|
||||||
|
| **0.95** | 1,955,015 | 1,747,585 | **0.4744** | **0.4241** | 207,430 (10.6%) |
|
||||||
|
| **1.00** | 4,118,719 | 4,118,896 | **0.9995** | **0.9995** | -177 (~0%) |
|
||||||
|
|
||||||
|
- **total**: 4,120,896 (两种模式一致)
|
||||||
|
|
||||||
|
### 32K 序列 (niah_single_1, 序列长度 32485)
|
||||||
|
|
||||||
|
| threshold | GPU-only selected | Offload selected | GPU-only density | Offload density | 差异 (selected) |
|
||||||
|
|-----------|------------------|------------------|------------------|-----------------|-----------------|
|
||||||
|
| **0.90** | 520,314 | 466,937 | **0.5021** | **0.4506** | 53,377 (10.3%) |
|
||||||
|
| **0.95** | 647,765 | 602,953 | **0.6251** | **0.5818** | 44,812 (6.9%) |
|
||||||
|
| **1.00** | 1,036,295 | 1,036,264 | **0.9999** | **0.9999** | 31 (~0%) |
|
||||||
|
|
||||||
|
- **total**: 1,036,320 (两种模式一致)
|
||||||
|
|
||||||
|
### 汇总对比
|
||||||
|
|
||||||
|
| 序列长度 | threshold | GPU-only density | Offload density | density 差异 |
|
||||||
|
|---------|-----------|------------------|-----------------|--------------|
|
||||||
|
| 32K | 0.90 | 0.5021 | 0.4506 | 5.2% |
|
||||||
|
| 64K | 0.90 | 0.3700 | 0.3229 | 4.7% |
|
||||||
|
| 32K | 0.95 | 0.6251 | 0.5818 | 4.3% |
|
||||||
|
| 64K | 0.95 | 0.4744 | 0.4241 | 5.0% |
|
||||||
|
| 32K | 1.00 | 0.9999 | 0.9999 | ~0% |
|
||||||
|
| 64K | 1.00 | 0.9995 | 0.9995 | ~0% |
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
### 1. Softmax 归一化本身是正确的
|
||||||
|
|
||||||
|
当 `threshold=1.0`(选择所有 blocks)时,GPU-only 和 Offload 模式的 density 几乎完全对齐(差异 < 0.01%)。
|
||||||
|
|
||||||
|
这说明:
|
||||||
|
- `_debug_k_full` 正确收集了完整的 K cache
|
||||||
|
- 分 chunk 调用 `xattn_estimate` 时,softmax 归一化在正确的 K 序列上计算
|
||||||
|
- causal mask 的 Q 偏移量处理正确
|
||||||
|
|
||||||
|
### 2. 问题在于 threshold 的应用方式
|
||||||
|
|
||||||
|
当 `threshold < 1.0` 时,差异显著(10-13%):
|
||||||
|
|
||||||
|
- **GPU-only**:对完整序列一次性应用 threshold,选择 cumulative attention >= threshold 的 blocks
|
||||||
|
- **Offload**:每个 chunk 独立应用 threshold,累积 selected counts
|
||||||
|
|
||||||
|
每个 chunk 独立应用 threshold 会导致:
|
||||||
|
- 某些在 GPU-only 中被选中的 blocks,在分 chunk 时因 attention 分布不同而未被选中
|
||||||
|
- 累积的 selected 比一次性计算的要少
|
||||||
|
|
||||||
|
### 3. XAttention Triton kernel 的 KV chunking 限制
|
||||||
|
|
||||||
|
**验证结论**:XAttention 的 `xattn_estimate` 可以正确处理 KV chunking(softmax 归一化正确),但 **threshold-based block selection 不能简单累积**。
|
||||||
|
|
||||||
|
如果要在 Offload 模式下获得与 GPU-only 一致的 block selection:
|
||||||
|
1. 需要先累积所有 chunks 的 attention scores
|
||||||
|
2. 最后一次性应用 threshold 选择 blocks
|
||||||
|
|
||||||
|
或者接受 10-13% 的 density 差异,这对实际推理准确性的影响需要进一步评估。
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GPU-only 模式
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||||
|
--sparse-policy xattn_bsa --sparse-threshold 0.9
|
||||||
|
|
||||||
|
# Offload 模式 (64K)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||||
|
--sparse-policy xattn_bsa --sparse-threshold 0.9 --enable-offload
|
||||||
|
|
||||||
|
# Offload 模式 (32K)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
|
||||||
|
--sparse-policy xattn_bsa --sparse-threshold 0.9 --enable-offload \
|
||||||
|
--data-dir /home/zijie/Code/nano-vllm/tests/data/ruler_32k --max-model-len 34000
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: DEBUG 代码位置
|
||||||
|
- `nanovllm/ops/xattn.py`: `xattn_estimate` 实现
|
||||||
|
- `nanovllm/utils/density_observer.py`: DensityObserver 实现
|
||||||
400
docs/xattn_kv_chunking_kernels.md
Normal file
400
docs/xattn_kv_chunking_kernels.md
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
# XAttention KV Chunking Kernels
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本文档描述了支持 KV 维度分 chunk 的 softmax kernels 实现。这些 kernels 允许在 CPU offload 场景下,沿 KV 维度分块计算 sparse attention estimation,而不需要在 GPU 上保存完整的 raw attention scores。
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
原始的 `softmax_fuse_block_sum` kernel 需要完整的 K 序列来计算正确的 softmax 归一化分母:
|
||||||
|
|
||||||
|
```
|
||||||
|
softmax(x_i) = exp(x_i) / Σ_j exp(x_j)
|
||||||
|
```
|
||||||
|
|
||||||
|
如果只有部分 K (KV chunk),分母 `Σ_j exp(x_j)` 不完整,导致归一化错误。
|
||||||
|
|
||||||
|
## 解决方案:三阶段计算
|
||||||
|
|
||||||
|
通过将 softmax 计算拆分为三个阶段,实现正确的 KV chunking:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ 三阶段 Pipeline │
|
||||||
|
├─────────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||||
|
│ │ KV Chunk 0 │ │ KV Chunk 1 │ │ KV Chunk N │ │
|
||||||
|
│ │ attn_scores │ │ attn_scores │ │ attn_scores │ │
|
||||||
|
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ 阶段 1: softmax_compute_partial_stats │ │
|
||||||
|
│ │ 计算每个 chunk 的 (m_partial, l_partial) │ │
|
||||||
|
│ └─────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ (m_0, l_0) (m_1, l_1) (m_N, l_N) │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ └────────────────┬┴─────────────────┘ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ 阶段 2: merge_softmax_stats │ │
|
||||||
|
│ │ Host 端合并 → (m_global, l_global) │ │
|
||||||
|
│ └─────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌────────────────┼────────────────┐ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ 阶段 3: softmax_normalize_and_block_sum │ │
|
||||||
|
│ │ 使用全局 stats 归一化并计算 block sums │ │
|
||||||
|
│ └─────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ▼ ▼ ▼ │
|
||||||
|
│ block_sums_0 block_sums_1 block_sums_N │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ └────────────────┴────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ torch.cat → final mask │
|
||||||
|
│ │
|
||||||
|
└─────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 1: `softmax_compute_partial_stats`
|
||||||
|
|
||||||
|
计算每个 KV chunk 的 partial statistics:
|
||||||
|
- `m_partial`: 该 chunk 内的最大值 (per query row)
|
||||||
|
- `l_partial`: 该 chunk 内的 partial sum = Σ exp(x - m_partial)
|
||||||
|
|
||||||
|
```python
|
||||||
|
m_partial, l_partial = softmax_compute_partial_stats(
|
||||||
|
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset, # KV chunk 在完整序列中的偏移
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
# 输出: m_partial, l_partial 形状为 [batch, heads, q_len]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 2: `merge_softmax_stats`
|
||||||
|
|
||||||
|
Host 端合并所有 KV chunks 的 statistics:
|
||||||
|
|
||||||
|
```python
|
||||||
|
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
合并公式 (Online Softmax):
|
||||||
|
```
|
||||||
|
m_new = max(m_global, m_chunk)
|
||||||
|
l_new = l_global * exp(m_global - m_new) + l_chunk * exp(m_chunk - m_new)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 3: `softmax_normalize_and_block_sum`
|
||||||
|
|
||||||
|
使用全局 statistics 归一化并计算 block sums:
|
||||||
|
|
||||||
|
```python
|
||||||
|
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
|
||||||
|
m_global, # [batch, heads, q_len]
|
||||||
|
l_global, # [batch, heads, q_len]
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
real_q_len=real_q_len,
|
||||||
|
scale=scale,
|
||||||
|
kv_offset=kv_offset,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
# 输出: [batch, heads, q_blocks, k_chunk_blocks]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 数学等价性证明
|
||||||
|
|
||||||
|
原始 softmax 计算 (完整 K):
|
||||||
|
```
|
||||||
|
softmax(x_i) = exp(x_i - m) / Σ_j exp(x_j - m)
|
||||||
|
```
|
||||||
|
|
||||||
|
分 KV chunk 计算:
|
||||||
|
```
|
||||||
|
Chunk 0: m_0 = max(x[0:N/2]), l_0 = Σ exp(x[0:N/2] - m_0)
|
||||||
|
Chunk 1: m_1 = max(x[N/2:N]), l_1 = Σ exp(x[N/2:N] - m_1)
|
||||||
|
|
||||||
|
合并:
|
||||||
|
m_global = max(m_0, m_1)
|
||||||
|
l_global = l_0 * exp(m_0 - m_global) + l_1 * exp(m_1 - m_global)
|
||||||
|
= Σ exp(x[0:N] - m_global) # 等于全局 sum
|
||||||
|
|
||||||
|
归一化:
|
||||||
|
softmax(x_i) = exp(x_i - m_global) / l_global # 正确!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Causal Mask 处理
|
||||||
|
|
||||||
|
两个 kernel 都正确处理了 causal attention:
|
||||||
|
|
||||||
|
1. **`softmax_partial_stats_kernel`**: 通过 `kv_offset` 参数确定当前 KV chunk 在完整序列中的位置,正确计算 causal boundary
|
||||||
|
|
||||||
|
2. **`softmax_normalize_block_sum_kernel`**: 同样使用 `kv_offset`,对 causal boundary 之后的位置输出 0
|
||||||
|
|
||||||
|
## 存储开销分析
|
||||||
|
|
||||||
|
### 符号定义
|
||||||
|
|
||||||
|
| 符号 | 含义 | 典型值 |
|
||||||
|
|------|------|--------|
|
||||||
|
| S | seq_len | 64K |
|
||||||
|
| B | batch_size | 1 |
|
||||||
|
| H | num_heads | 32 |
|
||||||
|
| D | head_dim | 128 |
|
||||||
|
| T | stride | 4-8 |
|
||||||
|
| C | chunk_size | 16K |
|
||||||
|
| n | num_kv_chunks = ceil(S/C) | 4 |
|
||||||
|
|
||||||
|
### 原始方式 (无 KV chunking)
|
||||||
|
|
||||||
|
**attn_weights 峰值内存**:
|
||||||
|
```
|
||||||
|
[B, H, S/T, S/T] × 4 bytes = B × H × (S/T)² × 4
|
||||||
|
|
||||||
|
例: S=64K, T=4, B=1, H=32
|
||||||
|
= 1 × 32 × 16384² × 4 = 32 GB
|
||||||
|
```
|
||||||
|
|
||||||
|
### KV Chunking 方式的额外存储
|
||||||
|
|
||||||
|
#### 1. Partial Stats (每个 KV chunk)
|
||||||
|
|
||||||
|
```
|
||||||
|
m_partial: [B, H, C/T] × 4 bytes
|
||||||
|
l_partial: [B, H, C/T] × 4 bytes
|
||||||
|
|
||||||
|
单个 chunk = 2 × B × H × (C/T) × 4
|
||||||
|
= 2 × 1 × 32 × 4096 × 4 = 1 MB
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Global Stats
|
||||||
|
|
||||||
|
```
|
||||||
|
m_global: [B, H, S/T] × 4 bytes
|
||||||
|
l_global: [B, H, S/T] × 4 bytes
|
||||||
|
|
||||||
|
= 2 × B × H × (S/T) × 4
|
||||||
|
= 2 × 1 × 32 × 16384 × 4 = 4 MB
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. 总额外开销
|
||||||
|
|
||||||
|
```
|
||||||
|
total_extra = n × partial_stats + global_stats
|
||||||
|
= 4 × 1MB + 4MB = 8 MB
|
||||||
|
```
|
||||||
|
|
||||||
|
### 存储开销随 seqlen 变化
|
||||||
|
|
||||||
|
| seqlen | num_chunks | 原始 attn_weights | 额外 stats | 比例 |
|
||||||
|
|--------|------------|-------------------|------------|------|
|
||||||
|
| 16K | 1 | 2 GB | 2 MB | 0.1% |
|
||||||
|
| 32K | 2 | 8 GB | 4 MB | 0.05% |
|
||||||
|
| 64K | 4 | 32 GB | 8 MB | 0.025% |
|
||||||
|
| 128K | 8 | 128 GB | 16 MB | 0.012% |
|
||||||
|
|
||||||
|
### 复杂度分析
|
||||||
|
|
||||||
|
| 存储组件 | 复杂度 | 说明 |
|
||||||
|
|----------|--------|------|
|
||||||
|
| 原始 attn_weights | O(S²) | 二次增长 |
|
||||||
|
| Partial/Global stats | O(S) | 线性增长 |
|
||||||
|
| **相对开销** | O(1/S) | **随 seqlen 递减** |
|
||||||
|
|
||||||
|
### 峰值显存优化
|
||||||
|
|
||||||
|
KV chunking 的主要收益是**峰值显存**从 O(S²) 降到 O(S×C):
|
||||||
|
|
||||||
|
```
|
||||||
|
原始: O(B × H × (S/T)²) # 完整 attn_weights
|
||||||
|
KV chunking: O(B × H × (S/T) × (C/T)) # 一次只处理一个 chunk
|
||||||
|
```
|
||||||
|
|
||||||
|
以 S=128K, C=16K 为例:
|
||||||
|
- 原始峰值: ~128 GB
|
||||||
|
- KV chunking 峰值: ~16 GB (降低 **8 倍**)
|
||||||
|
|
||||||
|
## 支持不同 Q/KV Chunk Size
|
||||||
|
|
||||||
|
三阶段 pipeline 支持 Q 和 KV 使用不同的 chunk size:
|
||||||
|
|
||||||
|
```python
|
||||||
|
q_chunk_size = 8192 # Q 分块大小
|
||||||
|
kv_chunk_size = 16384 # KV 分块大小
|
||||||
|
|
||||||
|
for q_chunk_idx in range(q_chunk_num):
|
||||||
|
Q_chunk = Q[:, :, q_start:q_end, :] # [B, H, q_chunk_size, D]
|
||||||
|
|
||||||
|
for kv_chunk_idx in range(kv_chunk_num):
|
||||||
|
K_chunk = K[:, :, kv_start:kv_end, :] # [B, H, kv_chunk_size, D]
|
||||||
|
# ... 三阶段处理
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试验证结果
|
||||||
|
|
||||||
|
| Config | seq_len | Q chunks | KV chunks | density | 对齐 |
|
||||||
|
|--------|---------|----------|-----------|---------|------|
|
||||||
|
| Q=16K, KV=16K | 64891 | 4 | 4 | 0.1117 | ✓ 100% |
|
||||||
|
| Q=8K, KV=16K | 64891 | 8 | 4 | 0.1112 | ✓ 100% |
|
||||||
|
| Q=16K, KV=8K | 64891 | 4 | 8 | 0.1117 | ✓ 100% |
|
||||||
|
| Q=8K, KV=8K | 64891 | 8 | 8 | 0.1112 | ✓ 100% |
|
||||||
|
| Q=4K, KV=16K | 64891 | 16 | 4 | 0.1109 | ✓ 100% |
|
||||||
|
|
||||||
|
## API 参考
|
||||||
|
|
||||||
|
### `softmax_compute_partial_stats`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def softmax_compute_partial_stats(
|
||||||
|
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
|
||||||
|
reshaped_block_size: int,
|
||||||
|
segment_size: int,
|
||||||
|
scale: float,
|
||||||
|
chunk_start: int = 0, # Q chunk 起始位置 (reshaped space)
|
||||||
|
kv_offset: int = 0, # KV chunk 偏移 (reshaped space)
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""返回 (m, l) partial stats"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### `merge_softmax_stats`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def merge_softmax_stats(
|
||||||
|
m_chunks: list, # List of [batch, heads, q_len] tensors
|
||||||
|
l_chunks: list, # List of [batch, heads, q_len] tensors
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""返回 (m_global, l_global)"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### `softmax_normalize_and_block_sum`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
|
||||||
|
m_global: torch.Tensor, # [batch, heads, q_len]
|
||||||
|
l_global: torch.Tensor, # [batch, heads, q_len]
|
||||||
|
reshaped_block_size: int,
|
||||||
|
segment_size: int,
|
||||||
|
chunk_start: int,
|
||||||
|
real_q_len: int,
|
||||||
|
scale: float,
|
||||||
|
kv_offset: int = 0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""返回 block sums [batch, heads, q_blocks, k_chunk_blocks]"""
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.ops.xattn import (
|
||||||
|
flat_group_gemm_fuse_reshape,
|
||||||
|
softmax_compute_partial_stats,
|
||||||
|
softmax_normalize_and_block_sum,
|
||||||
|
merge_softmax_stats,
|
||||||
|
find_blocks_chunked,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对每个 Q chunk
|
||||||
|
for q_chunk_idx in range(q_chunk_num):
|
||||||
|
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||||
|
|
||||||
|
# 阶段 1: 每个 KV chunk 计算 partial stats
|
||||||
|
m_chunks, l_chunks = [], []
|
||||||
|
attn_weights_chunks = []
|
||||||
|
|
||||||
|
for kv_chunk_idx in range(kv_chunk_num):
|
||||||
|
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||||
|
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
|
||||||
|
|
||||||
|
# 计算 raw scores
|
||||||
|
attn_weights = flat_group_gemm_fuse_reshape(
|
||||||
|
Q_chunk, K_chunk, STRIDE,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False, # K 不完整
|
||||||
|
)
|
||||||
|
attn_weights_chunks.append(attn_weights)
|
||||||
|
|
||||||
|
# 计算 partial stats
|
||||||
|
m, l = softmax_compute_partial_stats(
|
||||||
|
attn_weights, block_size, segment_size, scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
m_chunks.append(m)
|
||||||
|
l_chunks.append(l)
|
||||||
|
|
||||||
|
# 阶段 2: 合并 stats
|
||||||
|
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||||
|
|
||||||
|
# 阶段 3: 归一化并计算 block sums
|
||||||
|
block_sums_list = []
|
||||||
|
for kv_chunk_idx, attn_weights in enumerate(attn_weights_chunks):
|
||||||
|
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
|
||||||
|
block_sums = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights, m_global, l_global,
|
||||||
|
block_size, segment_size, chunk_start, real_q_len, scale,
|
||||||
|
kv_offset=kv_offset,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
block_sums_list.append(block_sums)
|
||||||
|
|
||||||
|
# 拼接并选择 blocks
|
||||||
|
attn_sum = torch.cat(block_sums_list, dim=-1)
|
||||||
|
mask = find_blocks_chunked(attn_sum, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能对比
|
||||||
|
|
||||||
|
| 方面 | 原始实现 | KV Chunking 实现 |
|
||||||
|
|------|---------|-----------------|
|
||||||
|
| Kernel 数量 | 1 | 2 (stats + normalize) |
|
||||||
|
| Raw scores 读取次数 | 2 | 2 |
|
||||||
|
| 额外内存 | 0 | O(B × H × S/T × 2) for (m, l) |
|
||||||
|
| Host 计算 | 无 | merge stats (轻量) |
|
||||||
|
| **峰值显存** | O(S²) | **O(S × C)** |
|
||||||
|
|
||||||
|
## 验证测试
|
||||||
|
|
||||||
|
### 批量测试结果
|
||||||
|
|
||||||
|
测试脚本 `tests/test_xattn_kv_chunking_batch.py` 验证了不同 seqlen 下的一致性:
|
||||||
|
|
||||||
|
```
|
||||||
|
| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |
|
||||||
|
|---------|--------|-----------|-----------|-------------|------------|----------|-----------|--------|
|
||||||
|
| 3688 | 4 | 0.90 | 1 | 0.383405 | 0.383405 | 0.000000 | 0.0000% | PASS |
|
||||||
|
| 7888 | 4 | 0.90 | 1 | 0.290611 | 0.290611 | 0.000000 | 0.0000% | PASS |
|
||||||
|
| 15685 | 4 | 0.90 | 1 | 0.197724 | 0.197724 | 0.000000 | 0.0000% | PASS |
|
||||||
|
| 32485 | 4 | 0.90 | 2 | 0.159023 | 0.159023 | 0.000000 | 0.0000% | PASS |
|
||||||
|
| 64891 | 4 | 0.90 | 4 | 0.111656 | 0.111656 | 0.000000 | 0.0000% | PASS |
|
||||||
|
```
|
||||||
|
|
||||||
|
### 关键结论
|
||||||
|
|
||||||
|
1. **数学等价性**: density_diff = 0.000000 对于所有测试
|
||||||
|
2. **Mask 完全对齐**: mask_diff = 0.0000% 对于所有测试
|
||||||
|
3. **支持任意 Q/KV chunk size 组合**
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `nanovllm/ops/xattn.py`: Kernel 实现
|
||||||
|
- `tests/test_xattn_estimate_alignment.py`: 单文件验证测试
|
||||||
|
- `tests/test_xattn_kv_chunking_batch.py`: 批量验证测试
|
||||||
|
- `docs/xattn_kernels_guide.md`: 原始 kernel 文档
|
||||||
154
docs/xattn_memory_benchmark.md
Normal file
154
docs/xattn_memory_benchmark.md
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# XAttention Memory Benchmark
|
||||||
|
|
||||||
|
GPU-only 模式下 XAttention 的内存使用分析。
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
### 硬件
|
||||||
|
- **GPU**: NVIDIA A100 80GB (用于基准测试)
|
||||||
|
- **目标**: 验证在 RTX 3090/4090 (24GB) 上的可行性
|
||||||
|
|
||||||
|
### 模型
|
||||||
|
- **Model**: Qwen3-0.6B (28 layers, 16 heads, 8 KV heads, head_dim=128)
|
||||||
|
- **Context Length**: 32K (max_model_len=40960)
|
||||||
|
|
||||||
|
### XAttention 配置
|
||||||
|
- **Sparse Policy**: XATTN_BSA
|
||||||
|
- **Threshold**: 0.9
|
||||||
|
- **Block Size**: 128 tokens (BSA block)
|
||||||
|
- **Stride**: 8
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 内存使用分析
|
||||||
|
|
||||||
|
### 基准测试 (gpu-utilization=0.9)
|
||||||
|
|
||||||
|
| 指标 | 数值 |
|
||||||
|
|------|------|
|
||||||
|
| KV Cache | 157 blocks × 448 MB = 70.3 GB |
|
||||||
|
| **峰值内存** | **73,949 MiB (72.2 GB)** |
|
||||||
|
| GPU 利用率 | 90.2% |
|
||||||
|
|
||||||
|
### 24GB 显存可行性测试
|
||||||
|
|
||||||
|
| gpu-utilization | KV Cache Blocks | KV Cache Size | 峰值内存 | 测试结果 |
|
||||||
|
|-----------------|-----------------|---------------|----------|----------|
|
||||||
|
| 0.25 | 39 blocks | 17.5 GB | **20.6 GB** | ✅ 5/5 PASSED |
|
||||||
|
| 0.28 | 44 blocks | 19.7 GB | **22.8 GB** | ✅ 5/5 PASSED |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 24GB 显存推荐配置
|
||||||
|
|
||||||
|
适用于 **RTX 3090 / RTX 4090 (24GB)**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \
|
||||||
|
--model ~/models/Qwen3-0.6B \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 5 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9 \
|
||||||
|
--gpu-utilization 0.28
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置说明
|
||||||
|
|
||||||
|
| 参数 | 值 | 说明 |
|
||||||
|
|------|-----|------|
|
||||||
|
| `--gpu-utilization` | 0.28 | 限制 GPU 内存使用到 ~23GB |
|
||||||
|
| `--max-model-len` | 40960 | 支持 32K+ context |
|
||||||
|
| `--sparse-policy` | XATTN_BSA | 启用 XAttention 稀疏注意力 |
|
||||||
|
| `--sparse-threshold` | 0.9 | 选择覆盖 90% attention 的 blocks |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 内存分解
|
||||||
|
|
||||||
|
### Qwen3-0.6B @ 32K Context
|
||||||
|
|
||||||
|
| 组件 | 计算公式 | 大小 |
|
||||||
|
|------|----------|------|
|
||||||
|
| 模型权重 | 0.6B × 2 bytes | ~1.2 GB |
|
||||||
|
| KV Cache (per-token) | 2 × 28 layers × 8 kv_heads × 128 head_dim × 2 bytes | 112 KB |
|
||||||
|
| KV Cache (per-block) | 112 KB × 4096 tokens | 448 MB |
|
||||||
|
| KV Cache (44 blocks) | 448 MB × 44 | 19.7 GB |
|
||||||
|
| XAttention Buffers | GQA + mask + KV chunking | ~0.3 GB |
|
||||||
|
| 中间激活 | 运行时分配 | ~1.5 GB |
|
||||||
|
| **总计** | | **~22.8 GB** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 性能指标
|
||||||
|
|
||||||
|
### RULER niah_single_1 (5 samples)
|
||||||
|
|
||||||
|
| 指标 | gpu-util=0.25 | gpu-util=0.28 | gpu-util=0.9 |
|
||||||
|
|------|---------------|---------------|--------------|
|
||||||
|
| 准确率 | 100% (5/5) | 100% (5/5) | 100% (5/5) |
|
||||||
|
| 耗时 | 11.4s | 11.5s | 11.6s |
|
||||||
|
| Compute Density | 24.77% | 24.77% | 24.77% |
|
||||||
|
| Min Layer Density | 4.29% (Layer 5) | 4.29% (Layer 5) | 4.29% (Layer 5) |
|
||||||
|
|
||||||
|
**结论**: 降低 gpu-utilization 不影响准确率和性能,只影响可支持的最大序列长度。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 不同模型的估算
|
||||||
|
|
||||||
|
### KV Cache 公式
|
||||||
|
|
||||||
|
```
|
||||||
|
KV Cache per-token = 2 × num_layers × num_kv_heads × head_dim × dtype_size
|
||||||
|
KV Cache per-block = per-token × block_size
|
||||||
|
```
|
||||||
|
|
||||||
|
### 常见模型估算 (32K context, block_size=4096)
|
||||||
|
|
||||||
|
| 模型 | Layers | KV Heads | Head Dim | Per-Token | 32K Tokens | 24GB 可行? |
|
||||||
|
|------|--------|----------|----------|-----------|------------|------------|
|
||||||
|
| Qwen3-0.6B | 28 | 8 | 128 | 112 KB | 3.5 GB | ✅ 是 |
|
||||||
|
| Qwen3-4B | 36 | 8 | 128 | 144 KB | 4.5 GB | ✅ 是 |
|
||||||
|
| Llama-3.1-8B | 32 | 8 | 128 | 128 KB | 4.0 GB | ⚠️ 需要 offload |
|
||||||
|
| Qwen2.5-7B | 28 | 4 | 128 | 56 KB | 1.8 GB | ✅ 是 |
|
||||||
|
|
||||||
|
注: 8B 模型权重约 16GB,加上 KV cache 超过 24GB,需要 CPU offload。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 使用建议
|
||||||
|
|
||||||
|
### RTX 3090/4090 (24GB)
|
||||||
|
|
||||||
|
1. **小模型 (≤4B)**:可直接使用 GPU-only + XAttention
|
||||||
|
```bash
|
||||||
|
--gpu-utilization 0.28 --sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **大模型 (7B-8B)**:需要 CPU offload
|
||||||
|
```bash
|
||||||
|
--enable-offload --num-gpu-blocks 4 --num-cpu-blocks 32
|
||||||
|
```
|
||||||
|
|
||||||
|
### A100 (40GB/80GB)
|
||||||
|
|
||||||
|
1. **所有模型**:可使用 GPU-only 模式
|
||||||
|
```bash
|
||||||
|
--gpu-utilization 0.9 --sparse-policy XATTN_BSA
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `tests/test_ruler.py`: RULER 测试脚本
|
||||||
|
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy 实现
|
||||||
|
- `docs/gpuonly_density_alignment_test.md`: Density 对齐验证
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Date**: 2026-02-02
|
||||||
|
**Author**: Zijie Tian
|
||||||
184
docs/xattn_offload_profiling_32k.md
Normal file
184
docs/xattn_offload_profiling_32k.md
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
# XAttention Offload Profiling - 32K Context
|
||||||
|
|
||||||
|
Nsys profiling 分析 XAttention vs Full Attention 在 Offload 模式下的性能。
|
||||||
|
|
||||||
|
**测试日期**: 2026-02-05
|
||||||
|
**测试模型**: Llama-3.1-8B-Instruct
|
||||||
|
**Context**: 32K tokens
|
||||||
|
**GPU**: A100-80GB (GPU 0)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试配置
|
||||||
|
|
||||||
|
| 参数 | Full | XAttention |
|
||||||
|
|------|------|------------|
|
||||||
|
| Policy | FULL | XATTN_BSA |
|
||||||
|
| Block size | 4096 | 4096 |
|
||||||
|
| GPU blocks | 4 | 4 |
|
||||||
|
| Threshold | - | 0.95 |
|
||||||
|
| Density | 100% | ~50% |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## XAttention 各阶段时间统计
|
||||||
|
|
||||||
|
### NVTX Markers Summary
|
||||||
|
|
||||||
|
| 阶段 | 总时间(ms) | 调用次数 | 平均时间(ms) | 说明 |
|
||||||
|
|------|------------|----------|--------------|------|
|
||||||
|
| xattn_find_blocks | 1155.1 | 256 | 4.51 | 块选择 (threshold-based) |
|
||||||
|
| xattn_estimate_pass1 | 588.3 | 256 | 2.30 | 第一轮: partial stats |
|
||||||
|
| xattn_compute_historical | 512.0 | 224 | 2.29 | 历史 KV attention |
|
||||||
|
| xattn_estimate_pass2 | 501.6 | 256 | 1.96 | 第二轮: block sums |
|
||||||
|
| xattn_estimate_merge | 197.9 | 256 | 0.77 | 合并 softmax stats |
|
||||||
|
| xattn_compute_merge | 93.8 | 256 | 0.37 | 计算结果合并 |
|
||||||
|
| xattn_compute_current | 59.2 | 256 | 0.23 | 当前 chunk attention |
|
||||||
|
|
||||||
|
### 时间分配
|
||||||
|
|
||||||
|
```
|
||||||
|
Total XAttention overhead: 3108 ms
|
||||||
|
|
||||||
|
Estimate 阶段: 1288 ms (41.4%)
|
||||||
|
- pass1: 588 ms
|
||||||
|
- pass2: 502 ms
|
||||||
|
- merge: 198 ms
|
||||||
|
|
||||||
|
Find blocks: 1155 ms (37.2%)
|
||||||
|
|
||||||
|
Compute 阶段: 665 ms (21.4%)
|
||||||
|
- historical: 512 ms
|
||||||
|
- merge: 94 ms
|
||||||
|
- current: 59 ms
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Chunk7 (最后一个 chunk) 对比
|
||||||
|
|
||||||
|
### Per-Layer 时间
|
||||||
|
|
||||||
|
| Policy | Layer 0 | Layer 1 | ... | Layer 31 | Avg |
|
||||||
|
|--------|---------|---------|-----|----------|-----|
|
||||||
|
| Full | 36.5 ms | 33.6 ms | ... | 32.7 ms | ~35 ms |
|
||||||
|
| XAttn | 39.7 ms | 39.3 ms | ... | 38.5 ms | ~38 ms |
|
||||||
|
|
||||||
|
### 分析
|
||||||
|
|
||||||
|
Chunk7 是序列的最后 ~4K tokens (3813 tokens),此时:
|
||||||
|
- K 长度: 32485 tokens
|
||||||
|
- Density: 42.08%
|
||||||
|
|
||||||
|
**结论**: XAttention 在 Chunk7 比 Full 慢约 8%,原因:
|
||||||
|
1. Estimate 开销无法被稀疏计算收益抵消
|
||||||
|
2. 42% density 仍然较高,稀疏收益有限
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Full Attention Chunk7 详细数据
|
||||||
|
|
||||||
|
```
|
||||||
|
Layer Time(ms)
|
||||||
|
L0 36.5
|
||||||
|
L1 44.3
|
||||||
|
L2 43.7
|
||||||
|
L3 38.7
|
||||||
|
L4 34.2
|
||||||
|
L5 45.2
|
||||||
|
...
|
||||||
|
L31 32.7
|
||||||
|
Avg ~35
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## XAttention Chunk7 详细数据
|
||||||
|
|
||||||
|
```
|
||||||
|
Layer Time(ms)
|
||||||
|
L0 39.7
|
||||||
|
L1 39.3
|
||||||
|
L2 37.1
|
||||||
|
L3 39.1
|
||||||
|
L4 38.7
|
||||||
|
L5 39.4
|
||||||
|
...
|
||||||
|
L31 38.5
|
||||||
|
Avg ~38
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 性能瓶颈分析
|
||||||
|
|
||||||
|
### 1. xattn_find_blocks 开销过高
|
||||||
|
|
||||||
|
- 平均 4.51 ms per call
|
||||||
|
- 占总时间 37.2%
|
||||||
|
- 原因: threshold-based 块选择涉及排序和累积求和
|
||||||
|
|
||||||
|
### 2. 两轮 estimate 开销
|
||||||
|
|
||||||
|
- Pass1 + Pass2 共 1090 ms
|
||||||
|
- 需要遍历所有 KV chunks 两次
|
||||||
|
- 可优化方向: 单轮 estimate
|
||||||
|
|
||||||
|
### 3. Compute 阶段相对高效
|
||||||
|
|
||||||
|
- 只占 21.4%
|
||||||
|
- 说明 BSA 稀疏计算本身效率不错
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 优化建议
|
||||||
|
|
||||||
|
### 短期
|
||||||
|
|
||||||
|
1. **减少 find_blocks 开销**
|
||||||
|
- 使用 top-k 而不是 threshold
|
||||||
|
- 预分配 mask buffer 避免动态分配
|
||||||
|
|
||||||
|
2. **合并 estimate 两轮**
|
||||||
|
- 在单轮中同时计算 stats 和 block sums
|
||||||
|
|
||||||
|
### 中期
|
||||||
|
|
||||||
|
1. **estimate 阶段使用更小的 block_size**
|
||||||
|
- 当前 block_size=4096 对 estimate 不友好
|
||||||
|
- 参考 `docs/estimate_block_size_performance.md`
|
||||||
|
|
||||||
|
2. **Pipeline estimate 和 H2D**
|
||||||
|
- 将 estimate 与下一个 chunk 的 H2D 重叠
|
||||||
|
|
||||||
|
### 长期
|
||||||
|
|
||||||
|
1. **预测式块选择**
|
||||||
|
- 基于历史 pattern 预测下一个 chunk 的重要 blocks
|
||||||
|
- 减少 estimate 开销
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `results/nsys/full_offload_32k_blk4096_20260205_023257.nsys-rep`
|
||||||
|
- `results/nsys/xattn_offload_32k_blk4096_20260205_023435.nsys-rep`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 命令
|
||||||
|
|
||||||
|
### Profile Full
|
||||||
|
```bash
|
||||||
|
bash scripts/profile_offload.sh --policy full --ctx-len 32k --gpu 0 --model ~/models/Llama-3.1-8B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### Profile XAttention
|
||||||
|
```bash
|
||||||
|
bash scripts/profile_offload.sh --policy xattn --ctx-len 32k --gpu 0 --model ~/models/Llama-3.1-8B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 分析 NVTX
|
||||||
|
```bash
|
||||||
|
nsys stats --report nvtx_pushpop_sum <file>.nsys-rep
|
||||||
|
```
|
||||||
307
docs/xattn_offload_stream_sync_fix.md
Normal file
307
docs/xattn_offload_stream_sync_fix.md
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
# XAttention Offload Stream Synchronization Fix
|
||||||
|
|
||||||
|
修复 XAttention BSA Policy 在 Offload 模式下的 CUDA stream 同步 bug。
|
||||||
|
|
||||||
|
**修复日期**: 2026-02-05
|
||||||
|
**Commit**: `829b311`
|
||||||
|
**影响文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`, `nanovllm/kvcache/offload_engine.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 问题描述
|
||||||
|
|
||||||
|
### 症状
|
||||||
|
|
||||||
|
在 Offload 模式下运行 RULER benchmark 时,XAttention BSA 的 `select_blocks` 方法中 Pass 1 和 Pass 2 从**同一个 CPU block** 加载的 K 数据不一致:
|
||||||
|
|
||||||
|
```
|
||||||
|
Pass 1: K_chunk sum = 745472.00 (正确)
|
||||||
|
Pass 2: K_chunk sum = 0.00 (错误,数据未加载完成)
|
||||||
|
```
|
||||||
|
|
||||||
|
这导致 attention 计算结果错误,RULER 准确率下降。
|
||||||
|
|
||||||
|
### 复现条件
|
||||||
|
|
||||||
|
- 模式: Offload (`--enable-offload`)
|
||||||
|
- Context: ≥ 32K tokens
|
||||||
|
- 稀疏策略: `--sparse-policy XATTN_BSA`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 根因分析
|
||||||
|
|
||||||
|
### Stream 配置回顾
|
||||||
|
|
||||||
|
nano-vllm 的 CPU offload 使用多个 CUDA streams 实现 pipeline:
|
||||||
|
|
||||||
|
| Stream | 用途 |
|
||||||
|
|--------|------|
|
||||||
|
| `slot_transfer_streams[i]` | H2D 传输 (CPU → GPU slot) |
|
||||||
|
| `compute_stream` | Attention 计算 |
|
||||||
|
| `prefill_offload_streams[i]` | D2H 传输 (GPU → CPU cache) |
|
||||||
|
|
||||||
|
### 同步机制
|
||||||
|
|
||||||
|
`wait_slot_layer(slot)` 使用 event 机制同步:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def wait_slot_layer(self, slot_idx: int):
|
||||||
|
"""Make compute_stream wait for H2D transfer completion."""
|
||||||
|
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bug 根因
|
||||||
|
|
||||||
|
在 `select_blocks` 方法中:
|
||||||
|
|
||||||
|
1. H2D 传输在 `slot_transfer_streams` 上执行
|
||||||
|
2. `wait_slot_layer` 让 `compute_stream` 等待传输完成
|
||||||
|
3. **但是** 后续的 compute kernels 在**默认 stream** 上执行,而不是 `compute_stream`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Bug 代码
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot) # compute_stream 等待
|
||||||
|
|
||||||
|
# 这些 kernel 在默认 stream 上运行,没有等待 H2D 完成!
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
# ... 后续计算 ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 时序图
|
||||||
|
|
||||||
|
```
|
||||||
|
slot_transfer_stream: [====H2D====]
|
||||||
|
compute_stream: |wait|
|
||||||
|
default_stream: [kernel1][kernel2] ← 没有等待!
|
||||||
|
↑
|
||||||
|
数据未就绪
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 修复方案
|
||||||
|
|
||||||
|
### 核心修改
|
||||||
|
|
||||||
|
将所有 estimate 阶段的 compute kernels 包装在 `with torch.cuda.stream(compute_stream):` 中:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 修复后代码
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot) # compute_stream 等待
|
||||||
|
|
||||||
|
# 所有计算在 compute_stream 上执行
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
# ... 后续计算 ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 修复位置
|
||||||
|
|
||||||
|
`select_blocks` 方法中共 6 处需要修复:
|
||||||
|
|
||||||
|
| 位置 | 阶段 | 修复内容 |
|
||||||
|
|------|------|----------|
|
||||||
|
| Pass 1 历史 blocks | `xattn_estimate_pass1` | 历史 KV chunk 处理 |
|
||||||
|
| Pass 1 当前 chunk | `xattn_estimate_pass1` | 当前 GPU 上的 K 处理 |
|
||||||
|
| Step 2 合并 | `merge_softmax_stats` | softmax stats 合并 |
|
||||||
|
| Pass 2 历史 blocks | `xattn_estimate_pass2` | 带全局 stats 的 block_sum |
|
||||||
|
| Pass 2 当前 chunk | `xattn_estimate_pass2` | 当前 chunk 的 block_sum |
|
||||||
|
| Step 4 block 选择 | `find_blocks_chunked` | 最终 block 选择 |
|
||||||
|
|
||||||
|
### 时序图(修复后)
|
||||||
|
|
||||||
|
```
|
||||||
|
slot_transfer_stream: [====H2D====]
|
||||||
|
compute_stream: |wait|[kernel1][kernel2]
|
||||||
|
↑
|
||||||
|
数据已就绪
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 代码变更详情
|
||||||
|
|
||||||
|
### 1. Pass 1 历史 blocks 处理
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Before (bug)
|
||||||
|
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot) # 默认 stream
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
# ... compute ...
|
||||||
|
|
||||||
|
# After (fixed)
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream): # 显式指定 stream
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
# ... compute ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 移除 STRONG SYNC
|
||||||
|
|
||||||
|
`offload_engine.py` 中移除了不必要的强同步:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Removed from load_to_slot_layer() and load_k_only_to_slot_layer()
|
||||||
|
# STRONG SYNC: Synchronize all prefill offload streams before H2D
|
||||||
|
# for offload_stream in self.prefill_offload_streams:
|
||||||
|
# offload_stream.synchronize()
|
||||||
|
```
|
||||||
|
|
||||||
|
这些同步现在由 event 机制正确处理,不再需要阻塞式同步。
|
||||||
|
|
||||||
|
### 3. 其他清理
|
||||||
|
|
||||||
|
- 移除 DEBUG print 语句
|
||||||
|
- 移除 `torch.save()` debug 代码
|
||||||
|
- 合并多个 fallback 条件
|
||||||
|
- 将 `chunk_size` 默认值从 16384 改为 4096(匹配 offload Q chunk size)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试验证
|
||||||
|
|
||||||
|
### 测试命令
|
||||||
|
|
||||||
|
**GPU 0 - Offload 模式测试**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 10 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--enable-offload \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
**GPU 1 - GPU-only 模式测试**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_ruler.py \
|
||||||
|
--model ~/models/Qwen3-0.6B \
|
||||||
|
--data-dir tests/data/ruler_32k \
|
||||||
|
--datasets niah_single_1 \
|
||||||
|
--num-samples 10 \
|
||||||
|
--max-model-len 40960 \
|
||||||
|
--sparse-policy XATTN_BSA \
|
||||||
|
--sparse-threshold 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试结果
|
||||||
|
|
||||||
|
| 模式 | 模型 | Context | Samples | Pass Rate | Density |
|
||||||
|
|------|------|---------|---------|-----------|---------|
|
||||||
|
| Offload | Llama-3.1-8B | 32K | 10/10 | **100%** | 9.53% |
|
||||||
|
| GPU-only | Qwen3-0.6B | 32K | 10/10 | **100%** | 9.84% |
|
||||||
|
|
||||||
|
### Density 对齐验证
|
||||||
|
|
||||||
|
| 模式 | Layer 0 Density | 差异 |
|
||||||
|
|------|-----------------|------|
|
||||||
|
| GPU-only | 9.84% | - |
|
||||||
|
| Offload | 9.53% | ~3% |
|
||||||
|
|
||||||
|
~3% 的差异是预期的,因为两种模式的 KV 累积模式不同:
|
||||||
|
- GPU-only: 一次性处理所有 KV
|
||||||
|
- Offload: 分 chunk 处理,每个 chunk 独立计算 softmax stats 后合并
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 技术细节
|
||||||
|
|
||||||
|
### 三阶段 KV Chunking 流程
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Stage 1: softmax_compute_partial_stats │
|
||||||
|
│ └── 每个 KV chunk 独立计算 partial stats (m_i, l_i) │
|
||||||
|
│ │
|
||||||
|
│ Stage 2: merge_softmax_stats │
|
||||||
|
│ └── Host 端合并所有 chunks: (m_global, l_global) │
|
||||||
|
│ │
|
||||||
|
│ Stage 3: softmax_normalize_and_block_sum │
|
||||||
|
│ └── 使用全局 stats 归一化并计算 block sums │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stream 配置要求
|
||||||
|
|
||||||
|
| 操作类型 | Stream | 原因 |
|
||||||
|
|----------|--------|------|
|
||||||
|
| H2D 传输 | `slot_transfer_streams` | 异步传输,不阻塞计算 |
|
||||||
|
| D2H 传输 | `prefill_offload_streams` | 异步 offload,不阻塞计算 |
|
||||||
|
| Estimate kernels | `compute_stream` | 与 attention 计算共享,确保同步 |
|
||||||
|
| Attention kernels | `compute_stream` | 主计算流 |
|
||||||
|
|
||||||
|
### Event 同步机制
|
||||||
|
|
||||||
|
```python
|
||||||
|
# H2D 传输完成后记录 event
|
||||||
|
self.ring_slot_ready[slot_idx].record(slot_transfer_stream)
|
||||||
|
|
||||||
|
# 计算前等待 H2D 完成
|
||||||
|
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
|
||||||
|
|
||||||
|
# 计算完成后记录 event(用于下一轮 H2D)
|
||||||
|
self.ring_slot_compute_done[slot_idx].record(compute_stream)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [`docs/architecture_guide.md`](architecture_guide.md): Stream 配置和 ring buffer 架构
|
||||||
|
- [`docs/xattn_kv_chunking_kernels.md`](xattn_kv_chunking_kernels.md): 三阶段 softmax kernels
|
||||||
|
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md): Density 对齐测试
|
||||||
|
- [`docs/xattn_bsa_policy_design.md`](xattn_bsa_policy_design.md): XAttention BSA Policy 设计
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 经验总结
|
||||||
|
|
||||||
|
### 1. Stream 同步的隐蔽性
|
||||||
|
|
||||||
|
CUDA stream 同步 bug 很难发现:
|
||||||
|
- 数据可能"大部分时间"正确(取决于时序)
|
||||||
|
- 错误表现为随机/间歇性的结果偏差
|
||||||
|
- 需要精确的 debug logging 才能定位
|
||||||
|
|
||||||
|
### 2. Event vs Synchronize
|
||||||
|
|
||||||
|
| 方法 | 优点 | 缺点 |
|
||||||
|
|------|------|------|
|
||||||
|
| `stream.wait_event()` | 非阻塞,保持 pipeline | 只同步指定 stream |
|
||||||
|
| `stream.synchronize()` | 保证完成 | 阻塞整个 stream,破坏 pipeline |
|
||||||
|
|
||||||
|
**最佳实践**: 使用 event 进行精确同步,避免 synchronize 阻塞。
|
||||||
|
|
||||||
|
### 3. 调试方法
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 打印 tensor sum 验证数据一致性
|
||||||
|
print(f"K_chunk sum = {K_chunk.sum().item()}")
|
||||||
|
|
||||||
|
# 保存中间结果进行离线比较
|
||||||
|
torch.save({'K': K_chunk, 'layer': layer_id}, f'/tmp/debug_{pass}_{chunk}.pt')
|
||||||
|
```
|
||||||
170
docs/xattn_performance_analysis.md
Normal file
170
docs/xattn_performance_analysis.md
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# XAttention Performance Analysis
|
||||||
|
|
||||||
|
本文档记录 XAttention 在不同配置下的性能分析结果,包括 NVTX 标记位置、block size 影响和性能瓶颈。
|
||||||
|
|
||||||
|
## NVTX 标记
|
||||||
|
|
||||||
|
XAttention 代码中添加了 NVTX 标记用于 nsys profiling,便于分析 estimate 和 compute 阶段的性能。
|
||||||
|
|
||||||
|
### 标记位置
|
||||||
|
|
||||||
|
| 模式 | 标记名称 | 文件位置 | 说明 |
|
||||||
|
|------|---------|---------|------|
|
||||||
|
| GPU-only | `xattn_estimate` | `xattn_bsa.py:compute_prefill` | xattn_estimate 调用 |
|
||||||
|
| GPU-only | `xattn_bsa_compute` | `xattn_bsa.py:compute_prefill` | BSA kernel 调用 |
|
||||||
|
| Offload | `xattn_estimate_gemm` | `xattn_bsa.py:select_blocks` | flat_group_gemm 循环 |
|
||||||
|
| Offload | `xattn_estimate_softmax` | `xattn_bsa.py:select_blocks` | softmax_fuse_block_sum |
|
||||||
|
| Offload | `xattn_estimate_find_blocks` | `xattn_bsa.py:select_blocks` | find_blocks_chunked |
|
||||||
|
| Offload | `xattn_compute_historical` | `xattn_bsa.py:compute_chunked_prefill` | 历史 chunks attention |
|
||||||
|
| Offload | `xattn_compute_current` | `xattn_bsa.py:compute_chunked_prefill` | 当前 chunk attention |
|
||||||
|
| Offload | `xattn_compute_merge` | `xattn_bsa.py:compute_chunked_prefill` | merge 操作 |
|
||||||
|
|
||||||
|
### 查看 NVTX 统计
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 生成 profile
|
||||||
|
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 4096 --gpu 0
|
||||||
|
|
||||||
|
# 查看 NVTX 统计
|
||||||
|
nsys stats --report nvtx_pushpop_sum results/nsys/<filename>.nsys-rep
|
||||||
|
```
|
||||||
|
|
||||||
|
## Block Size 对 Offload 模式的影响
|
||||||
|
|
||||||
|
### 测试配置
|
||||||
|
|
||||||
|
- Model: Llama-3.1-8B-Instruct
|
||||||
|
- Context: 64K tokens
|
||||||
|
- Mode: xattn + offload
|
||||||
|
- GPU: A100 40GB
|
||||||
|
|
||||||
|
### 性能对比
|
||||||
|
|
||||||
|
| 指标 | block_size=4096 | block_size=1024 | 变化 |
|
||||||
|
|------|----------------|-----------------|------|
|
||||||
|
| **总时间** | 27.7s | 55.5s | **2x 慢** |
|
||||||
|
| **Chunks 数量** | 16 | 64 | 4x |
|
||||||
|
| **CPU blocks** | 18 | 71 | ~4x |
|
||||||
|
|
||||||
|
### 各阶段耗时分布
|
||||||
|
|
||||||
|
#### block_size=4096
|
||||||
|
|
||||||
|
| 阶段 | 占比 | 总时间 | 平均时间 | 调用次数 |
|
||||||
|
|-----|------|--------|---------|---------|
|
||||||
|
| **xattn_estimate_find_blocks** | **39.7%** | 18.0s | 37.6ms | 480 |
|
||||||
|
| xattn_compute_historical | 4.4% | 2.0s | 4.2ms | 480 |
|
||||||
|
| xattn_estimate_gemm | 3.4% | 1.5s | 3.2ms | 480 |
|
||||||
|
| xattn_compute_current | 0.2% | 113ms | 0.22ms | 512 |
|
||||||
|
| xattn_compute_merge | 0.2% | 96ms | 0.19ms | 512 |
|
||||||
|
| xattn_estimate_softmax | 0.2% | 88ms | 0.18ms | 480 |
|
||||||
|
|
||||||
|
#### block_size=1024
|
||||||
|
|
||||||
|
| 阶段 | 占比 | 总时间 | 平均时间 | 调用次数 |
|
||||||
|
|-----|------|--------|---------|---------|
|
||||||
|
| **xattn_estimate_gemm** | **23.6%** | 22.6s | 11.4ms | 1984 |
|
||||||
|
| **xattn_compute_historical** | **16.9%** | 16.2s | 8.0ms | 2016 |
|
||||||
|
| xattn_estimate_find_blocks | 1.4% | 1.3s | 0.66ms | 1984 |
|
||||||
|
| xattn_compute_current | 0.5% | 433ms | 0.21ms | 2048 |
|
||||||
|
| xattn_compute_merge | 0.4% | 373ms | 0.18ms | 2048 |
|
||||||
|
| xattn_estimate_softmax | 0.2% | 222ms | 0.11ms | 1984 |
|
||||||
|
|
||||||
|
### 关键发现
|
||||||
|
|
||||||
|
1. **Block size 对性能影响显著**
|
||||||
|
- block_size=1024 比 4096 慢约 2x
|
||||||
|
- 更小的 block size 导致更多的 chunks,增加调用次数
|
||||||
|
|
||||||
|
2. **性能瓶颈随 block size 变化**
|
||||||
|
- **block_size=4096**: 瓶颈是 `find_blocks_chunked` (39.7%)
|
||||||
|
- **block_size=1024**: 瓶颈转移到 `estimate_gemm` (23.6%) 和 `compute_historical` (16.9%)
|
||||||
|
|
||||||
|
3. **Amortization 效应**
|
||||||
|
- 大 block size 虽然单次 `find_blocks` 更慢 (37.6ms vs 0.66ms)
|
||||||
|
- 但调用次数少 (480 vs 1984),总时间反而更少
|
||||||
|
|
||||||
|
4. **find_blocks_chunked 的特殊性**
|
||||||
|
- 该函数主要在 CPU 上执行 block 选择逻辑
|
||||||
|
- 处理更大的数据量时开销显著增加
|
||||||
|
- block_size=4096 时占用 40% 时间,是主要优化目标
|
||||||
|
|
||||||
|
## softmax_fuse_block_sum_kernel 性能分析
|
||||||
|
|
||||||
|
`softmax_fuse_block_sum_kernel_non_causal` 是 XAttention 估计阶段的核心 Triton kernel。
|
||||||
|
|
||||||
|
### Kernel 结构
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 每个 thread block 处理的数据形状
|
||||||
|
工作负载: [block_size, segment_size] # 单个 Q block 对所有 K 的注意力
|
||||||
|
|
||||||
|
# Pass 1: 计算全局 softmax 参数 (m_i, l_i)
|
||||||
|
for iter in range(num_iters): # num_iters = k_len / segment_size
|
||||||
|
X = load [block_size, segment_size]
|
||||||
|
compute max, sum for softmax normalization
|
||||||
|
|
||||||
|
# Pass 2: Normalize + Block Sum
|
||||||
|
for iter in range(num_iters):
|
||||||
|
X = load [block_size, segment_size]
|
||||||
|
X = softmax(X)
|
||||||
|
X = reshape(X, [block_size, segment_size/block_size, block_size])
|
||||||
|
X = sum(X, axis=2) # → [block_size, segment_size/block_size]
|
||||||
|
X = sum(X, axis=0) # → [segment_size/block_size]
|
||||||
|
store output
|
||||||
|
```
|
||||||
|
|
||||||
|
### 性能随 block_size 变化的因素
|
||||||
|
|
||||||
|
| 因素 | 小 block_size (64) | 大 block_size (256) |
|
||||||
|
|------|-------------------|---------------------|
|
||||||
|
| Grid 并行度 | 高 (更多 blocks) | 低 (更少 blocks) |
|
||||||
|
| 寄存器使用 | 低 | 高 (可能 spill) |
|
||||||
|
| L2 Cache 复用 | 差 | 好 |
|
||||||
|
| 输出大小 | 大 | 小 |
|
||||||
|
|
||||||
|
### 典型性能曲线
|
||||||
|
|
||||||
|
```
|
||||||
|
Performance
|
||||||
|
│
|
||||||
|
│ ┌─────┐
|
||||||
|
│ / \
|
||||||
|
│ / \
|
||||||
|
│ / \
|
||||||
|
│ / \
|
||||||
|
└────/───────────────\────────→ block_size
|
||||||
|
64 128 256 512
|
||||||
|
|
||||||
|
最优点通常在 128-256 之间
|
||||||
|
```
|
||||||
|
|
||||||
|
## 优化建议
|
||||||
|
|
||||||
|
1. **优先使用 block_size=4096**
|
||||||
|
- 减少 chunk 数量,降低调度开销
|
||||||
|
- 更好的 amortization 效果
|
||||||
|
|
||||||
|
2. **优化 find_blocks_chunked**
|
||||||
|
- 当前是 block_size=4096 的主要瓶颈
|
||||||
|
- 考虑 GPU 加速或批量处理
|
||||||
|
|
||||||
|
3. **Pipeline 优化**
|
||||||
|
- 利用多 slot 的 ring buffer 实现计算和传输 overlap
|
||||||
|
- 当前已实现,但 find_blocks 是 CPU 操作,无法 overlap
|
||||||
|
|
||||||
|
## 测试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GPU-only 模式 (需要 40GB+ VRAM)
|
||||||
|
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --no-offload --gpu 0
|
||||||
|
|
||||||
|
# Offload 模式,block_size=4096
|
||||||
|
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 4096 --gpu 0
|
||||||
|
|
||||||
|
# Offload 模式,block_size=1024
|
||||||
|
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 1024 --gpu 0
|
||||||
|
|
||||||
|
# 128K context
|
||||||
|
bash scripts/profile_offload.sh --policy xattn --ctx-len 128k --block-size 4096 --gpu 0
|
||||||
|
```
|
||||||
@@ -22,7 +22,7 @@ class Config:
|
|||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
enforce_eager: bool = False
|
enforce_eager: bool = False
|
||||||
hf_config: AutoConfig | None = None
|
hf_config: AutoConfig | None = None
|
||||||
eos: int = -1
|
eos: int | list[int] = -1 # Single EOS token or list of EOS tokens (e.g., GLM-4)
|
||||||
kvcache_block_size: int = 1024
|
kvcache_block_size: int = 1024
|
||||||
num_kvcache_blocks: int = -1
|
num_kvcache_blocks: int = -1
|
||||||
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
||||||
@@ -48,16 +48,20 @@ class Config:
|
|||||||
# XAttention BSA specific parameters
|
# XAttention BSA specific parameters
|
||||||
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
|
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
|
||||||
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
|
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
|
||||||
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1)
|
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
|
||||||
sparse_use_triton: bool = True # Use Triton kernels for estimation
|
sparse_use_triton: bool = True # Use Triton kernels for estimation
|
||||||
sparse_stride: int = 8 # Stride for Q/K downsampling
|
sparse_stride: int = 8 # Stride for Q/K downsampling
|
||||||
|
sparse_chunk_size: int = 16384 # Triton kernel chunk size for estimation
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert os.path.isdir(self.model)
|
assert os.path.isdir(self.model)
|
||||||
assert self.kvcache_block_size % 256 == 0
|
assert self.kvcache_block_size % 256 == 0
|
||||||
assert 1 <= self.tensor_parallel_size <= 8
|
assert 1 <= self.tensor_parallel_size <= 8
|
||||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=True)
|
||||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
# Get max position embeddings (GLM-4 uses seq_length instead of max_position_embeddings)
|
||||||
|
max_pos = getattr(self.hf_config, 'max_position_embeddings',
|
||||||
|
getattr(self.hf_config, 'seq_length', 4096))
|
||||||
|
self.max_model_len = min(self.max_model_len, max_pos)
|
||||||
assert self.max_num_batched_tokens >= self.max_model_len
|
assert self.max_num_batched_tokens >= self.max_model_len
|
||||||
|
|
||||||
# Override torch_dtype if user specified
|
# Override torch_dtype if user specified
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from nanovllm.sampling_params import SamplingParams
|
|||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
from nanovllm.engine.scheduler import Scheduler
|
from nanovllm.engine.scheduler import Scheduler
|
||||||
from nanovllm.engine.model_runner import ModelRunner
|
from nanovllm.engine.model_runner import ModelRunner
|
||||||
from nanovllm.utils.observer import Observer
|
from nanovllm.utils.observer import InferenceObserver
|
||||||
|
from nanovllm.utils.memory_observer import MemoryObserver
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
@@ -29,7 +30,13 @@ class LLMEngine:
|
|||||||
self.ps.append(process)
|
self.ps.append(process)
|
||||||
self.events.append(event)
|
self.events.append(event)
|
||||||
self.model_runner = ModelRunner(config, 0, self.events)
|
self.model_runner = ModelRunner(config, 0, self.events)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True)
|
||||||
|
# Get EOS token(s) from config (may be int or list, e.g., GLM-4 uses list)
|
||||||
|
# Prefer hf_config.eos_token_id which contains full list, fallback to tokenizer
|
||||||
|
eos_from_config = getattr(config.hf_config, 'eos_token_id', None)
|
||||||
|
if eos_from_config is not None:
|
||||||
|
config.eos = eos_from_config
|
||||||
|
else:
|
||||||
config.eos = self.tokenizer.eos_token_id
|
config.eos = self.tokenizer.eos_token_id
|
||||||
# Set Sequence.block_size to match the KV cache block size
|
# Set Sequence.block_size to match the KV cache block size
|
||||||
Sequence.block_size = config.kvcache_block_size
|
Sequence.block_size = config.kvcache_block_size
|
||||||
@@ -58,15 +65,18 @@ class LLMEngine:
|
|||||||
print(f"[DEBUG LLMEngine.step] Mode={mode}, active_sequences={len(seqs)}")
|
print(f"[DEBUG LLMEngine.step] Mode={mode}, active_sequences={len(seqs)}")
|
||||||
|
|
||||||
if not is_prefill:
|
if not is_prefill:
|
||||||
# The end of the prefill mode. Get TTFT.
|
# Decode mode: calculate TPOT from previous decode step
|
||||||
if Observer.ttft_start != 0:
|
if InferenceObserver.tpot_start != 0:
|
||||||
Observer.ttft = perf_counter_ns() - Observer.ttft_start
|
InferenceObserver.tpot = perf_counter_ns() - InferenceObserver.tpot_start
|
||||||
Observer.reset_ttft()
|
InferenceObserver.tpot_start = perf_counter_ns()
|
||||||
# The start of the decode mode. Get TPOT.
|
|
||||||
if Observer.tpot_start != 0:
|
|
||||||
Observer.tpot = perf_counter_ns() - Observer.tpot_start
|
|
||||||
Observer.tpot_start = perf_counter_ns()
|
|
||||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||||
|
|
||||||
|
if is_prefill:
|
||||||
|
# Calculate TTFT after prefill completes (including chunked prefill)
|
||||||
|
if InferenceObserver.ttft_start != 0:
|
||||||
|
InferenceObserver.ttft = perf_counter_ns() - InferenceObserver.ttft_start
|
||||||
|
InferenceObserver.reset_ttft()
|
||||||
self.scheduler.postprocess(seqs, token_ids)
|
self.scheduler.postprocess(seqs, token_ids)
|
||||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||||||
|
|
||||||
@@ -91,7 +101,8 @@ class LLMEngine:
|
|||||||
log_level = os.environ.get('NANOVLLM_LOG_LEVEL', 'INFO')
|
log_level = os.environ.get('NANOVLLM_LOG_LEVEL', 'INFO')
|
||||||
debug_enabled = log_level.upper() == 'DEBUG'
|
debug_enabled = log_level.upper() == 'DEBUG'
|
||||||
|
|
||||||
Observer.complete_reset()
|
InferenceObserver.complete_reset()
|
||||||
|
MemoryObserver.complete_reset()
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||||
if not isinstance(sampling_params, list):
|
if not isinstance(sampling_params, list):
|
||||||
@@ -128,8 +139,8 @@ class LLMEngine:
|
|||||||
pbar.set_postfix({
|
pbar.set_postfix({
|
||||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||||
"Decode": f"{int(decode_throughput)}tok/s",
|
"Decode": f"{int(decode_throughput)}tok/s",
|
||||||
"ttft": f"{float(Observer.ttft) / 1e6}ms",
|
"ttft": f"{float(InferenceObserver.ttft) / 1e6}ms",
|
||||||
"tpot": f"{float(Observer.tpot) / 1e6}ms",
|
"tpot": f"{float(InferenceObserver.tpot) / 1e6}ms",
|
||||||
})
|
})
|
||||||
for seq_id, token_ids in output:
|
for seq_id, token_ids in output:
|
||||||
outputs[seq_id] = token_ids
|
outputs[seq_id] = token_ids
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from nanovllm.config import Config
|
|||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
from nanovllm.models import get_model_class
|
from nanovllm.models import get_model_class
|
||||||
from nanovllm.layers.sampler import GreedySampler
|
from nanovllm.layers.sampler import GreedySampler
|
||||||
|
from nanovllm.layers.graphed_layers import OffloadGraphManager
|
||||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||||
from nanovllm.utils.loader import load_model
|
from nanovllm.utils.loader import load_model
|
||||||
from nanovllm.utils.logger import get_logger
|
from nanovllm.utils.logger import get_logger
|
||||||
@@ -29,6 +30,18 @@ def _find_free_port() -> int:
|
|||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_kv_heads(hf_config) -> int:
|
||||||
|
"""Get number of KV heads from config (handles GLM-4's multi_query_group_num)."""
|
||||||
|
return getattr(hf_config, 'num_key_value_heads',
|
||||||
|
getattr(hf_config, 'multi_query_group_num', hf_config.num_attention_heads))
|
||||||
|
|
||||||
|
|
||||||
|
def get_head_dim(hf_config) -> int:
|
||||||
|
"""Get head dimension from config (handles GLM-4's kv_channels)."""
|
||||||
|
return getattr(hf_config, "head_dim",
|
||||||
|
getattr(hf_config, "kv_channels", hf_config.hidden_size // hf_config.num_attention_heads))
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
|
||||||
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
||||||
@@ -63,6 +76,12 @@ class ModelRunner:
|
|||||||
self.allocate_kv_cache()
|
self.allocate_kv_cache()
|
||||||
if not self.enforce_eager:
|
if not self.enforce_eager:
|
||||||
self.capture_cudagraph()
|
self.capture_cudagraph()
|
||||||
|
|
||||||
|
# Initialize offload graph manager if CPU offload is enabled
|
||||||
|
self.offload_graph_manager = None
|
||||||
|
if config.enable_cpu_offload and not self.enforce_eager:
|
||||||
|
self.init_offload_graph_manager()
|
||||||
|
|
||||||
torch.set_default_device("cpu")
|
torch.set_default_device("cpu")
|
||||||
torch.set_default_dtype(default_dtype)
|
torch.set_default_dtype(default_dtype)
|
||||||
|
|
||||||
@@ -137,8 +156,8 @@ class ModelRunner:
|
|||||||
used = total - free
|
used = total - free
|
||||||
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
|
||||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
head_dim = get_head_dim(hf_config)
|
||||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
||||||
|
|
||||||
# Calculate max GPU blocks based on available memory
|
# Calculate max GPU blocks based on available memory
|
||||||
@@ -195,19 +214,37 @@ class ModelRunner:
|
|||||||
dtype=hf_config.torch_dtype,
|
dtype=hf_config.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize sparse policy if manager has one (CPU offload mode)
|
# Initialize sparse policy if manager has one (works for both CPU offload and GPU-only modes)
|
||||||
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||||
|
# Use CPU blocks for offload mode, GPU blocks for GPU-only mode
|
||||||
|
num_blocks_for_init = config.num_cpu_kvcache_blocks if config.enable_cpu_offload else config.num_kvcache_blocks
|
||||||
self.kvcache_manager.sparse_policy.initialize(
|
self.kvcache_manager.sparse_policy.initialize(
|
||||||
num_layers=hf_config.num_hidden_layers,
|
num_layers=hf_config.num_hidden_layers,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
num_cpu_blocks=num_blocks_for_init,
|
||||||
dtype=hf_config.torch_dtype,
|
dtype=hf_config.torch_dtype,
|
||||||
device=torch.device("cuda"),
|
device=torch.device("cuda"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pre-allocate policy metadata buffers
|
||||||
|
# - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
|
||||||
|
# - GPU-only mode: additionally allocate GQA expansion buffers
|
||||||
|
num_heads = hf_config.num_attention_heads // self.world_size
|
||||||
|
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_seq_len=config.max_model_len,
|
||||||
|
dtype=hf_config.torch_dtype,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
enable_cpu_offload=config.enable_cpu_offload,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log policy info (handle both enum and None cases)
|
||||||
|
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Sparse policy initialized: {config.sparse_policy.name} "
|
f"Sparse policy initialized: {policy_name} "
|
||||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -368,7 +405,16 @@ class ModelRunner:
|
|||||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
set_context(
|
||||||
|
is_prefill=True,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
block_tables=block_tables,
|
||||||
|
kvcache_manager=getattr(self, 'kvcache_manager', None),
|
||||||
|
)
|
||||||
return input_ids, positions
|
return input_ids, positions
|
||||||
|
|
||||||
def prepare_decode(self, seqs: list[Sequence]):
|
def prepare_decode(self, seqs: list[Sequence]):
|
||||||
@@ -397,7 +443,13 @@ class ModelRunner:
|
|||||||
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
# Use GPU physical block tables for attention
|
# Use GPU physical block tables for attention
|
||||||
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
||||||
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
set_context(
|
||||||
|
is_prefill=False,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
context_lens=context_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
kvcache_manager=self.kvcache_manager,
|
||||||
|
)
|
||||||
return input_ids, positions
|
return input_ids, positions
|
||||||
|
|
||||||
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
|
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
|
||||||
@@ -536,6 +588,13 @@ class ModelRunner:
|
|||||||
break
|
break
|
||||||
|
|
||||||
#> Run model forward
|
#> Run model forward
|
||||||
|
# Use graph-optimized forward if available (chunk_size == block_size), otherwise eager mode
|
||||||
|
if (hasattr(self, 'prefill_graph_manager') and
|
||||||
|
self.prefill_graph_manager is not None and
|
||||||
|
self.prefill_graph_manager.captured and
|
||||||
|
input_ids.shape[0] == self.block_size):
|
||||||
|
logits = self.run_prefill_with_offload_graph(input_ids, positions)
|
||||||
|
else:
|
||||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
@@ -657,6 +716,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run model forward pass
|
# Run model forward pass
|
||||||
|
# TODO: Phase 5 decode graph needs shape fix, use eager mode for now
|
||||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
@@ -698,7 +758,13 @@ class ModelRunner:
|
|||||||
|
|
||||||
for bs in reversed(self.graph_bs):
|
for bs in reversed(self.graph_bs):
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
set_context(
|
||||||
|
is_prefill=False,
|
||||||
|
slot_mapping=slot_mapping[:bs],
|
||||||
|
context_lens=context_lens[:bs],
|
||||||
|
block_tables=block_tables[:bs],
|
||||||
|
kvcache_manager=self.kvcache_manager,
|
||||||
|
)
|
||||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||||
with torch.cuda.graph(graph, self.graph_pool):
|
with torch.cuda.graph(graph, self.graph_pool):
|
||||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||||
@@ -716,3 +782,151 @@ class ModelRunner:
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def init_offload_graph_manager(self):
|
||||||
|
"""
|
||||||
|
Initialize and capture CUDA Graphs for offload path (Prefill + Decode).
|
||||||
|
|
||||||
|
Phase 5 Design:
|
||||||
|
- Creates N+2 graphs for both Prefill and Decode
|
||||||
|
- Decode graphs: seq_len=1
|
||||||
|
- Prefill graphs: seq_len=chunk_size (block_size)
|
||||||
|
|
||||||
|
Graph structure per mode:
|
||||||
|
- EmbedGraph: embed_tokens
|
||||||
|
- FirstGraph: input_norm → qkv_proj → rotary
|
||||||
|
- InterGraph[i]: o_proj → post_norm → mlp → input_norm → qkv_proj → rotary (N-1 graphs)
|
||||||
|
- LastGraph: o_proj → post_norm → mlp → final_norm
|
||||||
|
"""
|
||||||
|
hf_config = self.config.hf_config
|
||||||
|
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
|
||||||
|
head_dim = get_head_dim(hf_config)
|
||||||
|
|
||||||
|
# Create Decode Graph Manager (seq_len=1)
|
||||||
|
self.decode_graph_manager = OffloadGraphManager(
|
||||||
|
model=self.model,
|
||||||
|
seq_len=1,
|
||||||
|
hidden_size=hf_config.hidden_size,
|
||||||
|
num_heads=hf_config.num_attention_heads // self.world_size,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=hf_config.torch_dtype,
|
||||||
|
)
|
||||||
|
self.decode_graph_manager.capture_all()
|
||||||
|
|
||||||
|
# Create Prefill Graph Manager (seq_len=chunk_size)
|
||||||
|
chunk_size = self.block_size # chunk_size = block_size = 1024
|
||||||
|
self.prefill_graph_manager = OffloadGraphManager(
|
||||||
|
model=self.model,
|
||||||
|
seq_len=chunk_size,
|
||||||
|
hidden_size=hf_config.hidden_size,
|
||||||
|
num_heads=hf_config.num_attention_heads // self.world_size,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=hf_config.torch_dtype,
|
||||||
|
)
|
||||||
|
self.prefill_graph_manager.capture_all()
|
||||||
|
|
||||||
|
# Legacy compatibility (for backward compatibility)
|
||||||
|
self.offload_graph_manager = self.decode_graph_manager
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Offload CUDA Graphs captured: {self.decode_graph_manager.num_graphs} decode graphs + "
|
||||||
|
f"{self.prefill_graph_manager.num_graphs} prefill graphs "
|
||||||
|
f"({self.decode_graph_manager.num_layers} layers)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_model_with_offload_graph(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Run decode with Phase 5 CUDA Graph optimization.
|
||||||
|
|
||||||
|
Graph coverage (~70-80% of computation):
|
||||||
|
- GRAPH_EMBED: embed_tokens
|
||||||
|
- GRAPH_FIRST: input_norm_0 → qkv_proj_0 → rotary_0
|
||||||
|
- GRAPH_INTER_i: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
|
||||||
|
- GRAPH_LAST: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
|
||||||
|
|
||||||
|
EAGER (only attention core with offload):
|
||||||
|
- attn.forward(q, k, v) for each layer
|
||||||
|
"""
|
||||||
|
gm = self.decode_graph_manager
|
||||||
|
layers = self.model.model.layers
|
||||||
|
num_layers = len(layers)
|
||||||
|
use_graph = input_ids.shape[0] == 1 # Only use graph for batch=1
|
||||||
|
|
||||||
|
# GRAPH_EMBED: embed_tokens
|
||||||
|
hidden_states = gm.embed_graph(input_ids, use_graph=use_graph)
|
||||||
|
|
||||||
|
# GRAPH_FIRST: input_norm_0 → qkv_proj_0 → rotary_0
|
||||||
|
q, k, v, residual = gm.first_graph(hidden_states, positions, use_graph=use_graph)
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
# EAGER: Attention core only (with offload)
|
||||||
|
# Note: attn.forward already handles store_kvcache internally
|
||||||
|
attn_output = layers[i].self_attn.attn(q, k, v)
|
||||||
|
# attn.forward returns [batch, 1, num_heads, head_dim] for decode
|
||||||
|
# graph expects [seq_len, num_heads, head_dim], so squeeze to [1, heads, dim]
|
||||||
|
if attn_output.dim() == 4:
|
||||||
|
attn_output = attn_output.squeeze(0).squeeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
if i < num_layers - 1:
|
||||||
|
# GRAPH_INTER_i: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
|
||||||
|
q, k, v, residual = gm.inter_graphs[i](
|
||||||
|
attn_output, residual, positions, use_graph=use_graph
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# GRAPH_LAST: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
|
||||||
|
hidden_states = gm.last_graph(attn_output, residual, use_graph=use_graph)
|
||||||
|
|
||||||
|
return self.model.compute_logits(hidden_states)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_prefill_with_offload_graph(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Run chunked prefill with Phase 5 CUDA Graph optimization.
|
||||||
|
|
||||||
|
Graph coverage (~70-80% of computation):
|
||||||
|
- GRAPH_EMBED: embed_tokens
|
||||||
|
- GRAPH_FIRST: input_norm_0 → qkv_proj_0 → rotary_0
|
||||||
|
- GRAPH_INTER_i: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
|
||||||
|
- GRAPH_LAST: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
|
||||||
|
|
||||||
|
EAGER (only attention core with offload):
|
||||||
|
- attn.forward(q, k, v) for each layer
|
||||||
|
"""
|
||||||
|
gm = self.prefill_graph_manager
|
||||||
|
layers = self.model.model.layers
|
||||||
|
num_layers = len(layers)
|
||||||
|
use_graph = input_ids.shape[0] == self.block_size # Only use graph for chunk_size
|
||||||
|
|
||||||
|
# GRAPH_EMBED: embed_tokens
|
||||||
|
hidden_states = gm.embed_graph(input_ids, use_graph=use_graph)
|
||||||
|
|
||||||
|
# GRAPH_FIRST: input_norm_0 → qkv_proj_0 → rotary_0
|
||||||
|
q, k, v, residual = gm.first_graph(hidden_states, positions, use_graph=use_graph)
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
# EAGER: Attention core only (with offload)
|
||||||
|
# Note: attn.forward already handles store_kvcache internally
|
||||||
|
attn_output = layers[i].self_attn.attn(q, k, v)
|
||||||
|
|
||||||
|
if i < num_layers - 1:
|
||||||
|
# GRAPH_INTER_i: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
|
||||||
|
q, k, v, residual = gm.inter_graphs[i](
|
||||||
|
attn_output, residual, positions, use_graph=use_graph
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# GRAPH_LAST: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
|
||||||
|
hidden_states = gm.last_graph(attn_output, residual, use_graph=use_graph)
|
||||||
|
|
||||||
|
return self.model.compute_logits(hidden_states)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from nanovllm.config import Config
|
from nanovllm.config import Config
|
||||||
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
||||||
from nanovllm.utils.observer import Observer
|
from nanovllm.utils.observer import InferenceObserver
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanovllm.kvcache import KVCacheManager
|
from nanovllm.kvcache import KVCacheManager
|
||||||
@@ -15,7 +15,9 @@ class Scheduler:
|
|||||||
def __init__(self, config: Config, kvcache_manager: "KVCacheManager"):
|
def __init__(self, config: Config, kvcache_manager: "KVCacheManager"):
|
||||||
self.max_num_seqs = config.max_num_seqs
|
self.max_num_seqs = config.max_num_seqs
|
||||||
self.max_num_batched_tokens = config.max_num_batched_tokens
|
self.max_num_batched_tokens = config.max_num_batched_tokens
|
||||||
self.eos = config.eos
|
# Convert EOS to set for efficient lookup (supports single int or list)
|
||||||
|
eos = config.eos
|
||||||
|
self.eos_set = set(eos) if isinstance(eos, list) else {eos}
|
||||||
self.kvcache_manager = kvcache_manager
|
self.kvcache_manager = kvcache_manager
|
||||||
self.waiting: deque[Sequence] = deque()
|
self.waiting: deque[Sequence] = deque()
|
||||||
self.running: deque[Sequence] = deque()
|
self.running: deque[Sequence] = deque()
|
||||||
@@ -32,8 +34,8 @@ class Scheduler:
|
|||||||
num_seqs = 0
|
num_seqs = 0
|
||||||
num_batched_tokens = 0
|
num_batched_tokens = 0
|
||||||
while self.waiting and num_seqs < self.max_num_seqs:
|
while self.waiting and num_seqs < self.max_num_seqs:
|
||||||
if Observer.ttft_start == 0:
|
if InferenceObserver.ttft_start == 0:
|
||||||
Observer.ttft_start = perf_counter_ns()
|
InferenceObserver.ttft_start = perf_counter_ns()
|
||||||
seq = self.waiting[0]
|
seq = self.waiting[0]
|
||||||
|
|
||||||
# Check if sequence is too large
|
# Check if sequence is too large
|
||||||
@@ -94,7 +96,7 @@ class Scheduler:
|
|||||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||||
for seq, token_id in zip(seqs, token_ids):
|
for seq, token_id in zip(seqs, token_ids):
|
||||||
seq.append_token(token_id)
|
seq.append_token(token_id)
|
||||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
if (not seq.ignore_eos and token_id in self.eos_set) or seq.num_completion_tokens == seq.max_tokens:
|
||||||
seq.status = SequenceStatus.FINISHED
|
seq.status = SequenceStatus.FINISHED
|
||||||
self.kvcache_manager.deallocate(seq)
|
self.kvcache_manager.deallocate(seq)
|
||||||
self.running.remove(seq)
|
self.running.remove(seq)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
Factory function to create the appropriate KV cache manager.
|
Factory function to create the appropriate KV cache manager.
|
||||||
|
|
||||||
Decision logic:
|
Decision logic:
|
||||||
1. If enable_cpu_offload=False: use GPUOnlyManager
|
1. If enable_cpu_offload=False: use GPUOnlyManager (optionally with sparse policy)
|
||||||
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
|
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
|
||||||
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
|
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
|
||||||
|
|
||||||
@@ -37,9 +37,44 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
"""
|
"""
|
||||||
if not getattr(config, 'enable_cpu_offload', False):
|
if not getattr(config, 'enable_cpu_offload', False):
|
||||||
# Default: pure GPU mode
|
# Default: pure GPU mode
|
||||||
|
# Check if sparse policy is requested for GPU-only mode
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
sparse_policy_type = getattr(config, 'sparse_policy', None)
|
||||||
|
# Handle None case - use FULL as default
|
||||||
|
if sparse_policy_type is None:
|
||||||
|
sparse_policy_type = SparsePolicyType.FULL
|
||||||
|
|
||||||
|
sparse_policy = None
|
||||||
|
if sparse_policy_type != SparsePolicyType.FULL:
|
||||||
|
# Create sparse policy for GPU-only mode
|
||||||
|
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||||
|
|
||||||
|
policy_kwargs = {}
|
||||||
|
if sparse_policy_type == SparsePolicyType.QUEST:
|
||||||
|
policy_kwargs = {
|
||||||
|
'topk_blocks': getattr(config, 'sparse_topk_blocks', 8),
|
||||||
|
'threshold_blocks': getattr(config, 'sparse_threshold_blocks', 4),
|
||||||
|
}
|
||||||
|
elif sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
||||||
|
policy_kwargs = {
|
||||||
|
'block_size': getattr(config, 'sparse_block_size', 128),
|
||||||
|
'samples_per_chunk': getattr(config, 'sparse_samples_per_chunk', 128),
|
||||||
|
'threshold': getattr(config, 'sparse_threshold', 0.9),
|
||||||
|
'use_triton': getattr(config, 'sparse_use_triton', True),
|
||||||
|
'stride': getattr(config, 'sparse_stride', 8),
|
||||||
|
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
|
||||||
|
}
|
||||||
|
|
||||||
|
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
|
||||||
|
else:
|
||||||
|
# FULL policy for GPU-only mode - always create for consistent API
|
||||||
|
from nanovllm.kvcache.sparse import FullAttentionPolicy
|
||||||
|
sparse_policy = FullAttentionPolicy()
|
||||||
|
|
||||||
return GPUOnlyManager(
|
return GPUOnlyManager(
|
||||||
num_blocks=config.num_kvcache_blocks,
|
num_blocks=config.num_kvcache_blocks,
|
||||||
block_size=config.kvcache_block_size,
|
block_size=config.kvcache_block_size,
|
||||||
|
sparse_policy=sparse_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CPU offload is enabled
|
# CPU offload is enabled
|
||||||
@@ -79,6 +114,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
'threshold': getattr(config, 'sparse_threshold', 0.9),
|
'threshold': getattr(config, 'sparse_threshold', 0.9),
|
||||||
'use_triton': getattr(config, 'sparse_use_triton', True),
|
'use_triton': getattr(config, 'sparse_use_triton', True),
|
||||||
'stride': getattr(config, 'sparse_stride', 8),
|
'stride': getattr(config, 'sparse_stride', 8),
|
||||||
|
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
|
||||||
}
|
}
|
||||||
|
|
||||||
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
|
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
|
||||||
|
|||||||
@@ -7,13 +7,16 @@ the KVCacheManager interface.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import List, Tuple, Dict, Optional
|
from typing import List, Tuple, Dict, Optional, TYPE_CHECKING
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanovllm.kvcache.sparse.policy import SparsePolicy
|
||||||
|
|
||||||
|
|
||||||
class Block:
|
class Block:
|
||||||
"""Physical block in GPU memory."""
|
"""Physical block in GPU memory."""
|
||||||
@@ -50,17 +53,28 @@ class GPUOnlyManager(KVCacheManager):
|
|||||||
all data stays on GPU at fixed addresses.
|
all data stays on GPU at fixed addresses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_blocks: int, block_size: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
sparse_policy: Optional["SparsePolicy"] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize GPU-only manager.
|
Initialize GPU-only manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_blocks: Total number of blocks to manage
|
num_blocks: Total number of blocks to manage
|
||||||
block_size: Tokens per block (default 256)
|
block_size: Tokens per block (default 256)
|
||||||
|
sparse_policy: Optional sparse attention policy for GPU-only mode
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self._num_blocks = num_blocks
|
self._num_blocks = num_blocks
|
||||||
|
|
||||||
|
# Sparse policy for GPU-only mode (optional)
|
||||||
|
self.sparse_policy = sparse_policy
|
||||||
|
# No offload engine in GPU-only mode
|
||||||
|
self.offload_engine = None
|
||||||
|
|
||||||
# Block metadata
|
# Block metadata
|
||||||
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
||||||
|
|
||||||
|
|||||||
@@ -231,6 +231,9 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
seq.num_cached_tokens = 0
|
seq.num_cached_tokens = 0
|
||||||
seq.block_table.clear()
|
seq.block_table.clear()
|
||||||
|
|
||||||
|
# Clear decode position tracking for this sequence
|
||||||
|
self.clear_decode_tracking(seq)
|
||||||
|
|
||||||
# Reset OffloadEngine state to prevent request-to-request contamination
|
# Reset OffloadEngine state to prevent request-to-request contamination
|
||||||
# This clears all KV buffers and pending async events
|
# This clears all KV buffers and pending async events
|
||||||
if self.offload_engine is not None:
|
if self.offload_engine is not None:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ Key design principles for CUDA Graph compatibility:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda.nvtx
|
import torch.cuda.nvtx
|
||||||
|
import nvtx
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Dict, List, Tuple, Optional
|
from typing import Dict, List, Tuple, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -16,6 +17,7 @@ from dataclasses import dataclass
|
|||||||
from nanovllm.kvcache.kernels import gathered_copy_kv
|
from nanovllm.kvcache.kernels import gathered_copy_kv
|
||||||
from nanovllm.comm import memcpy_2d_async
|
from nanovllm.comm import memcpy_2d_async
|
||||||
from nanovllm.utils.logger import get_logger
|
from nanovllm.utils.logger import get_logger
|
||||||
|
from nanovllm.utils.memory_observer import MemoryObserver
|
||||||
|
|
||||||
# Import for type hints only (avoid circular import)
|
# Import for type hints only (avoid circular import)
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@@ -256,6 +258,7 @@ class OffloadEngine:
|
|||||||
- GPU ring buffer slots (k_cache_gpu, v_cache_gpu)
|
- GPU ring buffer slots (k_cache_gpu, v_cache_gpu)
|
||||||
- Per-layer decode buffers (decode_k_buffer, decode_v_buffer)
|
- Per-layer decode buffers (decode_k_buffer, decode_v_buffer)
|
||||||
- Per-layer prefill buffers (prefill_k/v_buffer)
|
- Per-layer prefill buffers (prefill_k/v_buffer)
|
||||||
|
- CPU KV cache (k_cache_cpu, v_cache_cpu)
|
||||||
- All pending async transfer events
|
- All pending async transfer events
|
||||||
"""
|
"""
|
||||||
# Clear GPU ring buffer slots
|
# Clear GPU ring buffer slots
|
||||||
@@ -270,6 +273,11 @@ class OffloadEngine:
|
|||||||
self.prefill_k_buffer.zero_()
|
self.prefill_k_buffer.zero_()
|
||||||
self.prefill_v_buffer.zero_()
|
self.prefill_v_buffer.zero_()
|
||||||
|
|
||||||
|
# Clear CPU cache (critical: prevents cross-request state leakage)
|
||||||
|
# This ensures KV cache from previous requests doesn't contaminate new requests
|
||||||
|
self.k_cache_cpu.zero_()
|
||||||
|
self.v_cache_cpu.zero_()
|
||||||
|
|
||||||
# Clear all pending async transfer events
|
# Clear all pending async transfer events
|
||||||
self.pending_events.clear()
|
self.pending_events.clear()
|
||||||
|
|
||||||
@@ -368,7 +376,10 @@ class OffloadEngine:
|
|||||||
"""
|
"""
|
||||||
self.ring_slot_compute_done[slot_idx].record()
|
self.ring_slot_compute_done[slot_idx].record()
|
||||||
|
|
||||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
def load_to_slot_layer(
|
||||||
|
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Async load a single CPU block to a ring buffer slot for one layer.
|
Async load a single CPU block to a ring buffer slot for one layer.
|
||||||
|
|
||||||
@@ -383,13 +394,21 @@ class OffloadEngine:
|
|||||||
slot_idx: Target GPU slot index
|
slot_idx: Target GPU slot index
|
||||||
layer_id: Layer index to load (for CPU cache indexing)
|
layer_id: Layer index to load (for CPU cache indexing)
|
||||||
cpu_block_id: Source CPU block ID
|
cpu_block_id: Source CPU block ID
|
||||||
|
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
|
||||||
|
is_prefill: True if in prefill phase, False if in decode phase (for MemoryObserver)
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||||
|
|
||||||
# Use per-slot stream for parallel transfers across different slots
|
# Use per-slot stream for parallel transfers across different slots
|
||||||
stream = self.slot_transfer_streams[slot_idx]
|
stream = self.slot_transfer_streams[slot_idx]
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
|
# Build NVTX label with optional chunk info
|
||||||
|
if chunk_idx >= 0:
|
||||||
|
nvtx_label = f"H2D: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||||
|
else:
|
||||||
|
nvtx_label = f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||||
|
|
||||||
|
nvtx.push_range(message=nvtx_label, color="blue")
|
||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(stream):
|
||||||
# Wait for previous compute on this slot to complete before overwriting
|
# Wait for previous compute on this slot to complete before overwriting
|
||||||
# This prevents data race: transfer must not start until attention finishes reading
|
# This prevents data race: transfer must not start until attention finishes reading
|
||||||
@@ -407,7 +426,66 @@ class OffloadEngine:
|
|||||||
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||||
)
|
)
|
||||||
self.ring_slot_ready[slot_idx].record(stream)
|
self.ring_slot_ready[slot_idx].record(stream)
|
||||||
torch.cuda.nvtx.range_pop()
|
nvtx.pop_range()
|
||||||
|
|
||||||
|
# Record H2D transfer: K + V = 2 * block_bytes
|
||||||
|
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=is_prefill)
|
||||||
|
|
||||||
|
def load_k_only_to_slot_layer(
|
||||||
|
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Async load only K (not V) from CPU block to GPU slot.
|
||||||
|
|
||||||
|
Used by XAttention estimate phase which only needs K for attention score
|
||||||
|
computation. Saves 50% communication compared to loading K+V.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slot_idx: Target GPU slot index
|
||||||
|
layer_id: Layer index to load (for CPU cache indexing)
|
||||||
|
cpu_block_id: Source CPU block ID
|
||||||
|
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
|
||||||
|
is_prefill: True if in prefill phase, False if in decode phase
|
||||||
|
"""
|
||||||
|
logger.debug(f"Ring load K-only: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||||
|
|
||||||
|
stream = self.slot_transfer_streams[slot_idx]
|
||||||
|
|
||||||
|
if chunk_idx >= 0:
|
||||||
|
nvtx_label = f"H2D K-only: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||||
|
else:
|
||||||
|
nvtx_label = f"H2D K-only: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||||
|
|
||||||
|
nvtx.push_range(message=nvtx_label, color="cyan")
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
stream.wait_event(self.ring_slot_compute_done[slot_idx])
|
||||||
|
stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||||
|
|
||||||
|
# Only copy K, not V
|
||||||
|
self.k_cache_gpu[slot_idx].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.ring_slot_ready[slot_idx].record(stream)
|
||||||
|
nvtx.pop_range()
|
||||||
|
|
||||||
|
# Record H2D transfer: K only = 1 * block_bytes
|
||||||
|
MemoryObserver.record_h2d(self.gpu_block_bytes, is_prefill=is_prefill)
|
||||||
|
|
||||||
|
def get_k_for_slot(self, slot_idx: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Get only K for a ring buffer slot (no V).
|
||||||
|
|
||||||
|
Used by XAttention estimate phase which only needs K for attention
|
||||||
|
score computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slot_idx: GPU slot index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
k_cache, shape: [1, block_size, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
return self.k_cache_gpu[slot_idx].unsqueeze(0)
|
||||||
|
|
||||||
def wait_slot_layer(self, slot_idx: int) -> None:
|
def wait_slot_layer(self, slot_idx: int) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -464,7 +542,8 @@ class OffloadEngine:
|
|||||||
else:
|
else:
|
||||||
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
nvtx_label = f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]"
|
||||||
|
nvtx.push_range(message=nvtx_label, color="green")
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
# Wait for both compute_stream and default stream
|
# Wait for both compute_stream and default stream
|
||||||
# - compute_stream: for flash attention operations
|
# - compute_stream: for flash attention operations
|
||||||
@@ -480,7 +559,10 @@ class OffloadEngine:
|
|||||||
self.v_cache_gpu[slot_idx], non_blocking=True
|
self.v_cache_gpu[slot_idx], non_blocking=True
|
||||||
)
|
)
|
||||||
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
|
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||||
torch.cuda.nvtx.range_pop()
|
nvtx.pop_range()
|
||||||
|
|
||||||
|
# Record D2H transfer: K + V = 2 * block_bytes
|
||||||
|
MemoryObserver.record_d2h(2 * self.gpu_block_bytes, is_prefill=is_prefill)
|
||||||
|
|
||||||
# ----- KV access methods for ring buffer -----
|
# ----- KV access methods for ring buffer -----
|
||||||
|
|
||||||
@@ -696,6 +778,69 @@ class OffloadEngine:
|
|||||||
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
|
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
|
def write_to_prefill_buffer(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
k: Tensor,
|
||||||
|
v: Tensor,
|
||||||
|
chunk_idx: int = -1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Write KV tensors to prefill buffer (D2D copy within GPU).
|
||||||
|
|
||||||
|
This is called during chunked prefill to store current chunk's KV
|
||||||
|
before computing attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
k: Key tensor [num_tokens, kv_heads, head_dim]
|
||||||
|
v: Value tensor [num_tokens, kv_heads, head_dim]
|
||||||
|
chunk_idx: Current chunk index for NVTX labeling (-1 = not specified)
|
||||||
|
"""
|
||||||
|
num_tokens = k.shape[0]
|
||||||
|
|
||||||
|
# Build NVTX label
|
||||||
|
if chunk_idx >= 0:
|
||||||
|
nvtx_label = f"D2D: L{layer_id} Chunk{chunk_idx} WritePrefillBuffer"
|
||||||
|
else:
|
||||||
|
nvtx_label = f"D2D: L{layer_id} WritePrefillBuffer"
|
||||||
|
|
||||||
|
torch.cuda.nvtx.range_push(nvtx_label)
|
||||||
|
self.prefill_k_buffer[layer_id, :num_tokens].copy_(k)
|
||||||
|
self.prefill_v_buffer[layer_id, :num_tokens].copy_(v)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
# Record D2D transfer: K + V
|
||||||
|
transfer_bytes = 2 * k.numel() * k.element_size()
|
||||||
|
MemoryObserver.record_d2d(transfer_bytes)
|
||||||
|
|
||||||
|
def write_to_decode_buffer(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
pos_in_block: int,
|
||||||
|
k: Tensor,
|
||||||
|
v: Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Write KV tensors to decode buffer (D2D copy within GPU).
|
||||||
|
|
||||||
|
This is called during chunked decode to store current decode token's KV.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
pos_in_block: Position within the current block
|
||||||
|
k: Key tensor [kv_heads, head_dim] (single token, squeezed)
|
||||||
|
v: Value tensor [kv_heads, head_dim] (single token, squeezed)
|
||||||
|
"""
|
||||||
|
torch.cuda.nvtx.range_push(f"D2D: L{layer_id} Pos{pos_in_block} WriteDecodeBuffer")
|
||||||
|
self.decode_k_buffer[layer_id, pos_in_block].copy_(k)
|
||||||
|
self.decode_v_buffer[layer_id, pos_in_block].copy_(v)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
# Record D2D transfer: K + V (single token)
|
||||||
|
transfer_bytes = 2 * k.numel() * k.element_size()
|
||||||
|
MemoryObserver.record_d2d(transfer_bytes)
|
||||||
|
|
||||||
def offload_prefill_buffer_async(
|
def offload_prefill_buffer_async(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
@@ -723,7 +868,8 @@ class OffloadEngine:
|
|||||||
# Use per-layer stream for parallel offloads
|
# Use per-layer stream for parallel offloads
|
||||||
stream = self.prefill_offload_streams[layer_id]
|
stream = self.prefill_offload_streams[layer_id]
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"AsyncPrefillOffload: L{layer_id}->CPU[{cpu_block_id}]")
|
nvtx_label = f"D2H: PrefillBuffer L{layer_id}->CPU[{cpu_block_id}]"
|
||||||
|
nvtx.push_range(message=nvtx_label, color="orange")
|
||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(stream):
|
||||||
# Wait for compute to finish writing to prefill buffer
|
# Wait for compute to finish writing to prefill buffer
|
||||||
stream.wait_stream(self.compute_stream)
|
stream.wait_stream(self.compute_stream)
|
||||||
@@ -738,7 +884,10 @@ class OffloadEngine:
|
|||||||
|
|
||||||
# Record completion event
|
# Record completion event
|
||||||
self.prefill_offload_events[layer_id].record(stream)
|
self.prefill_offload_events[layer_id].record(stream)
|
||||||
torch.cuda.nvtx.range_pop()
|
nvtx.pop_range()
|
||||||
|
|
||||||
|
# Record D2H transfer: K + V = 2 * block_bytes
|
||||||
|
MemoryObserver.record_d2h(2 * self.gpu_block_bytes, is_prefill=True)
|
||||||
|
|
||||||
def wait_all_prefill_offloads(self) -> None:
|
def wait_all_prefill_offloads(self) -> None:
|
||||||
"""Wait for all prefill buffer offloads to complete."""
|
"""Wait for all prefill buffer offloads to complete."""
|
||||||
@@ -778,6 +927,11 @@ class OffloadEngine:
|
|||||||
v_sample = self.v_cache_cpu[
|
v_sample = self.v_cache_cpu[
|
||||||
layer_id, cpu_block_id, :num_samples
|
layer_id, cpu_block_id, :num_samples
|
||||||
].clone().cuda()
|
].clone().cuda()
|
||||||
|
|
||||||
|
# Record H2D transfer: K + V samples
|
||||||
|
transfer_bytes = 2 * k_sample.numel() * k_sample.element_size()
|
||||||
|
MemoryObserver.record_h2d(transfer_bytes, is_prefill=True)
|
||||||
|
|
||||||
return k_sample, v_sample
|
return k_sample, v_sample
|
||||||
|
|
||||||
def load_block_full_from_cpu(
|
def load_block_full_from_cpu(
|
||||||
@@ -804,4 +958,8 @@ class OffloadEngine:
|
|||||||
v_full = self.v_cache_cpu[
|
v_full = self.v_cache_cpu[
|
||||||
layer_id, cpu_block_id
|
layer_id, cpu_block_id
|
||||||
].clone().cuda()
|
].clone().cuda()
|
||||||
|
|
||||||
|
# Record H2D transfer: K + V full block
|
||||||
|
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=True)
|
||||||
|
|
||||||
return k_full, v_full
|
return k_full, v_full
|
||||||
|
|||||||
@@ -61,6 +61,9 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
|
|||||||
block_size=kwargs.get("block_size", 128),
|
block_size=kwargs.get("block_size", 128),
|
||||||
samples_per_chunk=kwargs.get("samples_per_chunk", 128),
|
samples_per_chunk=kwargs.get("samples_per_chunk", 128),
|
||||||
threshold=kwargs.get("threshold", 0.9),
|
threshold=kwargs.get("threshold", 0.9),
|
||||||
|
stride=kwargs.get("stride", 8),
|
||||||
|
chunk_size=kwargs.get("chunk_size", 16384),
|
||||||
|
use_triton=kwargs.get("use_triton", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -37,15 +37,116 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize with statistics tracking."""
|
||||||
|
self._stats_total_blocks = 0
|
||||||
|
self._stats_num_chunks = 0
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Return all blocks - no sparsity."""
|
"""Return all blocks - no sparsity."""
|
||||||
|
# Update statistics (only for layer 0 to avoid overcounting)
|
||||||
|
if ctx.layer_id == 0 and available_blocks:
|
||||||
|
self._stats_total_blocks += len(available_blocks)
|
||||||
|
self._stats_num_chunks += 1
|
||||||
|
logger.debug(f"[Full] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, density=100.0%")
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
|
def reset_stats(self) -> None:
|
||||||
|
"""Reset density statistics."""
|
||||||
|
self._stats_total_blocks = 0
|
||||||
|
self._stats_num_chunks = 0
|
||||||
|
|
||||||
|
def get_density_stats(self) -> dict:
|
||||||
|
"""Get density statistics."""
|
||||||
|
return {
|
||||||
|
"total_available_blocks": self._stats_total_blocks,
|
||||||
|
"total_selected_blocks": self._stats_total_blocks, # Full = all selected
|
||||||
|
"num_chunks": self._stats_num_chunks,
|
||||||
|
"overall_density": 1.0, # Always 100%
|
||||||
|
}
|
||||||
|
|
||||||
|
def print_density_stats(self) -> None:
|
||||||
|
"""Print density statistics summary."""
|
||||||
|
stats = self.get_density_stats()
|
||||||
|
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
||||||
|
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# GPU-only methods (non-chunked)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GPU-only prefill attention using flash_attn_varlen_func.
|
||||||
|
|
||||||
|
This is the simplest implementation - just call flash attention directly.
|
||||||
|
For sparse policies, this method would implement block selection.
|
||||||
|
"""
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
block_table=block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GPU-only decode attention using flash_attn_with_kvcache.
|
||||||
|
|
||||||
|
This is the simplest implementation - just call flash attention directly.
|
||||||
|
For sparse policies, this method would implement block selection.
|
||||||
|
"""
|
||||||
|
from flash_attn import flash_attn_with_kvcache
|
||||||
|
|
||||||
|
# q is [batch, num_heads, head_dim], need to add seq dim
|
||||||
|
return flash_attn_with_kvcache(
|
||||||
|
q.unsqueeze(1), # [batch, 1, heads, dim]
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
block_table=block_tables,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Chunked offload methods
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
def compute_chunked_prefill(
|
def compute_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@@ -58,16 +159,17 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
current_chunk_idx: int,
|
current_chunk_idx: int,
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute full attention for chunked prefill.
|
Compute full attention for chunked prefill.
|
||||||
|
|
||||||
This method handles the complete chunked prefill flow:
|
This method handles the chunked prefill computation:
|
||||||
1. Get historical blocks
|
1. Load and compute attention to historical chunks (using selected_blocks)
|
||||||
2. Select blocks via select_blocks
|
2. Compute attention to current chunk
|
||||||
3. Load and compute attention to historical chunks
|
3. Merge all results
|
||||||
4. Compute attention to current chunk
|
|
||||||
5. Merge all results
|
Note: Block selection is done by the caller before invoking this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
@@ -80,37 +182,28 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
current_chunk_idx: Current chunk index
|
current_chunk_idx: Current chunk index
|
||||||
seq: Sequence object
|
seq: Sequence object
|
||||||
num_tokens: Number of tokens in current chunk
|
num_tokens: Number of tokens in current chunk
|
||||||
|
selected_blocks: List of CPU block IDs to process (already filtered)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
# Use FlashInfer-based implementations (more optimized)
|
||||||
|
from nanovllm.ops.chunked_attention import (
|
||||||
|
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||||
|
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
|
||||||
|
f"selected_blocks={len(selected_blocks)}")
|
||||||
|
|
||||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
compute_stream = offload_engine.compute_stream
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
# Step 1: Get historical blocks
|
# Use the pre-selected blocks directly
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = selected_blocks
|
||||||
|
|
||||||
# Step 2: Apply select_blocks to filter blocks
|
|
||||||
if cpu_block_table:
|
|
||||||
num_chunks = current_chunk_idx + 1
|
|
||||||
policy_ctx = PolicyContext(
|
|
||||||
query_chunk_idx=current_chunk_idx,
|
|
||||||
num_query_chunks=num_chunks,
|
|
||||||
layer_id=layer_id,
|
|
||||||
query=None, # Prefill typically doesn't use query for selection
|
|
||||||
is_prefill=True,
|
|
||||||
block_size=kvcache_manager.block_size,
|
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
|
||||||
)
|
|
||||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
|
||||||
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
|
|
||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
load_slots = list(range(offload_engine.num_ring_slots))
|
load_slots = list(range(offload_engine.num_ring_slots))
|
||||||
@@ -121,7 +214,8 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
slot = load_slots[0]
|
slot = load_slots[0]
|
||||||
for block_idx in range(num_blocks):
|
for block_idx in range(num_blocks):
|
||||||
cpu_block_id = cpu_block_table[block_idx]
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
# cpu_block_id is the chunk index (block N = chunk N)
|
||||||
|
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||||
offload_engine.wait_slot_layer(slot)
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
@@ -141,7 +235,8 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
num_slots = len(load_slots)
|
num_slots = len(load_slots)
|
||||||
num_preload = min(num_slots, num_blocks)
|
num_preload = min(num_slots, num_blocks)
|
||||||
for i in range(num_preload):
|
for i in range(num_preload):
|
||||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
cpu_block_id = cpu_block_table[i]
|
||||||
|
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||||
|
|
||||||
for block_idx in range(num_blocks):
|
for block_idx in range(num_blocks):
|
||||||
current_slot = load_slots[block_idx % num_slots]
|
current_slot = load_slots[block_idx % num_slots]
|
||||||
@@ -168,7 +263,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
if next_block_idx < num_blocks:
|
if next_block_idx < num_blocks:
|
||||||
next_slot = load_slots[next_block_idx % num_slots]
|
next_slot = load_slots[next_block_idx % num_slots]
|
||||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
|
||||||
|
|
||||||
# Step 4: Compute attention to current chunk (causal mask)
|
# Step 4: Compute attention to current chunk (causal mask)
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
@@ -200,16 +295,17 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
kvcache_manager: "KVCacheManager",
|
kvcache_manager: "KVCacheManager",
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute full attention for chunked decode.
|
Compute full attention for chunked decode.
|
||||||
|
|
||||||
This method handles the complete chunked decode flow:
|
This method handles the chunked decode computation:
|
||||||
1. Get prefilled CPU blocks
|
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
|
||||||
2. Apply select_blocks for block filtering
|
2. Read accumulated decode tokens from decode buffer
|
||||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
3. Merge all results
|
||||||
4. Read accumulated decode tokens from decode buffer
|
|
||||||
5. Merge all results
|
Note: Block selection is done by the caller before invoking this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [batch_size, num_heads, head_dim]
|
q: Query tensor [batch_size, num_heads, head_dim]
|
||||||
@@ -218,49 +314,49 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
offload_engine: OffloadEngine for loading blocks
|
offload_engine: OffloadEngine for loading blocks
|
||||||
kvcache_manager: KVCacheManager for block management
|
kvcache_manager: KVCacheManager for block management
|
||||||
seq: Sequence object
|
seq: Sequence object
|
||||||
|
selected_blocks: List of CPU block IDs to process (already filtered)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention output [batch_size, 1, num_heads, head_dim]
|
Attention output [batch_size, 1, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
# Use FlashInfer-based implementations (more optimized)
|
||||||
|
from nanovllm.ops.chunked_attention import (
|
||||||
|
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||||
|
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||||
|
|
||||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
# Use the pre-selected blocks directly
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = selected_blocks
|
||||||
if layer_id == 0:
|
if layer_id == 0:
|
||||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}")
|
||||||
if not cpu_block_table:
|
if not cpu_block_table:
|
||||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||||
|
|
||||||
# Calculate valid tokens in the last CPU block
|
# Calculate valid tokens in the last CPU block
|
||||||
# CRITICAL: Use original prefill length, not current seq length!
|
# CRITICAL: Use original prefill length, not current seq length!
|
||||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||||
|
# Note: We need to get all prefilled blocks to determine last_block_valid_tokens
|
||||||
block_size = kvcache_manager.block_size
|
block_size = kvcache_manager.block_size
|
||||||
num_prefill_blocks = len(cpu_block_table)
|
all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||||
last_block_valid_tokens = block_size # Last block was exactly full
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|
||||||
# Apply sparse policy (self) for block filtering
|
# Determine if selected_blocks contains the last prefilled block
|
||||||
policy_ctx = PolicyContext(
|
# If not, all selected blocks are full blocks (use block_size as valid tokens)
|
||||||
query_chunk_idx=0,
|
last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None
|
||||||
num_query_chunks=1,
|
selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block)
|
||||||
layer_id=layer_id,
|
effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size
|
||||||
query=q_batched,
|
|
||||||
is_prefill=False,
|
|
||||||
block_size=kvcache_manager.block_size,
|
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
|
||||||
)
|
|
||||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
|
||||||
|
|
||||||
# Use ring buffer pipeline for loading prefilled blocks
|
# Use ring buffer pipeline for loading prefilled blocks
|
||||||
load_slots = offload_engine.decode_load_slots
|
load_slots = offload_engine.decode_load_slots
|
||||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
block_size, last_block_valid_tokens, layer_id, softmax_scale
|
block_size, effective_last_block_tokens, layer_id, softmax_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||||
@@ -319,7 +415,11 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
Loads one block at a time, computes attention, and merges results.
|
Loads one block at a time, computes attention, and merges results.
|
||||||
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
||||||
"""
|
"""
|
||||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
# Use FlashInfer-based implementations (more optimized)
|
||||||
|
from nanovllm.ops.chunked_attention import (
|
||||||
|
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||||
|
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
num_blocks = len(cpu_block_table)
|
num_blocks = len(cpu_block_table)
|
||||||
if num_blocks == 0:
|
if num_blocks == 0:
|
||||||
@@ -335,7 +435,8 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
# Phase 1: Pre-load up to num_slots blocks
|
# Phase 1: Pre-load up to num_slots blocks
|
||||||
num_preload = min(num_slots, num_blocks)
|
num_preload = min(num_slots, num_blocks)
|
||||||
for i in range(num_preload):
|
for i in range(num_preload):
|
||||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
cpu_block_id = cpu_block_table[i]
|
||||||
|
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id, is_prefill=False)
|
||||||
|
|
||||||
# Phase 2: Process blocks with pipeline
|
# Phase 2: Process blocks with pipeline
|
||||||
for block_idx in range(num_blocks):
|
for block_idx in range(num_blocks):
|
||||||
@@ -368,7 +469,8 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
# Start loading next block (pipeline)
|
# Start loading next block (pipeline)
|
||||||
next_block_idx = block_idx + num_slots
|
next_block_idx = block_idx + num_slots
|
||||||
if next_block_idx < num_blocks:
|
if next_block_idx < num_blocks:
|
||||||
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
|
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||||
|
offload_engine.load_to_slot_layer(current_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id, is_prefill=False)
|
||||||
|
|
||||||
# Merge with accumulated
|
# Merge with accumulated
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
|
|||||||
@@ -108,12 +108,45 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def alloc_policy_metadata(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
enable_cpu_offload: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Pre-allocate GPU buffers for policy computation.
|
||||||
|
|
||||||
|
Called by the framework after KV cache allocation. Implementations should
|
||||||
|
use enable_cpu_offload to decide which buffers to allocate:
|
||||||
|
- Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
|
||||||
|
- GPU-only mode: additionally allocate GQA expansion buffers
|
||||||
|
|
||||||
|
This is separate from initialize() which is used for CPU offload metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_heads: Number of query heads
|
||||||
|
num_kv_heads: Number of KV heads (for GQA)
|
||||||
|
head_dim: Dimension per head
|
||||||
|
max_seq_len: Maximum sequence length (for buffer sizing)
|
||||||
|
dtype: Data type (typically float16/bfloat16)
|
||||||
|
device: Target device (cuda)
|
||||||
|
enable_cpu_offload: Whether CPU offload is enabled
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Select which KV blocks to load for the current query chunk.
|
Select which KV blocks to load for the current query chunk.
|
||||||
@@ -130,6 +163,8 @@ class SparsePolicy(ABC):
|
|||||||
to load KV to make selection decisions).
|
to load KV to make selection decisions).
|
||||||
ctx: PolicyContext with information about the current query
|
ctx: PolicyContext with information about the current query
|
||||||
chunk, layer, phase (prefill/decode), etc.
|
chunk, layer, phase (prefill/decode), etc.
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of block IDs to load (must be a subset of available_blocks).
|
List of block IDs to load (must be a subset of available_blocks).
|
||||||
@@ -191,6 +226,87 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# GPU-only methods (non-chunked)
|
||||||
|
# These methods are used when all KV cache is on GPU, no CPU offload needed.
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute GPU-only prefill attention (non-chunked).
|
||||||
|
|
||||||
|
This method is used when all KV cache resides on GPU (no CPU offload).
|
||||||
|
Override this to implement sparse prefill attention for GPU-only mode.
|
||||||
|
Default implementation raises NotImplementedError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [total_q, num_heads, head_dim] query tensor (packed variable length)
|
||||||
|
k: [total_kv, num_kv_heads, head_dim] key tensor
|
||||||
|
v: [total_kv, num_kv_heads, head_dim] value tensor
|
||||||
|
cu_seqlens_q: [batch+1] cumulative sequence lengths for queries
|
||||||
|
cu_seqlens_k: [batch+1] cumulative sequence lengths for keys
|
||||||
|
max_seqlen_q: maximum query sequence length
|
||||||
|
max_seqlen_k: maximum key sequence length
|
||||||
|
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
|
||||||
|
layer_id: transformer layer index
|
||||||
|
block_tables: [batch, max_blocks] paged attention block tables (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[total_q, num_heads, head_dim] attention output
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement compute_prefill for GPU-only mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute GPU-only decode attention (non-chunked).
|
||||||
|
|
||||||
|
This method is used when all KV cache resides on GPU (no CPU offload).
|
||||||
|
Override this to implement sparse decode attention for GPU-only mode.
|
||||||
|
Default implementation raises NotImplementedError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [batch, num_heads, head_dim] query tensor (single token per sequence)
|
||||||
|
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged key cache
|
||||||
|
v_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged value cache
|
||||||
|
cache_seqlens: [batch] sequence lengths in cache
|
||||||
|
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
|
||||||
|
layer_id: transformer layer index
|
||||||
|
block_tables: [batch, max_blocks] paged attention block tables (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch, 1, num_heads, head_dim] attention output
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement compute_decode for GPU-only mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Chunked offload methods (for CPU offload mode)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_chunked_prefill(
|
def compute_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
@@ -204,17 +320,20 @@ class SparsePolicy(ABC):
|
|||||||
current_chunk_idx: int,
|
current_chunk_idx: int,
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute chunked prefill attention (complete flow).
|
Compute chunked prefill attention (complete flow).
|
||||||
|
|
||||||
This is the main entry point for prefill attention computation.
|
This is the main entry point for prefill attention computation.
|
||||||
It defines the complete prefill flow:
|
It defines the complete prefill flow:
|
||||||
1. Get historical blocks
|
1. Load and compute historical blocks via offload_engine (using selected_blocks)
|
||||||
2. Select blocks (call select_blocks)
|
2. Get current chunk KV from offload_engine, compute attention
|
||||||
3. Load and compute historical blocks via offload_engine
|
3. Merge all results
|
||||||
4. Get current chunk KV from offload_engine, compute attention
|
|
||||||
5. Merge all results
|
Note: Block selection (select_blocks) is called by the caller (attention.py)
|
||||||
|
before invoking this method. The selected_blocks parameter contains the
|
||||||
|
filtered block IDs to process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: [seq_len, num_heads, head_dim] query for current chunk
|
q: [seq_len, num_heads, head_dim] query for current chunk
|
||||||
@@ -227,6 +346,7 @@ class SparsePolicy(ABC):
|
|||||||
current_chunk_idx: current chunk index
|
current_chunk_idx: current chunk index
|
||||||
seq: Sequence object
|
seq: Sequence object
|
||||||
num_tokens: number of tokens in current chunk
|
num_tokens: number of tokens in current chunk
|
||||||
|
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[seq_len, num_heads, head_dim] final attention output
|
[seq_len, num_heads, head_dim] final attention output
|
||||||
@@ -242,17 +362,20 @@ class SparsePolicy(ABC):
|
|||||||
offload_engine: "OffloadEngine",
|
offload_engine: "OffloadEngine",
|
||||||
kvcache_manager: "KVCacheManager",
|
kvcache_manager: "KVCacheManager",
|
||||||
seq: "Sequence",
|
seq: "Sequence",
|
||||||
|
selected_blocks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute chunked decode attention (complete flow).
|
Compute chunked decode attention (complete flow).
|
||||||
|
|
||||||
This is the main entry point for decode attention computation.
|
This is the main entry point for decode attention computation.
|
||||||
It defines the complete decode flow:
|
It defines the complete decode flow:
|
||||||
1. Get prefilled blocks from CPU
|
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
|
||||||
2. Select blocks (call select_blocks)
|
2. Read accumulated decode tokens from decode buffer
|
||||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
3. Merge all results
|
||||||
4. Read accumulated decode tokens from decode buffer
|
|
||||||
5. Merge all results
|
Note: Block selection (select_blocks) is called by the caller (attention.py)
|
||||||
|
before invoking this method. The selected_blocks parameter contains the
|
||||||
|
filtered block IDs to process.
|
||||||
|
|
||||||
The decode position information can be computed internally:
|
The decode position information can be computed internally:
|
||||||
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||||
@@ -265,6 +388,7 @@ class SparsePolicy(ABC):
|
|||||||
offload_engine: OffloadEngine for loading blocks
|
offload_engine: OffloadEngine for loading blocks
|
||||||
kvcache_manager: KVCacheManager for block management
|
kvcache_manager: KVCacheManager for block management
|
||||||
seq: Sequence object
|
seq: Sequence object
|
||||||
|
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[batch_size, 1, num_heads, head_dim] final attention output
|
[batch_size, 1, num_heads, head_dim] final attention output
|
||||||
|
|||||||
@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
|
|||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Select Top-K blocks based on query-key similarity bounds.
|
Select Top-K blocks based on query-key similarity bounds.
|
||||||
|
|
||||||
If query is not available (some prefill scenarios), falls back
|
If query is not available (some prefill scenarios), falls back
|
||||||
to loading all blocks.
|
to loading all blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_blocks: List of CPU block IDs
|
||||||
|
offload_engine: OffloadEngine for loading KV (unused in Quest)
|
||||||
|
ctx: PolicyContext with metadata
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected block IDs
|
||||||
"""
|
"""
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
|
|||||||
if n <= self.config.threshold_blocks:
|
if n <= self.config.threshold_blocks:
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
if ctx.query is None:
|
if q is None:
|
||||||
# No query available - cannot compute scores
|
# No query available - cannot compute scores
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Metadata is already on GPU, same device as query
|
# Metadata is already on GPU, same device as query
|
||||||
device = ctx.query.device
|
device = q.device
|
||||||
|
|
||||||
# Compute upper bound scores
|
# Compute upper bound scores
|
||||||
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
# query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
|
||||||
q = ctx.query
|
|
||||||
if q.dim() == 4:
|
if q.dim() == 4:
|
||||||
# Prefill: use mean over sequence length
|
# Prefill: use mean over sequence length
|
||||||
q = q.mean(dim=1) # [1, num_heads, head_dim]
|
q = q.mean(dim=1) # [1, num_heads, head_dim]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,7 @@ from torch import nn
|
|||||||
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -103,50 +104,67 @@ class Attention(nn.Module):
|
|||||||
# This enables fully async offloads since each layer has its own buffer.
|
# This enables fully async offloads since each layer has its own buffer.
|
||||||
offload_engine = context.kvcache_manager.offload_engine
|
offload_engine = context.kvcache_manager.offload_engine
|
||||||
compute_stream = offload_engine.compute_stream
|
compute_stream = offload_engine.compute_stream
|
||||||
|
chunk_idx = context.current_chunk_idx if hasattr(context, 'current_chunk_idx') else -1
|
||||||
|
|
||||||
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
# Write KV to per-layer prefill buffer via offload_engine
|
||||||
# k, v shape: [num_tokens, kv_heads, head_dim]
|
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||||
num_tokens = k.shape[0]
|
#! GPU 2 GPU
|
||||||
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
offload_engine.write_to_prefill_buffer(self.layer_id, k, v, chunk_idx=chunk_idx)
|
||||||
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
|
||||||
elif is_chunked_offload:
|
elif is_chunked_offload:
|
||||||
# Chunked decode mode: use compute_stream for store_kvcache
|
# Chunked decode mode: write KV to per-layer decode buffer via offload_engine
|
||||||
# This ensures proper synchronization with per-layer offload
|
# KV will be written to decode buffer in the decode branch below
|
||||||
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
# No store_kvcache needed - all KV management goes through offload_engine
|
||||||
if k_cache.numel() and v_cache.numel():
|
pass
|
||||||
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
|
|
||||||
# slot_mapping is created with non_blocking=True on default stream, but we use it
|
|
||||||
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
|
|
||||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
|
||||||
else:
|
else:
|
||||||
# Normal mode: store on default stream
|
# Normal mode: store on default stream
|
||||||
if k_cache.numel() and v_cache.numel():
|
if k_cache.numel() and v_cache.numel():
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
|
||||||
|
# Get sparse_policy from kvcache_manager (required, never None after warmup)
|
||||||
|
# During warmup, kvcache_manager is not yet allocated
|
||||||
|
if context.kvcache_manager is None:
|
||||||
|
# Warmup phase: use flash_attn directly
|
||||||
|
if context.is_prefill:
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||||
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||||
|
softmax_scale=self.scale, causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return flash_attn_with_kvcache(
|
||||||
|
q.unsqueeze(1), k_cache, v_cache,
|
||||||
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||||
|
softmax_scale=self.scale, causal=True,
|
||||||
|
)
|
||||||
|
sparse_policy = context.kvcache_manager.sparse_policy
|
||||||
|
assert sparse_policy is not None, "sparse_policy must not be None"
|
||||||
|
|
||||||
if context.is_prefill:
|
if context.is_prefill:
|
||||||
if context.is_chunked_prefill:
|
if context.is_chunked_prefill:
|
||||||
# Chunked prefill: merge attention from previous KV
|
# Chunked prefill: merge attention from previous KV (CPU offload mode)
|
||||||
o = self._chunked_prefill_attention(q, k, v, context)
|
o = self._chunked_prefill_attention(q, k, v, context)
|
||||||
elif context.block_tables is not None: # prefix cache
|
|
||||||
k, v = k_cache, v_cache
|
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
|
||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
|
||||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
|
||||||
else:
|
else:
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
# GPU-only mode: use policy for attention
|
||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
# Use paged attention if block_tables provided, else use k, v directly
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
if context.block_tables is not None:
|
||||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
k_for_attn, v_for_attn = k_cache, v_cache
|
||||||
|
else:
|
||||||
|
k_for_attn, v_for_attn = k, v
|
||||||
|
o = sparse_policy.compute_prefill(
|
||||||
|
q, k_for_attn, v_for_attn,
|
||||||
|
context.cu_seqlens_q, context.cu_seqlens_k,
|
||||||
|
context.max_seqlen_q, context.max_seqlen_k,
|
||||||
|
self.scale, self.layer_id,
|
||||||
|
context.block_tables,
|
||||||
|
)
|
||||||
else: # decode
|
else: # decode
|
||||||
if context.is_chunked_prefill:
|
if context.is_chunked_prefill:
|
||||||
# Chunked decode: need to load all KV from CPU+GPU
|
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
|
||||||
# Store current decode token to per-layer decode buffer
|
# Store current decode token to per-layer decode buffer
|
||||||
# This is needed because GPU cache has no layer dimension,
|
# This is needed because GPU cache has no layer dimension,
|
||||||
# so all layers would overwrite each other in decode_slot.
|
# so all layers would overwrite each other in decode_slot.
|
||||||
@@ -154,13 +172,15 @@ class Attention(nn.Module):
|
|||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
pos_in_block = context.decode_pos_in_block
|
pos_in_block = context.decode_pos_in_block
|
||||||
# k, v shape: [1, kv_heads, head_dim]
|
# k, v shape: [1, kv_heads, head_dim]
|
||||||
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
|
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
||||||
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
|
|
||||||
o = self._chunked_decode_attention(q, k, v, context)
|
o = self._chunked_decode_attention(q, k, v, context)
|
||||||
else:
|
else:
|
||||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
# GPU-only mode: use policy for attention
|
||||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
o = sparse_policy.compute_decode(
|
||||||
softmax_scale=self.scale, causal=True)
|
q, k_cache, v_cache,
|
||||||
|
context.context_lens, self.scale, self.layer_id,
|
||||||
|
context.block_tables,
|
||||||
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def _chunked_prefill_attention(
|
def _chunked_prefill_attention(
|
||||||
@@ -197,11 +217,29 @@ class Attention(nn.Module):
|
|||||||
if sparse_policy is None:
|
if sparse_policy is None:
|
||||||
raise RuntimeError("sparse_policy is required for chunked prefill")
|
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||||||
|
|
||||||
|
# Step 1: Get historical CPU blocks
|
||||||
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
|
||||||
|
# Always call select_blocks even for first chunk (cpu_block_table may be empty)
|
||||||
|
num_chunks = current_chunk_idx + 1
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=current_chunk_idx,
|
||||||
|
num_query_chunks=num_chunks,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
query=q, # Pass query for sparse policies that need it
|
||||||
|
is_prefill=True,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size if cpu_block_table else 0,
|
||||||
|
)
|
||||||
|
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
|
||||||
|
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||||
|
|
||||||
# [DEBUG] Verify execution path
|
# [DEBUG] Verify execution path
|
||||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||||||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||||
|
|
||||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
# Delegate computation to policy with pre-selected blocks
|
||||||
final_o = sparse_policy.compute_chunked_prefill(
|
final_o = sparse_policy.compute_chunked_prefill(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
self.layer_id,
|
self.layer_id,
|
||||||
@@ -211,6 +249,7 @@ class Attention(nn.Module):
|
|||||||
current_chunk_idx,
|
current_chunk_idx,
|
||||||
seq,
|
seq,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
|
selected_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
@@ -258,14 +297,36 @@ class Attention(nn.Module):
|
|||||||
raise RuntimeError("sparse_policy is required for chunked decode")
|
raise RuntimeError("sparse_policy is required for chunked decode")
|
||||||
|
|
||||||
# Check if policy supports decode phase
|
# Check if policy supports decode phase
|
||||||
|
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
|
||||||
if not sparse_policy.supports_decode:
|
if not sparse_policy.supports_decode:
|
||||||
raise RuntimeError(f"{sparse_policy} does not support decode phase")
|
from nanovllm.kvcache.sparse import FullAttentionPolicy
|
||||||
|
sparse_policy = FullAttentionPolicy()
|
||||||
|
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
|
||||||
|
f"falling back to FullAttentionPolicy")
|
||||||
|
|
||||||
|
# Step 1: Get prefilled CPU blocks
|
||||||
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
|
||||||
|
selected_blocks = []
|
||||||
|
if cpu_block_table:
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
query=q, # Pass query for sparse policies that need it
|
||||||
|
is_prefill=False,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
|
||||||
|
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||||
|
|
||||||
# [DEBUG] Verify execution path
|
# [DEBUG] Verify execution path
|
||||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||||||
f"policy={sparse_policy}, layer={self.layer_id}")
|
f"policy={sparse_policy}, layer={self.layer_id}")
|
||||||
|
|
||||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
# Delegate computation to policy with pre-selected blocks
|
||||||
return sparse_policy.compute_chunked_decode(
|
return sparse_policy.compute_chunked_decode(
|
||||||
q,
|
q,
|
||||||
self.layer_id,
|
self.layer_id,
|
||||||
@@ -273,4 +334,5 @@ class Attention(nn.Module):
|
|||||||
offload_engine,
|
offload_engine,
|
||||||
kvcache_manager,
|
kvcache_manager,
|
||||||
seq,
|
seq,
|
||||||
|
selected_blocks,
|
||||||
)
|
)
|
||||||
|
|||||||
572
nanovllm/layers/graphed_layers.py
Normal file
572
nanovllm/layers/graphed_layers.py
Normal file
@@ -0,0 +1,572 @@
|
|||||||
|
"""
|
||||||
|
CUDA Graph wrapped layers for offload optimization.
|
||||||
|
|
||||||
|
This module provides Graph-wrapped versions of non-attention layers
|
||||||
|
to reduce kernel launch overhead in CPU offload path.
|
||||||
|
|
||||||
|
Phase 5 Design:
|
||||||
|
- Supports both Prefill (seq_len=chunk_size) and Decode (seq_len=1)
|
||||||
|
- Extended coverage: embed, input_norm, qkv_proj, rotary, o_proj, post_norm, mlp, final_norm
|
||||||
|
- Only attention core (attn.forward) remains in eager mode
|
||||||
|
|
||||||
|
Graph Structure (N layers):
|
||||||
|
- EmbedGraph: embed_tokens
|
||||||
|
- FirstGraph: input_norm → qkv_proj → rotary
|
||||||
|
- InterGraph[i]: o_proj → post_norm → mlp → input_norm → qkv_proj → rotary (N-1 graphs)
|
||||||
|
- LastGraph: o_proj → post_norm → mlp → final_norm
|
||||||
|
Total: N+2 graphs
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedGraph(nn.Module):
|
||||||
|
"""
|
||||||
|
Graph wrapper for embedding layer.
|
||||||
|
|
||||||
|
Input: input_ids [seq_len]
|
||||||
|
Output: hidden_states [seq_len, hidden_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_tokens: nn.Module,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_tokens = embed_tokens
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# Graph state
|
||||||
|
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
|
self.ids_in: Optional[torch.Tensor] = None
|
||||||
|
self.h_out: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def _compute(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def capture_graph(self, graph_pool=None):
|
||||||
|
"""Capture CUDA Graph."""
|
||||||
|
# Allocate placeholders outside inference_mode
|
||||||
|
self.ids_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
|
||||||
|
self.h_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# Warmup
|
||||||
|
for _ in range(3):
|
||||||
|
h = self._compute(self.ids_in)
|
||||||
|
self.h_out.copy_(h)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(self.graph, pool=graph_pool):
|
||||||
|
h = self._compute(self.ids_in)
|
||||||
|
self.h_out.copy_(h)
|
||||||
|
|
||||||
|
return self.graph.pool() if graph_pool is None else graph_pool
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, use_graph: bool = False) -> torch.Tensor:
|
||||||
|
if use_graph and self.graph is not None and input_ids.shape[0] == self.seq_len:
|
||||||
|
self.ids_in.copy_(input_ids)
|
||||||
|
self.graph.replay()
|
||||||
|
return self.h_out.clone()
|
||||||
|
else:
|
||||||
|
return self._compute(input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class FirstGraph(nn.Module):
|
||||||
|
"""
|
||||||
|
Graph wrapper for first layer pre-attention:
|
||||||
|
input_norm → qkv_proj → split → reshape → rotary
|
||||||
|
|
||||||
|
Input: hidden_states [seq_len, hidden_size], positions [seq_len]
|
||||||
|
Output: q [seq_len, num_heads, head_dim], k [seq_len, num_kv_heads, head_dim],
|
||||||
|
v [seq_len, num_kv_heads, head_dim], residual [seq_len, hidden_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_norm: nn.Module,
|
||||||
|
qkv_proj: nn.Module,
|
||||||
|
rotary_emb: nn.Module,
|
||||||
|
# Shape parameters
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_norm = input_norm
|
||||||
|
self.qkv_proj = qkv_proj
|
||||||
|
self.rotary_emb = rotary_emb
|
||||||
|
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# Split sizes
|
||||||
|
self.q_size = num_heads * head_dim
|
||||||
|
self.kv_size = num_kv_heads * head_dim
|
||||||
|
|
||||||
|
# Graph state
|
||||||
|
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
|
|
||||||
|
def _compute(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
First layer computation:
|
||||||
|
1. input_layernorm (residual = hidden_states for first layer)
|
||||||
|
2. QKV projection
|
||||||
|
3. Split and reshape
|
||||||
|
4. Rotary embedding
|
||||||
|
"""
|
||||||
|
# For first layer, residual = hidden_states (before norm)
|
||||||
|
residual = hidden_states.clone()
|
||||||
|
hidden_states = self.input_norm(hidden_states)
|
||||||
|
|
||||||
|
# QKV projection
|
||||||
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
|
# Reshape
|
||||||
|
q = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
# Rotary embedding
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
|
return q, k, v, residual
|
||||||
|
|
||||||
|
def capture_graph(self, graph_pool=None):
|
||||||
|
"""Capture CUDA Graph."""
|
||||||
|
# Allocate placeholders
|
||||||
|
self.h_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
|
||||||
|
self.pos_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
|
||||||
|
|
||||||
|
self.q_out = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
|
||||||
|
self.k_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
|
||||||
|
self.v_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
|
||||||
|
self.r_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# Warmup
|
||||||
|
for _ in range(3):
|
||||||
|
q, k, v, r = self._compute(self.h_in, self.pos_in)
|
||||||
|
self.q_out.copy_(q)
|
||||||
|
self.k_out.copy_(k)
|
||||||
|
self.v_out.copy_(v)
|
||||||
|
self.r_out.copy_(r)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(self.graph, pool=graph_pool):
|
||||||
|
q, k, v, r = self._compute(self.h_in, self.pos_in)
|
||||||
|
self.q_out.copy_(q)
|
||||||
|
self.k_out.copy_(k)
|
||||||
|
self.v_out.copy_(v)
|
||||||
|
self.r_out.copy_(r)
|
||||||
|
|
||||||
|
return self.graph.pool() if graph_pool is None else graph_pool
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
use_graph: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
if use_graph and self.graph is not None and hidden_states.shape[0] == self.seq_len:
|
||||||
|
self.h_in.copy_(hidden_states)
|
||||||
|
self.pos_in.copy_(positions)
|
||||||
|
self.graph.replay()
|
||||||
|
return self.q_out.clone(), self.k_out.clone(), self.v_out.clone(), self.r_out.clone()
|
||||||
|
else:
|
||||||
|
return self._compute(hidden_states, positions)
|
||||||
|
|
||||||
|
|
||||||
|
class InterGraph(nn.Module):
|
||||||
|
"""
|
||||||
|
Graph wrapper for inter-layer computation:
|
||||||
|
o_proj → post_norm → mlp → input_norm → qkv_proj → rotary
|
||||||
|
|
||||||
|
Merges current layer's post-attention with next layer's pre-attention.
|
||||||
|
|
||||||
|
Input: attn_output [seq_len, num_heads, head_dim], residual [seq_len, hidden_size], positions [seq_len]
|
||||||
|
Output: q [seq_len, num_heads, head_dim], k [seq_len, num_kv_heads, head_dim],
|
||||||
|
v [seq_len, num_kv_heads, head_dim], residual [seq_len, hidden_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# Current layer components
|
||||||
|
o_proj: nn.Module,
|
||||||
|
post_norm: nn.Module,
|
||||||
|
mlp: nn.Module,
|
||||||
|
# Next layer components
|
||||||
|
next_input_norm: nn.Module,
|
||||||
|
next_qkv_proj: nn.Module,
|
||||||
|
next_rotary_emb: nn.Module,
|
||||||
|
# Shape parameters
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Current layer
|
||||||
|
self.o_proj = o_proj
|
||||||
|
self.post_norm = post_norm
|
||||||
|
self.mlp = mlp
|
||||||
|
|
||||||
|
# Next layer
|
||||||
|
self.next_input_norm = next_input_norm
|
||||||
|
self.next_qkv_proj = next_qkv_proj
|
||||||
|
self.next_rotary_emb = next_rotary_emb
|
||||||
|
|
||||||
|
# Shape params
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# Split sizes
|
||||||
|
self.q_size = num_heads * head_dim
|
||||||
|
self.kv_size = num_kv_heads * head_dim
|
||||||
|
|
||||||
|
# Graph state
|
||||||
|
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
|
|
||||||
|
def _compute(
|
||||||
|
self,
|
||||||
|
attn_output: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||||
|
residual: torch.Tensor, # [seq_len, hidden_size]
|
||||||
|
positions: torch.Tensor, # [seq_len]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Inter-layer computation:
|
||||||
|
1. O projection (flatten first)
|
||||||
|
2. Post-attention layernorm + residual
|
||||||
|
3. MLP
|
||||||
|
4. Next layer's input layernorm + residual
|
||||||
|
5. QKV projection
|
||||||
|
6. Split and reshape
|
||||||
|
7. Rotary embedding
|
||||||
|
"""
|
||||||
|
# O projection
|
||||||
|
hidden_states = self.o_proj(attn_output.flatten(1, -1))
|
||||||
|
|
||||||
|
# Post-attention of current layer
|
||||||
|
hidden_states, residual = self.post_norm(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
# Pre-attention of next layer
|
||||||
|
hidden_states, residual = self.next_input_norm(hidden_states, residual)
|
||||||
|
|
||||||
|
# QKV projection
|
||||||
|
qkv = self.next_qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
|
# Reshape
|
||||||
|
q = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
# Rotary embedding
|
||||||
|
q, k = self.next_rotary_emb(positions, q, k)
|
||||||
|
|
||||||
|
return q, k, v, residual
|
||||||
|
|
||||||
|
def capture_graph(self, graph_pool=None):
|
||||||
|
"""Capture CUDA Graph."""
|
||||||
|
# Allocate placeholders
|
||||||
|
self.attn_in = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
|
||||||
|
self.r_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
|
||||||
|
self.pos_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
|
||||||
|
|
||||||
|
self.q_out = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
|
||||||
|
self.k_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
|
||||||
|
self.v_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
|
||||||
|
self.r_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# Warmup
|
||||||
|
for _ in range(3):
|
||||||
|
q, k, v, r = self._compute(self.attn_in, self.r_in, self.pos_in)
|
||||||
|
self.q_out.copy_(q)
|
||||||
|
self.k_out.copy_(k)
|
||||||
|
self.v_out.copy_(v)
|
||||||
|
self.r_out.copy_(r)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(self.graph, pool=graph_pool):
|
||||||
|
q, k, v, r = self._compute(self.attn_in, self.r_in, self.pos_in)
|
||||||
|
self.q_out.copy_(q)
|
||||||
|
self.k_out.copy_(k)
|
||||||
|
self.v_out.copy_(v)
|
||||||
|
self.r_out.copy_(r)
|
||||||
|
|
||||||
|
return self.graph.pool() if graph_pool is None else graph_pool
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
attn_output: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
use_graph: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
if use_graph and self.graph is not None and attn_output.shape[0] == self.seq_len:
|
||||||
|
self.attn_in.copy_(attn_output)
|
||||||
|
self.r_in.copy_(residual)
|
||||||
|
self.pos_in.copy_(positions)
|
||||||
|
self.graph.replay()
|
||||||
|
return self.q_out.clone(), self.k_out.clone(), self.v_out.clone(), self.r_out.clone()
|
||||||
|
else:
|
||||||
|
return self._compute(attn_output, residual, positions)
|
||||||
|
|
||||||
|
|
||||||
|
class LastGraph(nn.Module):
|
||||||
|
"""
|
||||||
|
Graph wrapper for last layer:
|
||||||
|
o_proj → post_norm → mlp → final_norm
|
||||||
|
|
||||||
|
Input: attn_output [seq_len, num_heads, head_dim], residual [seq_len, hidden_size]
|
||||||
|
Output: hidden_states [seq_len, hidden_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
o_proj: nn.Module,
|
||||||
|
post_norm: nn.Module,
|
||||||
|
mlp: nn.Module,
|
||||||
|
final_norm: nn.Module,
|
||||||
|
# Shape parameters
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.o_proj = o_proj
|
||||||
|
self.post_norm = post_norm
|
||||||
|
self.mlp = mlp
|
||||||
|
self.final_norm = final_norm
|
||||||
|
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# Graph state
|
||||||
|
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
|
|
||||||
|
def _compute(
|
||||||
|
self,
|
||||||
|
attn_output: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Last layer computation:
|
||||||
|
1. O projection
|
||||||
|
2. Post-attention layernorm + residual
|
||||||
|
3. MLP
|
||||||
|
4. Final model norm + residual
|
||||||
|
"""
|
||||||
|
hidden_states = self.o_proj(attn_output.flatten(1, -1))
|
||||||
|
hidden_states, residual = self.post_norm(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states, _ = self.final_norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def capture_graph(self, graph_pool=None):
|
||||||
|
"""Capture CUDA Graph."""
|
||||||
|
# Allocate placeholders
|
||||||
|
self.attn_in = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
|
||||||
|
self.r_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
|
||||||
|
self.h_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# Warmup
|
||||||
|
for _ in range(3):
|
||||||
|
h = self._compute(self.attn_in, self.r_in)
|
||||||
|
self.h_out.copy_(h)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(self.graph, pool=graph_pool):
|
||||||
|
h = self._compute(self.attn_in, self.r_in)
|
||||||
|
self.h_out.copy_(h)
|
||||||
|
|
||||||
|
return self.graph.pool() if graph_pool is None else graph_pool
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
attn_output: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
use_graph: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if use_graph and self.graph is not None and attn_output.shape[0] == self.seq_len:
|
||||||
|
self.attn_in.copy_(attn_output)
|
||||||
|
self.r_in.copy_(residual)
|
||||||
|
self.graph.replay()
|
||||||
|
return self.h_out.clone()
|
||||||
|
else:
|
||||||
|
return self._compute(attn_output, residual)
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadGraphManager:
|
||||||
|
"""
|
||||||
|
Manager for all CUDA Graphs in offload path.
|
||||||
|
|
||||||
|
Creates and manages N+2 graphs for N-layer model:
|
||||||
|
- 1 EmbedGraph
|
||||||
|
- 1 FirstGraph
|
||||||
|
- N-1 InterGraphs
|
||||||
|
- 1 LastGraph
|
||||||
|
|
||||||
|
Supports both Prefill and Decode modes via seq_len parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize graph manager from model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The CausalLM model (e.g., LlamaForCausalLM)
|
||||||
|
seq_len: Sequence length (1 for decode, chunk_size for prefill)
|
||||||
|
hidden_size: Model hidden dimension
|
||||||
|
num_heads: Number of attention heads
|
||||||
|
num_kv_heads: Number of KV heads
|
||||||
|
head_dim: Head dimension
|
||||||
|
dtype: Data type for tensors
|
||||||
|
"""
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# Access model layers
|
||||||
|
layers = model.model.layers
|
||||||
|
num_layers = len(layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
# Create EmbedGraph
|
||||||
|
self.embed_graph = EmbedGraph(
|
||||||
|
embed_tokens=model.model.embed_tokens,
|
||||||
|
seq_len=seq_len,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create FirstGraph: input_norm_0 → qkv_proj_0 → rotary_0
|
||||||
|
self.first_graph = FirstGraph(
|
||||||
|
input_norm=layers[0].input_layernorm,
|
||||||
|
qkv_proj=layers[0].self_attn.qkv_proj,
|
||||||
|
rotary_emb=layers[0].self_attn.rotary_emb,
|
||||||
|
seq_len=seq_len,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create InterGraphs: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
|
||||||
|
self.inter_graphs = nn.ModuleList()
|
||||||
|
for i in range(num_layers - 1):
|
||||||
|
self.inter_graphs.append(InterGraph(
|
||||||
|
o_proj=layers[i].self_attn.o_proj,
|
||||||
|
post_norm=layers[i].post_attention_layernorm,
|
||||||
|
mlp=layers[i].mlp,
|
||||||
|
next_input_norm=layers[i + 1].input_layernorm,
|
||||||
|
next_qkv_proj=layers[i + 1].self_attn.qkv_proj,
|
||||||
|
next_rotary_emb=layers[i + 1].self_attn.rotary_emb,
|
||||||
|
seq_len=seq_len,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Create LastGraph: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
|
||||||
|
self.last_graph = LastGraph(
|
||||||
|
o_proj=layers[-1].self_attn.o_proj,
|
||||||
|
post_norm=layers[-1].post_attention_layernorm,
|
||||||
|
mlp=layers[-1].mlp,
|
||||||
|
final_norm=model.model.norm,
|
||||||
|
seq_len=seq_len,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.captured = False
|
||||||
|
self.graph_pool = None
|
||||||
|
|
||||||
|
def capture_all(self):
|
||||||
|
"""Capture all graphs, sharing memory pool."""
|
||||||
|
graph_pool = None
|
||||||
|
|
||||||
|
# Capture embed graph
|
||||||
|
graph_pool = self.embed_graph.capture_graph(graph_pool)
|
||||||
|
|
||||||
|
# Capture first graph
|
||||||
|
graph_pool = self.first_graph.capture_graph(graph_pool)
|
||||||
|
|
||||||
|
# Capture inter-layer graphs
|
||||||
|
for inter_graph in self.inter_graphs:
|
||||||
|
graph_pool = inter_graph.capture_graph(graph_pool)
|
||||||
|
|
||||||
|
# Capture last graph
|
||||||
|
graph_pool = self.last_graph.capture_graph(graph_pool)
|
||||||
|
|
||||||
|
self.graph_pool = graph_pool
|
||||||
|
self.captured = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_graphs(self) -> int:
|
||||||
|
"""Total number of graphs: 1 + 1 + (N-1) + 1 = N+2"""
|
||||||
|
return 1 + 1 + len(self.inter_graphs) + 1
|
||||||
|
|
||||||
|
|
||||||
|
# Legacy compatibility aliases (for gradual migration)
|
||||||
|
FirstLayerGraph = FirstGraph
|
||||||
|
InterLayerGraph = InterGraph
|
||||||
|
LastLayerGraph = LastGraph
|
||||||
@@ -8,12 +8,43 @@ def apply_rotary_emb(
|
|||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Non-interleaved RoPE (used by Llama, Qwen, etc.)"""
|
||||||
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
||||||
y1 = x1 * cos - x2 * sin
|
y1 = x1 * cos - x2 * sin
|
||||||
y2 = x2 * cos + x1 * sin
|
y2 = x2 * cos + x1 * sin
|
||||||
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb_interleaved(
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Interleaved RoPE (used by GLM-4, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [seq_len, num_heads, head_dim]
|
||||||
|
cos: [seq_len, 1, head_dim // 2]
|
||||||
|
sin: [seq_len, 1, head_dim // 2]
|
||||||
|
|
||||||
|
x is reshaped to [seq_len, num_heads, head_dim // 2, 2] where:
|
||||||
|
- x[..., 0] are even positions
|
||||||
|
- x[..., 1] are odd positions
|
||||||
|
"""
|
||||||
|
rot_dim = x.shape[-1]
|
||||||
|
# x_shaped: [seq_len, num_heads, rot_dim // 2, 2]
|
||||||
|
x_shaped = x.float().reshape(*x.shape[:-1], rot_dim // 2, 2)
|
||||||
|
# x_0, x_1: [seq_len, num_heads, rot_dim // 2]
|
||||||
|
x_0 = x_shaped[..., 0]
|
||||||
|
x_1 = x_shaped[..., 1]
|
||||||
|
# cos/sin: [seq_len, 1, rot_dim // 2] - broadcasts to num_heads
|
||||||
|
x_out = torch.stack([
|
||||||
|
x_0 * cos - x_1 * sin,
|
||||||
|
x_1 * cos + x_0 * sin,
|
||||||
|
], dim=-1)
|
||||||
|
return x_out.flatten(-2).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -140,6 +171,76 @@ class Llama3RotaryEmbedding(nn.Module):
|
|||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
class GLM4RotaryEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
GLM-4 RoPE with interleaved rotation and partial rotation.
|
||||||
|
|
||||||
|
GLM-4 uses:
|
||||||
|
- Interleaved rotation (pairs adjacent elements, not first/second half)
|
||||||
|
- rope_ratio to scale base: base = 10000 * rope_ratio
|
||||||
|
- Partial rotation: only rotates first rotary_dim elements, rest pass through
|
||||||
|
- rotary_dim = head_dim // 2 (only half of head_dim is rotated)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head_size = head_size
|
||||||
|
self.rotary_dim = rotary_dim # GLM-4: rotary_dim = head_dim // 2
|
||||||
|
# inv_freq shape: [rotary_dim // 2]
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||||
|
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rotary_dim // 2]
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
# cache shape [max_pos, 1, rotary_dim // 2, 2]
|
||||||
|
cache = torch.stack((cos, sin), dim=-1).unsqueeze_(1)
|
||||||
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply RoPE to query and key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positions: [seq_len]
|
||||||
|
query: [seq_len, num_heads, head_dim]
|
||||||
|
key: [seq_len, num_kv_heads, head_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rotated query and key with same shapes as input.
|
||||||
|
"""
|
||||||
|
cache = self.cos_sin_cache[positions] # [seq_len, 1, rotary_dim // 2, 2]
|
||||||
|
cos = cache[..., 0] # [seq_len, 1, rotary_dim // 2]
|
||||||
|
sin = cache[..., 1] # [seq_len, 1, rotary_dim // 2]
|
||||||
|
|
||||||
|
# Split into rotated and pass-through parts
|
||||||
|
q_rot = query[..., :self.rotary_dim]
|
||||||
|
q_pass = query[..., self.rotary_dim:]
|
||||||
|
k_rot = key[..., :self.rotary_dim]
|
||||||
|
k_pass = key[..., self.rotary_dim:]
|
||||||
|
|
||||||
|
# Apply interleaved RoPE to rotated part
|
||||||
|
q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin)
|
||||||
|
k_rot = apply_rotary_emb_interleaved(k_rot, cos, sin)
|
||||||
|
|
||||||
|
# Concatenate rotated and pass-through parts
|
||||||
|
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||||
|
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||||
|
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
# Cache for RoPE instances (keyed by hashable parameters)
|
# Cache for RoPE instances (keyed by hashable parameters)
|
||||||
_rope_cache: dict[tuple, nn.Module] = {}
|
_rope_cache: dict[tuple, nn.Module] = {}
|
||||||
|
|
||||||
@@ -150,10 +251,11 @@ def get_rope(
|
|||||||
max_position: int,
|
max_position: int,
|
||||||
base: float,
|
base: float,
|
||||||
rope_scaling: dict | None = None,
|
rope_scaling: dict | None = None,
|
||||||
|
is_interleaved: bool = False,
|
||||||
):
|
):
|
||||||
# Create hashable cache key
|
# Create hashable cache key
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
cache_key = (head_size, rotary_dim, max_position, base, None)
|
cache_key = (head_size, rotary_dim, max_position, base, None, is_interleaved)
|
||||||
else:
|
else:
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||||
if rope_type == "llama3":
|
if rope_type == "llama3":
|
||||||
@@ -163,14 +265,18 @@ def get_rope(
|
|||||||
rope_scaling["low_freq_factor"],
|
rope_scaling["low_freq_factor"],
|
||||||
rope_scaling["high_freq_factor"],
|
rope_scaling["high_freq_factor"],
|
||||||
rope_scaling["original_max_position_embeddings"],
|
rope_scaling["original_max_position_embeddings"],
|
||||||
|
is_interleaved,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cache_key = (head_size, rotary_dim, max_position, base, rope_type)
|
cache_key = (head_size, rotary_dim, max_position, base, rope_type, is_interleaved)
|
||||||
|
|
||||||
if cache_key in _rope_cache:
|
if cache_key in _rope_cache:
|
||||||
return _rope_cache[cache_key]
|
return _rope_cache[cache_key]
|
||||||
|
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
|
if is_interleaved:
|
||||||
|
rope = GLM4RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||||
|
else:
|
||||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||||
else:
|
else:
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||||
|
|||||||
@@ -3,7 +3,9 @@
|
|||||||
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
|
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
|
||||||
|
|
||||||
# Import models to trigger registration
|
# Import models to trigger registration
|
||||||
|
from nanovllm.models import qwen2
|
||||||
from nanovllm.models import qwen3
|
from nanovllm.models import qwen3
|
||||||
from nanovllm.models import llama
|
from nanovllm.models import llama
|
||||||
|
from nanovllm.models import glm4
|
||||||
|
|
||||||
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]
|
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]
|
||||||
|
|||||||
235
nanovllm/models/glm4.py
Normal file
235
nanovllm/models/glm4.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""GLM-4 model implementation for nano-vllm."""
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from nanovllm.layers.activation import SiluAndMul
|
||||||
|
from nanovllm.layers.attention import Attention
|
||||||
|
from nanovllm.layers.layernorm import RMSNorm
|
||||||
|
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||||
|
from nanovllm.layers.rotary_embedding import get_rope
|
||||||
|
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
from nanovllm.models.registry import register_model
|
||||||
|
|
||||||
|
|
||||||
|
class GLM4Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
max_position: int = 1048576,
|
||||||
|
head_dim: int = 128,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
tp_size = dist.get_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=True, # GLM-4 has QKV bias
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False, # GLM-4 has no output bias
|
||||||
|
)
|
||||||
|
# GLM-4 only rotates half of head_dim
|
||||||
|
rotary_dim = self.head_dim // 2
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=rotary_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_interleaved=True, # GLM-4 uses interleaved RoPE
|
||||||
|
)
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
o = self.attn(q, k, v)
|
||||||
|
output = self.o_proj(o.flatten(1, -1))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GLM4MLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False, # GLM-4 has no MLP bias
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GLM4DecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# GLM-4 config field mapping
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
num_kv_heads = getattr(config, 'multi_query_group_num', num_heads)
|
||||||
|
head_dim = getattr(config, 'kv_channels', hidden_size // num_heads)
|
||||||
|
max_position = getattr(config, 'seq_length', 1048576)
|
||||||
|
rope_ratio = getattr(config, 'rope_ratio', 1)
|
||||||
|
rope_theta = 10000 * rope_ratio # GLM-4 uses rope_ratio to scale base
|
||||||
|
intermediate_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None))
|
||||||
|
rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5)
|
||||||
|
|
||||||
|
self.self_attn = GLM4Attention(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
max_position=max_position,
|
||||||
|
head_dim=head_dim,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=getattr(config, "rope_scaling", None),
|
||||||
|
)
|
||||||
|
self.mlp = GLM4MLP(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(positions, hidden_states)
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class GLM4Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size)
|
||||||
|
num_layers = getattr(config, 'num_layers', config.num_hidden_layers)
|
||||||
|
rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5)
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(vocab_size, config.hidden_size)
|
||||||
|
self.layers = nn.ModuleList([GLM4DecoderLayer(config) for _ in range(num_layers)])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
residual = None
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@register_model("ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||||
|
class ChatGLMForCausalLM(nn.Module):
|
||||||
|
"""
|
||||||
|
GLM-4 model for causal language modeling.
|
||||||
|
|
||||||
|
Weight mapping from HuggingFace to nanovllm:
|
||||||
|
- transformer.embedding.word_embeddings → model.embed_tokens
|
||||||
|
- transformer.encoder.layers.X.input_layernorm → model.layers.X.input_layernorm
|
||||||
|
- transformer.encoder.layers.X.self_attention.query_key_value → model.layers.X.self_attn.qkv_proj (split q/k/v)
|
||||||
|
- transformer.encoder.layers.X.self_attention.dense → model.layers.X.self_attn.o_proj
|
||||||
|
- transformer.encoder.layers.X.post_attention_layernorm → model.layers.X.post_attention_layernorm
|
||||||
|
- transformer.encoder.layers.X.mlp.dense_h_to_4h → model.layers.X.mlp.gate_up_proj (split gate/up)
|
||||||
|
- transformer.encoder.layers.X.mlp.dense_4h_to_h → model.layers.X.mlp.down_proj
|
||||||
|
- transformer.encoder.final_layernorm → model.norm
|
||||||
|
- transformer.output_layer → lm_head
|
||||||
|
"""
|
||||||
|
packed_modules_mapping = {
|
||||||
|
# QKV is merged in GLM-4 as query_key_value
|
||||||
|
"query_key_value": ("qkv_proj", None), # Special handling needed
|
||||||
|
# MLP gate and up are merged as dense_h_to_4h
|
||||||
|
"dense_h_to_4h": ("gate_up_proj", None), # Special handling needed
|
||||||
|
}
|
||||||
|
|
||||||
|
# Weight name mapping for loader
|
||||||
|
hf_to_nanovllm_mapping = {
|
||||||
|
"transformer.embedding.word_embeddings": "model.embed_tokens",
|
||||||
|
"transformer.encoder.final_layernorm": "model.norm",
|
||||||
|
"transformer.output_layer": "lm_head",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size)
|
||||||
|
self.config = config
|
||||||
|
self.model = GLM4Model(config)
|
||||||
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||||
|
# GLM-4 does not tie embeddings
|
||||||
|
# if getattr(config, 'tie_word_embeddings', False):
|
||||||
|
# self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model(input_ids, positions)
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.lm_head(hidden_states)
|
||||||
207
nanovllm/models/qwen2.py
Normal file
207
nanovllm/models/qwen2.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
from transformers import Qwen2Config
|
||||||
|
|
||||||
|
from nanovllm.layers.activation import SiluAndMul
|
||||||
|
from nanovllm.layers.attention import Attention
|
||||||
|
from nanovllm.layers.layernorm import RMSNorm
|
||||||
|
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||||
|
from nanovllm.layers.rotary_embedding import get_rope
|
||||||
|
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
from nanovllm.models.registry import register_model
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Attention(nn.Module):
|
||||||
|
"""Qwen2/2.5 Attention without QK norm (unlike Qwen3)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
max_position: int = 4096 * 32,
|
||||||
|
head_dim: int | None = None,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: tuple | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
tp_size = dist.get_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
|
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=True, # Qwen2/2.5 always uses bias
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
o = self.attn(q, k, v)
|
||||||
|
output = self.o_proj(o.flatten(1, -1))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
assert hidden_act == "silu"
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2DecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen2Config,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = Qwen2Attention(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
max_position=config.max_position_embeddings,
|
||||||
|
head_dim=getattr(config, 'head_dim', None),
|
||||||
|
rope_theta=getattr(config, "rope_theta", 1000000),
|
||||||
|
rope_scaling=getattr(config, "rope_scaling", None),
|
||||||
|
)
|
||||||
|
self.mlp = Qwen2MLP(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(positions, hidden_states)
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen2Config,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
||||||
|
self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
residual = None
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@register_model("Qwen2ForCausalLM")
|
||||||
|
class Qwen2ForCausalLM(nn.Module):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"q_proj": ("qkv_proj", "q"),
|
||||||
|
"k_proj": ("qkv_proj", "k"),
|
||||||
|
"v_proj": ("qkv_proj", "v"),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen2Config
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = Qwen2Model(config)
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model(input_ids, positions)
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.lm_head(hidden_states)
|
||||||
@@ -187,7 +187,7 @@ class Qwen3Model(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@register_model("Qwen3ForCausalLM", "Qwen2ForCausalLM")
|
@register_model("Qwen3ForCausalLM")
|
||||||
class Qwen3ForCausalLM(nn.Module):
|
class Qwen3ForCausalLM(nn.Module):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"q_proj": ("qkv_proj", "q"),
|
"q_proj": ("qkv_proj", "q"),
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from nanovllm.ops.chunked_attention import (
|
|||||||
|
|
||||||
from nanovllm.ops.xattn import (
|
from nanovllm.ops.xattn import (
|
||||||
xattn_estimate,
|
xattn_estimate,
|
||||||
|
xattn_estimate_chunked,
|
||||||
flat_group_gemm_fuse_reshape,
|
flat_group_gemm_fuse_reshape,
|
||||||
softmax_fuse_block_sum,
|
softmax_fuse_block_sum,
|
||||||
find_blocks_chunked,
|
find_blocks_chunked,
|
||||||
@@ -28,6 +29,7 @@ __all__ = [
|
|||||||
"ChunkedPrefillState",
|
"ChunkedPrefillState",
|
||||||
# xattn
|
# xattn
|
||||||
"xattn_estimate",
|
"xattn_estimate",
|
||||||
|
"xattn_estimate_chunked",
|
||||||
"flat_group_gemm_fuse_reshape",
|
"flat_group_gemm_fuse_reshape",
|
||||||
"softmax_fuse_block_sum",
|
"softmax_fuse_block_sum",
|
||||||
"find_blocks_chunked",
|
"find_blocks_chunked",
|
||||||
|
|||||||
@@ -414,6 +414,90 @@ def merge_attention_outputs(
|
|||||||
return o_merged, lse_merged
|
return o_merged, lse_merged
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# FlashInfer-based implementations (recommended for merge only)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# LSE conversion constants: FlashInfer uses log2, flash_attn uses ln
|
||||||
|
_LOG2_E = 1.4426950408889634 # math.log2(math.e) - ln -> log2
|
||||||
|
_LN_2 = 0.6931471805599453 # math.log(2) - log2 -> ln
|
||||||
|
|
||||||
|
# Check FlashInfer availability (only for merge_state, not attention kernel)
|
||||||
|
try:
|
||||||
|
from flashinfer.cascade import merge_state, merge_state_in_place
|
||||||
|
FLASHINFER_MERGE_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
FLASHINFER_MERGE_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_with_lse_flashinfer(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Flash attention that returns output and LSE.
|
||||||
|
|
||||||
|
Uses flash_attn library (FlashInfer attention has JIT compatibility issues).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||||
|
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||||
|
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||||
|
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
|
||||||
|
causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||||
|
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] (ln format)
|
||||||
|
"""
|
||||||
|
# Use flash_attn directly (FlashInfer attention JIT has CUDA version issues)
|
||||||
|
return flash_attn_with_lse(q, k, v, softmax_scale, causal)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_attention_outputs_flashinfer(
|
||||||
|
o1: torch.Tensor,
|
||||||
|
lse1: torch.Tensor,
|
||||||
|
o2: torch.Tensor,
|
||||||
|
lse2: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Merge two attention outputs using FlashInfer's optimized kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o1: First output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse1: First LSE [batch, nheads, seqlen_q] (ln format)
|
||||||
|
o2: Second output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse2: Second LSE [batch, nheads, seqlen_q] (ln format)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse_merged: Merged LSE [batch, nheads, seqlen_q] (ln format)
|
||||||
|
"""
|
||||||
|
if not FLASHINFER_MERGE_AVAILABLE:
|
||||||
|
# Fallback to Triton implementation
|
||||||
|
return merge_attention_outputs(o1, lse1, o2, lse2)
|
||||||
|
|
||||||
|
# Convert to FlashInfer format
|
||||||
|
# o: [batch, seq, heads, dim] -> [seq, heads, dim]
|
||||||
|
# lse: [batch, heads, seq] -> [seq, heads] (convert ln -> log2)
|
||||||
|
v_a = o1.squeeze(0).contiguous()
|
||||||
|
s_a = (lse1.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E)
|
||||||
|
v_b = o2.squeeze(0).contiguous()
|
||||||
|
s_b = (lse2.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E)
|
||||||
|
|
||||||
|
# FlashInfer merge
|
||||||
|
v_merged, s_merged = merge_state(v_a, s_a, v_b, s_b)
|
||||||
|
|
||||||
|
# Convert back to flash_attn format
|
||||||
|
o_merged = v_merged.unsqueeze(0) # [1, seq, heads, dim]
|
||||||
|
lse_merged = (s_merged * _LN_2).transpose(0, 1).unsqueeze(0) # [1, heads, seq]
|
||||||
|
|
||||||
|
return o_merged, lse_merged
|
||||||
|
|
||||||
|
|
||||||
def chunked_attention_varlen(
|
def chunked_attention_varlen(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
|||||||
@@ -218,6 +218,209 @@ def softmax_fuse_block_sum_kernel_non_causal(
|
|||||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# KV Chunking Support Kernels
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def softmax_partial_stats_kernel(
|
||||||
|
In,
|
||||||
|
M_out, # max per row
|
||||||
|
L_out, # sum per row (normalized by M_out)
|
||||||
|
scale,
|
||||||
|
input_stride_0,
|
||||||
|
input_stride_1,
|
||||||
|
input_stride_2,
|
||||||
|
stats_stride_0,
|
||||||
|
stats_stride_1,
|
||||||
|
k_len,
|
||||||
|
chunk_start, # Q start position (for causal)
|
||||||
|
kv_offset, # KV chunk offset (for causal)
|
||||||
|
segment_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
is_causal: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute partial softmax statistics for a KV chunk.
|
||||||
|
|
||||||
|
For each query row, computes:
|
||||||
|
- m: max value in this chunk
|
||||||
|
- l: sum of exp(x - m) in this chunk
|
||||||
|
|
||||||
|
These can be merged across chunks using online softmax formula.
|
||||||
|
|
||||||
|
Input shape: [batch, heads, q_len, k_chunk_len]
|
||||||
|
Output shapes: M[batch, heads, q_len], L[batch, heads, q_len]
|
||||||
|
"""
|
||||||
|
block_id = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
|
batch_id = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||||
|
offs_k = tl.arange(0, segment_size)
|
||||||
|
|
||||||
|
num_iters = k_len // segment_size
|
||||||
|
|
||||||
|
# For causal: compute boundary
|
||||||
|
if is_causal:
|
||||||
|
# causal boundary: Q position where this KV chunk starts to be valid
|
||||||
|
# Q[i] can attend K[j] if i >= j
|
||||||
|
# For KV chunk at kv_offset, Q[i] can attend if i >= kv_offset
|
||||||
|
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
|
||||||
|
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
|
||||||
|
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
|
||||||
|
else:
|
||||||
|
num_iters_before_causal = num_iters
|
||||||
|
|
||||||
|
# Online softmax state
|
||||||
|
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([block_size], dtype=tl.float32)
|
||||||
|
|
||||||
|
# Input pointer
|
||||||
|
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||||
|
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||||
|
|
||||||
|
# Compute max and sum (before causal boundary)
|
||||||
|
for iter in range(0, num_iters_before_causal):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
m_local = tl.max(X, 1)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
|
||||||
|
X = X - m_new[:, None]
|
||||||
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
# Handle causal boundary
|
||||||
|
if is_causal:
|
||||||
|
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||||
|
if iter < num_iters:
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
# causal mask: Q[i] >= K[j] + kv_offset
|
||||||
|
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
|
||||||
|
X = tl.where(mask, X, -1.0e6)
|
||||||
|
m_local = tl.max(X, 1)
|
||||||
|
m_new = tl.maximum(m_i, m_local)
|
||||||
|
alpha = tl.math.exp2(m_i - m_new)
|
||||||
|
|
||||||
|
X = X - m_new[:, None]
|
||||||
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||||
|
l_i = l_i * alpha + l_local
|
||||||
|
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
# Output pointers
|
||||||
|
m_ptr = M_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||||
|
l_ptr = L_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||||
|
|
||||||
|
offs = tl.arange(0, block_size)
|
||||||
|
tl.store(m_ptr + offs, m_i.to(M_out.type.element_ty))
|
||||||
|
tl.store(l_ptr + offs, l_i.to(L_out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def softmax_normalize_block_sum_kernel(
|
||||||
|
In,
|
||||||
|
Out,
|
||||||
|
M_global, # global max per row
|
||||||
|
L_global, # global sum per row
|
||||||
|
scale,
|
||||||
|
input_stride_0,
|
||||||
|
input_stride_1,
|
||||||
|
input_stride_2,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
output_stride_2,
|
||||||
|
stats_stride_0,
|
||||||
|
stats_stride_1,
|
||||||
|
real_q_len,
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
kv_offset, # KV chunk offset (for causal)
|
||||||
|
segment_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
is_causal: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Normalize with global stats and compute block sums for a KV chunk.
|
||||||
|
|
||||||
|
Uses pre-computed global m and l to correctly normalize softmax
|
||||||
|
across all KV chunks.
|
||||||
|
|
||||||
|
Input shape: [batch, heads, q_len, k_chunk_len]
|
||||||
|
Output shape: [batch, heads, q_blocks, k_chunk_blocks]
|
||||||
|
"""
|
||||||
|
block_id = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
|
batch_id = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||||
|
offs_k = tl.arange(0, segment_size)
|
||||||
|
|
||||||
|
num_iters = k_len // segment_size
|
||||||
|
|
||||||
|
# For causal: compute boundary
|
||||||
|
if is_causal:
|
||||||
|
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
|
||||||
|
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
|
||||||
|
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
|
||||||
|
else:
|
||||||
|
num_iters_before_causal = num_iters
|
||||||
|
|
||||||
|
# Load global stats
|
||||||
|
m_ptr = M_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||||
|
l_ptr = L_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||||
|
|
||||||
|
offs = tl.arange(0, block_size)
|
||||||
|
m_global = tl.load(m_ptr + offs).to(tl.float32)
|
||||||
|
l_global = tl.load(l_ptr + offs).to(tl.float32)
|
||||||
|
# Handle l_global = 0 (when all positions are masked)
|
||||||
|
l_global_safe = tl.where(l_global > 0, l_global, 1.0)
|
||||||
|
l_global_inv = 1.0 / l_global_safe
|
||||||
|
|
||||||
|
# Input pointer
|
||||||
|
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||||
|
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||||
|
|
||||||
|
# Output pointer
|
||||||
|
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||||
|
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||||
|
|
||||||
|
sum_mask = offs_q[:, None] < real_q_len
|
||||||
|
|
||||||
|
# Normalize and compute block sums (before causal boundary)
|
||||||
|
for iter in range(0, num_iters_before_causal):
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
|
||||||
|
X = tl.where(sum_mask, X, 0)
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
X = tl.sum(X, 2)
|
||||||
|
X = tl.sum(X, 0)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
# Handle causal boundary
|
||||||
|
if is_causal:
|
||||||
|
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||||
|
if iter < num_iters:
|
||||||
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||||
|
# causal mask: Q[i] >= K[j] + kv_offset
|
||||||
|
causal_mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
|
||||||
|
X = tl.where(causal_mask, X, -1.0e6)
|
||||||
|
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
|
||||||
|
X = tl.where(sum_mask, X, 0)
|
||||||
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||||
|
X = tl.sum(X, 2)
|
||||||
|
X = tl.sum(X, 0)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
# Zero out future blocks
|
||||||
|
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||||
|
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
||||||
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def flat_group_gemm_fuse_reshape_kernel(
|
def flat_group_gemm_fuse_reshape_kernel(
|
||||||
Q, K, Out,
|
Q, K, Out,
|
||||||
@@ -380,6 +583,194 @@ def softmax_fuse_block_sum(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_compute_partial_stats(
|
||||||
|
attn_weights_slice: torch.Tensor,
|
||||||
|
reshaped_block_size: int,
|
||||||
|
segment_size: int,
|
||||||
|
scale: float,
|
||||||
|
chunk_start: int = 0,
|
||||||
|
kv_offset: int = 0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute partial softmax statistics for a KV chunk.
|
||||||
|
|
||||||
|
This is the first step for KV-chunked softmax computation.
|
||||||
|
For each query row, computes:
|
||||||
|
- m: max value in this chunk
|
||||||
|
- l: sum of exp(x - m) in this chunk
|
||||||
|
|
||||||
|
These partial stats can be merged across KV chunks using
|
||||||
|
`merge_softmax_stats()`, then used with `softmax_normalize_and_block_sum()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
|
||||||
|
reshaped_block_size: Block size in reshaped space
|
||||||
|
segment_size: Processing segment size
|
||||||
|
scale: Softmax scale factor
|
||||||
|
chunk_start: Q chunk start position (in reshaped space)
|
||||||
|
kv_offset: KV chunk offset (in reshaped space, for causal masking)
|
||||||
|
is_causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (m, l) where:
|
||||||
|
- m: [batch, heads, q_len] max values per row
|
||||||
|
- l: [batch, heads, q_len] partial sums per row
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||||
|
|
||||||
|
assert q_len % reshaped_block_size == 0
|
||||||
|
assert k_len % segment_size == 0
|
||||||
|
assert attn_weights_slice.stride(-1) == 1
|
||||||
|
|
||||||
|
m_out = torch.empty(
|
||||||
|
(batch_size, num_heads, q_len),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=attn_weights_slice.device
|
||||||
|
)
|
||||||
|
l_out = torch.empty(
|
||||||
|
(batch_size, num_heads, q_len),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=attn_weights_slice.device
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||||
|
|
||||||
|
softmax_partial_stats_kernel[grid](
|
||||||
|
attn_weights_slice,
|
||||||
|
m_out,
|
||||||
|
l_out,
|
||||||
|
scale,
|
||||||
|
attn_weights_slice.stride(0),
|
||||||
|
attn_weights_slice.stride(1),
|
||||||
|
attn_weights_slice.stride(2),
|
||||||
|
m_out.stride(0),
|
||||||
|
m_out.stride(1),
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
kv_offset,
|
||||||
|
segment_size,
|
||||||
|
reshaped_block_size,
|
||||||
|
is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return m_out, l_out
|
||||||
|
|
||||||
|
|
||||||
|
def merge_softmax_stats(
|
||||||
|
m_chunks: list,
|
||||||
|
l_chunks: list,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Merge partial softmax statistics from multiple KV chunks.
|
||||||
|
|
||||||
|
Uses the online softmax merging formula:
|
||||||
|
m_new = max(m1, m2)
|
||||||
|
l_new = l1 * exp(m1 - m_new) + l2 * exp(m2 - m_new)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m_chunks: List of max tensors [batch, heads, q_len] from each chunk
|
||||||
|
l_chunks: List of sum tensors [batch, heads, q_len] from each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (m_global, l_global) with same shape as inputs
|
||||||
|
"""
|
||||||
|
assert len(m_chunks) == len(l_chunks)
|
||||||
|
assert len(m_chunks) > 0
|
||||||
|
|
||||||
|
# Use log2 scale to match kernel (exp2)
|
||||||
|
LOG2E = 1.4426950408889634
|
||||||
|
|
||||||
|
m_global = m_chunks[0].clone()
|
||||||
|
l_global = l_chunks[0].clone()
|
||||||
|
|
||||||
|
for i in range(1, len(m_chunks)):
|
||||||
|
m_chunk = m_chunks[i]
|
||||||
|
l_chunk = l_chunks[i]
|
||||||
|
|
||||||
|
m_new = torch.maximum(m_global, m_chunk)
|
||||||
|
# exp2(m - m_new) = 2^(m - m_new)
|
||||||
|
l_global = l_global * torch.pow(2.0, m_global - m_new) + l_chunk * torch.pow(2.0, m_chunk - m_new)
|
||||||
|
m_global = m_new
|
||||||
|
|
||||||
|
return m_global, l_global
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_slice: torch.Tensor,
|
||||||
|
m_global: torch.Tensor,
|
||||||
|
l_global: torch.Tensor,
|
||||||
|
reshaped_block_size: int,
|
||||||
|
segment_size: int,
|
||||||
|
chunk_start: int,
|
||||||
|
real_q_len: int,
|
||||||
|
scale: float,
|
||||||
|
kv_offset: int = 0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalize with global stats and compute block sums for a KV chunk.
|
||||||
|
|
||||||
|
This is the second step for KV-chunked softmax computation.
|
||||||
|
Uses pre-computed global m and l (from `merge_softmax_stats()`)
|
||||||
|
to correctly normalize softmax values and compute block sums.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
|
||||||
|
m_global: Global max values [batch, heads, q_len]
|
||||||
|
l_global: Global sum values [batch, heads, q_len]
|
||||||
|
reshaped_block_size: Block size in reshaped space
|
||||||
|
segment_size: Processing segment size
|
||||||
|
chunk_start: Start position for this chunk (for masking)
|
||||||
|
real_q_len: Actual Q length (before padding)
|
||||||
|
scale: Softmax scale factor
|
||||||
|
kv_offset: KV chunk offset (in reshaped space, for causal masking)
|
||||||
|
is_causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Block-level attention sums [batch, heads, q_blocks, k_chunk_blocks]
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||||
|
|
||||||
|
assert q_len % reshaped_block_size == 0
|
||||||
|
assert k_len % segment_size == 0
|
||||||
|
assert segment_size % reshaped_block_size == 0
|
||||||
|
assert attn_weights_slice.stride(-1) == 1
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
||||||
|
dtype=attn_weights_slice.dtype,
|
||||||
|
device=attn_weights_slice.device
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||||
|
|
||||||
|
softmax_normalize_block_sum_kernel[grid](
|
||||||
|
attn_weights_slice,
|
||||||
|
output,
|
||||||
|
m_global,
|
||||||
|
l_global,
|
||||||
|
scale,
|
||||||
|
attn_weights_slice.stride(0),
|
||||||
|
attn_weights_slice.stride(1),
|
||||||
|
attn_weights_slice.stride(2),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
output.stride(2),
|
||||||
|
m_global.stride(0),
|
||||||
|
m_global.stride(1),
|
||||||
|
real_q_len,
|
||||||
|
k_len,
|
||||||
|
chunk_start,
|
||||||
|
kv_offset,
|
||||||
|
segment_size,
|
||||||
|
reshaped_block_size,
|
||||||
|
is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def flat_group_gemm_fuse_reshape(
|
def flat_group_gemm_fuse_reshape(
|
||||||
query_states: torch.Tensor,
|
query_states: torch.Tensor,
|
||||||
key_states: torch.Tensor,
|
key_states: torch.Tensor,
|
||||||
@@ -419,7 +810,9 @@ def flat_group_gemm_fuse_reshape(
|
|||||||
assert key_states.shape[1] == num_heads
|
assert key_states.shape[1] == num_heads
|
||||||
assert key_states.shape[3] == head_dim
|
assert key_states.shape[3] == head_dim
|
||||||
|
|
||||||
output = torch.empty(
|
# Use zeros instead of empty to handle causal early-exit in kernel
|
||||||
|
# (some blocks may not be written due to causal mask optimization)
|
||||||
|
output = torch.zeros(
|
||||||
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
||||||
dtype=query_states.dtype,
|
dtype=query_states.dtype,
|
||||||
device=query_states.device
|
device=query_states.device
|
||||||
@@ -950,3 +1343,239 @@ def compute_sparsity(mask: torch.Tensor, causal: bool = True) -> float:
|
|||||||
selected_blocks = mask.sum().item()
|
selected_blocks = mask.sum().item()
|
||||||
|
|
||||||
return 1.0 - (selected_blocks / total_blocks)
|
return 1.0 - (selected_blocks / total_blocks)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Chunked Estimation Function (for Chunked Prefill)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def xattn_estimate_chunked(
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
q_start_pos: int,
|
||||||
|
block_size: int = 128,
|
||||||
|
stride: int = 8,
|
||||||
|
norm: float = 1.0,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
chunk_size: int = 16384,
|
||||||
|
use_triton: bool = True,
|
||||||
|
causal: bool = True,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Estimate block importance for XAttention in chunked prefill mode.
|
||||||
|
|
||||||
|
This function is designed for chunked prefill scenarios where:
|
||||||
|
- Q is processed in chunks while K accumulates across chunks
|
||||||
|
- q_start_pos indicates the position of the current Q chunk in the full sequence
|
||||||
|
- K length can be >= Q length (accumulated KV cache)
|
||||||
|
|
||||||
|
Ported from COMPASS project (compass/src/Xattn_chunked.py).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states: Q tensor [batch, heads, q_chunk_len, head_dim] - current Q chunk
|
||||||
|
key_states: K tensor [batch, heads, k_len, head_dim] - accumulated K (k_len >= q_chunk_len)
|
||||||
|
q_start_pos: Start position of this Q chunk in the full sequence
|
||||||
|
block_size: Block size in tokens (typically 128 for BSA compatibility)
|
||||||
|
stride: Stride for Q/K reshape (typically 8)
|
||||||
|
norm: Normalization factor for attention scores
|
||||||
|
threshold: Cumulative attention threshold (0.0-1.0)
|
||||||
|
chunk_size: Processing chunk size for Triton kernel alignment
|
||||||
|
use_triton: Whether to use Triton kernels (requires SM 80+)
|
||||||
|
causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
attn_sums: Block-level attention scores [batch, heads, q_blocks, k_blocks]
|
||||||
|
simple_masks: Boolean mask for sparse attention [batch, heads, q_blocks, k_blocks]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Chunk 0: Q[0:C] attends to K[0:C]
|
||||||
|
>>> attn_sums, mask = xattn_estimate_chunked(q_chunk0, k_chunk0, q_start_pos=0)
|
||||||
|
>>>
|
||||||
|
>>> # Chunk 1: Q[C:2C] attends to K[0:2C]
|
||||||
|
>>> attn_sums, mask = xattn_estimate_chunked(q_chunk1, k_accum, q_start_pos=C)
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||||
|
_, _, k_len, _ = key_states.shape
|
||||||
|
|
||||||
|
# Store original lengths for valid region tracking
|
||||||
|
original_q_len = q_len
|
||||||
|
original_k_len = k_len
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
assert k_len >= q_len, f"K length ({k_len}) must be >= Q length ({q_len})"
|
||||||
|
assert q_start_pos + q_len <= k_len, f"Q end position ({q_start_pos + q_len}) exceeds K length ({k_len})"
|
||||||
|
|
||||||
|
# Calculate block counts
|
||||||
|
q_block_num = (q_len + block_size - 1) // block_size
|
||||||
|
k_block_num = (k_len + block_size - 1) // block_size
|
||||||
|
q_start_block = q_start_pos // block_size
|
||||||
|
|
||||||
|
# Check GPU capability for Triton
|
||||||
|
if use_triton:
|
||||||
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
if props.major < 8:
|
||||||
|
use_triton = False
|
||||||
|
|
||||||
|
# Pad Q and K for alignment
|
||||||
|
if use_triton:
|
||||||
|
# For Triton: pad to chunk_size alignment
|
||||||
|
padded_q_len = ((q_len + chunk_size - 1) // chunk_size) * chunk_size
|
||||||
|
padded_k_len = ((k_len + chunk_size - 1) // chunk_size) * chunk_size
|
||||||
|
else:
|
||||||
|
# For PyTorch fallback: pad to block_size alignment
|
||||||
|
padded_q_len = q_block_num * block_size
|
||||||
|
padded_k_len = k_block_num * block_size
|
||||||
|
|
||||||
|
q_pad = padded_q_len - q_len
|
||||||
|
k_pad = padded_k_len - k_len
|
||||||
|
|
||||||
|
if q_pad > 0:
|
||||||
|
query_states = F.pad(query_states, (0, 0, 0, q_pad), value=0)
|
||||||
|
if k_pad > 0:
|
||||||
|
key_states = F.pad(key_states, (0, 0, 0, k_pad), value=0)
|
||||||
|
|
||||||
|
# Reshape dimensions
|
||||||
|
reshaped_block_size = block_size // stride
|
||||||
|
reshaped_q_len = padded_q_len // stride
|
||||||
|
reshaped_k_len = padded_k_len // stride
|
||||||
|
|
||||||
|
# Calculate valid lengths in reshaped space (for masking padding)
|
||||||
|
valid_q_reshaped = (original_q_len + stride - 1) // stride
|
||||||
|
valid_k_reshaped = (original_k_len + stride - 1) // stride
|
||||||
|
|
||||||
|
if use_triton:
|
||||||
|
# Compute chunk boundaries in reshaped space
|
||||||
|
chunk_start = q_start_block * reshaped_block_size
|
||||||
|
chunk_end = chunk_start + reshaped_q_len # Padded end for computation
|
||||||
|
real_q_len = chunk_start + valid_q_reshaped # Valid end for masking padding
|
||||||
|
|
||||||
|
# Use Triton kernel for efficient computation
|
||||||
|
attn_weights = flat_group_gemm_fuse_reshape(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
stride,
|
||||||
|
chunk_start, # q_start in reshaped space
|
||||||
|
chunk_end, # q_end in reshaped space (padded)
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax + block sum
|
||||||
|
# segment_size should match the standard xattn_estimate for consistency
|
||||||
|
attn_sum = softmax_fuse_block_sum(
|
||||||
|
attn_weights,
|
||||||
|
reshaped_block_size,
|
||||||
|
min(4096, reshaped_block_size),
|
||||||
|
chunk_start,
|
||||||
|
chunk_end,
|
||||||
|
real_q_len,
|
||||||
|
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract only the valid block region
|
||||||
|
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
|
||||||
|
else:
|
||||||
|
# PyTorch fallback implementation
|
||||||
|
# Match Triton kernel exactly for consistency
|
||||||
|
#
|
||||||
|
# Triton uses:
|
||||||
|
# 1. exp2 (base-2 exponential) for softmax
|
||||||
|
# 2. scale factor includes log2(e) = 1.4426950408889634
|
||||||
|
# 3. causal mask: q_pos >= k_pos (not q_pos + 1 > k_pos)
|
||||||
|
# 4. chunk_start for global Q position tracking
|
||||||
|
|
||||||
|
# Reshape K: interleave positions and concatenate head dims
|
||||||
|
reshaped_key = torch.cat(
|
||||||
|
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
|
||||||
|
) # (B, H, k_len/stride, D*stride)
|
||||||
|
|
||||||
|
# Reshape Q (inverse mode)
|
||||||
|
reshaped_query = torch.cat(
|
||||||
|
[(query_states[:, :, (stride - 1 - q)::stride, :]) for q in range(stride)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use same scale as Triton: includes log2(e) for exp2 compatibility
|
||||||
|
# Triton: scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||||
|
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||||
|
|
||||||
|
# Convert to float32 for numerical stability (matching Triton)
|
||||||
|
reshaped_query_f32 = reshaped_query.to(torch.float32)
|
||||||
|
reshaped_key_f32 = reshaped_key.to(torch.float32)
|
||||||
|
|
||||||
|
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
reshaped_query_f32, reshaped_key_f32.transpose(2, 3)
|
||||||
|
) * scale
|
||||||
|
|
||||||
|
# Apply causal mask (matching Triton's logic exactly)
|
||||||
|
if causal:
|
||||||
|
# Triton uses: offs_q = chunk_start + block_id * block_size + arange(0, block_size)
|
||||||
|
# chunk_start = q_start_block * reshaped_block_size
|
||||||
|
chunk_start = q_start_block * reshaped_block_size
|
||||||
|
|
||||||
|
# Create position indices in reshaped space
|
||||||
|
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
|
||||||
|
k_positions = torch.arange(reshaped_k_len, device=attn_weights.device)
|
||||||
|
|
||||||
|
# Triton causal mask: q_pos >= k_pos
|
||||||
|
causal_mask = q_positions[:, None] >= k_positions[None, :] # (reshaped_q_len, reshaped_k_len)
|
||||||
|
|
||||||
|
# Apply causal mask: set future positions to -1e6 (matching Triton)
|
||||||
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
~causal_mask.unsqueeze(0).unsqueeze(0), -1e6
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax using exp2 (matching Triton exactly)
|
||||||
|
# Triton: X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||||
|
# All computation in float32
|
||||||
|
attn_max = attn_weights.max(dim=-1, keepdim=True).values
|
||||||
|
attn_weights_shifted = attn_weights - attn_max
|
||||||
|
attn_exp2 = torch.exp2(attn_weights_shifted)
|
||||||
|
attn_sum_exp2 = attn_exp2.sum(dim=-1, keepdim=True)
|
||||||
|
attn_weights = attn_exp2 / attn_sum_exp2
|
||||||
|
|
||||||
|
# Mask for valid Q positions (matching Triton's sum_mask)
|
||||||
|
# Triton: sum_mask = offs_q[:, None] < real_q_len
|
||||||
|
# real_q_len = chunk_start + valid_q_reshaped
|
||||||
|
chunk_start = q_start_block * reshaped_block_size
|
||||||
|
real_q_len = chunk_start + valid_q_reshaped
|
||||||
|
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
|
||||||
|
valid_q_mask = q_positions < real_q_len # (reshaped_q_len,)
|
||||||
|
|
||||||
|
# Zero out invalid Q positions
|
||||||
|
attn_weights = attn_weights * valid_q_mask.view(1, 1, -1, 1).float()
|
||||||
|
|
||||||
|
# Aggregate to block level (keep in float32)
|
||||||
|
attn_sum = attn_weights.view(
|
||||||
|
batch_size,
|
||||||
|
num_heads,
|
||||||
|
q_block_num,
|
||||||
|
reshaped_block_size,
|
||||||
|
k_block_num,
|
||||||
|
reshaped_block_size,
|
||||||
|
).sum(dim=-1).sum(dim=-2)
|
||||||
|
|
||||||
|
# Convert back to input dtype for consistency
|
||||||
|
attn_sum = attn_sum.to(query_states.dtype)
|
||||||
|
|
||||||
|
# Find blocks that exceed threshold
|
||||||
|
simple_mask = find_blocks_chunked(
|
||||||
|
attn_sum,
|
||||||
|
q_start_block, # offset for causal mask in find_blocks_chunked
|
||||||
|
threshold,
|
||||||
|
None,
|
||||||
|
decoding=False,
|
||||||
|
mode="prefill",
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply causal constraint on block level
|
||||||
|
if causal:
|
||||||
|
# For block-level causal: Q block i can only attend to K blocks j where j <= q_start_block + i
|
||||||
|
for q_blk_idx in range(q_block_num):
|
||||||
|
q_blk_global = q_start_block + q_blk_idx
|
||||||
|
if q_blk_global + 1 < k_block_num:
|
||||||
|
simple_mask[:, :, q_blk_idx, q_blk_global + 1:] = False
|
||||||
|
|
||||||
|
return attn_sum, simple_mask
|
||||||
|
|||||||
327
nanovllm/utils/density_observer.py
Normal file
327
nanovllm/utils/density_observer.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
"""
|
||||||
|
DensityObserver - Sparse Attention Density 统计 Observer。
|
||||||
|
|
||||||
|
统计两种 density:
|
||||||
|
1. Compute Density (计算密度): 基于 BSA block size (128)
|
||||||
|
- density = selected_bsa_blocks / total_causal_bsa_blocks
|
||||||
|
- GPU-only 和 Offload 模式应该一致
|
||||||
|
|
||||||
|
2. Communication Density (通信密度): 基于 CPU block size (如 4096)
|
||||||
|
- comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||||
|
- 仅用于 Offload 模式,由于粒度更粗,必然 >= compute density
|
||||||
|
|
||||||
|
统计位置:
|
||||||
|
- GPU-only: xattn_bsa.py compute_prefill() - 只记录 compute density
|
||||||
|
- Offload: xattn_bsa.py select_blocks() - 记录两种 density
|
||||||
|
|
||||||
|
对于 Offload 模式的 Density 计算:
|
||||||
|
- 不是简单的 avg 或 min
|
||||||
|
- 而是 sum(selected) / sum(total),正确处理不同 chunk 大小的权重
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from nanovllm.utils.observer import Observer
|
||||||
|
|
||||||
|
|
||||||
|
class DensityObserver(Observer):
|
||||||
|
"""
|
||||||
|
Sparse Attention Density Observer。
|
||||||
|
|
||||||
|
记录每层的 density,用于验证 GPU-only 和 Offload 模式的一致性。
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
DensityObserver.enable()
|
||||||
|
DensityObserver.complete_reset()
|
||||||
|
# ... run inference ...
|
||||||
|
DensityObserver.record(layer_id, mask, causal=True)
|
||||||
|
# 或者使用累积模式 (offload):
|
||||||
|
DensityObserver.record_counts(layer_id, selected, total)
|
||||||
|
# ...
|
||||||
|
DensityObserver.print_summary()
|
||||||
|
"""
|
||||||
|
|
||||||
|
_enabled: bool = False # 默认禁用
|
||||||
|
|
||||||
|
# 每层的 compute density 记录 (BSA block 粒度)
|
||||||
|
# key: layer_id, value: list of density values (每次 prefill chunk 一个)
|
||||||
|
_layer_densities: Dict[int, List[float]] = {}
|
||||||
|
|
||||||
|
# 每层的 communication density 记录 (CPU block 粒度,仅 offload 模式)
|
||||||
|
_layer_comm_densities: Dict[int, List[float]] = {}
|
||||||
|
|
||||||
|
# 累积模式: 记录 selected/total counts (用于 offload 模式)
|
||||||
|
# 这样可以在所有 chunks 完成后正确计算 density = sum(selected) / sum(total)
|
||||||
|
_layer_selected_counts: Dict[int, List[int]] = {}
|
||||||
|
_layer_total_counts: Dict[int, List[int]] = {}
|
||||||
|
|
||||||
|
# Mask shape 记录 (用于调试)
|
||||||
|
_last_q_blocks: int = 0
|
||||||
|
_last_k_blocks: int = 0
|
||||||
|
|
||||||
|
# 模式标记
|
||||||
|
_mode: str = "unknown" # "gpu_only" or "offload"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_mode(cls, mode: str) -> None:
|
||||||
|
"""设置当前模式 (gpu_only / offload)"""
|
||||||
|
cls._mode = mode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record(
|
||||||
|
cls,
|
||||||
|
layer_id: int,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
causal: bool = True,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
记录一层的 density (适用于 GPU-only 模式)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: 层 ID
|
||||||
|
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
|
||||||
|
causal: 是否考虑 causal mask (只计算下三角)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
density 值
|
||||||
|
"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
density = cls._compute_density(mask, causal)
|
||||||
|
|
||||||
|
# 记录
|
||||||
|
if layer_id not in cls._layer_densities:
|
||||||
|
cls._layer_densities[layer_id] = []
|
||||||
|
cls._layer_densities[layer_id].append(density)
|
||||||
|
|
||||||
|
# 记录 mask shape
|
||||||
|
cls._last_q_blocks = mask.shape[2]
|
||||||
|
cls._last_k_blocks = mask.shape[3]
|
||||||
|
|
||||||
|
return density
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record_counts(
|
||||||
|
cls,
|
||||||
|
layer_id: int,
|
||||||
|
selected_blocks: int,
|
||||||
|
total_blocks: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
记录一层的 selected/total block counts (适用于 offload 累积模式)。
|
||||||
|
|
||||||
|
使用累积计数而不是直接计算 density,这样在所有 chunks 处理完后可以正确计算:
|
||||||
|
overall_density = sum(selected) / sum(total)
|
||||||
|
|
||||||
|
这比 avg(density) 更准确,因为不同 chunk 的 Q 和 K 长度不同。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: 层 ID
|
||||||
|
selected_blocks: 这个 chunk 选中的 blocks 数量
|
||||||
|
total_blocks: 这个 chunk 的 total possible blocks 数量
|
||||||
|
"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 初始化列表
|
||||||
|
if layer_id not in cls._layer_selected_counts:
|
||||||
|
cls._layer_selected_counts[layer_id] = []
|
||||||
|
if layer_id not in cls._layer_total_counts:
|
||||||
|
cls._layer_total_counts[layer_id] = []
|
||||||
|
|
||||||
|
# 累积记录
|
||||||
|
cls._layer_selected_counts[layer_id].append(selected_blocks)
|
||||||
|
cls._layer_total_counts[layer_id].append(total_blocks)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record_comm_density(
|
||||||
|
cls,
|
||||||
|
layer_id: int,
|
||||||
|
selected_cpu_blocks: int,
|
||||||
|
total_cpu_blocks: int,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
记录一层的 communication density (CPU block 粒度)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: 层 ID
|
||||||
|
selected_cpu_blocks: 选中的 CPU blocks 数量
|
||||||
|
total_cpu_blocks: 总 CPU blocks 数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
communication density 值
|
||||||
|
"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if total_cpu_blocks == 0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||||
|
|
||||||
|
# 记录
|
||||||
|
if layer_id not in cls._layer_comm_densities:
|
||||||
|
cls._layer_comm_densities[layer_id] = []
|
||||||
|
cls._layer_comm_densities[layer_id].append(comm_density)
|
||||||
|
|
||||||
|
return comm_density
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float:
|
||||||
|
"""计算 mask 的 density"""
|
||||||
|
batch, heads, q_blocks, k_blocks = mask.shape
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
# 只计算下三角区域
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool)
|
||||||
|
)
|
||||||
|
total_blocks = causal_mask.sum().item() * batch * heads
|
||||||
|
selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
|
else:
|
||||||
|
total_blocks = mask.numel()
|
||||||
|
selected_blocks = mask.sum().item()
|
||||||
|
|
||||||
|
if total_blocks == 0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return selected_blocks / total_blocks
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def complete_reset(cls) -> None:
|
||||||
|
"""重置所有统计"""
|
||||||
|
cls._layer_densities = {}
|
||||||
|
cls._layer_comm_densities = {}
|
||||||
|
cls._layer_selected_counts = {}
|
||||||
|
cls._layer_total_counts = {}
|
||||||
|
cls._last_q_blocks = 0
|
||||||
|
cls._last_k_blocks = 0
|
||||||
|
cls._mode = "unknown"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_per_layer_density(cls) -> Dict[int, float]:
|
||||||
|
"""
|
||||||
|
获取每层的 density。
|
||||||
|
|
||||||
|
对于累积模式 (offload): density = sum(selected) / sum(total)
|
||||||
|
对于直接记录模式 (gpu_only): density = avg(density_values)
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# 优先使用累积模式 (offload)
|
||||||
|
if cls._layer_selected_counts:
|
||||||
|
for layer_id in cls._layer_selected_counts:
|
||||||
|
selected_list = cls._layer_selected_counts.get(layer_id, [])
|
||||||
|
total_list = cls._layer_total_counts.get(layer_id, [])
|
||||||
|
total_selected = sum(selected_list)
|
||||||
|
total_total = sum(total_list)
|
||||||
|
if total_total > 0:
|
||||||
|
result[layer_id] = total_selected / total_total
|
||||||
|
else:
|
||||||
|
# 直接记录模式 (gpu_only)
|
||||||
|
for layer_id, densities in cls._layer_densities.items():
|
||||||
|
if densities:
|
||||||
|
result[layer_id] = sum(densities) / len(densities)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_overall_density(cls) -> float:
|
||||||
|
"""
|
||||||
|
获取所有层的总体 compute density。
|
||||||
|
|
||||||
|
对于累积模式 (offload): density = sum(all_selected) / sum(all_total)
|
||||||
|
对于直接记录模式 (gpu_only): density = avg(all_density_values)
|
||||||
|
|
||||||
|
注意: 总体 density 不是简单的 avg(per_layer_density),
|
||||||
|
而是 sum(all_selected) / sum(all_total),这样可以正确处理权重。
|
||||||
|
"""
|
||||||
|
# 优先使用累积模式 (offload)
|
||||||
|
if cls._layer_selected_counts:
|
||||||
|
total_selected = 0
|
||||||
|
total_total = 0
|
||||||
|
for layer_id in cls._layer_selected_counts:
|
||||||
|
total_selected += sum(cls._layer_selected_counts[layer_id])
|
||||||
|
total_total += sum(cls._layer_total_counts.get(layer_id, []))
|
||||||
|
if total_total > 0:
|
||||||
|
return total_selected / total_total
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 直接记录模式 (gpu_only)
|
||||||
|
all_densities = []
|
||||||
|
for densities in cls._layer_densities.values():
|
||||||
|
all_densities.extend(densities)
|
||||||
|
if not all_densities:
|
||||||
|
return 0.0
|
||||||
|
return sum(all_densities) / len(all_densities)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_overall_comm_density(cls) -> float:
|
||||||
|
"""获取所有层的平均 communication density"""
|
||||||
|
all_densities = []
|
||||||
|
for densities in cls._layer_comm_densities.values():
|
||||||
|
all_densities.extend(densities)
|
||||||
|
if not all_densities:
|
||||||
|
return 0.0
|
||||||
|
return sum(all_densities) / len(all_densities)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_per_layer_comm_density(cls) -> Dict[int, float]:
|
||||||
|
"""
|
||||||
|
获取每层的 communication density (CPU block 粒度)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[layer_id, avg_comm_density]
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for layer_id, densities in cls._layer_comm_densities.items():
|
||||||
|
if densities:
|
||||||
|
result[layer_id] = sum(densities) / len(densities)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_summary(cls) -> dict:
|
||||||
|
"""返回统计摘要"""
|
||||||
|
per_layer = cls.get_per_layer_density()
|
||||||
|
per_layer_comm = cls.get_per_layer_comm_density()
|
||||||
|
return {
|
||||||
|
"mode": cls._mode,
|
||||||
|
"overall_compute_density": cls.get_overall_density(),
|
||||||
|
"overall_comm_density": cls.get_overall_comm_density(),
|
||||||
|
"per_layer_compute_density": per_layer,
|
||||||
|
"per_layer_comm_density": per_layer_comm,
|
||||||
|
"num_layers": len(per_layer),
|
||||||
|
"last_mask_shape": {
|
||||||
|
"q_blocks": cls._last_q_blocks,
|
||||||
|
"k_blocks": cls._last_k_blocks,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_density(cls) -> Tuple[int, float]:
|
||||||
|
"""获取最低 density 的层和值"""
|
||||||
|
per_layer = cls.get_per_layer_density()
|
||||||
|
if not per_layer:
|
||||||
|
return -1, 0.0
|
||||||
|
min_layer = min(per_layer, key=per_layer.get)
|
||||||
|
return min_layer, per_layer[min_layer]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def print_summary(cls) -> None:
|
||||||
|
"""打印人类可读的摘要"""
|
||||||
|
per_layer = cls.get_per_layer_density()
|
||||||
|
overall = cls.get_overall_density()
|
||||||
|
min_layer, min_density = cls.get_min_density()
|
||||||
|
overall_comm = cls.get_overall_comm_density()
|
||||||
|
|
||||||
|
print(f"[DensityObserver] Mode: {cls._mode}")
|
||||||
|
print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
|
||||||
|
if overall_comm > 0:
|
||||||
|
# Offload mode: show both densities with explanation
|
||||||
|
print(f" Comm density: {overall_comm:.4f} (CPU block granularity)")
|
||||||
|
print(f" Savings ratio: {1 - overall_comm:.1%} H2D transfer reduction")
|
||||||
|
print(f" Num layers: {len(per_layer)}")
|
||||||
|
# 输出 layer 0 的 density 用于对比
|
||||||
|
if 0 in per_layer:
|
||||||
|
print(f" Layer 0 density: {per_layer[0]:.6f}")
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from glob import glob
|
from glob import glob
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -9,20 +10,146 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
|||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
# GLM-4 weight name mappings
|
||||||
|
GLM4_NAME_MAPPING = {
|
||||||
|
"transformer.embedding.word_embeddings": "model.embed_tokens",
|
||||||
|
"transformer.encoder.final_layernorm": "model.norm",
|
||||||
|
"transformer.output_layer": "lm_head",
|
||||||
|
}
|
||||||
|
|
||||||
|
GLM4_LAYER_MAPPING = {
|
||||||
|
"self_attention.query_key_value": "self_attn.qkv_proj",
|
||||||
|
"self_attention.dense": "self_attn.o_proj",
|
||||||
|
"mlp.dense_h_to_4h": "mlp.gate_up_proj",
|
||||||
|
"mlp.dense_4h_to_h": "mlp.down_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_glm4_weight_name(weight_name: str) -> tuple[str, str | None]:
|
||||||
|
"""
|
||||||
|
Convert GLM-4 weight name to nanovllm format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (converted_name, shard_id) where shard_id is used for packed modules
|
||||||
|
Returns (None, None) for weights that should be skipped
|
||||||
|
"""
|
||||||
|
# Skip rotary embedding weights (we use our own RoPE implementation)
|
||||||
|
if "rotary_pos_emb" in weight_name:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Check direct mappings first
|
||||||
|
for glm_name, nano_name in GLM4_NAME_MAPPING.items():
|
||||||
|
if weight_name.startswith(glm_name):
|
||||||
|
return weight_name.replace(glm_name, nano_name), None
|
||||||
|
|
||||||
|
# Handle layer weights: transformer.encoder.layers.X.xxx
|
||||||
|
layer_match = re.match(r"transformer\.encoder\.layers\.(\d+)\.(.+)", weight_name)
|
||||||
|
if layer_match:
|
||||||
|
layer_idx = layer_match.group(1)
|
||||||
|
remainder = layer_match.group(2)
|
||||||
|
|
||||||
|
# Handle packed modules (QKV and gate_up)
|
||||||
|
for glm_subname, nano_subname in GLM4_LAYER_MAPPING.items():
|
||||||
|
if remainder.startswith(glm_subname):
|
||||||
|
suffix = remainder[len(glm_subname):] # .weight or .bias
|
||||||
|
new_name = f"model.layers.{layer_idx}.{nano_subname}{suffix}"
|
||||||
|
|
||||||
|
# Determine shard_id for packed modules
|
||||||
|
if "qkv_proj" in nano_subname:
|
||||||
|
return new_name, "qkv" # Special marker for GLM4 QKV
|
||||||
|
elif "gate_up_proj" in nano_subname:
|
||||||
|
return new_name, "gate_up" # Special marker for GLM4 gate_up
|
||||||
|
else:
|
||||||
|
return new_name, None
|
||||||
|
|
||||||
|
# Handle non-packed layer weights (layernorms)
|
||||||
|
new_name = f"model.layers.{layer_idx}.{remainder}"
|
||||||
|
return new_name, None
|
||||||
|
|
||||||
|
# No mapping found, return original
|
||||||
|
return weight_name, None
|
||||||
|
|
||||||
|
|
||||||
|
def load_glm4_qkv(param: nn.Parameter, loaded_weight: torch.Tensor, config):
|
||||||
|
"""Load GLM-4 merged QKV weights by splitting into q, k, v."""
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
num_kv_heads = getattr(config, 'multi_query_group_num', num_heads)
|
||||||
|
head_dim = getattr(config, 'kv_channels', config.hidden_size // num_heads)
|
||||||
|
|
||||||
|
q_size = num_heads * head_dim
|
||||||
|
kv_size = num_kv_heads * head_dim
|
||||||
|
|
||||||
|
# Split QKV: [q_size + kv_size + kv_size, hidden_size]
|
||||||
|
q, k, v = loaded_weight.split([q_size, kv_size, kv_size], dim=0)
|
||||||
|
|
||||||
|
# Load each part using the weight_loader
|
||||||
|
weight_loader = getattr(param, "weight_loader")
|
||||||
|
weight_loader(param, q, "q")
|
||||||
|
weight_loader(param, k, "k")
|
||||||
|
weight_loader(param, v, "v")
|
||||||
|
|
||||||
|
|
||||||
|
def load_glm4_gate_up(param: nn.Parameter, loaded_weight: torch.Tensor, config):
|
||||||
|
"""Load GLM-4 merged gate_up weights by splitting into gate, up."""
|
||||||
|
ffn_hidden_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None))
|
||||||
|
|
||||||
|
# Split gate_up: [ffn_hidden_size * 2, hidden_size]
|
||||||
|
gate, up = loaded_weight.split([ffn_hidden_size, ffn_hidden_size], dim=0)
|
||||||
|
|
||||||
|
# Load each part using the weight_loader
|
||||||
|
weight_loader = getattr(param, "weight_loader")
|
||||||
|
weight_loader(param, gate, 0) # gate_proj is shard 0
|
||||||
|
weight_loader(param, up, 1) # up_proj is shard 1
|
||||||
|
|
||||||
|
|
||||||
|
def is_glm4_model(model: nn.Module) -> bool:
|
||||||
|
"""Check if the model is a GLM-4 model."""
|
||||||
|
return model.__class__.__name__ in ("ChatGLMForCausalLM",)
|
||||||
|
|
||||||
|
|
||||||
def load_model(model: nn.Module, path: str):
|
def load_model(model: nn.Module, path: str):
|
||||||
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
||||||
|
is_glm4 = is_glm4_model(model)
|
||||||
|
config = getattr(model, "config", None)
|
||||||
|
|
||||||
for file in glob(os.path.join(path, "*.safetensors")):
|
for file in glob(os.path.join(path, "*.safetensors")):
|
||||||
with safe_open(file, "pt", "cpu") as f:
|
with safe_open(file, "pt", "cpu") as f:
|
||||||
for weight_name in f.keys():
|
for weight_name in f.keys():
|
||||||
|
loaded_weight = f.get_tensor(weight_name)
|
||||||
|
|
||||||
|
# GLM-4 specific handling
|
||||||
|
if is_glm4:
|
||||||
|
param_name, shard_id = convert_glm4_weight_name(weight_name)
|
||||||
|
|
||||||
|
# Skip weights that don't need to be loaded
|
||||||
|
if param_name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if shard_id == "qkv":
|
||||||
|
param = model.get_parameter(param_name)
|
||||||
|
load_glm4_qkv(param, loaded_weight, config)
|
||||||
|
continue
|
||||||
|
elif shard_id == "gate_up":
|
||||||
|
param = model.get_parameter(param_name)
|
||||||
|
load_glm4_gate_up(param, loaded_weight, config)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Regular weight, use converted name
|
||||||
|
param = model.get_parameter(param_name)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Original loading logic for other models
|
||||||
for k in packed_modules_mapping:
|
for k in packed_modules_mapping:
|
||||||
if k in weight_name:
|
if k in weight_name:
|
||||||
v, shard_id = packed_modules_mapping[k]
|
v, shard_id = packed_modules_mapping[k]
|
||||||
param_name = weight_name.replace(k, v)
|
param_name = weight_name.replace(k, v)
|
||||||
param = model.get_parameter(param_name)
|
param = model.get_parameter(param_name)
|
||||||
weight_loader = getattr(param, "weight_loader")
|
weight_loader = getattr(param, "weight_loader")
|
||||||
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
param = model.get_parameter(weight_name)
|
param = model.get_parameter(weight_name)
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, f.get_tensor(weight_name))
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
133
nanovllm/utils/memory_observer.py
Normal file
133
nanovllm/utils/memory_observer.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
MemoryObserver - 内存传输统计 Observer。
|
||||||
|
|
||||||
|
统计 GPU-CPU 间的数据传输量:
|
||||||
|
- H2D (Host to Device): CPU → GPU
|
||||||
|
- D2H (Device to Host): GPU → CPU
|
||||||
|
- D2D (Device to Device): GPU → GPU (buffer copy)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nanovllm.utils.observer import Observer
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryObserver(Observer):
|
||||||
|
"""
|
||||||
|
内存传输 Observer,统计 GPU-CPU 间的数据传输量。
|
||||||
|
|
||||||
|
统计类型:
|
||||||
|
- H2D (Host to Device): CPU → GPU
|
||||||
|
- D2H (Device to Host): GPU → CPU
|
||||||
|
- D2D (Device to Device): GPU → GPU (buffer copy)
|
||||||
|
|
||||||
|
统计位置(均在 offload_engine.py):
|
||||||
|
- H2D: load_to_slot_layer(), load_block_sample_from_cpu(), load_block_full_from_cpu()
|
||||||
|
- D2H: offload_slot_layer_to_cpu(), offload_prefill_buffer_async()
|
||||||
|
- D2D: write_to_prefill_buffer(), write_to_decode_buffer()
|
||||||
|
- 重置: llm_engine.py:generate() - 与 InferenceObserver 一起重置
|
||||||
|
"""
|
||||||
|
|
||||||
|
_enabled: bool = False # 默认禁用,需要显式启用
|
||||||
|
|
||||||
|
# H2D 统计
|
||||||
|
h2d_bytes: int = 0
|
||||||
|
h2d_count: int = 0
|
||||||
|
|
||||||
|
# D2H 统计
|
||||||
|
d2h_bytes: int = 0
|
||||||
|
d2h_count: int = 0
|
||||||
|
|
||||||
|
# D2D 统计
|
||||||
|
d2d_bytes: int = 0
|
||||||
|
d2d_count: int = 0
|
||||||
|
|
||||||
|
# 按阶段统计
|
||||||
|
prefill_h2d_bytes: int = 0
|
||||||
|
prefill_d2h_bytes: int = 0
|
||||||
|
decode_h2d_bytes: int = 0
|
||||||
|
decode_d2h_bytes: int = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record_h2d(cls, num_bytes: int, is_prefill: bool = True) -> None:
|
||||||
|
"""记录 H2D 传输"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return
|
||||||
|
cls.h2d_bytes += num_bytes
|
||||||
|
cls.h2d_count += 1
|
||||||
|
if is_prefill:
|
||||||
|
cls.prefill_h2d_bytes += num_bytes
|
||||||
|
else:
|
||||||
|
cls.decode_h2d_bytes += num_bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record_d2h(cls, num_bytes: int, is_prefill: bool = True) -> None:
|
||||||
|
"""记录 D2H 传输"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return
|
||||||
|
cls.d2h_bytes += num_bytes
|
||||||
|
cls.d2h_count += 1
|
||||||
|
if is_prefill:
|
||||||
|
cls.prefill_d2h_bytes += num_bytes
|
||||||
|
else:
|
||||||
|
cls.decode_d2h_bytes += num_bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def record_d2d(cls, num_bytes: int) -> None:
|
||||||
|
"""记录 D2D 传输"""
|
||||||
|
if not cls._enabled:
|
||||||
|
return
|
||||||
|
cls.d2d_bytes += num_bytes
|
||||||
|
cls.d2d_count += 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def complete_reset(cls) -> None:
|
||||||
|
"""重置所有统计"""
|
||||||
|
cls.h2d_bytes = cls.h2d_count = 0
|
||||||
|
cls.d2h_bytes = cls.d2h_count = 0
|
||||||
|
cls.d2d_bytes = cls.d2d_count = 0
|
||||||
|
cls.prefill_h2d_bytes = cls.prefill_d2h_bytes = 0
|
||||||
|
cls.decode_h2d_bytes = cls.decode_d2h_bytes = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_summary(cls) -> dict:
|
||||||
|
"""返回统计摘要"""
|
||||||
|
return {
|
||||||
|
"total": {
|
||||||
|
"h2d_bytes": cls.h2d_bytes,
|
||||||
|
"h2d_count": cls.h2d_count,
|
||||||
|
"d2h_bytes": cls.d2h_bytes,
|
||||||
|
"d2h_count": cls.d2h_count,
|
||||||
|
"d2d_bytes": cls.d2d_bytes,
|
||||||
|
"d2d_count": cls.d2d_count,
|
||||||
|
},
|
||||||
|
"prefill": {
|
||||||
|
"h2d_bytes": cls.prefill_h2d_bytes,
|
||||||
|
"d2h_bytes": cls.prefill_d2h_bytes,
|
||||||
|
},
|
||||||
|
"decode": {
|
||||||
|
"h2d_bytes": cls.decode_h2d_bytes,
|
||||||
|
"d2h_bytes": cls.decode_d2h_bytes,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fmt_bytes(cls, b: int) -> str:
|
||||||
|
"""格式化字节数"""
|
||||||
|
if b >= 1e9:
|
||||||
|
return f"{b/1e9:.2f} GB"
|
||||||
|
if b >= 1e6:
|
||||||
|
return f"{b/1e6:.2f} MB"
|
||||||
|
if b >= 1e3:
|
||||||
|
return f"{b/1e3:.2f} KB"
|
||||||
|
return f"{b} B"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def print_summary(cls) -> None:
|
||||||
|
"""打印人类可读的摘要"""
|
||||||
|
fmt = cls._fmt_bytes
|
||||||
|
total = cls.h2d_bytes + cls.d2h_bytes + cls.d2d_bytes
|
||||||
|
print(f"[MemoryObserver] Total: {fmt(total)}")
|
||||||
|
print(f" H2D: {fmt(cls.h2d_bytes)} ({cls.h2d_count} ops)")
|
||||||
|
print(f" D2H: {fmt(cls.d2h_bytes)} ({cls.d2h_count} ops)")
|
||||||
|
print(f" D2D: {fmt(cls.d2d_bytes)} ({cls.d2d_count} ops)")
|
||||||
|
print(f" Prefill - H2D: {fmt(cls.prefill_h2d_bytes)}, D2H: {fmt(cls.prefill_d2h_bytes)}")
|
||||||
|
print(f" Decode - H2D: {fmt(cls.decode_h2d_bytes)}, D2H: {fmt(cls.decode_d2h_bytes)}")
|
||||||
@@ -1,17 +1,106 @@
|
|||||||
class Observer():
|
"""
|
||||||
ttft_start = 0
|
Observer 基类和 InferenceObserver 实现。
|
||||||
tpot_start = 0
|
|
||||||
|
|
||||||
ttft = 0
|
Observer 架构:
|
||||||
tpot = 0
|
- Observer: 基类,定义通用接口
|
||||||
|
- InferenceObserver: 推理性能观测(TTFT/TPOT)
|
||||||
|
- MemoryObserver: 内存传输观测(在 memory_observer.py 中定义)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Observer:
|
||||||
|
"""
|
||||||
|
Observer 基类,提供通用的启用/禁用、重置、输出接口。
|
||||||
|
|
||||||
|
所有 Observer 子类应继承此类并实现:
|
||||||
|
- complete_reset(): 重置所有统计数据
|
||||||
|
- get_summary(): 返回统计摘要 dict
|
||||||
|
- print_summary(): 打印人类可读的摘要
|
||||||
|
"""
|
||||||
|
|
||||||
|
_enabled: bool = True # 默认启用
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reset_ttft(cls):
|
def enable(cls) -> None:
|
||||||
|
"""启用 observer"""
|
||||||
|
cls._enabled = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def disable(cls) -> None:
|
||||||
|
"""禁用 observer"""
|
||||||
|
cls._enabled = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_enabled(cls) -> bool:
|
||||||
|
"""检查是否启用"""
|
||||||
|
return cls._enabled
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def complete_reset(cls) -> None:
|
||||||
|
"""重置所有统计数据(子类实现)"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_summary(cls) -> dict:
|
||||||
|
"""返回统计摘要(子类实现)"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def print_summary(cls) -> None:
|
||||||
|
"""打印人类可读的摘要(子类可选覆盖)"""
|
||||||
|
import json
|
||||||
|
print(json.dumps(cls.get_summary(), indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceObserver(Observer):
|
||||||
|
"""
|
||||||
|
推理性能 Observer,统计 TTFT 和 TPOT。
|
||||||
|
|
||||||
|
- TTFT (Time To First Token): 首个 token 生成延迟
|
||||||
|
- TPOT (Time Per Output Token): 每个输出 token 的平均延迟
|
||||||
|
|
||||||
|
统计位置:
|
||||||
|
- TTFT 开始: scheduler.py:35-36 - 第一个 sequence 从 waiting 队列取出时
|
||||||
|
- TTFT 结束: llm_engine.py:69-72 - prefill 完成后(包括 chunked prefill 所有 chunks)
|
||||||
|
- TPOT 开始: llm_engine.py:65 - 每次 decode step 结束时
|
||||||
|
- TPOT 结束: llm_engine.py:62-63 - 下一次 decode step 开始时计算(测量上一次 decode 时间)
|
||||||
|
- 重置: llm_engine.py:97 - generate() 开始时
|
||||||
|
|
||||||
|
注意:TPOT 需要至少 2 个输出 token 才能计算(测量 decode step 间隔)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 时间戳 (nanoseconds)
|
||||||
|
ttft_start: int = 0
|
||||||
|
tpot_start: int = 0
|
||||||
|
|
||||||
|
# 统计结果 (nanoseconds)
|
||||||
|
ttft: int = 0
|
||||||
|
tpot: int = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_ttft(cls) -> None:
|
||||||
|
"""重置 TTFT 计时器"""
|
||||||
cls.ttft_start = 0
|
cls.ttft_start = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def complete_reset(cls):
|
def complete_reset(cls) -> None:
|
||||||
|
"""重置所有统计数据"""
|
||||||
cls.ttft_start = 0
|
cls.ttft_start = 0
|
||||||
cls.tpot_start = 0
|
cls.tpot_start = 0
|
||||||
cls.ttft = 0
|
cls.ttft = 0
|
||||||
cls.tpot = 0
|
cls.tpot = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_summary(cls) -> dict:
|
||||||
|
"""返回统计摘要"""
|
||||||
|
return {
|
||||||
|
"ttft_ns": cls.ttft,
|
||||||
|
"ttft_ms": cls.ttft / 1e6,
|
||||||
|
"tpot_ns": cls.tpot,
|
||||||
|
"tpot_ms": cls.tpot / 1e6,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def print_summary(cls) -> None:
|
||||||
|
"""打印摘要"""
|
||||||
|
print(f"[InferenceObserver] TTFT: {cls.ttft / 1e6:.2f}ms, TPOT: {cls.tpot / 1e6:.2f}ms")
|
||||||
|
|||||||
158
scripts/profile.sh
Executable file
158
scripts/profile.sh
Executable file
@@ -0,0 +1,158 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Profile bench.py using NVIDIA Nsight Systems (GPU-only mode)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# bash scripts/profile.sh [options]
|
||||||
|
#
|
||||||
|
# Options:
|
||||||
|
# --max-len LENGTH Max sequence length (default: 32768)
|
||||||
|
# --policy POLICY Sparse policy: full, xattn (default: xattn)
|
||||||
|
# --gpu GPU_ID GPU to use (default: 0)
|
||||||
|
# --gpu-util UTIL GPU memory utilization (default: 0.9)
|
||||||
|
# --input-len LENGTH Input length (default: max-len - 1)
|
||||||
|
# --bench-decode Run decode benchmark instead of prefill
|
||||||
|
#
|
||||||
|
# Output:
|
||||||
|
# results/nsys/bench_<policy>_<max_len>_<timestamp>.nsys-rep
|
||||||
|
#
|
||||||
|
# Examples:
|
||||||
|
# bash scripts/profile.sh
|
||||||
|
# bash scripts/profile.sh --max-len 65536 --gpu-util 0.7
|
||||||
|
# bash scripts/profile.sh --policy full --max-len 32768
|
||||||
|
# bash scripts/profile.sh --bench-decode
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Default configuration
|
||||||
|
MAX_LEN="32768"
|
||||||
|
POLICY="xattn"
|
||||||
|
GPU_ID="0"
|
||||||
|
GPU_UTIL="0.9"
|
||||||
|
INPUT_LEN=""
|
||||||
|
BENCH_MODE="prefill"
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
--max-len)
|
||||||
|
MAX_LEN="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--policy)
|
||||||
|
POLICY="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--gpu)
|
||||||
|
GPU_ID="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--gpu-util)
|
||||||
|
GPU_UTIL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--input-len)
|
||||||
|
INPUT_LEN="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--bench-decode)
|
||||||
|
BENCH_MODE="decode"
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 [options]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --max-len LENGTH Max sequence length (default: 32768)"
|
||||||
|
echo " --policy POLICY Sparse policy: full, xattn (default: xattn)"
|
||||||
|
echo " --gpu GPU_ID GPU to use (default: 0)"
|
||||||
|
echo " --gpu-util UTIL GPU memory utilization (default: 0.9)"
|
||||||
|
echo " --input-len LENGTH Input length (default: max-len - 1)"
|
||||||
|
echo " --bench-decode Run decode benchmark instead of prefill"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Path configuration
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||||
|
OUTPUT_DIR="$PROJECT_ROOT/results/nsys"
|
||||||
|
BENCH_SCRIPT="$PROJECT_ROOT/bench.py"
|
||||||
|
|
||||||
|
# Create output directory if needed
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
# Generate timestamp for unique filename
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
|
||||||
|
# Convert max_len to human-readable format (e.g., 32768 -> 32k)
|
||||||
|
if [ "$MAX_LEN" -ge 1024 ]; then
|
||||||
|
MAX_LEN_SUFFIX="$((MAX_LEN / 1024))k"
|
||||||
|
else
|
||||||
|
MAX_LEN_SUFFIX="${MAX_LEN}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
OUTPUT_FILE="$OUTPUT_DIR/bench_${POLICY}_${MAX_LEN_SUFFIX}_${BENCH_MODE}_${TIMESTAMP}"
|
||||||
|
|
||||||
|
# Build bench.py arguments
|
||||||
|
BENCH_ARGS="--max-len $MAX_LEN --gpu-util $GPU_UTIL"
|
||||||
|
|
||||||
|
if [ -n "$POLICY" ]; then
|
||||||
|
BENCH_ARGS="$BENCH_ARGS --policy $POLICY"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$INPUT_LEN" ]; then
|
||||||
|
BENCH_ARGS="$BENCH_ARGS --input-len $INPUT_LEN"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$BENCH_MODE" = "decode" ]; then
|
||||||
|
BENCH_ARGS="$BENCH_ARGS --bench-decode"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "NVIDIA Nsight Systems Profiling (GPU-only)"
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Bench script: $BENCH_SCRIPT"
|
||||||
|
echo "Policy: $POLICY"
|
||||||
|
echo "Max length: $MAX_LEN"
|
||||||
|
echo "GPU: $GPU_ID"
|
||||||
|
echo "GPU util: $GPU_UTIL"
|
||||||
|
echo "Bench mode: $BENCH_MODE"
|
||||||
|
echo "Output file: $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# nsys profile options:
|
||||||
|
# --trace=cuda,nvtx : Trace CUDA API and NVTX markers
|
||||||
|
# --force-overwrite=true : Overwrite existing output file
|
||||||
|
# --output=<path> : Output file path (without .nsys-rep extension)
|
||||||
|
|
||||||
|
echo "Running nsys profile..."
|
||||||
|
echo "Command: python bench.py $BENCH_ARGS"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" \
|
||||||
|
nsys profile \
|
||||||
|
--trace=cuda,nvtx \
|
||||||
|
--force-overwrite=true \
|
||||||
|
--output="$OUTPUT_FILE" \
|
||||||
|
python "$BENCH_SCRIPT" $BENCH_ARGS
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Profiling completed successfully!"
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Output file: $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo ""
|
||||||
|
echo "To view results in GUI:"
|
||||||
|
echo " nsight-sys $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo ""
|
||||||
|
echo "To export statistics:"
|
||||||
|
echo " nsys stats --report cuda_api_sum $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo " nsys stats --report cuda_gpu_kern_sum $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo " nsys stats --report cuda_gpu_mem_size_sum $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo "============================================================"
|
||||||
@@ -1,35 +1,171 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Profile test_attention_offload.py using NVIDIA Nsight Systems
|
# Profile test_ruler.py using NVIDIA Nsight Systems
|
||||||
#
|
#
|
||||||
# Usage:
|
# Usage:
|
||||||
# bash scripts/profile_offload.sh
|
# bash scripts/profile_offload.sh [options]
|
||||||
|
#
|
||||||
|
# Options:
|
||||||
|
# --policy POLICY Sparse policy name (default: full)
|
||||||
|
# --ctx-len LENGTH Context length: 32k, 64k, 128k (default: 64k)
|
||||||
|
# --dataset DATASET Task name (default: niah_single_1)
|
||||||
|
# --sample INDEX Sample index (default: 0)
|
||||||
|
# --gpu GPU_ID GPU to use (default: 0)
|
||||||
|
# --num-gpu-blocks N Number of GPU blocks/slots (default: 4)
|
||||||
|
# --block-size SIZE KV cache block size (default: 4096)
|
||||||
|
# --no-offload Disable CPU offload
|
||||||
#
|
#
|
||||||
# Output:
|
# Output:
|
||||||
# results/nsys/attention_offload_<timestamp>.nsys-rep
|
# results/nsys/<policy>_<gpuonly|offload>_<ctx-len>_blk<size>_<timestamp>.nsys-rep
|
||||||
#
|
#
|
||||||
# View results:
|
# Examples:
|
||||||
# nsight-sys results/nsys/attention_offload_<timestamp>.nsys-rep
|
# bash scripts/profile_offload.sh
|
||||||
|
# bash scripts/profile_offload.sh --policy xattn --ctx-len 128k --no-offload
|
||||||
|
# bash scripts/profile_offload.sh --policy full --ctx-len 32k --num-gpu-blocks 8
|
||||||
|
|
||||||
set -e
|
# Default configuration
|
||||||
|
POLICY="full"
|
||||||
|
CTX_LEN="64k"
|
||||||
|
DATASET="niah_single_1"
|
||||||
|
SAMPLE_INDEX="0"
|
||||||
|
GPU_ID="0"
|
||||||
|
NUM_GPU_BLOCKS="4"
|
||||||
|
BLOCK_SIZE="4096"
|
||||||
|
GPU_UTIL="0.9"
|
||||||
|
ENABLE_OFFLOAD="--enable-offload"
|
||||||
|
MODEL=""
|
||||||
|
DATA_DIR_OVERRIDE=""
|
||||||
|
|
||||||
# Configuration
|
# Parse arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
--policy)
|
||||||
|
POLICY="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--ctx-len)
|
||||||
|
CTX_LEN="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--dataset)
|
||||||
|
DATASET="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--sample)
|
||||||
|
SAMPLE_INDEX="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--gpu)
|
||||||
|
GPU_ID="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--no-offload)
|
||||||
|
ENABLE_OFFLOAD=""
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
--num-gpu-blocks)
|
||||||
|
NUM_GPU_BLOCKS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--gpu-util)
|
||||||
|
GPU_UTIL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--block-size)
|
||||||
|
BLOCK_SIZE="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--model)
|
||||||
|
MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--data-dir)
|
||||||
|
DATA_DIR_OVERRIDE="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 [options]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --policy POLICY Sparse policy name (default: full)"
|
||||||
|
echo " --ctx-len LENGTH Context length: 32k, 64k, 128k (default: 64k)"
|
||||||
|
echo " --block-size SIZE KV cache block size (default: 4096)"
|
||||||
|
echo " --dataset DATASET Task name (default: niah_single_1)"
|
||||||
|
echo " --sample INDEX Sample index (default: 0)"
|
||||||
|
echo " --gpu GPU_ID GPU to use (default: 0)"
|
||||||
|
echo " --gpu-util UTIL GPU memory utilization (default: 0.9)"
|
||||||
|
echo " --no-offload Disable CPU offload"
|
||||||
|
echo " --num-gpu-blocks N Number of GPU blocks/slots (default: 4)"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Path configuration
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||||
OUTPUT_DIR="$PROJECT_ROOT/results/nsys"
|
OUTPUT_DIR="$PROJECT_ROOT/results/nsys"
|
||||||
TEST_SCRIPT="$PROJECT_ROOT/tests/test_attention_offload.py"
|
TEST_SCRIPT="$PROJECT_ROOT/tests/test_ruler.py"
|
||||||
|
DATA_DIR="$PROJECT_ROOT/tests/data/ruler_${CTX_LEN}"
|
||||||
|
|
||||||
|
# Set max-model-len based on context length
|
||||||
|
case "$CTX_LEN" in
|
||||||
|
32k)
|
||||||
|
MAX_MODEL_LEN=36000
|
||||||
|
;;
|
||||||
|
64k)
|
||||||
|
MAX_MODEL_LEN=72000
|
||||||
|
;;
|
||||||
|
128k)
|
||||||
|
MAX_MODEL_LEN=144000
|
||||||
|
;;
|
||||||
|
256k)
|
||||||
|
MAX_MODEL_LEN=288000
|
||||||
|
;;
|
||||||
|
512k)
|
||||||
|
MAX_MODEL_LEN=576000
|
||||||
|
;;
|
||||||
|
1m)
|
||||||
|
MAX_MODEL_LEN=1100000
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
MAX_MODEL_LEN=72000
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
# Override DATA_DIR if specified
|
||||||
|
if [ -n "$DATA_DIR_OVERRIDE" ]; then
|
||||||
|
DATA_DIR="$DATA_DIR_OVERRIDE"
|
||||||
|
fi
|
||||||
|
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
mkdir -p "$OUTPUT_DIR"
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
# Generate timestamp for unique filename
|
# Generate timestamp for unique filename
|
||||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
OUTPUT_FILE="$OUTPUT_DIR/attention_offload_$TIMESTAMP"
|
if [ -n "$ENABLE_OFFLOAD" ]; then
|
||||||
|
OFFLOAD_TAG="offload"
|
||||||
|
else
|
||||||
|
OFFLOAD_TAG="gpuonly"
|
||||||
|
fi
|
||||||
|
OUTPUT_FILE="$OUTPUT_DIR/${POLICY}_${OFFLOAD_TAG}_${CTX_LEN}_blk${BLOCK_SIZE}_${TIMESTAMP}"
|
||||||
|
|
||||||
echo "============================================================"
|
echo "============================================================"
|
||||||
echo "NVIDIA Nsight Systems Profiling"
|
echo "NVIDIA Nsight Systems Profiling"
|
||||||
echo "============================================================"
|
echo "============================================================"
|
||||||
echo "Test script: $TEST_SCRIPT"
|
echo "Policy: $POLICY"
|
||||||
|
echo "Offload: $OFFLOAD_TAG"
|
||||||
|
echo "Context: $CTX_LEN"
|
||||||
|
echo "Block Size: $BLOCK_SIZE"
|
||||||
|
echo "Dataset: $DATASET"
|
||||||
|
echo "Sample: $SAMPLE_INDEX"
|
||||||
|
echo "GPU: $GPU_ID"
|
||||||
|
echo "GPU Blocks: $NUM_GPU_BLOCKS"
|
||||||
|
echo "Data Dir: $DATA_DIR"
|
||||||
echo "Output file: $OUTPUT_FILE.nsys-rep"
|
echo "Output file: $OUTPUT_FILE.nsys-rep"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
@@ -43,13 +179,59 @@ echo ""
|
|||||||
echo "Running nsys profile..."
|
echo "Running nsys profile..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
|
# Map policy name to internal enum name
|
||||||
|
# User-friendly name -> SparsePolicyType enum name
|
||||||
|
case "$POLICY" in
|
||||||
|
xattn)
|
||||||
|
POLICY_ENUM="XATTN_BSA"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
POLICY_ENUM="$POLICY"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
# Build sparse policy argument
|
||||||
|
SPARSE_POLICY_ARG=""
|
||||||
|
if [ -n "$POLICY_ENUM" ] && [ "$POLICY_ENUM" != "full" ]; then
|
||||||
|
SPARSE_POLICY_ARG="--sparse-policy $POLICY_ENUM"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build model argument
|
||||||
|
MODEL_ARG=""
|
||||||
|
if [ -n "$MODEL" ]; then
|
||||||
|
MODEL_ARG="--model $MODEL"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run nsys profile and capture exit code
|
||||||
|
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" \
|
||||||
nsys profile \
|
nsys profile \
|
||||||
--trace=cuda,nvtx,osrt,cudnn,cublas \
|
--trace=cuda,nvtx \
|
||||||
--cuda-memory-usage=true \
|
|
||||||
--stats=true \
|
|
||||||
--force-overwrite=true \
|
--force-overwrite=true \
|
||||||
--output="$OUTPUT_FILE" \
|
--output="$OUTPUT_FILE" \
|
||||||
python "$TEST_SCRIPT"
|
python "$TEST_SCRIPT" \
|
||||||
|
--data-dir "$DATA_DIR" \
|
||||||
|
--datasets "$DATASET" \
|
||||||
|
--sample-indices "$SAMPLE_INDEX" \
|
||||||
|
--num-gpu-blocks "$NUM_GPU_BLOCKS" \
|
||||||
|
--block-size "$BLOCK_SIZE" \
|
||||||
|
--max-model-len "$MAX_MODEL_LEN" \
|
||||||
|
--gpu-utilization "$GPU_UTIL" \
|
||||||
|
$ENABLE_OFFLOAD \
|
||||||
|
$SPARSE_POLICY_ARG \
|
||||||
|
$MODEL_ARG \
|
||||||
|
--quiet
|
||||||
|
EXIT_CODE=$?
|
||||||
|
|
||||||
|
# If test failed, delete the output file
|
||||||
|
if [ $EXIT_CODE -ne 0 ]; then
|
||||||
|
echo ""
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Test FAILED! Cleaning up..."
|
||||||
|
echo "============================================================"
|
||||||
|
rm -f "$OUTPUT_FILE.nsys-rep"
|
||||||
|
echo "Deleted: $OUTPUT_FILE.nsys-rep"
|
||||||
|
exit $EXIT_CODE
|
||||||
|
fi
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "============================================================"
|
echo "============================================================"
|
||||||
|
|||||||
@@ -1,757 +0,0 @@
|
|||||||
"""
|
|
||||||
Custom Qwen3 implementation using only torch and transformers.
|
|
||||||
This file provides a clean reference implementation for understanding the model computation graph.
|
|
||||||
|
|
||||||
Computation Graph:
|
|
||||||
==================
|
|
||||||
|
|
||||||
Input: token_ids [batch, seq_len]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────┐
|
|
||||||
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
|
|
||||||
└─────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
hidden_states [batch, seq_len, hidden_size]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────────────────────────────────────────────────┐
|
|
||||||
│ Decoder Layer (x N) │
|
|
||||||
│ ┌───────────────────────────────────────────────────┐ │
|
|
||||||
│ │ Self Attention Block │ │
|
|
||||||
│ │ │ │
|
|
||||||
│ │ input_layernorm (RMSNorm) │ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ ▼ │ │
|
|
||||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
|
||||||
│ │ │ Qwen3Attention │ │ │
|
|
||||||
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
|
|
||||||
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
|
|
||||||
│ │ │ V = v_proj(x) → reshape │ │ │
|
|
||||||
│ │ │ │ │ │ │
|
|
||||||
│ │ │ ▼ │ │ │
|
|
||||||
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
|
|
||||||
│ │ │ │ │ │ │
|
|
||||||
│ │ │ ▼ │ │ │
|
|
||||||
│ │ │ attn_output = attention(Q, K, V) │ │ │
|
|
||||||
│ │ │ │ │ │ │
|
|
||||||
│ │ │ ▼ │ │ │
|
|
||||||
│ │ │ output = o_proj(attn_output) │ │ │
|
|
||||||
│ │ └─────────────────────────────────────────────┘ │ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ ▼ │ │
|
|
||||||
│ │ hidden_states = residual + attn_output │ │
|
|
||||||
│ └───────────────────────────────────────────────────┘ │
|
|
||||||
│ │ │
|
|
||||||
│ ▼ │
|
|
||||||
│ ┌───────────────────────────────────────────────────┐ │
|
|
||||||
│ │ MLP Block │ │
|
|
||||||
│ │ │ │
|
|
||||||
│ │ post_attention_layernorm (RMSNorm) │ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ ▼ │ │
|
|
||||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
|
||||||
│ │ │ Qwen3MLP │ │ │
|
|
||||||
│ │ │ gate = gate_proj(x) │ │ │
|
|
||||||
│ │ │ up = up_proj(x) │ │ │
|
|
||||||
│ │ │ output = down_proj(silu(gate) * up) │ │ │
|
|
||||||
│ │ └─────────────────────────────────────────────┘ │ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ ▼ │ │
|
|
||||||
│ │ hidden_states = residual + mlp_output │ │
|
|
||||||
│ └───────────────────────────────────────────────────┘ │
|
|
||||||
└─────────────────────────────────────────────────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────┐
|
|
||||||
│ norm │ final RMSNorm
|
|
||||||
└─────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────┐
|
|
||||||
│ lm_head │ [hidden_size, vocab_size]
|
|
||||||
└─────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
logits [batch, seq_len, vocab_size]
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional, Tuple, List
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3RMSNorm(nn.Module):
|
|
||||||
"""RMSNorm implementation."""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
input_dtype = x.dtype
|
|
||||||
x = x.float()
|
|
||||||
variance = x.pow(2).mean(-1, keepdim=True)
|
|
||||||
x = x * torch.rsqrt(variance + self.eps)
|
|
||||||
return self.weight * x.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3RotaryEmbedding(nn.Module):
|
|
||||||
"""Rotary Position Embedding (RoPE)."""
|
|
||||||
|
|
||||||
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.base = base
|
|
||||||
|
|
||||||
# Compute inverse frequencies
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
|
|
||||||
position_ids: Position indices [batch, seq_len]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
cos, sin: [batch, seq_len, head_dim]
|
|
||||||
"""
|
|
||||||
# inv_freq: [dim/2]
|
|
||||||
# position_ids: [batch, seq_len]
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
|
|
||||||
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
|
|
||||||
|
|
||||||
# freqs: [batch, dim/2, seq_len]
|
|
||||||
freqs = inv_freq_expanded @ position_ids_expanded
|
|
||||||
# freqs: [batch, seq_len, dim/2]
|
|
||||||
freqs = freqs.transpose(1, 2)
|
|
||||||
|
|
||||||
# Duplicate for full head_dim: [batch, seq_len, dim]
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
|
|
||||||
cos = emb.cos().to(x.dtype)
|
|
||||||
sin = emb.sin().to(x.dtype)
|
|
||||||
|
|
||||||
return cos, sin
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Rotate half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Apply rotary position embeddings to Q and K.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: [batch, num_heads, seq_len, head_dim]
|
|
||||||
k: [batch, num_kv_heads, seq_len, head_dim]
|
|
||||||
cos: [batch, seq_len, head_dim]
|
|
||||||
sin: [batch, seq_len, head_dim]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
q_embed, k_embed with same shapes as inputs
|
|
||||||
"""
|
|
||||||
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
|
|
||||||
cos = cos.unsqueeze(1)
|
|
||||||
sin = sin.unsqueeze(1)
|
|
||||||
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Attention(nn.Module):
|
|
||||||
"""
|
|
||||||
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
|
|
||||||
|
|
||||||
Data Flow:
|
|
||||||
---------
|
|
||||||
hidden_states [batch, seq_len, hidden_size]
|
|
||||||
│
|
|
||||||
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
|
|
||||||
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
|
|
||||||
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
apply_rotary_pos_emb(Q, K)
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position_embeddings: int = 32768,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
rms_norm_eps: float = 1e-6,
|
|
||||||
layer_idx: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_heads = num_attention_heads
|
|
||||||
self.num_kv_heads = num_key_value_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.num_kv_groups = num_attention_heads // num_key_value_heads
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
|
|
||||||
# Scaling factor
|
|
||||||
self.scaling = head_dim ** -0.5
|
|
||||||
|
|
||||||
# QKV projections
|
|
||||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
|
|
||||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
|
||||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
|
||||||
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
|
|
||||||
|
|
||||||
# QK normalization (Qwen3 specific)
|
|
||||||
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
|
||||||
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
|
||||||
|
|
||||||
# Rotary embeddings
|
|
||||||
self.rotary_emb = Qwen3RotaryEmbedding(
|
|
||||||
head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
base=rope_theta,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_qkv: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: [batch, seq_len, hidden_size]
|
|
||||||
position_ids: [batch, seq_len]
|
|
||||||
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
|
|
||||||
past_key_value: (k_cache, v_cache) from previous steps
|
|
||||||
use_cache: Whether to return updated cache
|
|
||||||
output_qkv: Whether to output Q, K, V tensors for debugging
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
output: [batch, seq_len, hidden_size]
|
|
||||||
past_key_value: Updated cache (if use_cache=True)
|
|
||||||
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
|
|
||||||
"""
|
|
||||||
batch_size, seq_len, _ = hidden_states.shape
|
|
||||||
|
|
||||||
# === QKV Projections ===
|
|
||||||
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
|
|
||||||
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
|
||||||
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
|
||||||
|
|
||||||
# Reshape to [batch, seq_len, num_heads, head_dim]
|
|
||||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
|
||||||
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
|
||||||
|
|
||||||
# === QK Normalization (Qwen3 specific) ===
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
# Transpose to [batch, num_heads, seq_len, head_dim]
|
|
||||||
q = q.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
|
|
||||||
# === Rotary Position Embeddings ===
|
|
||||||
cos, sin = self.rotary_emb(v, position_ids)
|
|
||||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
||||||
|
|
||||||
# === KV Cache Update ===
|
|
||||||
if past_key_value is not None:
|
|
||||||
k_cache, v_cache = past_key_value
|
|
||||||
k = torch.cat([k_cache, k], dim=2)
|
|
||||||
v = torch.cat([v_cache, v], dim=2)
|
|
||||||
|
|
||||||
new_past_key_value = (k, v) if use_cache else None
|
|
||||||
|
|
||||||
# === Grouped Query Attention (expand KV heads if needed) ===
|
|
||||||
if self.num_kv_groups > 1:
|
|
||||||
# Repeat KV for each query group
|
|
||||||
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
|
||||||
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
|
||||||
|
|
||||||
# === Attention Computation (using SDPA for memory efficiency) ===
|
|
||||||
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
|
|
||||||
# is_causal only works when q_len == kv_len (prefill), not during decode
|
|
||||||
q_len, kv_len = q.shape[2], k.shape[2]
|
|
||||||
is_causal = (q_len == kv_len) and (q_len > 1)
|
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
q, k, v,
|
|
||||||
attn_mask=None,
|
|
||||||
dropout_p=0.0,
|
|
||||||
is_causal=is_causal,
|
|
||||||
scale=self.scaling,
|
|
||||||
) # [batch, num_heads, seq_len, head_dim]
|
|
||||||
|
|
||||||
# === Output Projection ===
|
|
||||||
# Transpose back and reshape
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
|
|
||||||
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
|
|
||||||
output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
# Optional QKV output for debugging
|
|
||||||
qkv_dict = None
|
|
||||||
if output_qkv:
|
|
||||||
qkv_dict = {
|
|
||||||
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
|
|
||||||
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
|
|
||||||
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
|
|
||||||
}
|
|
||||||
|
|
||||||
return output, new_past_key_value, qkv_dict
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MLP(nn.Module):
|
|
||||||
"""
|
|
||||||
Qwen3 MLP with SwiGLU activation.
|
|
||||||
|
|
||||||
Data Flow:
|
|
||||||
---------
|
|
||||||
hidden_states [batch, seq_len, hidden_size]
|
|
||||||
│
|
|
||||||
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
|
|
||||||
│
|
|
||||||
└──► up_proj ──► up [batch, seq_len, intermediate_size]
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
silu(gate) * up
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
down_proj ──► output [batch, seq_len, hidden_size]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
|
||||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
|
||||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
gate = self.gate_proj(x)
|
|
||||||
up = self.up_proj(x)
|
|
||||||
return self.down_proj(F.silu(gate) * up)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3DecoderLayer(nn.Module):
|
|
||||||
"""Single Qwen3 Decoder Layer."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position_embeddings: int = 32768,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
rms_norm_eps: float = 1e-6,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
mlp_bias: bool = False,
|
|
||||||
layer_idx: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
|
|
||||||
# Pre-attention LayerNorm
|
|
||||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
|
||||||
|
|
||||||
# Self-attention
|
|
||||||
self.self_attn = Qwen3Attention(
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
num_key_value_heads=num_key_value_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
rope_theta=rope_theta,
|
|
||||||
attention_bias=attention_bias,
|
|
||||||
rms_norm_eps=rms_norm_eps,
|
|
||||||
layer_idx=layer_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Post-attention LayerNorm
|
|
||||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
|
||||||
|
|
||||||
# MLP
|
|
||||||
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_qkv: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: [batch, seq_len, hidden_size]
|
|
||||||
position_ids: [batch, seq_len]
|
|
||||||
attention_mask: Causal attention mask
|
|
||||||
past_key_value: KV cache for this layer
|
|
||||||
use_cache: Whether to return updated cache
|
|
||||||
output_qkv: Whether to output Q, K, V for debugging
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
hidden_states: [batch, seq_len, hidden_size]
|
|
||||||
past_key_value: Updated cache
|
|
||||||
qkv_dict: QKV tensors (if output_qkv=True)
|
|
||||||
"""
|
|
||||||
# === Self Attention Block ===
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
attn_output, new_past_key_value, qkv_dict = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_qkv=output_qkv,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = residual + attn_output
|
|
||||||
|
|
||||||
# === MLP Block ===
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
return hidden_states, new_past_key_value, qkv_dict
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Model(nn.Module):
|
|
||||||
"""Qwen3 Transformer Model (without LM head)."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
num_hidden_layers: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position_embeddings: int = 32768,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
rms_norm_eps: float = 1e-6,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
mlp_bias: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
|
|
||||||
# Token embeddings
|
|
||||||
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
|
|
||||||
|
|
||||||
# Decoder layers
|
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
Qwen3DecoderLayer(
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
num_key_value_heads=num_key_value_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
rope_theta=rope_theta,
|
|
||||||
rms_norm_eps=rms_norm_eps,
|
|
||||||
attention_bias=attention_bias,
|
|
||||||
mlp_bias=mlp_bias,
|
|
||||||
layer_idx=i,
|
|
||||||
)
|
|
||||||
for i in range(num_hidden_layers)
|
|
||||||
])
|
|
||||||
|
|
||||||
# Final LayerNorm
|
|
||||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_qkv_layers: Optional[List[int]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_ids: [batch, seq_len]
|
|
||||||
position_ids: [batch, seq_len]
|
|
||||||
attention_mask: [batch, seq_len] or pre-computed 4D mask
|
|
||||||
past_key_values: List of (k, v) tuples for each layer
|
|
||||||
use_cache: Whether to return new cache
|
|
||||||
output_qkv_layers: List of layer indices to output QKV for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
hidden_states: [batch, seq_len, hidden_size]
|
|
||||||
new_past_key_values: Updated cache
|
|
||||||
qkv_outputs: {layer_idx: qkv_dict}
|
|
||||||
"""
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
|
|
||||||
# Embedding
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# Position IDs
|
|
||||||
if position_ids is None:
|
|
||||||
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
|
|
||||||
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
|
||||||
|
|
||||||
# Attention mask (create causal mask if not provided)
|
|
||||||
if attention_mask is None or attention_mask.dim() == 2:
|
|
||||||
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
|
|
||||||
causal_mask = torch.triu(
|
|
||||||
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
|
|
||||||
diagonal=kv_seq_len - seq_len + 1,
|
|
||||||
)
|
|
||||||
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
|
|
||||||
|
|
||||||
# Initialize cache list
|
|
||||||
new_past_key_values = [] if use_cache else None
|
|
||||||
qkv_outputs = {} if output_qkv_layers else None
|
|
||||||
|
|
||||||
# Decoder layers
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
past_kv = past_key_values[i] if past_key_values else None
|
|
||||||
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
|
|
||||||
|
|
||||||
hidden_states, new_kv, qkv_dict = layer(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_value=past_kv,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_qkv=output_qkv,
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
new_past_key_values.append(new_kv)
|
|
||||||
if qkv_dict is not None:
|
|
||||||
qkv_outputs[i] = qkv_dict
|
|
||||||
|
|
||||||
# Final norm
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states, new_past_key_values, qkv_outputs
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3ForCausalLM(nn.Module):
|
|
||||||
"""Qwen3 Model with Language Modeling head."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
num_hidden_layers: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position_embeddings: int = 32768,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
rms_norm_eps: float = 1e-6,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
mlp_bias: bool = False,
|
|
||||||
tie_word_embeddings: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
|
||||||
|
|
||||||
# Transformer model
|
|
||||||
self.model = Qwen3Model(
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
num_hidden_layers=num_hidden_layers,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
num_key_value_heads=num_key_value_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
rope_theta=rope_theta,
|
|
||||||
rms_norm_eps=rms_norm_eps,
|
|
||||||
attention_bias=attention_bias,
|
|
||||||
mlp_bias=mlp_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
# LM head
|
|
||||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_qkv_layers: Optional[List[int]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_ids: [batch, seq_len]
|
|
||||||
... (same as Qwen3Model)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logits: [batch, seq_len, vocab_size]
|
|
||||||
past_key_values: Updated KV cache
|
|
||||||
qkv_outputs: QKV tensors for specified layers
|
|
||||||
"""
|
|
||||||
hidden_states, new_past_key_values, qkv_outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_qkv_layers=output_qkv_layers,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
return logits, new_past_key_values, qkv_outputs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
|
|
||||||
"""
|
|
||||||
Load weights from a pretrained Qwen3 model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model directory containing config.json and model weights
|
|
||||||
dtype: Data type for model weights
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Initialized Qwen3ForCausalLM model
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
# Load config
|
|
||||||
config_path = os.path.join(model_path, "config.json")
|
|
||||||
with open(config_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
# Create model
|
|
||||||
model = cls(
|
|
||||||
vocab_size=config["vocab_size"],
|
|
||||||
hidden_size=config["hidden_size"],
|
|
||||||
intermediate_size=config["intermediate_size"],
|
|
||||||
num_hidden_layers=config["num_hidden_layers"],
|
|
||||||
num_attention_heads=config["num_attention_heads"],
|
|
||||||
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
|
|
||||||
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
|
|
||||||
max_position_embeddings=config.get("max_position_embeddings", 32768),
|
|
||||||
rope_theta=config.get("rope_theta", 10000.0),
|
|
||||||
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
|
|
||||||
attention_bias=config.get("attention_bias", False),
|
|
||||||
mlp_bias=config.get("mlp_bias", False),
|
|
||||||
tie_word_embeddings=config.get("tie_word_embeddings", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load weights
|
|
||||||
weight_files = sorted([
|
|
||||||
f for f in os.listdir(model_path)
|
|
||||||
if f.endswith(".safetensors")
|
|
||||||
])
|
|
||||||
|
|
||||||
state_dict = {}
|
|
||||||
for wf in weight_files:
|
|
||||||
state_dict.update(load_file(os.path.join(model_path, wf)))
|
|
||||||
|
|
||||||
# Load into model
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
# Tie lm_head weights to embed_tokens if configured
|
|
||||||
if model.tie_word_embeddings:
|
|
||||||
model.lm_head.weight = model.model.embed_tokens.weight
|
|
||||||
|
|
||||||
model = model.to(dtype)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
max_new_tokens: int = 32,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
do_sample: bool = True,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Simple autoregressive generation."""
|
|
||||||
device = input_ids.device
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
past_key_values = None
|
|
||||||
generated = input_ids.clone()
|
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
|
||||||
if past_key_values is None:
|
|
||||||
current_input = generated
|
|
||||||
else:
|
|
||||||
current_input = generated[:, -1:]
|
|
||||||
|
|
||||||
logits, past_key_values, _ = self(
|
|
||||||
input_ids=current_input,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_logits = logits[:, -1, :]
|
|
||||||
if temperature > 0 and do_sample:
|
|
||||||
next_token_logits = next_token_logits / temperature
|
|
||||||
probs = torch.softmax(next_token_logits, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
else:
|
|
||||||
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
generated = torch.cat([generated, next_token], dim=1)
|
|
||||||
|
|
||||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
|
||||||
break
|
|
||||||
|
|
||||||
return generated
|
|
||||||
|
|
||||||
|
|
||||||
def print_computation_graph():
|
|
||||||
"""Print the computation graph for reference."""
|
|
||||||
print(__doc__)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print_computation_graph()
|
|
||||||
@@ -1,254 +0,0 @@
|
|||||||
"""
|
|
||||||
Needle-in-a-haystack test for LLM.
|
|
||||||
|
|
||||||
Tests: Long context retrieval capability with configurable sequence length.
|
|
||||||
|
|
||||||
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
|
|
||||||
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from nanovllm.config import SparsePolicyType
|
|
||||||
from utils import generate_needle_prompt, check_needle_answer
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def run_needle_test(
|
|
||||||
model_path: str,
|
|
||||||
max_model_len: int,
|
|
||||||
input_len: int,
|
|
||||||
num_gpu_blocks: int = 4,
|
|
||||||
block_size: int = 1024,
|
|
||||||
needle_position: float = 0.5,
|
|
||||||
needle_value: str = "7492",
|
|
||||||
max_new_tokens: int = 32,
|
|
||||||
enable_cpu_offload: bool = False,
|
|
||||||
enable_quest: bool = False,
|
|
||||||
enable_xattn_bsa: bool = False,
|
|
||||||
sparse_topk: int = 8,
|
|
||||||
sparse_threshold: int = 4,
|
|
||||||
sparse_samples: int = 128,
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Run a needle-in-haystack test.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model
|
|
||||||
max_model_len: Maximum model context length
|
|
||||||
input_len: Target input sequence length
|
|
||||||
num_gpu_blocks: Number of GPU blocks for offload
|
|
||||||
block_size: KV cache block size
|
|
||||||
needle_position: Where to place needle (0.0-1.0)
|
|
||||||
needle_value: The secret value to find
|
|
||||||
max_new_tokens: Maximum tokens to generate
|
|
||||||
enable_cpu_offload: Enable CPU offload mode
|
|
||||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
|
||||||
enable_xattn_bsa: Enable XAttention BSA sparse attention (prefill-only)
|
|
||||||
sparse_topk: Top-K blocks for Quest
|
|
||||||
sparse_threshold: Threshold for sparse selection (Quest/XAttention BSA)
|
|
||||||
sparse_samples: Samples per chunk for XAttention BSA estimation
|
|
||||||
verbose: Print detailed output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if test passed, False otherwise
|
|
||||||
"""
|
|
||||||
# Determine sparse policy
|
|
||||||
if enable_xattn_bsa:
|
|
||||||
sparse_policy = SparsePolicyType.XATTN_BSA
|
|
||||||
elif enable_quest:
|
|
||||||
sparse_policy = SparsePolicyType.QUEST
|
|
||||||
else:
|
|
||||||
sparse_policy = SparsePolicyType.FULL
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Needle-in-Haystack Test")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Max model len: {max_model_len}")
|
|
||||||
print(f"Input length: {input_len}")
|
|
||||||
print(f"Block size: {block_size}")
|
|
||||||
print(f"Needle position: {needle_position:.0%}")
|
|
||||||
print(f"Needle value: {needle_value}")
|
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
|
||||||
if enable_cpu_offload:
|
|
||||||
print(f"Sparse policy: {sparse_policy.name}")
|
|
||||||
if sparse_policy == SparsePolicyType.QUEST:
|
|
||||||
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
|
|
||||||
elif sparse_policy == SparsePolicyType.XATTN_BSA:
|
|
||||||
print(f" XAttention BSA: threshold={sparse_threshold}, samples={sparse_samples}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# 1. Initialize LLM
|
|
||||||
llm_kwargs = {
|
|
||||||
"enforce_eager": True,
|
|
||||||
"max_model_len": max_model_len,
|
|
||||||
"max_num_batched_tokens": max_model_len,
|
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
|
||||||
"kvcache_block_size": block_size,
|
|
||||||
}
|
|
||||||
if enable_cpu_offload:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
|
||||||
llm_kwargs["sparse_policy"] = sparse_policy
|
|
||||||
if sparse_policy == SparsePolicyType.QUEST:
|
|
||||||
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
|
||||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
|
||||||
elif sparse_policy == SparsePolicyType.XATTN_BSA:
|
|
||||||
llm_kwargs["sparse_threshold"] = float(sparse_threshold) / 10.0 # Convert to 0.0-1.0 range
|
|
||||||
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
|
||||||
|
|
||||||
# 2. Generate needle prompt
|
|
||||||
prompt, expected = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=needle_position,
|
|
||||||
needle_value=needle_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Generate output
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.6, # Moderate temperature
|
|
||||||
max_tokens=max_new_tokens,
|
|
||||||
)
|
|
||||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
|
|
||||||
|
|
||||||
# 4. Check result
|
|
||||||
output_text = outputs[0]["text"]
|
|
||||||
output_token_ids = outputs[0]["token_ids"]
|
|
||||||
passed = check_needle_answer(output_text, expected)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Result")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Expected: {expected}")
|
|
||||||
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
|
|
||||||
print(f"Output: {output_text[:200]}...")
|
|
||||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return passed
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# CLI Entry Point
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
type=str,
|
|
||||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
|
||||||
help="Path to model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-model-len",
|
|
||||||
type=int,
|
|
||||||
default=128 * 1024,
|
|
||||||
help="Maximum model context length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=8 * 1024,
|
|
||||||
help="Target input sequence length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-gpu-blocks",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Number of GPU blocks for CPU offload"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--block-size",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="KV cache block size"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-position",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-value",
|
|
||||||
type=str,
|
|
||||||
default="7492",
|
|
||||||
help="The secret value to hide"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-new-tokens",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Maximum tokens to generate"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-offload",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CPU offload (has known bug for long sequences)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-quest",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-xattn-bsa",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable XAttention BSA sparse attention (prefill-only)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sparse-topk",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Top-K blocks for Quest sparse attention"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sparse-threshold",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="Apply sparse only when blocks > threshold (Quest) or attention threshold 0-9 (XAttention BSA)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sparse-samples",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
help="Samples per chunk for XAttention BSA estimation"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
passed = run_needle_test(
|
|
||||||
model_path=args.model,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
input_len=args.input_len,
|
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
|
||||||
block_size=args.block_size,
|
|
||||||
needle_position=args.needle_position,
|
|
||||||
needle_value=args.needle_value,
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
enable_cpu_offload=args.enable_offload,
|
|
||||||
enable_quest=args.enable_quest,
|
|
||||||
enable_xattn_bsa=args.enable_xattn_bsa,
|
|
||||||
sparse_topk=args.sparse_topk,
|
|
||||||
sparse_threshold=args.sparse_threshold,
|
|
||||||
sparse_samples=args.sparse_samples,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_needle: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_needle: FAILED")
|
|
||||||
exit(1)
|
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
"""
|
|
||||||
Needle-in-a-haystack reference test using pure torch + transformers.
|
|
||||||
|
|
||||||
This is a reference implementation for comparison with nanovllm.
|
|
||||||
Uses standard HuggingFace inference (no custom KV cache, no offload).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from modeling_qwen3 import Qwen3ForCausalLM
|
|
||||||
from utils import generate_needle_prompt, check_needle_answer
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def run_needle_test(
|
|
||||||
model_path: str,
|
|
||||||
input_len: int,
|
|
||||||
needle_position: float = 0.5,
|
|
||||||
needle_value: str = "7492",
|
|
||||||
max_new_tokens: int = 32,
|
|
||||||
dtype: str = "auto",
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Run a needle-in-haystack test using standard transformers inference.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model
|
|
||||||
input_len: Target input sequence length
|
|
||||||
needle_position: Where to place needle (0.0-1.0)
|
|
||||||
needle_value: The secret value to find
|
|
||||||
max_new_tokens: Maximum tokens to generate
|
|
||||||
dtype: Model dtype ("auto", "float16", "bfloat16")
|
|
||||||
verbose: Print detailed output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if test passed, False otherwise
|
|
||||||
"""
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Needle-in-Haystack Reference Test (torch + transformers)")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Input length: {input_len}")
|
|
||||||
print(f"Needle position: {needle_position:.0%}")
|
|
||||||
print(f"Needle value: {needle_value}")
|
|
||||||
print(f"Dtype: {dtype}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# 1. Load tokenizer
|
|
||||||
print("[1/4] Loading tokenizer...")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
# 2. Generate needle prompt
|
|
||||||
print("[2/4] Generating needle prompt...")
|
|
||||||
prompt, expected = generate_needle_prompt(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=needle_position,
|
|
||||||
needle_value=needle_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Load model
|
|
||||||
print("[3/4] Loading model...")
|
|
||||||
torch_dtype = {
|
|
||||||
"auto": torch.float16, # default to float16 for custom model
|
|
||||||
"float16": torch.float16,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
}.get(dtype, torch.float16)
|
|
||||||
|
|
||||||
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
|
|
||||||
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# 4. Generate output
|
|
||||||
print("[4/4] Running inference...")
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
||||||
print(f" Input shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output_ids = model.generate(
|
|
||||||
input_ids,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
temperature=0.6,
|
|
||||||
do_sample=True,
|
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode only the new tokens
|
|
||||||
new_token_ids = output_ids[0, input_ids.shape[1]:]
|
|
||||||
output_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
|
|
||||||
|
|
||||||
# 5. Check result
|
|
||||||
passed = check_needle_answer(output_text, expected)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Result")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Expected: {expected}")
|
|
||||||
print(f"Output tokens ({len(new_token_ids)}): {new_token_ids[:20].tolist()}")
|
|
||||||
print(f"Output: {output_text[:200]}...")
|
|
||||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return passed
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# CLI Entry Point
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Needle-in-haystack reference test (torch + transformers)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
type=str,
|
|
||||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
|
||||||
help="Path to model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=8 * 1024,
|
|
||||||
help="Target input sequence length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-position",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-value",
|
|
||||||
type=str,
|
|
||||||
default="7492",
|
|
||||||
help="The secret value to hide"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-new-tokens",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Maximum tokens to generate"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dtype",
|
|
||||||
type=str,
|
|
||||||
default="auto",
|
|
||||||
choices=["auto", "float16", "bfloat16"],
|
|
||||||
help="Model dtype"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
passed = run_needle_test(
|
|
||||||
model_path=args.model,
|
|
||||||
input_len=args.input_len,
|
|
||||||
needle_position=args.needle_position,
|
|
||||||
needle_value=args.needle_value,
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
dtype=args.dtype,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_needle_ref: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_needle_ref: FAILED")
|
|
||||||
exit(1)
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
"""
|
|
||||||
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
|
|
||||||
|
|
||||||
Demonstrates the key limitation: scores are AVERAGED across heads,
|
|
||||||
so blocks strongly needed by one head but not others may be dropped.
|
|
||||||
|
|
||||||
This is the expected Quest behavior - not a bug.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from nanovllm.kvcache.sparse import (
|
|
||||||
create_sparse_policy,
|
|
||||||
SparsePolicyType,
|
|
||||||
PolicyContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test: Per-Head Score Averaging in GQA
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Determine device (GPU if available, else CPU)
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Running test on device: {device}")
|
|
||||||
|
|
||||||
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
|
|
||||||
# topk=2 to make selection competitive
|
|
||||||
|
|
||||||
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
|
|
||||||
quest.initialize(
|
|
||||||
num_layers=1,
|
|
||||||
num_kv_heads=2,
|
|
||||||
head_dim=4,
|
|
||||||
num_cpu_blocks=6,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device, # Metadata stored on GPU
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = quest.metadata
|
|
||||||
|
|
||||||
def set_key(block_id, head_id, values):
|
|
||||||
"""Set both key_min and key_max to same values for deterministic scoring."""
|
|
||||||
# Values need to be on the same device as metadata
|
|
||||||
tensor = torch.tensor(values, device=device)
|
|
||||||
metadata.key_min[block_id, 0, head_id, :] = tensor
|
|
||||||
metadata.key_max[block_id, 0, head_id, :] = tensor
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Design: Different heads want different blocks
|
|
||||||
# ============================================================
|
|
||||||
#
|
|
||||||
# Query = [1,1,1,1] for all heads, so score = sum(key values)
|
|
||||||
#
|
|
||||||
# Block | Head 0 | Head 1 | Average | Result
|
|
||||||
# ------|--------|--------|---------|--------
|
|
||||||
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
|
|
||||||
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
|
|
||||||
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
|
|
||||||
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
|
|
||||||
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
|
|
||||||
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
|
|
||||||
|
|
||||||
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
|
|
||||||
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
|
||||||
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
|
|
||||||
|
|
||||||
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
|
|
||||||
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
|
|
||||||
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
|
||||||
|
|
||||||
# Block 2: Both heads want equally (highest average)
|
|
||||||
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
|
||||||
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
|
||||||
|
|
||||||
# Block 3: Both heads want moderately
|
|
||||||
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
|
|
||||||
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
|
|
||||||
|
|
||||||
# Block 4: Head 0 strongly wants, Head 1 neutral
|
|
||||||
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
|
||||||
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
|
|
||||||
|
|
||||||
# Block 5: Head 1 strongly wants, Head 0 neutral
|
|
||||||
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
|
|
||||||
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Run selection
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Query on same device as metadata
|
|
||||||
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
|
|
||||||
|
|
||||||
ctx = PolicyContext(
|
|
||||||
query_chunk_idx=0,
|
|
||||||
num_query_chunks=1,
|
|
||||||
layer_id=0,
|
|
||||||
query=query,
|
|
||||||
is_prefill=False,
|
|
||||||
block_size=1024,
|
|
||||||
total_kv_len=6144,
|
|
||||||
)
|
|
||||||
|
|
||||||
available = list(range(6))
|
|
||||||
selected = quest.select_blocks(available, ctx)
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Verify: Averaging behavior
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
|
|
||||||
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
|
|
||||||
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
|
|
||||||
|
|
||||||
# Key insight: blocks 0 and 1 have score +4 for ONE head,
|
|
||||||
# but they cancel out due to averaging with the other head's -4
|
|
||||||
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
|
|
||||||
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
|
|
||||||
|
|
||||||
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
|
|
||||||
# But +2 < +3 (block 3), so they don't make the top-2
|
|
||||||
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
|
|
||||||
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
|
|
||||||
|
|
||||||
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
|
|
||||||
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
|
|
||||||
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
|
|
||||||
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
|
|
||||||
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
|
|
||||||
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
|
|
||||||
|
|
||||||
# Verify metadata is on correct device
|
|
||||||
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
|
|
||||||
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
|
|
||||||
print(f"✓ Metadata stored on {device.type.upper()}")
|
|
||||||
|
|
||||||
print("\ntest_quest_policy: PASSED")
|
|
||||||
@@ -17,6 +17,15 @@ Usage:
|
|||||||
|
|
||||||
# Test all samples in all datasets
|
# Test all samples in all datasets
|
||||||
python tests/test_ruler.py --enable-offload
|
python tests/test_ruler.py --enable-offload
|
||||||
|
|
||||||
|
# Test specific sample indices (comma-separated)
|
||||||
|
python tests/test_ruler.py --enable-offload --datasets niah_single_1 --sample-indices 28,33,40
|
||||||
|
|
||||||
|
# Single-sample mode: reinitialize LLM for each sample (avoids state leakage)
|
||||||
|
python tests/test_ruler.py --enable-offload --datasets niah_single_1 --fresh-llm
|
||||||
|
|
||||||
|
# JSON output mode for scripting
|
||||||
|
python tests/test_ruler.py --enable-offload --datasets niah_single_1 --json-output
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -32,6 +41,7 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Tuple, Optional
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.utils.density_observer import DensityObserver
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -39,11 +49,67 @@ from nanovllm import LLM, SamplingParams
|
|||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k"
|
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Chat Template Conversion
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def convert_llama_to_glm4_format(prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert Llama 3 chat template format to GLM-4 format.
|
||||||
|
|
||||||
|
Llama 3 format:
|
||||||
|
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{user_content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{assistant_prefix}
|
||||||
|
|
||||||
|
GLM-4 format:
|
||||||
|
[gMASK]<sop><|user|>
|
||||||
|
{user_content}<|assistant|>
|
||||||
|
{assistant_prefix}
|
||||||
|
"""
|
||||||
|
# Split into user content and assistant prefix
|
||||||
|
parts = prompt.split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
|
||||||
|
|
||||||
|
# Extract user content (remove Llama header tokens)
|
||||||
|
user_content = parts[0]
|
||||||
|
user_content = user_content.replace("<|begin_of_text|>", "")
|
||||||
|
user_content = user_content.replace("<|start_header_id|>user<|end_header_id|>", "")
|
||||||
|
user_content = user_content.strip()
|
||||||
|
|
||||||
|
# Extract assistant prefix (if exists)
|
||||||
|
assistant_prefix = ""
|
||||||
|
if len(parts) > 1:
|
||||||
|
assistant_prefix = parts[1].replace("<|eot_id|>", "").strip()
|
||||||
|
|
||||||
|
# Apply GLM-4 format
|
||||||
|
glm_prompt = f"[gMASK]<sop><|user|>\n{user_content}<|assistant|>"
|
||||||
|
if assistant_prefix:
|
||||||
|
glm_prompt += f"\n{assistant_prefix}"
|
||||||
|
|
||||||
|
return glm_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def is_glm_model(model_path: str) -> bool:
|
||||||
|
"""Check if the model is a GLM model based on config."""
|
||||||
|
from transformers import AutoConfig
|
||||||
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
return getattr(config, 'model_type', '') == 'chatglm'
|
||||||
|
|
||||||
|
|
||||||
|
def convert_prompt_for_model(prompt: str, model_path: str) -> str:
|
||||||
|
"""Convert prompt format based on model type."""
|
||||||
|
if is_glm_model(model_path):
|
||||||
|
return convert_llama_to_glm4_format(prompt)
|
||||||
|
return prompt # Keep original format for Llama and other models
|
||||||
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
|
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
|
||||||
# Note: max_model_len must be > max_input_len to leave room for output tokens
|
# Note: max_model_len must be > max_input_len to leave room for output tokens
|
||||||
# 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664
|
# 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664
|
||||||
DEFAULT_MAX_MODEL_LEN = 65664
|
DEFAULT_MAX_MODEL_LEN = 65664
|
||||||
DEFAULT_MAX_NEW_TOKENS = 128 # Larger for multi-value tasks
|
DEFAULT_MAX_NEW_TOKENS = 16 # Sufficient for NIAH single-value answers
|
||||||
|
|
||||||
# Task categories for evaluation
|
# Task categories for evaluation
|
||||||
NIAH_TASKS = ["niah_single_1", "niah_single_2", "niah_single_3",
|
NIAH_TASKS = ["niah_single_1", "niah_single_2", "niah_single_3",
|
||||||
@@ -150,17 +216,31 @@ def run_task_test(
|
|||||||
sample_indices: Optional[List[int]] = None,
|
sample_indices: Optional[List[int]] = None,
|
||||||
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
llm_factory: Optional[callable] = None,
|
||||||
|
fresh_llm: bool = False,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Run test for a single RULER task.
|
Run test for a single RULER task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: LLM instance (ignored if fresh_llm=True)
|
||||||
|
task_name: Name of the task to test
|
||||||
|
data_dir: Path to data directory
|
||||||
|
sample_indices: Optional list of specific sample indices to test
|
||||||
|
max_new_tokens: Maximum tokens to generate
|
||||||
|
verbose: Print detailed output
|
||||||
|
llm_factory: Callable to create LLM instance (required if fresh_llm=True)
|
||||||
|
fresh_llm: If True, reinitialize LLM for each sample (avoids state leakage)
|
||||||
|
|
||||||
Returns dict with: task, correct, total, score, results
|
Returns dict with: task, correct, total, score, results
|
||||||
"""
|
"""
|
||||||
data_file = data_dir / task_name / "validation.jsonl"
|
data_file = data_dir / task_name / "validation.jsonl"
|
||||||
samples = load_samples(data_file, sample_indices)
|
samples = load_samples(data_file, sample_indices)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n Testing {task_name}: {len(samples)} samples")
|
mode_str = " [fresh-llm mode]" if fresh_llm else ""
|
||||||
|
print(f"\n Testing {task_name}: {len(samples)} samples{mode_str}")
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
@@ -171,13 +251,29 @@ def run_task_test(
|
|||||||
total_score = 0.0
|
total_score = 0.0
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
current_llm = llm
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
idx = sample.get("index", sample["_local_idx"])
|
idx = sample.get("index", sample["_local_idx"])
|
||||||
prompt = sample["input"]
|
prompt = sample["input"]
|
||||||
|
# Convert prompt format for GLM models
|
||||||
|
if model_path:
|
||||||
|
prompt = convert_prompt_for_model(prompt, model_path)
|
||||||
expected = sample["outputs"]
|
expected = sample["outputs"]
|
||||||
|
|
||||||
|
# Fresh LLM mode: reinitialize for each sample
|
||||||
|
if fresh_llm:
|
||||||
|
if llm_factory is None:
|
||||||
|
raise ValueError("llm_factory required when fresh_llm=True")
|
||||||
|
# Cleanup previous LLM
|
||||||
|
if current_llm is not None:
|
||||||
|
del current_llm
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
current_llm = llm_factory()
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=False)
|
outputs = current_llm.generate([prompt], sampling_params, use_tqdm=False)
|
||||||
output_text = outputs[0]["text"]
|
output_text = outputs[0]["text"]
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
@@ -200,6 +296,12 @@ def run_task_test(
|
|||||||
out_preview = output_text[:50].replace('\n', ' ')
|
out_preview = output_text[:50].replace('\n', ' ')
|
||||||
print(f" [{idx:3d}] {status} (score={score:.2f}) exp={exp_preview}... | out={out_preview}...")
|
print(f" [{idx:3d}] {status} (score={score:.2f}) exp={exp_preview}... | out={out_preview}...")
|
||||||
|
|
||||||
|
# Cleanup last LLM instance in fresh mode
|
||||||
|
if fresh_llm and current_llm is not None:
|
||||||
|
del current_llm
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
avg_score = total_score / len(samples) if samples else 0.0
|
avg_score = total_score / len(samples) if samples else 0.0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -217,19 +319,24 @@ def run_ruler_benchmark(
|
|||||||
data_dir: Path,
|
data_dir: Path,
|
||||||
datasets: Optional[List[str]] = None,
|
datasets: Optional[List[str]] = None,
|
||||||
num_samples: Optional[int] = None,
|
num_samples: Optional[int] = None,
|
||||||
|
sample_indices: Optional[List[int]] = None,
|
||||||
max_model_len: int = DEFAULT_MAX_MODEL_LEN,
|
max_model_len: int = DEFAULT_MAX_MODEL_LEN,
|
||||||
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
||||||
enable_cpu_offload: bool = False,
|
enable_cpu_offload: bool = False,
|
||||||
num_gpu_blocks: int = 4,
|
num_gpu_blocks: int = 4,
|
||||||
block_size: int = 1024,
|
block_size: int = 4096,
|
||||||
num_kv_buffers: int = 4,
|
num_kv_buffers: int = 4,
|
||||||
gpu_utilization: float = 0.9,
|
gpu_utilization: float = 0.9,
|
||||||
enforce_eager: bool = True,
|
enforce_eager: bool = True,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
fresh_llm: bool = False,
|
||||||
|
json_output: bool = False,
|
||||||
sparse_policy: Optional[str] = None,
|
sparse_policy: Optional[str] = None,
|
||||||
sparse_threshold: float = 0.9,
|
sparse_threshold: float = 0.9,
|
||||||
sparse_samples: int = 128,
|
sparse_samples: int = 128,
|
||||||
sparse_block_size: int = 128,
|
sparse_block_size: int = 128,
|
||||||
|
sparse_stride: int = 8,
|
||||||
|
dtype: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Run RULER benchmark on multiple tasks.
|
Run RULER benchmark on multiple tasks.
|
||||||
@@ -239,7 +346,9 @@ def run_ruler_benchmark(
|
|||||||
data_dir: Directory containing task subdirectories
|
data_dir: Directory containing task subdirectories
|
||||||
datasets: List of task names to test (None = all)
|
datasets: List of task names to test (None = all)
|
||||||
num_samples: Number of samples per task (None = all)
|
num_samples: Number of samples per task (None = all)
|
||||||
...other LLM config params...
|
sample_indices: Specific sample indices to test (overrides num_samples)
|
||||||
|
fresh_llm: If True, reinitialize LLM for each sample (avoids state leakage)
|
||||||
|
json_output: If True, output JSON results at the end
|
||||||
sparse_policy: Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)
|
sparse_policy: Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -251,21 +360,39 @@ def run_ruler_benchmark(
|
|||||||
else:
|
else:
|
||||||
tasks = datasets
|
tasks = datasets
|
||||||
|
|
||||||
# Sample indices
|
# Sample indices: explicit list takes precedence over num_samples
|
||||||
sample_indices = list(range(num_samples)) if num_samples else None
|
if sample_indices is not None:
|
||||||
|
indices = sample_indices
|
||||||
|
elif num_samples:
|
||||||
|
indices = list(range(num_samples))
|
||||||
|
else:
|
||||||
|
indices = None
|
||||||
|
|
||||||
|
samples_desc = str(sample_indices) if sample_indices else (str(num_samples) if num_samples else 'all')
|
||||||
|
|
||||||
|
if not json_output:
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"RULER Benchmark")
|
print(f"RULER Benchmark")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
print(f"Model: {model_path}")
|
print(f"Model: {model_path}")
|
||||||
print(f"Data dir: {data_dir}")
|
print(f"Data dir: {data_dir}")
|
||||||
print(f"Tasks: {len(tasks)}")
|
print(f"Tasks: {len(tasks)}")
|
||||||
print(f"Samples per task: {num_samples if num_samples else 'all'}")
|
print(f"Samples: {samples_desc}")
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
print(f"CPU offload: {enable_cpu_offload}")
|
||||||
|
print(f"Fresh LLM mode: {fresh_llm}")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
|
|
||||||
# Initialize LLM
|
# Enable DensityObserver for XAttention BSA
|
||||||
print("\nInitializing LLM...")
|
if sparse_policy and sparse_policy.upper() == "XATTN_BSA":
|
||||||
|
DensityObserver.enable()
|
||||||
|
DensityObserver.complete_reset()
|
||||||
|
# Set mode for correct density interpretation
|
||||||
|
DensityObserver.set_mode("offload" if enable_cpu_offload else "gpu_only")
|
||||||
|
if not json_output:
|
||||||
|
mode_str = "offload" if enable_cpu_offload else "gpu_only"
|
||||||
|
print(f"[DensityObserver] Enabled for XAttention BSA (mode: {mode_str})")
|
||||||
|
|
||||||
|
# LLM initialization kwargs
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
"max_model_len": max_model_len,
|
"max_model_len": max_model_len,
|
||||||
"max_num_batched_tokens": max_model_len,
|
"max_num_batched_tokens": max_model_len,
|
||||||
@@ -274,6 +401,8 @@ def run_ruler_benchmark(
|
|||||||
"kvcache_block_size": block_size,
|
"kvcache_block_size": block_size,
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
"enable_cpu_offload": enable_cpu_offload,
|
||||||
}
|
}
|
||||||
|
if dtype:
|
||||||
|
llm_kwargs["dtype"] = dtype
|
||||||
if enable_cpu_offload:
|
if enable_cpu_offload:
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||||
llm_kwargs["num_kv_buffers"] = num_kv_buffers
|
llm_kwargs["num_kv_buffers"] = num_kv_buffers
|
||||||
@@ -285,8 +414,18 @@ def run_ruler_benchmark(
|
|||||||
if sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
if sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
||||||
llm_kwargs["sparse_threshold"] = sparse_threshold
|
llm_kwargs["sparse_threshold"] = sparse_threshold
|
||||||
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
||||||
|
llm_kwargs["sparse_stride"] = sparse_stride
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
# Factory function for fresh_llm mode
|
||||||
|
def create_llm():
|
||||||
|
return LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
|
# Initialize LLM (only once if not fresh_llm mode)
|
||||||
|
llm = None
|
||||||
|
if not fresh_llm:
|
||||||
|
if not json_output:
|
||||||
|
print("\nInitializing LLM...")
|
||||||
|
llm = create_llm()
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -297,19 +436,23 @@ def run_ruler_benchmark(
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
task_name=task_name,
|
task_name=task_name,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
sample_indices=sample_indices,
|
sample_indices=indices,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
verbose=verbose,
|
verbose=verbose and not json_output,
|
||||||
|
llm_factory=create_llm,
|
||||||
|
fresh_llm=fresh_llm,
|
||||||
|
model_path=model_path,
|
||||||
)
|
)
|
||||||
task_results.append(result)
|
task_results.append(result)
|
||||||
|
|
||||||
if verbose:
|
if verbose and not json_output:
|
||||||
print(f" -> {task_name}: {result['correct']}/{result['total']} "
|
print(f" -> {task_name}: {result['correct']}/{result['total']} "
|
||||||
f"({result['accuracy']*100:.1f}%) avg_score={result['avg_score']:.3f}")
|
f"({result['accuracy']*100:.1f}%) avg_score={result['avg_score']:.3f}")
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup (only if not fresh_llm mode, since fresh mode cleans up itself)
|
||||||
|
if llm is not None:
|
||||||
del llm
|
del llm
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -320,7 +463,15 @@ def run_ruler_benchmark(
|
|||||||
overall_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
|
overall_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
|
||||||
avg_score = sum(r["avg_score"] for r in task_results) / len(task_results) if task_results else 0.0
|
avg_score = sum(r["avg_score"] for r in task_results) / len(task_results) if task_results else 0.0
|
||||||
|
|
||||||
|
# Collect failed samples
|
||||||
|
failed_samples = {}
|
||||||
|
for r in task_results:
|
||||||
|
failed = [res["index"] for res in r["results"] if not res["passed"]]
|
||||||
|
if failed:
|
||||||
|
failed_samples[r["task"]] = failed
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
|
if not json_output:
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"RULER Benchmark Results")
|
print(f"RULER Benchmark Results")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
@@ -331,17 +482,42 @@ def run_ruler_benchmark(
|
|||||||
print(f"{'-'*54}")
|
print(f"{'-'*54}")
|
||||||
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
|
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
|
||||||
print(f"\nTime: {total_time:.1f}s")
|
print(f"\nTime: {total_time:.1f}s")
|
||||||
|
|
||||||
|
# Print DensityObserver summary if enabled
|
||||||
|
if sparse_policy and sparse_policy.upper() == "XATTN_BSA" and DensityObserver.is_enabled():
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print("Density Statistics (XAttention BSA)")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
DensityObserver.print_summary()
|
||||||
|
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
return {
|
results = {
|
||||||
"total_correct": total_correct,
|
"total_correct": total_correct,
|
||||||
"total_samples": total_samples,
|
"total_samples": total_samples,
|
||||||
"overall_accuracy": overall_accuracy,
|
"overall_accuracy": overall_accuracy,
|
||||||
"avg_score": avg_score,
|
"avg_score": avg_score,
|
||||||
"time": total_time,
|
"time": total_time,
|
||||||
"task_results": task_results,
|
"task_results": task_results,
|
||||||
|
"failed_samples": failed_samples,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# JSON output
|
||||||
|
if json_output:
|
||||||
|
json_results = {
|
||||||
|
"total_correct": total_correct,
|
||||||
|
"total_samples": total_samples,
|
||||||
|
"overall_accuracy": overall_accuracy,
|
||||||
|
"avg_score": avg_score,
|
||||||
|
"time": total_time,
|
||||||
|
"tasks": {r["task"]: {"correct": r["correct"], "total": r["total"], "accuracy": r["accuracy"]}
|
||||||
|
for r in task_results},
|
||||||
|
"failed_samples": failed_samples,
|
||||||
|
}
|
||||||
|
print(json.dumps(json_results, indent=2))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# CLI Entry Point
|
# CLI Entry Point
|
||||||
@@ -361,6 +537,8 @@ if __name__ == "__main__":
|
|||||||
help="Comma-separated list of datasets to test (default: all)")
|
help="Comma-separated list of datasets to test (default: all)")
|
||||||
parser.add_argument("--num-samples", type=int, default=0,
|
parser.add_argument("--num-samples", type=int, default=0,
|
||||||
help="Number of samples per dataset (default: 0 = all)")
|
help="Number of samples per dataset (default: 0 = all)")
|
||||||
|
parser.add_argument("--sample-indices", type=str, default="",
|
||||||
|
help="Comma-separated specific sample indices (e.g., 28,33,40)")
|
||||||
parser.add_argument("--max-model-len", type=int, default=DEFAULT_MAX_MODEL_LEN,
|
parser.add_argument("--max-model-len", type=int, default=DEFAULT_MAX_MODEL_LEN,
|
||||||
help=f"Maximum model context length (default: {DEFAULT_MAX_MODEL_LEN})")
|
help=f"Maximum model context length (default: {DEFAULT_MAX_MODEL_LEN})")
|
||||||
parser.add_argument("--max-new-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS,
|
parser.add_argument("--max-new-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS,
|
||||||
@@ -369,8 +547,8 @@ if __name__ == "__main__":
|
|||||||
help="Enable CPU offload mode")
|
help="Enable CPU offload mode")
|
||||||
parser.add_argument("--num-gpu-blocks", type=int, default=4,
|
parser.add_argument("--num-gpu-blocks", type=int, default=4,
|
||||||
help="Number of GPU blocks for CPU offload (default: 4)")
|
help="Number of GPU blocks for CPU offload (default: 4)")
|
||||||
parser.add_argument("--block-size", type=int, default=1024,
|
parser.add_argument("--block-size", type=int, default=4096,
|
||||||
help="KV cache block size (default: 1024)")
|
help="KV cache block size (default: 4096)")
|
||||||
parser.add_argument("--num-kv-buffers", type=int, default=4,
|
parser.add_argument("--num-kv-buffers", type=int, default=4,
|
||||||
help="Number of KV buffers for ring buffer (default: 4)")
|
help="Number of KV buffers for ring buffer (default: 4)")
|
||||||
parser.add_argument("--gpu-utilization", type=float, default=0.9,
|
parser.add_argument("--gpu-utilization", type=float, default=0.9,
|
||||||
@@ -379,6 +557,10 @@ if __name__ == "__main__":
|
|||||||
help="Enable CUDA graph")
|
help="Enable CUDA graph")
|
||||||
parser.add_argument("--quiet", "-q", action="store_true",
|
parser.add_argument("--quiet", "-q", action="store_true",
|
||||||
help="Quiet mode")
|
help="Quiet mode")
|
||||||
|
parser.add_argument("--fresh-llm", action="store_true",
|
||||||
|
help="Reinitialize LLM for each sample (avoids state leakage)")
|
||||||
|
parser.add_argument("--json-output", action="store_true",
|
||||||
|
help="Output results in JSON format")
|
||||||
parser.add_argument("--sparse-policy", type=str, default="",
|
parser.add_argument("--sparse-policy", type=str, default="",
|
||||||
help="Sparse attention policy (FULL, QUEST, XATTN_BSA)")
|
help="Sparse attention policy (FULL, QUEST, XATTN_BSA)")
|
||||||
# XAttention BSA specific parameters
|
# XAttention BSA specific parameters
|
||||||
@@ -388,6 +570,10 @@ if __name__ == "__main__":
|
|||||||
help="XAttention BSA: samples per chunk for estimation")
|
help="XAttention BSA: samples per chunk for estimation")
|
||||||
parser.add_argument("--sparse-block-size", type=int, default=128,
|
parser.add_argument("--sparse-block-size", type=int, default=128,
|
||||||
help="XAttention BSA: block size for estimation")
|
help="XAttention BSA: block size for estimation")
|
||||||
|
parser.add_argument("--sparse-stride", type=int, default=8,
|
||||||
|
help="XAttention BSA: stride for Q/K downsampling")
|
||||||
|
parser.add_argument("--dtype", type=str, default=None,
|
||||||
|
help="Model dtype (bfloat16, float16). Required for models with float32 default.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -395,6 +581,11 @@ if __name__ == "__main__":
|
|||||||
datasets = args.datasets.split(",") if args.datasets else None
|
datasets = args.datasets.split(",") if args.datasets else None
|
||||||
num_samples = args.num_samples if args.num_samples > 0 else None
|
num_samples = args.num_samples if args.num_samples > 0 else None
|
||||||
|
|
||||||
|
# Parse sample indices (takes precedence over num_samples)
|
||||||
|
sample_indices = None
|
||||||
|
if args.sample_indices:
|
||||||
|
sample_indices = [int(x.strip()) for x in args.sample_indices.split(",")]
|
||||||
|
|
||||||
# Parse sparse policy
|
# Parse sparse policy
|
||||||
sparse_policy_str = args.sparse_policy.upper() if args.sparse_policy else None
|
sparse_policy_str = args.sparse_policy.upper() if args.sparse_policy else None
|
||||||
|
|
||||||
@@ -403,6 +594,7 @@ if __name__ == "__main__":
|
|||||||
data_dir=Path(args.data_dir),
|
data_dir=Path(args.data_dir),
|
||||||
datasets=datasets,
|
datasets=datasets,
|
||||||
num_samples=num_samples,
|
num_samples=num_samples,
|
||||||
|
sample_indices=sample_indices,
|
||||||
max_model_len=args.max_model_len,
|
max_model_len=args.max_model_len,
|
||||||
max_new_tokens=args.max_new_tokens,
|
max_new_tokens=args.max_new_tokens,
|
||||||
enable_cpu_offload=args.enable_offload,
|
enable_cpu_offload=args.enable_offload,
|
||||||
@@ -412,13 +604,18 @@ if __name__ == "__main__":
|
|||||||
gpu_utilization=args.gpu_utilization,
|
gpu_utilization=args.gpu_utilization,
|
||||||
enforce_eager=not args.use_cuda_graph,
|
enforce_eager=not args.use_cuda_graph,
|
||||||
verbose=not args.quiet,
|
verbose=not args.quiet,
|
||||||
|
fresh_llm=args.fresh_llm,
|
||||||
|
json_output=args.json_output,
|
||||||
sparse_policy=sparse_policy_str,
|
sparse_policy=sparse_policy_str,
|
||||||
sparse_threshold=args.sparse_threshold,
|
sparse_threshold=args.sparse_threshold,
|
||||||
sparse_samples=args.sparse_samples,
|
sparse_samples=args.sparse_samples,
|
||||||
sparse_block_size=args.sparse_block_size,
|
sparse_block_size=args.sparse_block_size,
|
||||||
|
sparse_stride=args.sparse_stride,
|
||||||
|
dtype=args.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exit code
|
# Exit code (skip for json output mode)
|
||||||
|
if not args.json_output:
|
||||||
if results["overall_accuracy"] >= 0.5:
|
if results["overall_accuracy"] >= 0.5:
|
||||||
print("test_ruler: PASSED")
|
print("test_ruler: PASSED")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,199 +0,0 @@
|
|||||||
"""
|
|
||||||
Sequential inference test for LLM.
|
|
||||||
|
|
||||||
Tests: After completing one prompt, the system can correctly handle
|
|
||||||
a second prompt with a clean state (first prompt's KV cache deallocated).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from utils import generate_needle_prompt, check_needle_answer
|
|
||||||
|
|
||||||
|
|
||||||
def run_sequential_test(
|
|
||||||
model_path: str,
|
|
||||||
max_model_len: int,
|
|
||||||
input_len: int,
|
|
||||||
num_gpu_blocks: int = 4,
|
|
||||||
block_size: int = 1024,
|
|
||||||
enable_cpu_offload: bool = False,
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Run sequential inference test with two different prompts.
|
|
||||||
|
|
||||||
Each prompt has a different needle value. Both must be retrieved correctly.
|
|
||||||
"""
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Sequential Inference Test")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Max model len: {max_model_len}")
|
|
||||||
print(f"Input length: {input_len}")
|
|
||||||
print(f"Block size: {block_size}")
|
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# Initialize LLM once
|
|
||||||
llm_kwargs = {
|
|
||||||
"enforce_eager": True,
|
|
||||||
"max_model_len": max_model_len,
|
|
||||||
"max_num_batched_tokens": max_model_len,
|
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
|
||||||
"kvcache_block_size": block_size,
|
|
||||||
}
|
|
||||||
if enable_cpu_offload:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=32,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test 1: First prompt with needle value "1234"
|
|
||||||
# ============================================================
|
|
||||||
needle_value_1 = "1234"
|
|
||||||
if verbose:
|
|
||||||
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
|
|
||||||
|
|
||||||
prompt_1, expected_1 = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=0.5,
|
|
||||||
needle_value=needle_value_1,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
|
|
||||||
output_text_1 = outputs_1[0]["text"]
|
|
||||||
passed_1 = check_needle_answer(output_text_1, expected_1)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f" Expected: {expected_1}")
|
|
||||||
print(f" Output: {output_text_1[:100]}...")
|
|
||||||
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test 2: Second prompt with needle value "5678"
|
|
||||||
# ============================================================
|
|
||||||
needle_value_2 = "5678"
|
|
||||||
if verbose:
|
|
||||||
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
|
|
||||||
|
|
||||||
prompt_2, expected_2 = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=0.5,
|
|
||||||
needle_value=needle_value_2,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
|
|
||||||
output_text_2 = outputs_2[0]["text"]
|
|
||||||
passed_2 = check_needle_answer(output_text_2, expected_2)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f" Expected: {expected_2}")
|
|
||||||
print(f" Output: {output_text_2[:100]}...")
|
|
||||||
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
|
|
||||||
# ============================================================
|
|
||||||
needle_value_3 = "9999"
|
|
||||||
if verbose:
|
|
||||||
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
|
|
||||||
|
|
||||||
prompt_3, expected_3 = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=0.5,
|
|
||||||
needle_value=needle_value_3,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
|
|
||||||
output_text_3 = outputs_3[0]["text"]
|
|
||||||
passed_3 = check_needle_answer(output_text_3, expected_3)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f" Expected: {expected_3}")
|
|
||||||
print(f" Output: {output_text_3[:100]}...")
|
|
||||||
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Summary
|
|
||||||
# ============================================================
|
|
||||||
all_passed = passed_1 and passed_2 and passed_3
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Summary")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
|
|
||||||
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
|
|
||||||
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
|
|
||||||
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return all_passed
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Sequential inference test")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
type=str,
|
|
||||||
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
|
|
||||||
help="Path to model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-model-len",
|
|
||||||
type=int,
|
|
||||||
default=36 * 1024,
|
|
||||||
help="Maximum model context length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=8 * 1024,
|
|
||||||
help="Target input sequence length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-gpu-blocks",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Number of GPU blocks for CPU offload"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--block-size",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="KV cache block size"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-offload",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CPU offload"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
passed = run_sequential_test(
|
|
||||||
model_path=args.model,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
input_len=args.input_len,
|
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
|
||||||
block_size=args.block_size,
|
|
||||||
enable_cpu_offload=args.enable_offload,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_sequential: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_sequential: FAILED")
|
|
||||||
exit(1)
|
|
||||||
365
tests/test_xattn_estimate_alignment.py
Normal file
365
tests/test_xattn_estimate_alignment.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""
|
||||||
|
Test: 验证 xattn_estimate 与 KV chunking kernels 的一致性
|
||||||
|
|
||||||
|
使用真实 KV cache 数据,对比:
|
||||||
|
1. xattn_estimate (高层 API)
|
||||||
|
2. 三阶段 KV chunking (softmax_compute_partial_stats + merge + normalize)
|
||||||
|
|
||||||
|
三阶段 KV chunking 流程:
|
||||||
|
1. softmax_compute_partial_stats: 计算每个 KV chunk 的 (m, l)
|
||||||
|
2. merge_softmax_stats: Host 端合并所有 chunks 的 stats
|
||||||
|
3. softmax_normalize_and_block_sum: 使用全局 stats 归一化
|
||||||
|
|
||||||
|
支持两种数据格式:
|
||||||
|
1. offload 模式保存: {"query", "key", "stride", "threshold", "density", "layer_id"}
|
||||||
|
2. GPU-only 模式保存: {"Q", "K", "chunk_size", "block_size", "stride", "threshold", "mask", "attn_sums", ...}
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# 使用 offload 模式数据
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_alignment.py
|
||||||
|
|
||||||
|
# 使用 GPU-only 模式数据
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_alignment.py --gpuonly
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from nanovllm.ops.xattn import (
|
||||||
|
xattn_estimate,
|
||||||
|
flat_group_gemm_fuse_reshape,
|
||||||
|
softmax_compute_partial_stats,
|
||||||
|
softmax_normalize_and_block_sum,
|
||||||
|
merge_softmax_stats,
|
||||||
|
find_blocks_chunked,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 命令行参数
|
||||||
|
# ============================================================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--gpuonly", action="store_true", help="使用 GPU-only 模式保存的数据")
|
||||||
|
parser.add_argument("--data-file", type=str, default=None, help="数据文件路径")
|
||||||
|
parser.add_argument("--chunk-size", type=int, default=None, help="覆盖 CHUNK_SIZE (用于测试不同分块大小)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 参数配置
|
||||||
|
# ============================================================
|
||||||
|
if args.gpuonly:
|
||||||
|
DATA_FILE = args.data_file or "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt"
|
||||||
|
else:
|
||||||
|
DATA_FILE = args.data_file or "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt"
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 1: 加载真实数据
|
||||||
|
# ============================================================
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 1: 加载真实 KV cache 数据")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
data = torch.load(DATA_FILE, map_location="cpu")
|
||||||
|
|
||||||
|
# 检测数据格式并加载
|
||||||
|
if "Q" in data:
|
||||||
|
# GPU-only 模式保存的格式
|
||||||
|
print(f"[INFO] 检测到 GPU-only 模式数据格式")
|
||||||
|
Q = data["Q"].to(device)
|
||||||
|
K = data["K"].to(device)
|
||||||
|
BSA_BLOCK_SIZE = data.get("block_size", 128)
|
||||||
|
CHUNK_SIZE = data.get("chunk_size", 4096)
|
||||||
|
STRIDE = data.get("stride", 8)
|
||||||
|
THRESHOLD = data.get("threshold", 0.9)
|
||||||
|
if isinstance(THRESHOLD, torch.Tensor):
|
||||||
|
THRESHOLD = THRESHOLD.item()
|
||||||
|
# GPU-only 模式保存了 mask 和 attn_sums,可以用于验证
|
||||||
|
saved_mask = data.get("mask", None)
|
||||||
|
saved_attn_sums = data.get("attn_sums", None)
|
||||||
|
saved_density = None # GPU-only 模式没有保存 density
|
||||||
|
layer_id = 0 # GPU-only 只保存 layer 0
|
||||||
|
else:
|
||||||
|
# offload 模式保存的格式
|
||||||
|
print(f"[INFO] 检测到 offload 模式数据格式")
|
||||||
|
Q = data["query"].to(device)
|
||||||
|
K = data["key"].to(device)
|
||||||
|
BSA_BLOCK_SIZE = 128
|
||||||
|
CHUNK_SIZE = 4096
|
||||||
|
STRIDE = data["stride"]
|
||||||
|
THRESHOLD = data["threshold"]
|
||||||
|
if isinstance(THRESHOLD, torch.Tensor):
|
||||||
|
THRESHOLD = THRESHOLD[0].item()
|
||||||
|
saved_mask = None
|
||||||
|
saved_attn_sums = None
|
||||||
|
saved_density = data.get("density", None)
|
||||||
|
layer_id = data.get("layer_id", 0)
|
||||||
|
|
||||||
|
batch_size, num_heads, seq_len, head_dim = Q.shape
|
||||||
|
|
||||||
|
# 命令行覆盖 CHUNK_SIZE
|
||||||
|
if args.chunk_size is not None:
|
||||||
|
CHUNK_SIZE = args.chunk_size
|
||||||
|
print(f"[INFO] 使用命令行指定的 CHUNK_SIZE={CHUNK_SIZE}")
|
||||||
|
|
||||||
|
print(f"Q shape: {Q.shape}")
|
||||||
|
print(f"K shape: {K.shape}")
|
||||||
|
if saved_density is not None:
|
||||||
|
print(f"Data layer_id: {layer_id}, saved density: {saved_density:.4f}")
|
||||||
|
else:
|
||||||
|
print(f"Data layer_id: {layer_id}")
|
||||||
|
print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}, CHUNK_SIZE={CHUNK_SIZE}, BSA_BLOCK_SIZE={BSA_BLOCK_SIZE}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 2: 使用 xattn_estimate 高层 API
|
||||||
|
# ============================================================
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 2: 调用 xattn_estimate (高层 API)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
attn_sums_api, mask_api = xattn_estimate(
|
||||||
|
Q, K,
|
||||||
|
block_size=BSA_BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 裁剪到有效区域
|
||||||
|
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||||
|
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||||
|
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
|
||||||
|
|
||||||
|
# 计算 density (causal)
|
||||||
|
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
|
||||||
|
total_api = causal_mask.sum().item() * batch_size * num_heads
|
||||||
|
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
|
density_api = selected_api / total_api
|
||||||
|
|
||||||
|
print(f"mask_api shape (padded): {mask_api.shape}")
|
||||||
|
print(f"mask_api_valid shape: {mask_api_valid.shape}")
|
||||||
|
print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, total={total_api})")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 3: 三阶段 KV Chunking
|
||||||
|
# ============================================================
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 3: 三阶段 KV Chunking")
|
||||||
|
print("=" * 60)
|
||||||
|
print(" 1) 每个 KV chunk 计算 partial stats")
|
||||||
|
print(" 2) Host 端合并 stats")
|
||||||
|
print(" 3) 使用全局 stats 归一化并计算 block sums")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 计算 padding 参数
|
||||||
|
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||||
|
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||||
|
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
|
||||||
|
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||||
|
|
||||||
|
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
|
||||||
|
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
|
||||||
|
|
||||||
|
reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||||
|
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
||||||
|
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
|
||||||
|
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
|
||||||
|
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||||
|
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||||
|
|
||||||
|
print(f"seq_len: {seq_len}, q_chunk_num: {q_chunk_num}, kv_chunk_num: {kv_chunk_num}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Padding
|
||||||
|
if k_num_to_pad > 0:
|
||||||
|
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
|
||||||
|
else:
|
||||||
|
K_padded = K
|
||||||
|
|
||||||
|
if q_num_to_pad > 0:
|
||||||
|
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
|
||||||
|
else:
|
||||||
|
Q_padded = Q
|
||||||
|
|
||||||
|
# Softmax scale
|
||||||
|
norm = 1.0
|
||||||
|
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
||||||
|
|
||||||
|
simple_mask_list = []
|
||||||
|
|
||||||
|
for q_chunk_idx in range(q_chunk_num):
|
||||||
|
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
|
||||||
|
q_end = q_start + reshaped_chunk_size * STRIDE
|
||||||
|
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||||
|
|
||||||
|
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
|
||||||
|
chunk_end = chunk_start + reshaped_chunk_size
|
||||||
|
|
||||||
|
# 阶段 1: 每个 KV chunk 计算 partial stats 和 raw scores
|
||||||
|
m_chunks = []
|
||||||
|
l_chunks = []
|
||||||
|
attn_weights_chunks = []
|
||||||
|
|
||||||
|
for kv_chunk_idx in range(kv_chunk_num):
|
||||||
|
kv_start = kv_chunk_idx * CHUNK_SIZE
|
||||||
|
kv_end = kv_start + CHUNK_SIZE
|
||||||
|
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||||
|
|
||||||
|
# KV offset in reshaped space
|
||||||
|
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||||
|
|
||||||
|
# 计算 raw attention scores
|
||||||
|
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||||
|
Q_chunk, K_chunk, STRIDE,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False, # K 不完整,不能在这里用 causal
|
||||||
|
)
|
||||||
|
attn_weights_chunks.append(attn_weights_kv)
|
||||||
|
|
||||||
|
# 计算 partial stats (带 causal mask)
|
||||||
|
m_partial, l_partial = softmax_compute_partial_stats(
|
||||||
|
attn_weights_kv,
|
||||||
|
reshaped_block_size,
|
||||||
|
min(4096, reshaped_block_size),
|
||||||
|
scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset_reshaped,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
m_chunks.append(m_partial)
|
||||||
|
l_chunks.append(l_partial)
|
||||||
|
|
||||||
|
# 阶段 2: Host 端合并 stats
|
||||||
|
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||||
|
|
||||||
|
# 阶段 3: 使用全局 stats 归一化并计算 block sums
|
||||||
|
attn_sum_per_kv = []
|
||||||
|
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
|
||||||
|
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||||
|
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_kv,
|
||||||
|
m_global,
|
||||||
|
l_global,
|
||||||
|
reshaped_block_size,
|
||||||
|
min(4096, reshaped_block_size),
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||||
|
scale=scale,
|
||||||
|
kv_offset=kv_offset_reshaped,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
attn_sum_per_kv.append(attn_sum_kv)
|
||||||
|
|
||||||
|
# 拼接各 KV chunk 的 block sums
|
||||||
|
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||||
|
|
||||||
|
# 选择 blocks
|
||||||
|
simple_mask = find_blocks_chunked(
|
||||||
|
attn_sum_concat,
|
||||||
|
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
num_to_choose=None,
|
||||||
|
decoding=False,
|
||||||
|
mode="prefill",
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
simple_mask_list.append(simple_mask)
|
||||||
|
|
||||||
|
print(f" Q chunk {q_chunk_idx}: merged {kv_chunk_num} KV chunks, attn_sum shape={attn_sum_concat.shape}")
|
||||||
|
|
||||||
|
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
|
||||||
|
|
||||||
|
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
|
||||||
|
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||||
|
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
|
||||||
|
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
|
||||||
|
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
|
density_kv = selected_kv / total_api
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"[KV chunking] density: {density_kv:.6f} (selected={selected_kv}, total={total_api})")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 4: 对比结果
|
||||||
|
# ============================================================
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 4: 对比结果")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
mask_total = mask_api_valid.numel()
|
||||||
|
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
|
||||||
|
|
||||||
|
print("| 方法 | density | 与 API 差异 | Mask 差异 |")
|
||||||
|
print("|------|---------|-------------|-----------|")
|
||||||
|
print(f"| xattn_estimate API | {density_api:.6f} | - | - |")
|
||||||
|
print(f"| KV chunking | {density_kv:.6f} | {abs(density_api - density_kv):.6f} | {100*mask_diff/mask_total:.4f}% |")
|
||||||
|
print()
|
||||||
|
|
||||||
|
passed = abs(density_api - density_kv) < 1e-6 and mask_diff / mask_total < 0.001
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Step 5: 与 GPU-only 保存的数据对比 (如果有)
|
||||||
|
# ============================================================
|
||||||
|
if saved_mask is not None or saved_attn_sums is not None:
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 5: 与 GPU-only 保存的数据对比")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if saved_mask is not None:
|
||||||
|
saved_mask_gpu = saved_mask.to(device)
|
||||||
|
# 比较 mask
|
||||||
|
mask_saved_diff = (mask_api_valid != saved_mask_gpu).sum().item()
|
||||||
|
mask_saved_total = saved_mask_gpu.numel()
|
||||||
|
print(f"| xattn_estimate vs GPU-only saved mask | 差异 blocks: {mask_saved_diff} / {mask_saved_total} ({100*mask_saved_diff/mask_saved_total:.4f}%) |")
|
||||||
|
|
||||||
|
if mask_saved_diff == 0:
|
||||||
|
print("✅ mask 与 GPU-only 保存完全一致")
|
||||||
|
else:
|
||||||
|
print("❌ mask 与 GPU-only 保存存在差异")
|
||||||
|
passed = False
|
||||||
|
|
||||||
|
if saved_attn_sums is not None:
|
||||||
|
saved_attn_sums_gpu = saved_attn_sums.to(device)
|
||||||
|
# 需要从 xattn_estimate 获取 attn_sums
|
||||||
|
# 重新调用一次获取 attn_sums
|
||||||
|
attn_sums_check, _ = xattn_estimate(
|
||||||
|
Q, K,
|
||||||
|
block_size=BSA_BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn_sums_check_valid = attn_sums_check[:, :, :q_blocks, :k_blocks]
|
||||||
|
|
||||||
|
max_diff = (attn_sums_check_valid - saved_attn_sums_gpu).abs().max().item()
|
||||||
|
mean_diff = (attn_sums_check_valid - saved_attn_sums_gpu).abs().mean().item()
|
||||||
|
print(f"| xattn_estimate vs GPU-only saved attn_sums | max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e} |")
|
||||||
|
|
||||||
|
if max_diff < 1e-5:
|
||||||
|
print("✅ attn_sums 与 GPU-only 保存一致")
|
||||||
|
else:
|
||||||
|
print("❌ attn_sums 与 GPU-only 保存存在差异")
|
||||||
|
passed = False
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
print("test_xattn_estimate_alignment: PASSED")
|
||||||
|
else:
|
||||||
|
print("test_xattn_estimate_alignment: FAILED")
|
||||||
Reference in New Issue
Block a user