From 69b779e252846780b9bb4e44257ac3661e76bba7 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Thu, 22 Jan 2026 06:04:36 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20docs:=20add=20layer=20offload=20?= =?UTF-8?q?planning=20notes=20and=20task=20plan?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add planning documents for layer-wise offload implementation: - notes.md: Implementation notes and findings - task_plan.md: Detailed task breakdown and progress tracking Co-Authored-By: Claude Opus 4.5 --- notes.md | 130 ++++++++++++ task_plan.md | 549 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 679 insertions(+) create mode 100644 notes.md create mode 100644 task_plan.md diff --git a/notes.md b/notes.md new file mode 100644 index 0000000..284c5cc --- /dev/null +++ b/notes.md @@ -0,0 +1,130 @@ +# Notes: SparsePolicy Refactoring Research + +## Sources + +### Source 1: tzj/minference branch - policy.py +- 路径: `nanovllm/kvcache/sparse/policy.py` +- 关键设计: + - `PolicyContext` 数据类包含 query_chunk_idx, num_query_chunks, layer_id, query, is_prefill 等 + - `select_blocks()` 需要 offload_engine 参数 + - `compute_chunked_prefill()` 和 `compute_chunked_decode()` 是完整的 attention 流程 + - `on_prefill_offload()` / `on_decode_offload()` hooks 用于收集元数据 + +### Source 2: tzj/minference branch - full_policy.py +- 路径: `nanovllm/kvcache/sparse/full_policy.py` +- 关键实现: + - `compute_chunked_prefill()` 内部使用 ring buffer pipeline 加载 blocks + - 使用 `flash_attn_with_lse` 和 `merge_attention_outputs` 合并多个 chunk 的 attention + - `compute_chunked_decode()` 处理 prefilled blocks + decode buffer + +### Source 3: tzj/layer-offload branch - model_runner.py +- 路径: `nanovllm/engine/model_runner.py` +- 关键设计: + - `run_layerwise_offload_prefill()` 逐层处理,每层计算完整 attention + - `sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)` 简单接口 + - FULL policy 通过 `if sparse_prefill_policy is None` 走 else 分支 + +### Source 4: tzj/layer-offload branch - xattn.py +- 路径: `nanovllm/kvcache/sparse/xattn.py` +- 关键实现: + - `sparse_prefill_attention()` 直接使用 FlashAttention(因为 chunked prefill 架构限制) + - 保留 Triton kernels 供未来 GPU-only 模式 + +## Synthesized Findings + +### 架构差异总结 + +| 方面 | Chunked Offload | Layerwise Offload | +|------|-----------------|-------------------| +| **Prefill 流程** | chunk-by-chunk,跨层 | layer-by-layer,完整序列 | +| **KV 存储** | 每 chunk 立即 offload | 每层计算后 offload | +| **Attention 计算** | 分多次计算+合并 | 一次完整计算 | +| **Block 加载** | 需要从 CPU 加载历史 | 不需要,已在 GPU | +| **Policy 责任** | 完整 attention 流程 | 仅 attention kernel 选择 | + +### Layerwise Offload 的简化点 + +1. **不需要 block selection**: 整层 KV 都在 GPU,无需选择 +2. **不需要 offload_engine 参数**: Policy 不负责加载 KV +3. **不需要 merge_attention_outputs**: 一次计算完整 attention +4. **不需要 offload hooks**: offload 在 model_runner 统一处理 + +### 设计建议 + +1. **保持接口简单**: 只需要 `compute_prefill_attention()` 和 `compute_decode_attention()` +2. **FULL 也实现方法**: 不再通过 `is None` 判断,所有 policy 统一调用 +3. **移除不必要的参数**: 不需要 offload_engine, kvcache_manager, seq 等 +4. **统一命名**: 使用 `compute_*_attention` 而不是 `sparse_prefill_attention` + +## Code Examples + +### 当前调用方式 (model_runner.py:876-891) + +```python +# Sparse or Full attention +if self.sparse_prefill_policy is not None: + # MInference or other sparse prefill policy + attn_output = self.sparse_prefill_policy.sparse_prefill_attention( + q, k, v, layer_id + ) +else: + # Full attention using FlashAttention + attn_output = flash_attn_varlen_func( + q, k, v, ... + ) +``` + +### 建议的新调用方式 + +```python +# 所有 policy 统一调用 +attn_output = self.attention_policy.compute_prefill_attention( + q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale +) +``` + +## Questions Resolved + +- Q: 是否需要 PolicyContext? +- A: 可以简化,因为 layerwise 模式下不需要 chunk 信息 + +- Q: decode 阶段如何处理? +- A: **Decode 不需要 policy**!当前 `run_layerwise_offload_decode()` 使用标准 `layer(positions, hidden_states, residual)` 调用,走 Attention.forward() 路径 + +- Q: 为什么 decode 不需要 sparse? +- A: 因为 decode 每次只有 1 个 token,没有稀疏化的意义。KV 从 ring buffer 加载后直接用 flash_attn_with_kvcache + +## Key Insight + +**Layerwise Offload 的 Policy 设计应该只关注 Prefill**: + +``` +Prefill: 需要 Policy +- 整个序列一次计算 attention +- 可以使用 sparse attention 方法(如 MInference 的 vertical+slash pattern) +- Policy 接收 q, k, v, layer_id, softmax_scale + +Decode: 不需要 Policy +- 每次只有 1 个 token query +- KV 从 ring buffer 加载 +- 使用标准 flash_attn_with_kvcache +``` + +## Interface Comparison Summary + +| 方面 | tzj/minference | tzj/layer-offload (新设计) | +|------|----------------|---------------------------| +| 类名 | SparsePolicy | AttentionPolicy | +| Prefill 方法 | compute_chunked_prefill() | compute_attention() | +| Decode 方法 | compute_chunked_decode() | 不需要(用标准路径) | +| 需要 offload_engine | 是 | 否 | +| 需要 kvcache_manager | 是 | 否 | +| 需要 seq | 是 | 否 | +| 支持 FULL | 是 | 是 | + +## Migration Path + +1. 保留 `SparsePolicy` 作为 `AttentionPolicy` 的别名 +2. 保留 `PolicyContext` 供未来扩展 +3. 保留 `select_blocks()` 方法签名(虽然不使用) +4. 移除 `requires_block_selection` 属性(不需要) diff --git a/task_plan.md b/task_plan.md new file mode 100644 index 0000000..220fc1a --- /dev/null +++ b/task_plan.md @@ -0,0 +1,549 @@ +# Task Plan: Refactor SparsePolicy for Layerwise Offload + +## Goal +重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy,并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。 + +## Background + +### 两种 Offload 架构对比 + +| 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) | +|------|----------------------------------|---------------------------------------| +| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) | +| KV 位置 | 历史 chunks 在 CPU,需要加载 | 整层 KV 都在 GPU | +| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` | +| 需要 offload_engine | 是(加载 blocks) | 否(KV 已在 GPU) | +| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask | + +### tzj/minference 的 Policy 接口 + +```python +class SparsePolicy(ABC): + supports_prefill: bool + supports_decode: bool + + @abstractmethod + def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int] + + @abstractmethod + def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor + + @abstractmethod + def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor +``` + +### 当前 branch 的 Policy 接口(重构前) + +```python +class SparsePolicy(ABC): + supports_prefill: bool + supports_decode: bool + + @abstractmethod + def select_blocks(self, available_blocks, ctx) -> List[int] + + def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor +``` + +## Phases + +- [x] Phase 1: 分析差异并设计新接口 +- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过 +- [ ] Phase 2: 重构 AttentionPolicy 基类 +- [ ] Phase 3: 重构 FullAttentionPolicy +- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法) +- [ ] Phase 5: 更新 model_runner 调用方式 +- [ ] Phase 6: 测试验证 + +--- + +## Phase 0: 创建 nanovllm.ops 模块 + +### 目标 +从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。 + +### 步骤 + +1. **创建目录结构** + ``` + nanovllm/ops/ + ├── __init__.py + ├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels + └── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用) + ``` + +2. **从 tzj/minference 提取文件** + ```bash + git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py + git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py + git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py + ``` + +3. **Cherry-pick 测试文件** + ```bash + git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py + ``` + +4. **运行测试验证** + ```bash + CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \ + python tests/test_xattn_estimate_chunked.py + ``` + +### nanovllm/ops 模块内容 + +| 文件 | 核心函数 | 用途 | +|------|----------|------| +| `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation | +| `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 | +| `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM | +| `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum | +| `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold | +| `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output | +| `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks | + +### 与 Policy 的关系 + +``` +XAttentionPolicy.estimate() + └── 调用 nanovllm.ops.xattn.xattn_estimate() + ├── flat_group_gemm_fuse_reshape() (Triton) + ├── softmax_fuse_block_sum() (Triton) + └── find_blocks_chunked() +``` + +--- + +## Key Questions + +1. **`select_blocks` 改为什么?** + - 改名为 `estimate()`:用于计算 sparse mask + - 对于 XAttention,对应 COMPASS 的 `xattn_estimate()` 函数 + - FullAttentionPolicy 的 `estimate()` 返回 None(表示 full attention) + +2. **Policy 接口应该如何设计?** + - Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)` + - Decode: `compute_decode(q, k, v, layer_id, softmax_scale)` + - Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask + +3. **FULL policy 如何处理?** + - FULL 也实现 `compute_prefill/decode`,使用 FlashAttention + - `estimate()` 返回 None(表示不进行稀疏化) + +## Proposed New Interface + +```python +from abc import ABC, abstractmethod +from typing import Optional +import torch + + +class AttentionPolicy(ABC): + """Layerwise Offload 模式下的 Attention Policy + + 所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。 + 支持 prefill 和 decode 两个阶段。 + """ + + supports_prefill: bool = True + supports_decode: bool = True + + def estimate( + self, + q: torch.Tensor, # [seq_len, num_heads, head_dim] + k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] + layer_id: int, + ) -> Optional[torch.Tensor]: + """ + 估算 sparse attention mask。 + + 对于 sparse policy(如 XAttention),计算哪些 blocks 需要 attend。 + 对于 full policy,返回 None 表示使用完整 attention。 + + 对应 COMPASS 的 xattn_estimate() 函数。 + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + layer_id: Transformer layer index + + Returns: + sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None + """ + return None # 默认为 full attention + + @abstractmethod + def compute_prefill( + self, + q: torch.Tensor, # [seq_len, num_heads, head_dim] + k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] + v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] + layer_id: int, + softmax_scale: float, + ) -> torch.Tensor: + """ + 计算 prefill attention。 + + 整层 KV 都在 GPU 上,一次计算完整 attention。 + 可以先调用 estimate() 获取 sparse mask,然后应用 block sparse attention。 + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + layer_id: Transformer layer index + softmax_scale: Softmax scaling factor (1/sqrt(head_dim)) + + Returns: + Attention output [seq_len, num_heads, head_dim] + """ + pass + + def compute_decode( + self, + q: torch.Tensor, # [1, num_heads, head_dim] + k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim] + v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim] + layer_id: int, + softmax_scale: float, + ) -> torch.Tensor: + """ + 计算 decode attention。 + + KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。 + + Args: + q: Query tensor [1, num_heads, head_dim] + k: Key tensor [context_len+1, num_kv_heads, head_dim] + v: Value tensor [context_len+1, num_kv_heads, head_dim] + layer_id: Transformer layer index + softmax_scale: Softmax scaling factor + + Returns: + Attention output [1, num_heads, head_dim] + """ + # 默认实现:使用 FlashAttention + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + context_len = k.shape[0] + cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device) + + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=1, + max_seqlen_k=context_len, + softmax_scale=softmax_scale, + causal=False, + ) + + def reset(self) -> None: + """Reset policy state between sequences.""" + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +# 保留旧名称作为别名 +SparsePolicy = AttentionPolicy +``` + +## Implementation Plan + +### Phase 2: 重构 policy.py + +```python +# nanovllm/kvcache/sparse/policy.py + +from abc import ABC, abstractmethod +from typing import Optional +import torch + + +class AttentionPolicy(ABC): + """Base class for attention policies in layerwise offload mode.""" + + supports_prefill: bool = True + supports_decode: bool = True + + def estimate( + self, + q: torch.Tensor, + k: torch.Tensor, + layer_id: int, + ) -> Optional[torch.Tensor]: + """ + Estimate sparse attention mask. + + For sparse policies (e.g., XAttention), computes block-level importance. + For full policy, returns None. + + Corresponds to xattn_estimate() in COMPASS. + + Returns: + sparse_mask: [num_heads, q_blocks, k_blocks] or None + """ + return None + + @abstractmethod + def compute_prefill( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + ) -> torch.Tensor: + """Compute prefill attention.""" + pass + + def compute_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + ) -> torch.Tensor: + """Compute decode attention (default: FlashAttention).""" + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + context_len = k.shape[0] + cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device) + + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=1, + max_seqlen_k=context_len, + softmax_scale=softmax_scale, + causal=False, + ) + + def reset(self) -> None: + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +# Backward compatibility alias +SparsePolicy = AttentionPolicy +``` + +### Phase 3: 重构 FullAttentionPolicy + +```python +# nanovllm/kvcache/sparse/full_policy.py + +import torch +from .policy import AttentionPolicy + + +class FullAttentionPolicy(AttentionPolicy): + """Full attention using FlashAttention (no sparsity).""" + + supports_prefill = True + supports_decode = True + + def estimate(self, q, k, layer_id): + """Full attention - no sparse mask needed.""" + return None + + def compute_prefill(self, q, k, v, layer_id, softmax_scale): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + seq_len = q.shape[0] + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) + + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=softmax_scale, + causal=True, + ) + + def __repr__(self): + return "FullAttentionPolicy()" +``` + +### Phase 4: 重构 XAttentionPolicy + +```python +# nanovllm/kvcache/sparse/xattn.py + +import torch +from typing import Optional +from .policy import AttentionPolicy + + +class XAttentionPolicy(AttentionPolicy): + """ + XAttention sparse prefill policy. + + Uses chunked estimation to compute sparse attention mask, + then applies block sparse attention. + """ + + supports_prefill = True + supports_decode = True + + def __init__( + self, + stride: int = 8, + threshold: float = 0.9, + block_size: int = 128, + chunk_size: int = 16384, + use_triton: bool = True, + ): + self.stride = stride + self.threshold = threshold + self.block_size = block_size + self.chunk_size = chunk_size + self.use_triton = use_triton + + def estimate( + self, + q: torch.Tensor, + k: torch.Tensor, + layer_id: int, + ) -> Optional[torch.Tensor]: + """ + XAttention estimation (xattn_estimate). + + Uses chunked GEMM + softmax to estimate block-level importance, + then selects important blocks based on threshold. + + 对应 COMPASS 的 xattn_estimate() 函数: + 1. Pad inputs to chunk_size multiples + 2. Reshape with stride + 3. Compute QK^T in chunks (Triton) + 4. Block-wise softmax + aggregation + 5. Threshold-based selection + + Args: + q: [seq_len, num_heads, head_dim] + k: [seq_len, num_kv_heads, head_dim] + layer_id: transformer layer index + + Returns: + sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask + or None (fallback to full attention) + """ + # TODO: 实现真正的 xattn_estimate + # 当前返回 None 使用 full attention + return None + + def compute_prefill( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + ) -> torch.Tensor: + """ + Compute XAttention sparse prefill. + + Flow: + 1. Call estimate() to get sparse mask + 2. If mask is None, use full attention + 3. Otherwise, apply block sparse attention with mask + """ + # Step 1: Estimate sparse mask + sparse_mask = self.estimate(q, k, layer_id) + + # Step 2: Compute attention + if sparse_mask is None: + # Fallback to full attention + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + seq_len = q.shape[0] + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) + + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=softmax_scale, + causal=True, + ) + else: + # Apply block sparse attention with mask + # 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size) + raise NotImplementedError("Block sparse attention not yet implemented") + + def __repr__(self): + return (f"XAttentionPolicy(" + f"stride={self.stride}, " + f"threshold={self.threshold}, " + f"block_size={self.block_size})") +``` + +### Phase 5: 更新 model_runner.py + +```python +# model_runner.py - allocate_kv_cache() + +# 改为总是创建 policy(包括 FULL) +from nanovllm.kvcache.sparse import create_attention_policy +self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs) +logger.info(f"Attention policy: {self.attention_policy}") + +# run_layerwise_offload_prefill() 和 run_gpu_only_prefill() + +# 旧代码: +if self.sparse_prefill_policy is not None: + attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id) +else: + attn_output = flash_attn_varlen_func(...) + +# 新代码: +attn_output = self.attention_policy.compute_prefill( + q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale +) +``` + +## Method Mapping + +| 旧方法 | 新方法 | 说明 | +|--------|--------|------| +| `select_blocks()` | `estimate()` | 计算 sparse mask(对应 xattn_estimate) | +| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention | +| (无) | `compute_decode()` | Decode attention(默认实现) | +| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 | + +## Files to Modify + +| File | Changes | +|------|---------| +| `nanovllm/kvcache/sparse/policy.py` | 新接口:estimate, compute_prefill, compute_decode | +| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None | +| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() | +| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 | +| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() | +| `nanovllm/config.py` | 可选:重命名配置项 | + +## Decisions Made + +1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格 +2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs +3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()` +4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention +5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现 + +## Errors Encountered +- (无) + +## Status +**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2