Files
nano-vllm/task_plan.md
Zijie Tian 69b779e252 📝 docs: add layer offload planning notes and task plan
Add planning documents for layer-wise offload implementation:
- notes.md: Implementation notes and findings
- task_plan.md: Detailed task breakdown and progress tracking

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 06:04:36 +08:00

17 KiB
Raw Permalink Blame History

Task Plan: Refactor SparsePolicy for Layerwise Offload

Goal

重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。

Background

两种 Offload 架构对比

特性 tzj/minference (Chunked Offload) tzj/layer-offload (Layerwise Offload)
处理粒度 每次一个 chunk (block_size tokens) 每次一整层 (所有 tokens)
KV 位置 历史 chunks 在 CPU需要加载 整层 KV 都在 GPU
Policy 入口 compute_chunked_prefill()/decode() compute_prefill()/decode()
需要 offload_engine 是(加载 blocks KV 已在 GPU
Mask 计算 select_blocks() 返回 block IDs estimate() 返回 sparse mask

tzj/minference 的 Policy 接口

class SparsePolicy(ABC):
    supports_prefill: bool
    supports_decode: bool

    @abstractmethod
    def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]

    @abstractmethod
    def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor

    @abstractmethod
    def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor

当前 branch 的 Policy 接口(重构前)

class SparsePolicy(ABC):
    supports_prefill: bool
    supports_decode: bool

    @abstractmethod
    def select_blocks(self, available_blocks, ctx) -> List[int]

    def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor

Phases

  • Phase 1: 分析差异并设计新接口
  • Phase 0: 创建 nanovllm.ops 模块 测试通过
  • Phase 2: 重构 AttentionPolicy 基类
  • Phase 3: 重构 FullAttentionPolicy
  • Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
  • Phase 5: 更新 model_runner 调用方式
  • Phase 6: 测试验证

Phase 0: 创建 nanovllm.ops 模块

目标

从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。

步骤

  1. 创建目录结构

    nanovllm/ops/
    ├── __init__.py
    ├── xattn.py           # xattn_estimate, xattn_estimate_chunked, Triton kernels
    └── chunked_attention.py  # flash_attn_with_lse, merge_attention_outputs (备用)
    
  2. 从 tzj/minference 提取文件

    git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
    git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
    git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
    
  3. Cherry-pick 测试文件

    git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
    
  4. 运行测试验证

    CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
        python tests/test_xattn_estimate_chunked.py
    

nanovllm/ops 模块内容

文件 核心函数 用途
xattn.py xattn_estimate() 标准 XAttention estimation
xattn.py xattn_estimate_chunked() Chunked prefill 版本
xattn.py flat_group_gemm_fuse_reshape() Triton kernel: fused reshape + GEMM
xattn.py softmax_fuse_block_sum() Triton kernel: softmax + block sum
xattn.py find_blocks_chunked() Block selection based on threshold
chunked_attention.py flash_attn_with_lse() Flash attention with LSE output
chunked_attention.py merge_attention_outputs() Merge multiple attention chunks

与 Policy 的关系

XAttentionPolicy.estimate()
    └── 调用 nanovllm.ops.xattn.xattn_estimate()
            ├── flat_group_gemm_fuse_reshape() (Triton)
            ├── softmax_fuse_block_sum() (Triton)
            └── find_blocks_chunked()

Key Questions

  1. select_blocks 改为什么?

    • 改名为 estimate():用于计算 sparse mask
    • 对于 XAttention对应 COMPASS 的 xattn_estimate() 函数
    • FullAttentionPolicy 的 estimate() 返回 None表示 full attention
  2. Policy 接口应该如何设计?

    • Prefill: compute_prefill(q, k, v, layer_id, softmax_scale)
    • Decode: compute_decode(q, k, v, layer_id, softmax_scale)
    • Estimate: estimate(q, k, layer_id) - 计算 sparse mask
  3. FULL policy 如何处理?

    • FULL 也实现 compute_prefill/decode,使用 FlashAttention
    • estimate() 返回 None表示不进行稀疏化

Proposed New Interface

from abc import ABC, abstractmethod
from typing import Optional
import torch


