♻️ refactor: migrate chunked prefill attention to SparsePolicy

Move all chunked prefill attention computation from attention.py to
SparsePolicy.compute_chunked_attention(). This is the v4 architecture
refactoring for sparse attention policies.

Changes:
- Add compute_chunked_attention abstract method to SparsePolicy base
- Add offload_engine parameter to select_blocks for policies needing
  KV access during block selection
- Implement compute_chunked_attention in FullAttentionPolicy with
  complete ring buffer pipeline logic
- Simplify attention.py to delegate all chunked prefill to policy
- Remove redundant _sync_load_previous_chunks and
  _ring_buffer_pipeline_load methods from Attention class

Test: test_needle.py --enable-offload PASSED

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-20 00:58:46 +08:00
parent 6783a45e6f
commit baa4be7e2e
4 changed files with 240 additions and 297 deletions

View File

@@ -5,12 +5,20 @@ This serves as a baseline and default policy when sparse
attention is not needed.
"""
import logging
import torch
from typing import List, Optional
from typing import List, Optional, TYPE_CHECKING
from .policy import SparsePolicy, PolicyContext
from nanovllm.utils.context import get_context
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.manager import KVCacheManager
from nanovllm.engine.sequence import Sequence
logger = logging.getLogger(__name__)
class FullAttentionPolicy(SparsePolicy):
"""
@@ -32,30 +40,34 @@ class FullAttentionPolicy(SparsePolicy):
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
) -> List[int]:
"""Return all blocks - no sparsity."""
return available_blocks
def compute_prefill_attention(
def compute_chunked_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq,
seq: "Sequence",
num_tokens: int,
) -> torch.Tensor:
"""
Compute full attention for chunked prefill.
This method handles the complete chunked prefill flow:
1. Load historical blocks from CPU
2. Compute attention to historical chunks
3. Compute attention to current chunk
4. Merge all results
1. Get historical blocks
2. Select blocks via select_blocks
3. Load and compute attention to historical chunks
4. Compute attention to current chunk
5. Merge all results
Args:
q: Query tensor [seq_len, num_heads, head_dim]
@@ -64,22 +76,41 @@ class FullAttentionPolicy(SparsePolicy):
layer_id: Current layer index
softmax_scale: Softmax scaling factor
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
current_chunk_idx: Current chunk index
seq: ChunkedSequence
seq: Sequence object
num_tokens: Number of tokens in current chunk
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, "
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
num_tokens = q.shape[0]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Step 1: Get and load historical blocks
cpu_block_table = seq.kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 1: Get historical blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots))
@@ -139,7 +170,7 @@ class FullAttentionPolicy(SparsePolicy):
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
# Step 2: Compute attention to current chunk (causal mask)
# Step 4: Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
@@ -148,7 +179,7 @@ class FullAttentionPolicy(SparsePolicy):
causal=True,
)
# Step 3: Merge historical and current attention
# Step 5: Merge historical and current attention
with torch.cuda.stream(compute_stream):
if o_acc is None:
final_o = current_o

View File

@@ -7,12 +7,17 @@ from CPU for each query chunk during chunked attention computation.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Any
from typing import List, Optional, Any, TYPE_CHECKING
import torch
# Import SparsePolicyType from config to avoid circular imports
from nanovllm.config import SparsePolicyType
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.manager import KVCacheManager
from nanovllm.engine.sequence import Sequence
@dataclass
class PolicyContext:
@@ -107,6 +112,7 @@ class SparsePolicy(ABC):
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
) -> List[int]:
"""
@@ -120,6 +126,8 @@ class SparsePolicy(ABC):
available_blocks: List of CPU block IDs that contain KV cache
from previous chunks. These are ordered by
their position in the sequence.
offload_engine: OffloadEngine for loading KV (some policies need
to load KV to make selection decisions).
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
@@ -183,5 +191,47 @@ class SparsePolicy(ABC):
"""
pass
@abstractmethod
def compute_chunked_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
) -> torch.Tensor:
"""
Compute chunked prefill attention (complete flow).
This is the main entry point for prefill attention computation.
It defines the complete prefill flow:
1. Get historical blocks
2. Select blocks (call select_blocks)
3. Load and compute historical blocks via offload_engine
4. Get current chunk KV from offload_engine, compute attention
5. Merge all results
Args:
q: [seq_len, num_heads, head_dim] query for current chunk
k: [seq_len, num_kv_heads, head_dim] key for current chunk (in prefill buffer)
v: [seq_len, num_kv_heads, head_dim] value for current chunk (in prefill buffer)
layer_id: transformer layer index
softmax_scale: softmax scaling factor
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
current_chunk_idx: current chunk index
seq: Sequence object
num_tokens: number of tokens in current chunk
Returns:
[seq_len, num_heads, head_dim] final attention output
"""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"