From baa4be7e2e65ca6a21721e4906050ccbb5ed755d Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Tue, 20 Jan 2026 00:58:46 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20migrate=20chun?= =?UTF-8?q?ked=20prefill=20attention=20to=20SparsePolicy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move all chunked prefill attention computation from attention.py to SparsePolicy.compute_chunked_attention(). This is the v4 architecture refactoring for sparse attention policies. Changes: - Add compute_chunked_attention abstract method to SparsePolicy base - Add offload_engine parameter to select_blocks for policies needing KV access during block selection - Implement compute_chunked_attention in FullAttentionPolicy with complete ring buffer pipeline logic - Simplify attention.py to delegate all chunked prefill to policy - Remove redundant _sync_load_previous_chunks and _ring_buffer_pipeline_load methods from Attention class Test: test_needle.py --enable-offload PASSED Co-Authored-By: Claude Opus 4.5 --- nanovllm/kvcache/sparse/full_policy.py | 59 +++-- nanovllm/kvcache/sparse/policy.py | 52 ++++- nanovllm/layers/attention.py | 312 +++---------------------- test_report_sparse_policy_refactor.md | 114 +++++++++ 4 files changed, 240 insertions(+), 297 deletions(-) create mode 100644 test_report_sparse_policy_refactor.md diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index 8dd8b42..4df7b63 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -5,12 +5,20 @@ This serves as a baseline and default policy when sparse attention is not needed. """ +import logging import torch -from typing import List, Optional +from typing import List, Optional, TYPE_CHECKING from .policy import SparsePolicy, PolicyContext from nanovllm.utils.context import get_context +if TYPE_CHECKING: + from nanovllm.kvcache.offload_engine import OffloadEngine + from nanovllm.kvcache.manager import KVCacheManager + from nanovllm.engine.sequence import Sequence + +logger = logging.getLogger(__name__) + class FullAttentionPolicy(SparsePolicy): """ @@ -32,30 +40,34 @@ class FullAttentionPolicy(SparsePolicy): def select_blocks( self, available_blocks: List[int], + offload_engine: "OffloadEngine", ctx: PolicyContext, ) -> List[int]: """Return all blocks - no sparsity.""" return available_blocks - def compute_prefill_attention( + def compute_chunked_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, - offload_engine, + offload_engine: "OffloadEngine", + kvcache_manager: "KVCacheManager", current_chunk_idx: int, - seq, + seq: "Sequence", + num_tokens: int, ) -> torch.Tensor: """ Compute full attention for chunked prefill. This method handles the complete chunked prefill flow: - 1. Load historical blocks from CPU - 2. Compute attention to historical chunks - 3. Compute attention to current chunk - 4. Merge all results + 1. Get historical blocks + 2. Select blocks via select_blocks + 3. Load and compute attention to historical chunks + 4. Compute attention to current chunk + 5. Merge all results Args: q: Query tensor [seq_len, num_heads, head_dim] @@ -64,22 +76,41 @@ class FullAttentionPolicy(SparsePolicy): layer_id: Current layer index softmax_scale: Softmax scaling factor offload_engine: OffloadEngine for loading blocks + kvcache_manager: KVCacheManager for block management current_chunk_idx: Current chunk index - seq: ChunkedSequence + seq: Sequence object + num_tokens: Number of tokens in current chunk Returns: Attention output [seq_len, num_heads, head_dim] """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, " + f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}") + q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] - num_tokens = q.shape[0] o_acc = None lse_acc = None compute_stream = offload_engine.compute_stream - # Step 1: Get and load historical blocks - cpu_block_table = seq.kvcache_manager.get_prefilled_cpu_blocks(seq) + # Step 1: Get historical blocks + cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) + + # Step 2: Apply select_blocks to filter blocks + if cpu_block_table: + num_chunks = current_chunk_idx + 1 + policy_ctx = PolicyContext( + query_chunk_idx=current_chunk_idx, + num_query_chunks=num_chunks, + layer_id=layer_id, + query=None, # Prefill typically doesn't use query for selection + is_prefill=True, + block_size=kvcache_manager.block_size, + total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, + ) + cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx) + logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks") if cpu_block_table: load_slots = list(range(offload_engine.num_ring_slots)) @@ -139,7 +170,7 @@ class FullAttentionPolicy(SparsePolicy): next_cpu_block_id = cpu_block_table[next_block_idx] offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id) - # Step 2: Compute attention to current chunk (causal mask) + # Step 4: Compute attention to current chunk (causal mask) with torch.cuda.stream(compute_stream): k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) current_o, current_lse = flash_attn_with_lse( @@ -148,7 +179,7 @@ class FullAttentionPolicy(SparsePolicy): causal=True, ) - # Step 3: Merge historical and current attention + # Step 5: Merge historical and current attention with torch.cuda.stream(compute_stream): if o_acc is None: final_o = current_o diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index bbb0809..7cb9dcd 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -7,12 +7,17 @@ from CPU for each query chunk during chunked attention computation. from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional, Any +from typing import List, Optional, Any, TYPE_CHECKING import torch # Import SparsePolicyType from config to avoid circular imports from nanovllm.config import SparsePolicyType +if TYPE_CHECKING: + from nanovllm.kvcache.offload_engine import OffloadEngine + from nanovllm.kvcache.manager import KVCacheManager + from nanovllm.engine.sequence import Sequence + @dataclass class PolicyContext: @@ -107,6 +112,7 @@ class SparsePolicy(ABC): def select_blocks( self, available_blocks: List[int], + offload_engine: "OffloadEngine", ctx: PolicyContext, ) -> List[int]: """ @@ -120,6 +126,8 @@ class SparsePolicy(ABC): available_blocks: List of CPU block IDs that contain KV cache from previous chunks. These are ordered by their position in the sequence. + offload_engine: OffloadEngine for loading KV (some policies need + to load KV to make selection decisions). ctx: PolicyContext with information about the current query chunk, layer, phase (prefill/decode), etc. @@ -183,5 +191,47 @@ class SparsePolicy(ABC): """ pass + @abstractmethod + def compute_chunked_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + offload_engine: "OffloadEngine", + kvcache_manager: "KVCacheManager", + current_chunk_idx: int, + seq: "Sequence", + num_tokens: int, + ) -> torch.Tensor: + """ + Compute chunked prefill attention (complete flow). + + This is the main entry point for prefill attention computation. + It defines the complete prefill flow: + 1. Get historical blocks + 2. Select blocks (call select_blocks) + 3. Load and compute historical blocks via offload_engine + 4. Get current chunk KV from offload_engine, compute attention + 5. Merge all results + + Args: + q: [seq_len, num_heads, head_dim] query for current chunk + k: [seq_len, num_kv_heads, head_dim] key for current chunk (in prefill buffer) + v: [seq_len, num_kv_heads, head_dim] value for current chunk (in prefill buffer) + layer_id: transformer layer index + softmax_scale: softmax scaling factor + offload_engine: OffloadEngine for loading blocks + kvcache_manager: KVCacheManager for block management + current_chunk_idx: current chunk index + seq: Sequence object + num_tokens: number of tokens in current chunk + + Returns: + [seq_len, num_heads, head_dim] final attention output + """ + pass + def __repr__(self) -> str: return f"{self.__class__.__name__}()" diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index d403c73..e13456b 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -174,123 +174,45 @@ class Attention(nn.Module): """ Compute attention with per-layer prefill buffer for async offload. - Optimized design: - - Current chunk's KV is written to per-layer prefill buffer (not GPU slot) - - Previous chunks' KV are loaded from CPU using GPU slots - - Each layer offloads from its own buffer - no waiting required! + Simplified design: + - All computation logic is delegated to sparse_policy.compute_chunked_attention() + - This method only handles async offload after computation - For each layer: - 1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model) - 2. Load previous chunks from CPU using available slots (pipeline) - 3. Compute attention against previous KV (no causal mask) - 4. Compute attention against current KV from prefill buffer (causal) - 5. Merge all results using online softmax - 6. Async offload prefill buffer to CPU (no waiting!) + The policy handles: + 1. Loading historical blocks from CPU + 2. Computing attention against historical KV (no causal mask) + 3. Computing attention against current KV from prefill buffer (causal) + 4. Merging all results """ - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs - current_chunk_idx = context.current_chunk_idx torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}") - # q shape: [total_tokens, num_heads, head_dim] - q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim] num_tokens = k.shape[0] - o_acc = None - lse_acc = None - kvcache_manager = context.kvcache_manager seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None - if kvcache_manager is not None and seq is not None and self.layer_id >= 0: - # Get prefilled CPU blocks (blocks from previous chunks) - cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) + # Get sparse policy - required for chunked prefill + sparse_policy = kvcache_manager.sparse_policy + if sparse_policy is None: + raise RuntimeError("sparse_policy is required for chunked prefill") - # Apply sparse policy if enabled - sparse_policy = kvcache_manager.sparse_policy + # [DEBUG] Verify execution path + logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_attention, " + f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}") - # === All sparse policies use select_blocks interface === - if cpu_block_table and sparse_policy is not None: - num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1) - policy_ctx = PolicyContext( - query_chunk_idx=current_chunk_idx, - num_query_chunks=num_chunks, - layer_id=self.layer_id, - query=None, # Prefill typically doesn't use query for selection - is_prefill=True, - block_size=kvcache_manager.block_size, - total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, - ) - cpu_block_table = sparse_policy.select_blocks( - cpu_block_table, policy_ctx - ) - - if cpu_block_table: - # Get available load slots (all slots can be used since we use prefill buffer) - load_slots = list(range(offload_engine.num_ring_slots)) - pipeline_depth = len(load_slots) - - if pipeline_depth == 0: - # Only 1 slot total, cannot pipeline - use sync loading - o_acc, lse_acc = self._sync_load_previous_chunks( - q_batched, cpu_block_table, offload_engine - ) - else: - # Use ring buffer pipeline - o_acc, lse_acc = self._ring_buffer_pipeline_load( - q_batched, cpu_block_table, load_slots, offload_engine, - current_chunk_idx - ) - - # Get compute stream for all attention operations - compute_stream = offload_engine.compute_stream if offload_engine is not None else None - - # Compute attention against current chunk's KV from prefill buffer (with causal mask) - needs_current_chunk_attention = True - - if needs_current_chunk_attention: - if compute_stream is not None: - with torch.cuda.stream(compute_stream): - torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") - # Get KV from per-layer prefill buffer - k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens) - current_o, current_lse = flash_attn_with_lse( - q_batched, - k_batched, - v_batched, - softmax_scale=self.scale, - causal=True, - ) - torch.cuda.nvtx.range_pop() - else: - torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") - k_batched = k.unsqueeze(0) - v_batched = v.unsqueeze(0) - current_o, current_lse = flash_attn_with_lse( - q_batched, - k_batched, - v_batched, - softmax_scale=self.scale, - causal=True, - ) - torch.cuda.nvtx.range_pop() - - # Merge with accumulated (all on compute_stream for consistency) - if o_acc is None: - # No accumulated attention (no historical chunks processed) - final_o = current_o - else: - # Has accumulated attention (historical chunks processed) - if compute_stream is not None: - with torch.cuda.stream(compute_stream): - torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") - final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) - torch.cuda.nvtx.range_pop() - else: - torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") - final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) - torch.cuda.nvtx.range_pop() + # Delegate all computation to policy (no flash_attn or merge calls here!) + final_o = sparse_policy.compute_chunked_attention( + q, k, v, + self.layer_id, + self.scale, + offload_engine, + kvcache_manager, + current_chunk_idx, + seq, + num_tokens, + ) torch.cuda.nvtx.range_pop() # ChunkedPrefill @@ -305,181 +227,7 @@ class Attention(nn.Module): self.layer_id, cpu_block_id, num_tokens ) - # Sync default stream with compute_stream before returning - # This ensures the result is ready for the rest of the model (layernorm, MLP) - if compute_stream is not None: - torch.cuda.default_stream().wait_stream(compute_stream) - - # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] - return final_o.squeeze(0) - - def _sync_load_previous_chunks( - self, - q_batched: torch.Tensor, - cpu_block_table: list, - offload_engine, - ): - """Synchronous loading fallback when pipeline_depth=0.""" - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs - - o_acc, lse_acc = None, None - compute_stream = offload_engine.compute_stream - - for block_idx, cpu_block_id in enumerate(cpu_block_table): - # Load to slot 0 (single slot) - offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id) - offload_engine.wait_slot_layer(0) - - # IMPORTANT: Must use compute_stream to match wait_slot_layer - with torch.cuda.stream(compute_stream): - prev_k, prev_v = offload_engine.get_kv_for_slot(0) - - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=self.scale, - causal=False, - ) - - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - - return o_acc, lse_acc - - def _ring_buffer_pipeline_load( - self, - q_batched: torch.Tensor, - cpu_block_table: list, - load_slots: list, - offload_engine, - current_chunk_idx: int = -1, - ): - """ - Ring buffer async pipeline loading with double buffering. - - Uses compute_done events to ensure safe buffer reuse: - - Before loading to slot X, wait for previous compute on slot X to finish - - Before computing on slot X, wait for load to slot X to finish - - Timeline with 2 slots (A, B): - ┌──────────────┐ - │ Load B0→A │ - └──────────────┘ - ┌──────────────┐ ┌──────────────┐ - │ Load B1→B │ │ Load B2→A │ ... - └──────────────┘ └──────────────┘ - ↘ ↘ - ┌──────────────┐ ┌──────────────┐ - │ Compute(A) │ │ Compute(B) │ ... - └──────────────┘ └──────────────┘ - - The load_to_slot_layer internally waits for compute_done[slot] before - starting the transfer, ensuring no data race. - """ - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs - - num_blocks = len(cpu_block_table) - if num_blocks == 0: - return None, None - - pipeline_depth = len(load_slots) - if pipeline_depth == 0: - return None, None - - o_acc, lse_acc = None, None - - if pipeline_depth == 1: - # Only 1 slot available, cannot pipeline - use synchronous mode - # IMPORTANT: Must use compute_stream to match synchronization in - # load_to_slot_layer (waits for compute_done) and wait_slot_layer - slot = load_slots[0] - compute_stream = offload_engine.compute_stream - for block_idx in range(num_blocks): - cpu_block_id = cpu_block_table[block_idx] - offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id) - offload_engine.wait_slot_layer(slot) - - with torch.cuda.stream(compute_stream): - # Debug: call hooks on compute_stream (synchronized with transfer) - if offload_engine.debug_mode: - offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id) - - prev_k, prev_v = offload_engine.get_kv_for_slot(slot) - - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=self.scale, - causal=False, - ) - # Record compute done so next load can safely reuse this slot - offload_engine.record_slot_compute_done(slot) - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - return o_acc, lse_acc - - # N-way pipeline: use ALL available slots for maximum overlap - # Pipeline depth = num_slots - 1 (num_slots blocks in flight) - num_slots = len(load_slots) - - # Phase 1: Pre-load up to num_slots blocks to fill the pipeline - # This starts all transfers in parallel, utilizing full PCIe bandwidth - num_preload = min(num_slots, num_blocks) - for i in range(num_preload): - offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i]) - - # Phase 2: Main loop - compute and immediately reuse slot for next transfer - # Use dedicated compute_stream (not default stream) to enable overlap with transfers - compute_stream = offload_engine.compute_stream - - for block_idx in range(num_blocks): - torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}") - - # Cycle through slots: slot[block_idx % num_slots] - current_slot = load_slots[block_idx % num_slots] - cpu_block_id = cpu_block_table[block_idx] - - # Wait for current slot's transfer to complete (on compute_stream) - offload_engine.wait_slot_layer(current_slot) - - # Compute attention on current slot's data - # IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream - with torch.cuda.stream(compute_stream): - # Debug: call hooks on compute_stream (synchronized with transfer) - if offload_engine.debug_mode: - offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id) - - torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}") - prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) - - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=self.scale, - causal=False, - ) - torch.cuda.nvtx.range_pop() - - # Record compute done - this allows the next transfer to safely overwrite this slot - offload_engine.record_slot_compute_done(current_slot) - - # Immediately start loading the NEXT block into this slot (if more blocks remain) - # Key insight: reuse current_slot immediately after compute is done! - next_block_idx = block_idx + num_slots - if next_block_idx < num_blocks: - offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx]) - - # Merge with accumulated (also on compute_stream for consistency) - with torch.cuda.stream(compute_stream): - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - - torch.cuda.nvtx.range_pop() # PipelineBlock - - return o_acc, lse_acc + return final_o def _chunked_decode_attention( self, @@ -524,6 +272,8 @@ class Attention(nn.Module): if last_block_valid_tokens == 0 and total_prefill_tokens > 0: last_block_valid_tokens = block_size # Last block was exactly full + offload_engine = kvcache_manager.offload_engine + # Apply sparse policy if enabled (Quest does Top-K selection for decode) sparse_policy = kvcache_manager.sparse_policy if sparse_policy is not None: @@ -537,11 +287,9 @@ class Attention(nn.Module): total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, ) cpu_block_table = sparse_policy.select_blocks( - cpu_block_table, policy_ctx + cpu_block_table, offload_engine, policy_ctx ) - offload_engine = kvcache_manager.offload_engine - # Use cross-layer pipeline if active (initialized in model_runner) if offload_engine.is_pipeline_active(): o_acc, lse_acc = self._decode_with_layer_pipeline( diff --git a/test_report_sparse_policy_refactor.md b/test_report_sparse_policy_refactor.md new file mode 100644 index 0000000..a68bf21 --- /dev/null +++ b/test_report_sparse_policy_refactor.md @@ -0,0 +1,114 @@ +# SparsePolicy 重构测试报告 + +## 任务概述 + +根据 task_plan.md 的要求,对 nanovllm 的 SparsePolicy 架构进行重构(v4 版本),将 chunked prefill attention 计算逻辑从 attention.py 完全迁移到 SparsePolicy。 + +## 修改范围 + +仅针对 FullPolicy,不涉及 QuestPolicy 或 XAttentionBSAPolicy,不修改 decode 阶段逻辑。 + +## 完成的修改 + +### 1. policy.py (SparsePolicy 基类) + +- 添加 TYPE_CHECKING imports: `OffloadEngine`, `KVCacheManager`, `Sequence` +- 修改 `select_blocks` 签名:添加 `offload_engine` 参数 +- 添加 `compute_chunked_attention` 抽象方法,参数包括: + - `q, k, v`: 张量 + - `layer_id`: 层索引 + - `softmax_scale`: softmax 缩放因子 + - `offload_engine`: OffloadEngine 实例 + - `kvcache_manager`: KVCacheManager 实例 + - `current_chunk_idx`: 当前 chunk 索引 + - `seq`: Sequence 对象 + - `num_tokens`: 当前 chunk 的 token 数 + +### 2. full_policy.py (FullAttentionPolicy) + +- 更新 TYPE_CHECKING imports +- `select_blocks` 方法签名添加 `offload_engine` 参数 +- 重命名 `compute_prefill_attention` → `compute_chunked_attention` +- 添加 `kvcache_manager` 参数,替换所有 `seq.kvcache_manager` 引用 +- 添加 debug 日志输出 + +### 3. attention.py + +- 简化 `_chunked_prefill_attention` 方法: + - 删除所有 `flash_attn_*` 调用 + - 删除所有 `merge_attention_outputs` 调用 + - 仅保留委托调用 `sparse_policy.compute_chunked_attention()` +- 删除冗余方法:`_sync_load_previous_chunks`, `_ring_buffer_pipeline_load` +- decode 路径的 `select_blocks` 调用添加 `offload_engine` 参数 + +## 验收标准检查 + +| 标准 | 状态 | 说明 | +|------|------|------| +| test_needle.py --enable-offload 通过 | ✅ | 测试输出 PASSED | +| attention.py chunked prefill path 无 flash_attn_* 调用 | ✅ | `_chunked_prefill_attention` 方法(169-230行)内无直接 flash_attn 调用 | +| attention.py chunked prefill path 无 merge_attention_outputs 调用 | ✅ | 同上 | +| 所有 KV 通信通过 offload_engine 方法 | ✅ | 全部通过 `offload_engine.load_to_slot_layer`, `get_kv_for_slot`, `get_prefill_buffer_slice` | + +## 测试结果 + +``` +============================================================ +Needle-in-Haystack Test +============================================================ +Model: /home/zijie/models/Llama-3.1-8B-Instruct +Max model len: 131072 +Input length: 8192 +Block size: 1024 +Needle position: 50% +Needle value: 7492 +CPU offload: True +Sparse policy: FULL +============================================================ + +[NeedleTest] Target: 8192, Actual: 8213 tokens (diff=21) +Expected: 7492 +Output: 7492<|eot_id|>... +Status: PASSED +============================================================ + +test_needle: PASSED +``` + +## 性能指标 + +- Prefill: 3527 tok/s +- Decode: 11 tok/s +- TTFT: 2329.29 ms +- TPOT: 655.38 ms + +## 架构变更总结 + +**重构前**: +``` +attention.py::_chunked_prefill_attention() + ├── 获取 cpu_block_table + ├── 调用 sparse_policy.select_blocks() + ├── 直接调用 flash_attn_with_lse + merge_attention_outputs + └── 返回结果 +``` + +**重构后**: +``` +attention.py::_chunked_prefill_attention() + ├── 获取 context 信息 + ├── 调用 sparse_policy.compute_chunked_attention() # 委托全部计算 + └── 返回结果 + +sparse_policy.compute_chunked_attention() # 在 FullPolicy 中 + ├── 获取 cpu_block_table + ├── 调用 self.select_blocks() + ├── 加载并计算历史 KV attention + ├── 计算当前 chunk attention (causal) + ├── 合并所有结果 + └── 返回最终输出 +``` + +## 结论 + +SparsePolicy 架构 v4 重构成功完成。所有验收标准均已满足,测试通过。