♻️ refactor: migrate chunked prefill attention to SparsePolicy

Move all chunked prefill attention computation from attention.py to
SparsePolicy.compute_chunked_attention(). This is the v4 architecture
refactoring for sparse attention policies.

Changes:
- Add compute_chunked_attention abstract method to SparsePolicy base
- Add offload_engine parameter to select_blocks for policies needing
  KV access during block selection
- Implement compute_chunked_attention in FullAttentionPolicy with
  complete ring buffer pipeline logic
- Simplify attention.py to delegate all chunked prefill to policy
- Remove redundant _sync_load_previous_chunks and
  _ring_buffer_pipeline_load methods from Attention class

Test: test_needle.py --enable-offload PASSED

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-20 00:58:46 +08:00
parent 6783a45e6f
commit baa4be7e2e
4 changed files with 240 additions and 297 deletions

View File

@@ -5,12 +5,20 @@ This serves as a baseline and default policy when sparse
attention is not needed.
"""
import logging
import torch
from typing import List, Optional
from typing import List, Optional, TYPE_CHECKING
from .policy import SparsePolicy, PolicyContext
from nanovllm.utils.context import get_context
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.manager import KVCacheManager
from nanovllm.engine.sequence import Sequence
logger = logging.getLogger(__name__)
class FullAttentionPolicy(SparsePolicy):
"""
@@ -32,30 +40,34 @@ class FullAttentionPolicy(SparsePolicy):
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
) -> List[int]:
"""Return all blocks - no sparsity."""
return available_blocks
def compute_prefill_attention(
def compute_chunked_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq,
seq: "Sequence",
num_tokens: int,
) -> torch.Tensor:
"""
Compute full attention for chunked prefill.
This method handles the complete chunked prefill flow:
1. Load historical blocks from CPU
2. Compute attention to historical chunks
3. Compute attention to current chunk
4. Merge all results
1. Get historical blocks
2. Select blocks via select_blocks
3. Load and compute attention to historical chunks
4. Compute attention to current chunk
5. Merge all results
Args:
q: Query tensor [seq_len, num_heads, head_dim]
@@ -64,22 +76,41 @@ class FullAttentionPolicy(SparsePolicy):
layer_id: Current layer index
softmax_scale: Softmax scaling factor
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
current_chunk_idx: Current chunk index
seq: ChunkedSequence
seq: Sequence object
num_tokens: Number of tokens in current chunk
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, "
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
num_tokens = q.shape[0]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Step 1: Get and load historical blocks
cpu_block_table = seq.kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 1: Get historical blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots))
@@ -139,7 +170,7 @@ class FullAttentionPolicy(SparsePolicy):
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
# Step 2: Compute attention to current chunk (causal mask)
# Step 4: Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
@@ -148,7 +179,7 @@ class FullAttentionPolicy(SparsePolicy):
causal=True,
)
# Step 3: Merge historical and current attention
# Step 5: Merge historical and current attention
with torch.cuda.stream(compute_stream):
if o_acc is None:
final_o = current_o