♻️ refactor: migrate chunked prefill attention to SparsePolicy

Move all chunked prefill attention computation from attention.py to
SparsePolicy.compute_chunked_attention(). This is the v4 architecture
refactoring for sparse attention policies.

Changes:
- Add compute_chunked_attention abstract method to SparsePolicy base
- Add offload_engine parameter to select_blocks for policies needing
  KV access during block selection
- Implement compute_chunked_attention in FullAttentionPolicy with
  complete ring buffer pipeline logic
- Simplify attention.py to delegate all chunked prefill to policy
- Remove redundant _sync_load_previous_chunks and
  _ring_buffer_pipeline_load methods from Attention class

Test: test_needle.py --enable-offload PASSED

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-20 00:58:46 +08:00
parent 6783a45e6f
commit baa4be7e2e
4 changed files with 240 additions and 297 deletions

View File

@@ -174,123 +174,45 @@ class Attention(nn.Module):
"""
Compute attention with per-layer prefill buffer for async offload.
Optimized design:
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
- Previous chunks' KV are loaded from CPU using GPU slots
- Each layer offloads from its own buffer - no waiting required!
Simplified design:
- All computation logic is delegated to sparse_policy.compute_chunked_attention()
- This method only handles async offload after computation
For each layer:
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
2. Load previous chunks from CPU using available slots (pipeline)
3. Compute attention against previous KV (no causal mask)
4. Compute attention against current KV from prefill buffer (causal)
5. Merge all results using online softmax
6. Async offload prefill buffer to CPU (no waiting!)
The policy handles:
1. Loading historical blocks from CPU
2. Computing attention against historical KV (no causal mask)
3. Computing attention against current KV from prefill buffer (causal)
4. Merging all results
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
# q shape: [total_tokens, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
num_tokens = k.shape[0]
o_acc = None
lse_acc = None
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Get sparse policy - required for chunked prefill
sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is None:
raise RuntimeError("sparse_policy is required for chunked prefill")
# Apply sparse policy if enabled
sparse_policy = kvcache_manager.sparse_policy
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_attention, "
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
# === All sparse policies use select_blocks interface ===
if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
if cpu_block_table:
# Get available load slots (all slots can be used since we use prefill buffer)
load_slots = list(range(offload_engine.num_ring_slots))
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
# Only 1 slot total, cannot pipeline - use sync loading
o_acc, lse_acc = self._sync_load_previous_chunks(
q_batched, cpu_block_table, offload_engine
)
else:
# Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine,
current_chunk_idx
)
# Get compute stream for all attention operations
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
needs_current_chunk_attention = True
if needs_current_chunk_attention:
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
# Get KV from per-layer prefill buffer
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated (all on compute_stream for consistency)
if o_acc is None:
# No accumulated attention (no historical chunks processed)
final_o = current_o
else:
# Has accumulated attention (historical chunks processed)
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
# Delegate all computation to policy (no flash_attn or merge calls here!)
final_o = sparse_policy.compute_chunked_attention(
q, k, v,
self.layer_id,
self.scale,
offload_engine,
kvcache_manager,
current_chunk_idx,
seq,
num_tokens,
)
torch.cuda.nvtx.range_pop() # ChunkedPrefill
@@ -305,181 +227,7 @@ class Attention(nn.Module):
self.layer_id, cpu_block_id, num_tokens
)
# Sync default stream with compute_stream before returning
# This ensures the result is ready for the rest of the model (layernorm, MLP)
if compute_stream is not None:
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
def _sync_load_previous_chunks(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
):
"""Synchronous loading fallback when pipeline_depth=0."""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
compute_stream = offload_engine.compute_stream
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0)
# IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _ring_buffer_pipeline_load(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
current_chunk_idx: int = -1,
):
"""
Ring buffer async pipeline loading with double buffering.
Uses compute_done events to ensure safe buffer reuse:
- Before loading to slot X, wait for previous compute on slot X to finish
- Before computing on slot X, wait for load to slot X to finish
Timeline with 2 slots (A, B):
┌──────────────┐
│ Load B0→A │
└──────────────┘
┌──────────────┐ ┌──────────────┐
│ Load B1→B │ │ Load B2→A │ ...
└──────────────┘ └──────────────┘
↘ ↘
┌──────────────┐ ┌──────────────┐
│ Compute(A) │ │ Compute(B) │ ...
└──────────────┘ └──────────────┘
The load_to_slot_layer internally waits for compute_done[slot] before
starting the transfer, ensuring no data race.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
return None, None
o_acc, lse_acc = None, None
if pipeline_depth == 1:
# Only 1 slot available, cannot pipeline - use synchronous mode
# IMPORTANT: Must use compute_stream to match synchronization in
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
slot = load_slots[0]
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
# N-way pipeline: use ALL available slots for maximum overlap
# Pipeline depth = num_slots - 1 (num_slots blocks in flight)
num_slots = len(load_slots)
# Phase 1: Pre-load up to num_slots blocks to fill the pipeline
# This starts all transfers in parallel, utilizing full PCIe bandwidth
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
# Phase 2: Main loop - compute and immediately reuse slot for next transfer
# Use dedicated compute_stream (not default stream) to enable overlap with transfers
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
# Cycle through slots: slot[block_idx % num_slots]
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete (on compute_stream)
offload_engine.wait_slot_layer(current_slot)
# Compute attention on current slot's data
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next transfer to safely overwrite this slot
offload_engine.record_slot_compute_done(current_slot)
# Immediately start loading the NEXT block into this slot (if more blocks remain)
# Key insight: reuse current_slot immediately after compute is done!
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
# Merge with accumulated (also on compute_stream for consistency)
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop() # PipelineBlock
return o_acc, lse_acc
return final_o
def _chunked_decode_attention(
self,
@@ -524,6 +272,8 @@ class Attention(nn.Module):
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
offload_engine = kvcache_manager.offload_engine
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is not None:
@@ -537,11 +287,9 @@ class Attention(nn.Module):
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
cpu_block_table, offload_engine, policy_ctx
)
offload_engine = kvcache_manager.offload_engine
# Use cross-layer pipeline if active (initialized in model_runner)
if offload_engine.is_pipeline_active():
o_acc, lse_acc = self._decode_with_layer_pipeline(