class AttentionPolicy(ABC):
    """Layerwise Offload 模式下的 Attention Policy

    所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
    支持 prefill 和 decode 两个阶段。
    """

    supports_prefill: bool = True
    supports_decode: bool = True

    def estimate(
        self,
        q: torch.Tensor,      # [seq_len, num_heads, head_dim]
        k: torch.Tensor,      # [seq_len, num_kv_heads, head_dim]
        layer_id: int,
    ) -> Optional[torch.Tensor]:
        """
        估算 sparse attention mask。

        对于 sparse policy如 XAttention计算哪些 blocks 需要 attend。
        对于 full policy返回 None 表示使用完整 attention。

        对应 COMPASS 的 xattn_estimate() 函数。

        Args:
            q: Query tensor [seq_len, num_heads, head_dim]
            k: Key tensor [seq_len, num_kv_heads, head_dim]
            layer_id: Transformer layer index

        Returns:
            sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
        """
        return None  # 默认为 full attention

    @abstractmethod
    def compute_prefill(
        self,
        q: torch.Tensor,      # [seq_len, num_heads, head_dim]
        k: torch.Tensor,      # [seq_len, num_kv_heads, head_dim]
        v: torch.Tensor,      # [seq_len, num_kv_heads, head_dim]
        layer_id: int,
        softmax_scale: float,
    ) -> torch.Tensor:
        """
        计算 prefill attention。

        整层 KV 都在 GPU 上,一次计算完整 attention。
        可以先调用 estimate() 获取 sparse mask然后应用 block sparse attention。

        Args:
            q: Query tensor [seq_len, num_heads, head_dim]
            k: Key tensor [seq_len, num_kv_heads, head_dim]
            v: Value tensor [seq_len, num_kv_heads, head_dim]
            layer_id: Transformer layer index
            softmax_scale: Softmax scaling factor (1/sqrt(head_dim))

        Returns:
            Attention output [seq_len, num_heads, head_dim]
        """
        pass

    def compute_decode(
        self,
        q: torch.Tensor,      # [1, num_heads, head_dim]
        k: torch.Tensor,      # [context_len+1, num_kv_heads, head_dim]
        v: torch.Tensor,      # [context_len+1, num_kv_heads, head_dim]
        layer_id: int,
        softmax_scale: float,
    ) -> torch.Tensor:
        """
        计算 decode attention。

        KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。

        Args:
            q: Query tensor [1, num_heads, head_dim]
            k: Key tensor [context_len+1, num_kv_heads, head_dim]
            v: Value tensor [context_len+1, num_kv_heads, head_dim]
            layer_id: Transformer layer index
            softmax_scale: Softmax scaling factor

        Returns:
            Attention output [1, num_heads, head_dim]
        """
        # 默认实现:使用 FlashAttention
        from flash_attn.flash_attn_interface import flash_attn_varlen_func

        context_len = k.shape[0]
        cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
        cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)

        return flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=1,
            max_seqlen_k=context_len,
            softmax_scale=softmax_scale,
            causal=False,
        )

    def reset(self) -> None:
        """Reset policy state between sequences."""
        pass

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"


# 保留旧名称作为别名
SparsePolicy = AttentionPolicy

Implementation Plan

Phase 2: 重构 policy.py

# nanovllm/kvcache/sparse/policy.py

from abc import ABC, abstractmethod
from typing import Optional
import torch


class AttentionPolicy(ABC):
    """Base class for attention policies in layerwise offload mode."""

    supports_prefill: bool = True
    supports_decode: bool = True

    def estimate(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        layer_id: int,
    ) -> Optional[torch.Tensor]:
        """
        Estimate sparse attention mask.

        For sparse policies (e.g., XAttention), computes block-level importance.
        For full policy, returns None.

        Corresponds to xattn_estimate() in COMPASS.

        Returns:
            sparse_mask: [num_heads, q_blocks, k_blocks] or None
        """
        return None

    @abstractmethod
    def compute_prefill(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        layer_id: int,
        softmax_scale: float,
    ) -> torch.Tensor:
        """Compute prefill attention."""
        pass

    def compute_decode(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        layer_id: int,
        softmax_scale: float,
    ) -> torch.Tensor:
        """Compute decode attention (default: FlashAttention)."""
        from flash_attn.flash_attn_interface import flash_attn_varlen_func

        context_len = k.shape[0]
        cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
        cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)

        return flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=1,
            max_seqlen_k=context_len,
            softmax_scale=softmax_scale,
            causal=False,
        )

    def reset(self) -> None:
        pass

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"


# Backward compatibility alias
SparsePolicy = AttentionPolicy

Phase 3: 重构 FullAttentionPolicy

# nanovllm/kvcache/sparse/full_policy.py

import torch
from .policy import AttentionPolicy


class FullAttentionPolicy(AttentionPolicy):
    """Full attention using FlashAttention (no sparsity)."""

    supports_prefill = True
    supports_decode = True

    def estimate(self, q, k, layer_id):
        """Full attention - no sparse mask needed."""
        return None

    def compute_prefill(self, q, k, v, layer_id, softmax_scale):
        from flash_attn.flash_attn_interface import flash_attn_varlen_func

        seq_len = q.shape[0]
        cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)

        return flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=seq_len,
            max_seqlen_k=seq_len,
            softmax_scale=softmax_scale,
            causal=True,
        )

    def __repr__(self):
        return "FullAttentionPolicy()"

