diff --git a/docs/chunked_prefill_integration_plan.md b/docs/chunked_prefill_integration_plan.md new file mode 100644 index 0000000..04d9b7d --- /dev/null +++ b/docs/chunked_prefill_integration_plan.md @@ -0,0 +1,354 @@ +# Chunked Prefill 集成计划 + +**目标**: 将 tzj/minference 分支的 chunked prefill 机制移植到 tzj/vs_offload 分支 + +**创建日期**: 2026-01-18 +**基础分支**: `tzj/vs_offload` +**源分支**: `tzj/minference` + +--- + +## 目标 + +在 tzj/vs_offload 分支上实现 chunked prefill + layerwise offload 机制,支持在 24GB RTX 3090 上运行任意长度的推理(4M, 8M, 16M+ tokens)。 + +--- + +## 核心问题 + +### tzj/vs_offload 分支的局限性 + +当前 tzj/vs_offload 分支的 GPU ring buffer 按 `max_seq_len` 分配,导致 GPU 内存随序列长度线性增长: + +```python +# 当前设计 +self.layer_k_cache = torch.zeros( + num_kv_buffers, # e.g., 4 + max_seq_len, # e.g., 131072 tokens + kv_heads, + head_dim, + dtype=dtype, device="cuda" +) +``` + +**问题**: +- GPU 内存需求 ~ `max_seq_len × 4 × 8 × 128 × 2 bytes` +- 对于超长序列不可行: + - 4M tokens → ~64 GB GPU 内存 ❌ + - 8M tokens → ~128 GB GPU 内存 ❌ + +### 解决方案:Block-Based 设计 + +tzj/minference 分支采用 block-based 设计,GPU 内存固定: + +```python +# Block-based 设计 +self.k_cache_gpu = torch.zeros( + num_gpu_blocks, # e.g., 2 + block_size, # e.g., 1024 tokens (固定!) + kv_heads, + head_dim, + dtype=dtype, device="cuda" +) +# GPU 内存: ~4 MB (固定,不随序列长度增长) +``` + +**优势**: +- GPU 内存固定(~1.6 GB),不随序列长度增长 +- 24GB RTX 3090 可运行 4M+ tokens +- 通过 chunked prefill 分块处理超长序列 + +--- + +## 内存布局对比 + +| 组件 | tzj/vs_offload | tzj/minference | 说明 | +|------|---------------|----------------|------| +| **GPU Ring Buffer** | `[num_kv_buffers, max_seq_len, ...]` | `[num_gpu_blocks, block_size, ...]` | minference 无 layer 维度 | +| **GPU 内存** | ~2.15 GB (128K) → ~64 GB (4M) | ~4 MB (固定) | minference 节省显著 | +| **Prefill Buffer** | ❌ 无 | ✅ `[num_layers, block_size, ...]` | minference 独有 | +| **Pipeline Buffers** | ❌ 无 | ✅ 双缓冲区 `[blocks, block_size, ...]` | minference 独有 | +| **CPU Cache** | `[num_layers, num_cpu_blocks, block_size, ...]` | 相同 | **一致** | + +### 序列长度支持对比 + +| 序列长度 | vs_offload GPU 内存 | minference GPU 内存 | RTX 3090 (24GB) | +|----------|-------------------|---------------------|-----------------| +| 128K tokens | ~2.15 GB | ~4 MB | ✅ 两者均可 | +| 1M tokens | ~16 GB | ~4 MB | ✅ 两者均可 | +| **4M tokens** | **~64 GB** ❌ | **~4 MB** ✅ | **仅 minference 可行** | +| **8M tokens** | **~128 GB** ❌ | **~4 MB** ✅ | **仅 minference 可行** | +| **16M+ tokens** | **~256 GB+** ❌ | **~4 MB** ✅ | **仅 minference 可行** | + +--- + +## 关键设计原则 + +1. **Block-Based 设计**:按 `block_size` (1024 tokens) 组织,支持 chunked prefill +2. **GPU 内存固定**:不随序列长度增长,是 constant factor +3. **CPU 内存线性缩放**:`num_cpu_blocks = ceil(seq_len / block_size)` +4. **Unified Ring Buffer**:无 layer 维度,所有层共享 slots +5. **完全并行 offload**:per-layer buffer 最大化 PCIe 带宽 + +--- + +## 统一内存布局设计 + +### GPU Memory Layout + +```python +class OffloadEngine: + # 1. Unified Ring Buffer - Block-based,无 layer 维度 + self.k_cache_gpu = torch.zeros( + num_gpu_blocks, # e.g., 2 + block_size, # e.g., 1024 + kv_heads, + head_dim, + dtype=dtype, device="cuda" + ) # ~4 MB (固定) + + # 2. Per-layer Prefill Buffer - 完全并行 offload + self.prefill_k_buffer = torch.zeros( + num_layers, block_size, kv_heads, head_dim, + dtype=dtype, device="cuda" + ) # ~58 MB (固定) + + # 3. Cross-layer Pipeline Buffers - Double-buffering + self.layer_k_buffer_a = torch.zeros( + max_prefill_blocks, block_size, kv_heads, head_dim, + dtype=dtype, device="cuda" + ) # ~512 MB (固定) + self.layer_k_buffer_b = torch.zeros(...) # ~512 MB (固定) + + # 4. Per-layer Decode Buffer + self.decode_k_buffer = torch.zeros( + num_layers, block_size, kv_heads, head_dim, + dtype=dtype, device="cuda" + ) # ~58 MB (固定) + + # GPU 总计:~1.6 GB (固定,不随序列长度增长) +``` + +### CPU Memory Layout + +```python + # CPU Cache - 有 block 维度 + self.k_cache_cpu = torch.zeros( + num_layers, + num_cpu_blocks, # 随序列长度缩放 + block_size, + kv_heads, + head_dim, + dtype=dtype, device="cpu", pin_memory=True + ) + # 128K tokens: ~2.9 GB + # 1M tokens: ~5.8 GB + # 4M tokens: ~23.3 GB +``` + +--- + +## Chunked Prefill 流程 + +### Prefill 阶段 + +``` +For each chunk: +├── 1. Prepare chunk input (block_size tokens) +├── 2. Get ring buffer slot: slot = chunk_idx % num_gpu_blocks +├── 3. Load previous KV chunks to ring slots[1..N-1] +├── 4. Model Forward (all layers) +│ For each layer: +│ ├── Load previous KV from ring slots +│ ├── Compute attention (current chunk + previous) +│ ├── Write KV to prefill_buffer[layer_id] ← Per-layer! +│ └── Async offload to CPU (parallel across layers) +├── 5. Merge attention outputs (LSE) +└── 6. Record compute done for slot + +Key: Per-layer prefill buffer → Layer 0 offload || Layer 1 compute || Layer 2 load ... +``` + +### Decode 阶段 + +``` +├── 1. Setup pipeline: preload Layer 0 to buffer_a +├── 2. For each layer: +│ ├── Get KV from pipeline buffer (a or b) +│ ├── Trigger preload of next layer to other buffer +│ ├── Compute attention +│ └── Store to decode buffer +└── 3. End pipeline + +Key: Double-buffering → Layer N compute || Layer N+1 load +``` + +--- + +## 合并策略 + +### 基础分支选择:tzj/vs_offload + +**原因**: +1. 更完善的文档系统 +2. 更完整的 sparse attention 实现(QUEST, XAttention 等) +3. 更清晰的代码组织和注释 +4. 更活跃的开发维护 + +### 移植策略 + +**从 tzj/minference 移植**: +1. GPU cache 内存布局(无 layer 维度,block-based) +2. Per-layer prefill buffer +3. Cross-layer pipeline buffers +4. Chunked prefill 流程 +5. LSE 在线合并机制 + +**保留 tzj/vs_offload 优势**: +1. 文档系统 +2. Sparse policy 架构 +3. 代码组织和注释 + +--- + +## Sparse Policy 策略 + +**策略**:保留架构,现阶段仅实现 FULL + +- **保留** sparse policy 的架构设计和接口 +- **预留** 扩展接口给未来的 QUEST 等其他策略 +- **现阶段仅实现** FULL 策略,确保正确性和稳定性 + +### 实现 + +```python +class SparsePolicy(ABC): + @property + def supports_prefill(self) -> bool: + return False + + @property + def supports_decode(self) -> bool: + return True + + def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens): + """预留给未来策略(如 QUEST 收集元数据)""" + pass + + def select_blocks(self, available_blocks, context) -> List[int]: + """FULL: 返回所有可用块""" + return available_blocks + +class FullAttentionPolicy(SparsePolicy): + @property + def supports_prefill(self) -> bool: + return True + + @property + def supports_decode(self) -> bool: + return True +``` + +--- + +## 关键 API + +### Ring Buffer 管理 + +```python +# Prefill 阶段 +get_write_slot_for_prefill(chunk_idx) -> slot_idx +get_load_slots_for_prefill(write_slot_idx) -> [slot_ids] + +# Decode 阶段 +get_load_slots_for_decode() -> [slot_ids] (excludes decode_slot) +``` + +### Per-layer 操作 + +```python +# 加载 +load_to_slot_layer(slot_idx, layer_id, cpu_block_id) +wait_slot_layer(slot_idx) + +# Prefill buffer +get_prefill_buffer(layer_id) -> (k, v) +offload_prefill_buffer_async(layer_id, cpu_block_id, num_tokens) +wait_prefill_offload(layer_id) + +# Pipeline +start_decode_pipeline(cpu_block_ids) +get_decode_layer_kv(layer_id, num_blocks) -> (k, v) +end_decode_pipeline() +``` + +--- + +## 实施阶段 + +### Phase 1: 内存布局重构 +- 修改 GPU cache 为 unified ring buffer +- 添加 per-layer prefill buffer +- 添加 cross-layer pipeline buffers + +### Phase 2: API 实现 +- 实现 ring buffer slot 管理 API +- 实现 per-layer prefill offload API +- 实现 cross-layer pipeline API + +### Phase 3: 集成到 Attention Layer +- 修改 attention forward 流程 +- 集成 per-layer prefill buffer +- 集成 cross-layer pipeline + +### Phase 4: 集成到 Model Runner +- 实现 chunked prefill 流程 +- 集成 LSE 合并 +- 优化流水线 + +### Phase 5: Sparse Policy 集成(FULL) +- 设计统一的策略接口 +- 实现 FullAttentionPolicy +- 预留 QUEST 等未来策略的扩展接口 + +--- + +## 关键决策 + +1. **Block-Based 设计优先**:支持任意长度推理的核心 +2. **采用 tzj/minference 的内存布局**:GPU cache 无 layer 维度 + block-based +3. **以 tzj/vs_offload 为基础分支**:更好的文档和代码组织 +4. **分阶段合并策略**:降低复杂度,便于验证 +5. **Sparse Policy - FULL 优先**:保留架构,现阶段仅实现 FULL + +--- + +## 预期结果 + +### 内存使用(28层模型,block_size=1024) + +| 组件 | 内存 | +|------|------| +| GPU Unified Ring Buffer | ~4 MB | +| GPU Per-layer Prefill Buffer | ~58 MB | +| GPU Pipeline Buffers (×2) | ~1 GB | +| GPU Decode Buffer | ~58 MB | +| **GPU 总计** | **~1.6 GB (固定)** | +| CPU Cache (4M tokens) | ~23.3 GB | +| **总计 (4M tokens)** | **~24.9 GB** ✅ 适配 24GB RTX 3090 | + +### 性能支持 + +- ✅ 支持 4M, 8M, 16M+ tokens 的推理 +- ✅ GPU 内存固定,不随序列长度增长 +- ✅ 完全并行的 layerwise offload +- ✅ Cross-layer 流水线优化 + +--- + +## 参考 + +- **OffloadEngine**: `nanovllm/kvcache/offload_engine.py` +- **Attention Layer**: `nanovllm/layers/attention.py` +- **Model Runner**: `nanovllm/engine/model_runner.py` +- **Sparse Policy**: `nanovllm/kvcache/sparse/policy.py`