# Task Plan: Refactor SparsePolicy for Layerwise Offload ## Goal 重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy,并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。 ## Background ### 两种 Offload 架构对比 | 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) | |------|----------------------------------|---------------------------------------| | 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) | | KV 位置 | 历史 chunks 在 CPU,需要加载 | 整层 KV 都在 GPU | | Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` | | 需要 offload_engine | 是(加载 blocks) | 否(KV 已在 GPU) | | Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask | ### tzj/minference 的 Policy 接口 ```python class SparsePolicy(ABC): supports_prefill: bool supports_decode: bool @abstractmethod def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int] @abstractmethod def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor @abstractmethod def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor ``` ### 当前 branch 的 Policy 接口(重构前) ```python class SparsePolicy(ABC): supports_prefill: bool supports_decode: bool @abstractmethod def select_blocks(self, available_blocks, ctx) -> List[int] def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor ``` ## Phases - [x] Phase 1: 分析差异并设计新接口 - [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过 - [ ] Phase 2: 重构 AttentionPolicy 基类 - [ ] Phase 3: 重构 FullAttentionPolicy - [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法) - [ ] Phase 5: 更新 model_runner 调用方式 - [ ] Phase 6: 测试验证 --- ## Phase 0: 创建 nanovllm.ops 模块 ### 目标 从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。 ### 步骤 1. **创建目录结构** ``` nanovllm/ops/ ├── __init__.py ├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels └── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用) ``` 2. **从 tzj/minference 提取文件** ```bash git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py ``` 3. **Cherry-pick 测试文件** ```bash git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py ``` 4. **运行测试验证** ```bash CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \ python tests/test_xattn_estimate_chunked.py ``` ### nanovllm/ops 模块内容 | 文件 | 核心函数 | 用途 | |------|----------|------| | `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation | | `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 | | `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM | | `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum | | `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold | | `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output | | `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks | ### 与 Policy 的关系 ``` XAttentionPolicy.estimate() └── 调用 nanovllm.ops.xattn.xattn_estimate() ├── flat_group_gemm_fuse_reshape() (Triton) ├── softmax_fuse_block_sum() (Triton) └── find_blocks_chunked() ``` --- ## Key Questions 1. **`select_blocks` 改为什么?** - 改名为 `estimate()`:用于计算 sparse mask - 对于 XAttention,对应 COMPASS 的 `xattn_estimate()` 函数 - FullAttentionPolicy 的 `estimate()` 返回 None(表示 full attention) 2. **Policy 接口应该如何设计?** - Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)` - Decode: `compute_decode(q, k, v, layer_id, softmax_scale)` - Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask 3. **FULL policy 如何处理?** - FULL 也实现 `compute_prefill/decode`,使用 FlashAttention - `estimate()` 返回 None(表示不进行稀疏化) ## Proposed New Interface ```python from abc import ABC, abstractmethod from typing import Optional import torch class AttentionPolicy(ABC): """Layerwise Offload 模式下的 Attention Policy 所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。 支持 prefill 和 decode 两个阶段。 """ supports_prefill: bool = True supports_decode: bool = True def estimate( self, q: torch.Tensor, # [seq_len, num_heads, head_dim] k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] layer_id: int, ) -> Optional[torch.Tensor]: """ 估算 sparse attention mask。 对于 sparse policy(如 XAttention),计算哪些 blocks 需要 attend。 对于 full policy,返回 None 表示使用完整 attention。 对应 COMPASS 的 xattn_estimate() 函数。 Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] layer_id: Transformer layer index Returns: sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None """ return None # 默认为 full attention @abstractmethod def compute_prefill( self, q: torch.Tensor, # [seq_len, num_heads, head_dim] k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] layer_id: int, softmax_scale: float, ) -> torch.Tensor: """ 计算 prefill attention。 整层 KV 都在 GPU 上,一次计算完整 attention。 可以先调用 estimate() 获取 sparse mask,然后应用 block sparse attention。 Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim] layer_id: Transformer layer index softmax_scale: Softmax scaling factor (1/sqrt(head_dim)) Returns: Attention output [seq_len, num_heads, head_dim] """ pass def compute_decode( self, q: torch.Tensor, # [1, num_heads, head_dim] k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim] v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim] layer_id: int, softmax_scale: float, ) -> torch.Tensor: """ 计算 decode attention。 KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。 Args: q: Query tensor [1, num_heads, head_dim] k: Key tensor [context_len+1, num_kv_heads, head_dim] v: Value tensor [context_len+1, num_kv_heads, head_dim] layer_id: Transformer layer index softmax_scale: Softmax scaling factor Returns: Attention output [1, num_heads, head_dim] """ # 默认实现:使用 FlashAttention from flash_attn.flash_attn_interface import flash_attn_varlen_func context_len = k.shape[0] cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device) cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device) return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=1, max_seqlen_k=context_len, softmax_scale=softmax_scale, causal=False, ) def reset(self) -> None: """Reset policy state between sequences.""" pass def __repr__(self) -> str: return f"{self.__class__.__name__}()" # 保留旧名称作为别名 SparsePolicy = AttentionPolicy ``` ## Implementation Plan ### Phase 2: 重构 policy.py ```python # nanovllm/kvcache/sparse/policy.py from abc import ABC, abstractmethod from typing import Optional import torch class AttentionPolicy(ABC): """Base class for attention policies in layerwise offload mode.""" supports_prefill: bool = True supports_decode: bool = True def estimate( self, q: torch.Tensor, k: torch.Tensor, layer_id: int, ) -> Optional[torch.Tensor]: """ Estimate sparse attention mask. For sparse policies (e.g., XAttention), computes block-level importance. For full policy, returns None. Corresponds to xattn_estimate() in COMPASS. Returns: sparse_mask: [num_heads, q_blocks, k_blocks] or None """ return None @abstractmethod def compute_prefill( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, ) -> torch.Tensor: """Compute prefill attention.""" pass def compute_decode( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, ) -> torch.Tensor: """Compute decode attention (default: FlashAttention).""" from flash_attn.flash_attn_interface import flash_attn_varlen_func context_len = k.shape[0] cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device) cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device) return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=1, max_seqlen_k=context_len, softmax_scale=softmax_scale, causal=False, ) def reset(self) -> None: pass def __repr__(self) -> str: return f"{self.__class__.__name__}()" # Backward compatibility alias SparsePolicy = AttentionPolicy ``` ### Phase 3: 重构 FullAttentionPolicy ```python # nanovllm/kvcache/sparse/full_policy.py import torch from .policy import AttentionPolicy class FullAttentionPolicy(AttentionPolicy): """Full attention using FlashAttention (no sparsity).""" supports_prefill = True supports_decode = True def estimate(self, q, k, layer_id): """Full attention - no sparse mask needed.""" return None def compute_prefill(self, q, k, v, layer_id, softmax_scale): from flash_attn.flash_attn_interface import flash_attn_varlen_func seq_len = q.shape[0] cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, softmax_scale=softmax_scale, causal=True, ) def __repr__(self): return "FullAttentionPolicy()" ``` ### Phase 4: 重构 XAttentionPolicy ```python # nanovllm/kvcache/sparse/xattn.py import torch from typing import Optional from .policy import AttentionPolicy class XAttentionPolicy(AttentionPolicy): """ XAttention sparse prefill policy. Uses chunked estimation to compute sparse attention mask, then applies block sparse attention. """ supports_prefill = True supports_decode = True def __init__( self, stride: int = 8, threshold: float = 0.9, block_size: int = 128, chunk_size: int = 16384, use_triton: bool = True, ): self.stride = stride self.threshold = threshold self.block_size = block_size self.chunk_size = chunk_size self.use_triton = use_triton def estimate( self, q: torch.Tensor, k: torch.Tensor, layer_id: int, ) -> Optional[torch.Tensor]: """ XAttention estimation (xattn_estimate). Uses chunked GEMM + softmax to estimate block-level importance, then selects important blocks based on threshold. 对应 COMPASS 的 xattn_estimate() 函数: 1. Pad inputs to chunk_size multiples 2. Reshape with stride 3. Compute QK^T in chunks (Triton) 4. Block-wise softmax + aggregation 5. Threshold-based selection Args: q: [seq_len, num_heads, head_dim] k: [seq_len, num_kv_heads, head_dim] layer_id: transformer layer index Returns: sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask or None (fallback to full attention) """ # TODO: 实现真正的 xattn_estimate # 当前返回 None 使用 full attention return None def compute_prefill( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, ) -> torch.Tensor: """ Compute XAttention sparse prefill. Flow: 1. Call estimate() to get sparse mask 2. If mask is None, use full attention 3. Otherwise, apply block sparse attention with mask """ # Step 1: Estimate sparse mask sparse_mask = self.estimate(q, k, layer_id) # Step 2: Compute attention if sparse_mask is None: # Fallback to full attention from flash_attn.flash_attn_interface import flash_attn_varlen_func seq_len = q.shape[0] cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, softmax_scale=softmax_scale, causal=True, ) else: # Apply block sparse attention with mask # 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size) raise NotImplementedError("Block sparse attention not yet implemented") def __repr__(self): return (f"XAttentionPolicy(" f"stride={self.stride}, " f"threshold={self.threshold}, " f"block_size={self.block_size})") ``` ### Phase 5: 更新 model_runner.py ```python # model_runner.py - allocate_kv_cache() # 改为总是创建 policy(包括 FULL) from nanovllm.kvcache.sparse import create_attention_policy self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs) logger.info(f"Attention policy: {self.attention_policy}") # run_layerwise_offload_prefill() 和 run_gpu_only_prefill() # 旧代码: if self.sparse_prefill_policy is not None: attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id) else: attn_output = flash_attn_varlen_func(...) # 新代码: attn_output = self.attention_policy.compute_prefill( q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale ) ``` ## Method Mapping | 旧方法 | 新方法 | 说明 | |--------|--------|------| | `select_blocks()` | `estimate()` | 计算 sparse mask(对应 xattn_estimate) | | `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention | | (无) | `compute_decode()` | Decode attention(默认实现) | | `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 | ## Files to Modify | File | Changes | |------|---------| | `nanovllm/kvcache/sparse/policy.py` | 新接口:estimate, compute_prefill, compute_decode | | `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None | | `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() | | `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 | | `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() | | `nanovllm/config.py` | 可选:重命名配置项 | ## Decisions Made 1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格 2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs 3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()` 4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention 5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现 ## Errors Encountered - (无) ## Status **Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2