Phase 4: 重构 XAttentionPolicy

# nanovllm/kvcache/sparse/xattn.py

import torch
from typing import Optional
from .policy import AttentionPolicy


class XAttentionPolicy(AttentionPolicy):
    """
    XAttention sparse prefill policy.

    Uses chunked estimation to compute sparse attention mask,
    then applies block sparse attention.
    """

    supports_prefill = True
    supports_decode = True

    def __init__(
        self,
        stride: int = 8,
        threshold: float = 0.9,
        block_size: int = 128,
        chunk_size: int = 16384,
        use_triton: bool = True,
    ):
        self.stride = stride
        self.threshold = threshold
        self.block_size = block_size
        self.chunk_size = chunk_size
        self.use_triton = use_triton

    def estimate(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        layer_id: int,
    ) -> Optional[torch.Tensor]:
        """
        XAttention estimation (xattn_estimate).

        Uses chunked GEMM + softmax to estimate block-level importance,
        then selects important blocks based on threshold.

        对应 COMPASS 的 xattn_estimate() 函数:
        1. Pad inputs to chunk_size multiples
        2. Reshape with stride
        3. Compute QK^T in chunks (Triton)
        4. Block-wise softmax + aggregation
        5. Threshold-based selection

        Args:
            q: [seq_len, num_heads, head_dim]
            k: [seq_len, num_kv_heads, head_dim]
            layer_id: transformer layer index

        Returns:
            sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
                        or None (fallback to full attention)
        """
        # TODO: 实现真正的 xattn_estimate
        # 当前返回 None 使用 full attention
        return None

    def compute_prefill(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        layer_id: int,
        softmax_scale: float,
    ) -> torch.Tensor:
        """
        Compute XAttention sparse prefill.

        Flow:
        1. Call estimate() to get sparse mask
        2. If mask is None, use full attention
        3. Otherwise, apply block sparse attention with mask
        """
        # Step 1: Estimate sparse mask
        sparse_mask = self.estimate(q, k, layer_id)

        # Step 2: Compute attention
        if sparse_mask is None:
            # Fallback to full attention
            from flash_attn.flash_attn_interface import flash_attn_varlen_func

            seq_len = q.shape[0]
            cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)

            return flash_attn_varlen_func(
                q, k, v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=seq_len,
                max_seqlen_k=seq_len,
                softmax_scale=softmax_scale,
                causal=True,
            )
        else:
            # Apply block sparse attention with mask
            # 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
            raise NotImplementedError("Block sparse attention not yet implemented")

    def __repr__(self):
        return (f"XAttentionPolicy("
                f"stride={self.stride}, "
                f"threshold={self.threshold}, "
                f"block_size={self.block_size})")

Phase 5: 更新 model_runner.py

# model_runner.py - allocate_kv_cache()

# 改为总是创建 policy包括 FULL
from nanovllm.kvcache.sparse import create_attention_policy
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")

# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()

# 旧代码:
if self.sparse_prefill_policy is not None:
    attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
else:
    attn_output = flash_attn_varlen_func(...)

# 新代码:
attn_output = self.attention_policy.compute_prefill(
    q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
)

Method Mapping

旧方法 新方法 说明
select_blocks() estimate() 计算 sparse mask对应 xattn_estimate
sparse_prefill_attention() compute_prefill() Prefill attention
(无) compute_decode() Decode attention默认实现
on_prefill_offload() (移除) Offload 在 model_runner 处理

Files to Modify

File Changes
nanovllm/kvcache/sparse/policy.py 新接口estimate, compute_prefill, compute_decode
nanovllm/kvcache/sparse/full_policy.py 实现 compute_prefill(), estimate() 返回 None
nanovllm/kvcache/sparse/xattn.py estimate() 对应 xattn_estimate, compute_prefill()
nanovllm/kvcache/sparse/__init__.py 更新工厂函数
nanovllm/engine/model_runner.py 统一调用 attention_policy.compute_prefill()
nanovllm/config.py 可选:重命名配置项

Decisions Made

  1. 方法命名: compute_prefill / compute_decode 对应 chunked 版本的命名风格
  2. estimate 方法: 替代 select_blocks,返回 sparse mask 而不是 block IDs
  3. XAttention: estimate() 对应 COMPASS 的 xattn_estimate()
  4. Full Policy: estimate() 返回 None 表示使用完整 attention
  5. Decode 默认实现: 基类提供默认的 FlashAttention 实现

Errors Encountered

  • (无)

Status

Currently in Phase 1 - 完成分析和接口设计,等待用户确认后进入 Phase 2