♻️ refactor: migrate chunked prefill attention to SparsePolicy
Move all chunked prefill attention computation from attention.py to SparsePolicy.compute_chunked_attention(). This is the v4 architecture refactoring for sparse attention policies. Changes: - Add compute_chunked_attention abstract method to SparsePolicy base - Add offload_engine parameter to select_blocks for policies needing KV access during block selection - Implement compute_chunked_attention in FullAttentionPolicy with complete ring buffer pipeline logic - Simplify attention.py to delegate all chunked prefill to policy - Remove redundant _sync_load_previous_chunks and _ring_buffer_pipeline_load methods from Attention class Test: test_needle.py --enable-offload PASSED Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -5,12 +5,20 @@ This serves as a baseline and default policy when sparse
|
|||||||
attention is not needed.
|
attention is not needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
from typing import List, Optional
|
from typing import List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from .policy import SparsePolicy, PolicyContext
|
from .policy import SparsePolicy, PolicyContext
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||||
|
from nanovllm.kvcache.manager import KVCacheManager
|
||||||
|
from nanovllm.engine.sequence import Sequence
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FullAttentionPolicy(SparsePolicy):
|
class FullAttentionPolicy(SparsePolicy):
|
||||||
"""
|
"""
|
||||||
@@ -32,30 +40,34 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Return all blocks - no sparsity."""
|
"""Return all blocks - no sparsity."""
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
def compute_prefill_attention(
|
def compute_chunked_attention(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
offload_engine,
|
offload_engine: "OffloadEngine",
|
||||||
|
kvcache_manager: "KVCacheManager",
|
||||||
current_chunk_idx: int,
|
current_chunk_idx: int,
|
||||||
seq,
|
seq: "Sequence",
|
||||||
|
num_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute full attention for chunked prefill.
|
Compute full attention for chunked prefill.
|
||||||
|
|
||||||
This method handles the complete chunked prefill flow:
|
This method handles the complete chunked prefill flow:
|
||||||
1. Load historical blocks from CPU
|
1. Get historical blocks
|
||||||
2. Compute attention to historical chunks
|
2. Select blocks via select_blocks
|
||||||
3. Compute attention to current chunk
|
3. Load and compute attention to historical chunks
|
||||||
4. Merge all results
|
4. Compute attention to current chunk
|
||||||
|
5. Merge all results
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
@@ -64,22 +76,41 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
layer_id: Current layer index
|
layer_id: Current layer index
|
||||||
softmax_scale: Softmax scaling factor
|
softmax_scale: Softmax scaling factor
|
||||||
offload_engine: OffloadEngine for loading blocks
|
offload_engine: OffloadEngine for loading blocks
|
||||||
|
kvcache_manager: KVCacheManager for block management
|
||||||
current_chunk_idx: Current chunk index
|
current_chunk_idx: Current chunk index
|
||||||
seq: ChunkedSequence
|
seq: Sequence object
|
||||||
|
num_tokens: Number of tokens in current chunk
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, "
|
||||||
|
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
||||||
|
|
||||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||||
num_tokens = q.shape[0]
|
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
compute_stream = offload_engine.compute_stream
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
# Step 1: Get and load historical blocks
|
# Step 1: Get historical blocks
|
||||||
cpu_block_table = seq.kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Step 2: Apply select_blocks to filter blocks
|
||||||
|
if cpu_block_table:
|
||||||
|
num_chunks = current_chunk_idx + 1
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=current_chunk_idx,
|
||||||
|
num_query_chunks=num_chunks,
|
||||||
|
layer_id=layer_id,
|
||||||
|
query=None, # Prefill typically doesn't use query for selection
|
||||||
|
is_prefill=True,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||||
|
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
|
||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
load_slots = list(range(offload_engine.num_ring_slots))
|
load_slots = list(range(offload_engine.num_ring_slots))
|
||||||
@@ -139,7 +170,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
||||||
|
|
||||||
# Step 2: Compute attention to current chunk (causal mask)
|
# Step 4: Compute attention to current chunk (causal mask)
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
@@ -148,7 +179,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: Merge historical and current attention
|
# Step 5: Merge historical and current attention
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
final_o = current_o
|
final_o = current_o
|
||||||
|
|||||||
@@ -7,12 +7,17 @@ from CPU for each query chunk during chunked attention computation.
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any, TYPE_CHECKING
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Import SparsePolicyType from config to avoid circular imports
|
# Import SparsePolicyType from config to avoid circular imports
|
||||||
from nanovllm.config import SparsePolicyType
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||||
|
from nanovllm.kvcache.manager import KVCacheManager
|
||||||
|
from nanovllm.engine.sequence import Sequence
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PolicyContext:
|
class PolicyContext:
|
||||||
@@ -107,6 +112,7 @@ class SparsePolicy(ABC):
|
|||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
ctx: PolicyContext,
|
ctx: PolicyContext,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
@@ -120,6 +126,8 @@ class SparsePolicy(ABC):
|
|||||||
available_blocks: List of CPU block IDs that contain KV cache
|
available_blocks: List of CPU block IDs that contain KV cache
|
||||||
from previous chunks. These are ordered by
|
from previous chunks. These are ordered by
|
||||||
their position in the sequence.
|
their position in the sequence.
|
||||||
|
offload_engine: OffloadEngine for loading KV (some policies need
|
||||||
|
to load KV to make selection decisions).
|
||||||
ctx: PolicyContext with information about the current query
|
ctx: PolicyContext with information about the current query
|
||||||
chunk, layer, phase (prefill/decode), etc.
|
chunk, layer, phase (prefill/decode), etc.
|
||||||
|
|
||||||
@@ -183,5 +191,47 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_chunked_attention(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
kvcache_manager: "KVCacheManager",
|
||||||
|
current_chunk_idx: int,
|
||||||
|
seq: "Sequence",
|
||||||
|
num_tokens: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute chunked prefill attention (complete flow).
|
||||||
|
|
||||||
|
This is the main entry point for prefill attention computation.
|
||||||
|
It defines the complete prefill flow:
|
||||||
|
1. Get historical blocks
|
||||||
|
2. Select blocks (call select_blocks)
|
||||||
|
3. Load and compute historical blocks via offload_engine
|
||||||
|
4. Get current chunk KV from offload_engine, compute attention
|
||||||
|
5. Merge all results
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [seq_len, num_heads, head_dim] query for current chunk
|
||||||
|
k: [seq_len, num_kv_heads, head_dim] key for current chunk (in prefill buffer)
|
||||||
|
v: [seq_len, num_kv_heads, head_dim] value for current chunk (in prefill buffer)
|
||||||
|
layer_id: transformer layer index
|
||||||
|
softmax_scale: softmax scaling factor
|
||||||
|
offload_engine: OffloadEngine for loading blocks
|
||||||
|
kvcache_manager: KVCacheManager for block management
|
||||||
|
current_chunk_idx: current chunk index
|
||||||
|
seq: Sequence object
|
||||||
|
num_tokens: number of tokens in current chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[seq_len, num_heads, head_dim] final attention output
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
|
|||||||
@@ -174,123 +174,45 @@ class Attention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Compute attention with per-layer prefill buffer for async offload.
|
Compute attention with per-layer prefill buffer for async offload.
|
||||||
|
|
||||||
Optimized design:
|
Simplified design:
|
||||||
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
|
- All computation logic is delegated to sparse_policy.compute_chunked_attention()
|
||||||
- Previous chunks' KV are loaded from CPU using GPU slots
|
- This method only handles async offload after computation
|
||||||
- Each layer offloads from its own buffer - no waiting required!
|
|
||||||
|
|
||||||
For each layer:
|
The policy handles:
|
||||||
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
|
1. Loading historical blocks from CPU
|
||||||
2. Load previous chunks from CPU using available slots (pipeline)
|
2. Computing attention against historical KV (no causal mask)
|
||||||
3. Compute attention against previous KV (no causal mask)
|
3. Computing attention against current KV from prefill buffer (causal)
|
||||||
4. Compute attention against current KV from prefill buffer (causal)
|
4. Merging all results
|
||||||
5. Merge all results using online softmax
|
|
||||||
6. Async offload prefill buffer to CPU (no waiting!)
|
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
current_chunk_idx = context.current_chunk_idx
|
current_chunk_idx = context.current_chunk_idx
|
||||||
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||||||
|
|
||||||
# q shape: [total_tokens, num_heads, head_dim]
|
|
||||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
|
||||||
num_tokens = k.shape[0]
|
num_tokens = k.shape[0]
|
||||||
|
|
||||||
o_acc = None
|
|
||||||
lse_acc = None
|
|
||||||
|
|
||||||
kvcache_manager = context.kvcache_manager
|
kvcache_manager = context.kvcache_manager
|
||||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||||
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
||||||
|
|
||||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
# Get sparse policy - required for chunked prefill
|
||||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
if sparse_policy is None:
|
||||||
|
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# [DEBUG] Verify execution path
|
||||||
sparse_policy = kvcache_manager.sparse_policy
|
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_attention, "
|
||||||
|
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||||
|
|
||||||
# === All sparse policies use select_blocks interface ===
|
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||||
if cpu_block_table and sparse_policy is not None:
|
final_o = sparse_policy.compute_chunked_attention(
|
||||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
q, k, v,
|
||||||
policy_ctx = PolicyContext(
|
self.layer_id,
|
||||||
query_chunk_idx=current_chunk_idx,
|
self.scale,
|
||||||
num_query_chunks=num_chunks,
|
offload_engine,
|
||||||
layer_id=self.layer_id,
|
kvcache_manager,
|
||||||
query=None, # Prefill typically doesn't use query for selection
|
current_chunk_idx,
|
||||||
is_prefill=True,
|
seq,
|
||||||
block_size=kvcache_manager.block_size,
|
num_tokens,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
)
|
||||||
)
|
|
||||||
cpu_block_table = sparse_policy.select_blocks(
|
|
||||||
cpu_block_table, policy_ctx
|
|
||||||
)
|
|
||||||
|
|
||||||
if cpu_block_table:
|
|
||||||
# Get available load slots (all slots can be used since we use prefill buffer)
|
|
||||||
load_slots = list(range(offload_engine.num_ring_slots))
|
|
||||||
pipeline_depth = len(load_slots)
|
|
||||||
|
|
||||||
if pipeline_depth == 0:
|
|
||||||
# Only 1 slot total, cannot pipeline - use sync loading
|
|
||||||
o_acc, lse_acc = self._sync_load_previous_chunks(
|
|
||||||
q_batched, cpu_block_table, offload_engine
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use ring buffer pipeline
|
|
||||||
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
|
||||||
current_chunk_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get compute stream for all attention operations
|
|
||||||
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
|
||||||
|
|
||||||
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
|
||||||
needs_current_chunk_attention = True
|
|
||||||
|
|
||||||
if needs_current_chunk_attention:
|
|
||||||
if compute_stream is not None:
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
|
||||||
# Get KV from per-layer prefill buffer
|
|
||||||
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
|
||||||
q_batched,
|
|
||||||
k_batched,
|
|
||||||
v_batched,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
else:
|
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
|
||||||
k_batched = k.unsqueeze(0)
|
|
||||||
v_batched = v.unsqueeze(0)
|
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
|
||||||
q_batched,
|
|
||||||
k_batched,
|
|
||||||
v_batched,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
# Merge with accumulated (all on compute_stream for consistency)
|
|
||||||
if o_acc is None:
|
|
||||||
# No accumulated attention (no historical chunks processed)
|
|
||||||
final_o = current_o
|
|
||||||
else:
|
|
||||||
# Has accumulated attention (historical chunks processed)
|
|
||||||
if compute_stream is not None:
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
|
||||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
else:
|
|
||||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
|
||||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
|
|
||||||
@@ -305,181 +227,7 @@ class Attention(nn.Module):
|
|||||||
self.layer_id, cpu_block_id, num_tokens
|
self.layer_id, cpu_block_id, num_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sync default stream with compute_stream before returning
|
return final_o
|
||||||
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
|
||||||
if compute_stream is not None:
|
|
||||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
|
||||||
|
|
||||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
|
||||||
return final_o.squeeze(0)
|
|
||||||
|
|
||||||
def _sync_load_previous_chunks(
|
|
||||||
self,
|
|
||||||
q_batched: torch.Tensor,
|
|
||||||
cpu_block_table: list,
|
|
||||||
offload_engine,
|
|
||||||
):
|
|
||||||
"""Synchronous loading fallback when pipeline_depth=0."""
|
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
o_acc, lse_acc = None, None
|
|
||||||
compute_stream = offload_engine.compute_stream
|
|
||||||
|
|
||||||
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
|
||||||
# Load to slot 0 (single slot)
|
|
||||||
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
|
||||||
offload_engine.wait_slot_layer(0)
|
|
||||||
|
|
||||||
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
|
|
||||||
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
|
||||||
q_batched, prev_k, prev_v,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
||||||
|
|
||||||
return o_acc, lse_acc
|
|
||||||
|
|
||||||
def _ring_buffer_pipeline_load(
|
|
||||||
self,
|
|
||||||
q_batched: torch.Tensor,
|
|
||||||
cpu_block_table: list,
|
|
||||||
load_slots: list,
|
|
||||||
offload_engine,
|
|
||||||
current_chunk_idx: int = -1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Ring buffer async pipeline loading with double buffering.
|
|
||||||
|
|
||||||
Uses compute_done events to ensure safe buffer reuse:
|
|
||||||
- Before loading to slot X, wait for previous compute on slot X to finish
|
|
||||||
- Before computing on slot X, wait for load to slot X to finish
|
|
||||||
|
|
||||||
Timeline with 2 slots (A, B):
|
|
||||||
┌──────────────┐
|
|
||||||
│ Load B0→A │
|
|
||||||
└──────────────┘
|
|
||||||
┌──────────────┐ ┌──────────────┐
|
|
||||||
│ Load B1→B │ │ Load B2→A │ ...
|
|
||||||
└──────────────┘ └──────────────┘
|
|
||||||
↘ ↘
|
|
||||||
┌──────────────┐ ┌──────────────┐
|
|
||||||
│ Compute(A) │ │ Compute(B) │ ...
|
|
||||||
└──────────────┘ └──────────────┘
|
|
||||||
|
|
||||||
The load_to_slot_layer internally waits for compute_done[slot] before
|
|
||||||
starting the transfer, ensuring no data race.
|
|
||||||
"""
|
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
num_blocks = len(cpu_block_table)
|
|
||||||
if num_blocks == 0:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
pipeline_depth = len(load_slots)
|
|
||||||
if pipeline_depth == 0:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
o_acc, lse_acc = None, None
|
|
||||||
|
|
||||||
if pipeline_depth == 1:
|
|
||||||
# Only 1 slot available, cannot pipeline - use synchronous mode
|
|
||||||
# IMPORTANT: Must use compute_stream to match synchronization in
|
|
||||||
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
|
|
||||||
slot = load_slots[0]
|
|
||||||
compute_stream = offload_engine.compute_stream
|
|
||||||
for block_idx in range(num_blocks):
|
|
||||||
cpu_block_id = cpu_block_table[block_idx]
|
|
||||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
|
||||||
offload_engine.wait_slot_layer(slot)
|
|
||||||
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
# Debug: call hooks on compute_stream (synchronized with transfer)
|
|
||||||
if offload_engine.debug_mode:
|
|
||||||
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
|
||||||
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
|
||||||
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
|
||||||
q_batched, prev_k, prev_v,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
# Record compute done so next load can safely reuse this slot
|
|
||||||
offload_engine.record_slot_compute_done(slot)
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
||||||
return o_acc, lse_acc
|
|
||||||
|
|
||||||
# N-way pipeline: use ALL available slots for maximum overlap
|
|
||||||
# Pipeline depth = num_slots - 1 (num_slots blocks in flight)
|
|
||||||
num_slots = len(load_slots)
|
|
||||||
|
|
||||||
# Phase 1: Pre-load up to num_slots blocks to fill the pipeline
|
|
||||||
# This starts all transfers in parallel, utilizing full PCIe bandwidth
|
|
||||||
num_preload = min(num_slots, num_blocks)
|
|
||||||
for i in range(num_preload):
|
|
||||||
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
|
||||||
|
|
||||||
# Phase 2: Main loop - compute and immediately reuse slot for next transfer
|
|
||||||
# Use dedicated compute_stream (not default stream) to enable overlap with transfers
|
|
||||||
compute_stream = offload_engine.compute_stream
|
|
||||||
|
|
||||||
for block_idx in range(num_blocks):
|
|
||||||
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
|
|
||||||
|
|
||||||
# Cycle through slots: slot[block_idx % num_slots]
|
|
||||||
current_slot = load_slots[block_idx % num_slots]
|
|
||||||
cpu_block_id = cpu_block_table[block_idx]
|
|
||||||
|
|
||||||
# Wait for current slot's transfer to complete (on compute_stream)
|
|
||||||
offload_engine.wait_slot_layer(current_slot)
|
|
||||||
|
|
||||||
# Compute attention on current slot's data
|
|
||||||
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
# Debug: call hooks on compute_stream (synchronized with transfer)
|
|
||||||
if offload_engine.debug_mode:
|
|
||||||
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
|
||||||
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
|
||||||
q_batched, prev_k, prev_v,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
# Record compute done - this allows the next transfer to safely overwrite this slot
|
|
||||||
offload_engine.record_slot_compute_done(current_slot)
|
|
||||||
|
|
||||||
# Immediately start loading the NEXT block into this slot (if more blocks remain)
|
|
||||||
# Key insight: reuse current_slot immediately after compute is done!
|
|
||||||
next_block_idx = block_idx + num_slots
|
|
||||||
if next_block_idx < num_blocks:
|
|
||||||
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
|
||||||
|
|
||||||
# Merge with accumulated (also on compute_stream for consistency)
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # PipelineBlock
|
|
||||||
|
|
||||||
return o_acc, lse_acc
|
|
||||||
|
|
||||||
def _chunked_decode_attention(
|
def _chunked_decode_attention(
|
||||||
self,
|
self,
|
||||||
@@ -524,6 +272,8 @@ class Attention(nn.Module):
|
|||||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||||
last_block_valid_tokens = block_size # Last block was exactly full
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|
||||||
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
||||||
sparse_policy = kvcache_manager.sparse_policy
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
if sparse_policy is not None:
|
if sparse_policy is not None:
|
||||||
@@ -537,11 +287,9 @@ class Attention(nn.Module):
|
|||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
)
|
)
|
||||||
cpu_block_table = sparse_policy.select_blocks(
|
cpu_block_table = sparse_policy.select_blocks(
|
||||||
cpu_block_table, policy_ctx
|
cpu_block_table, offload_engine, policy_ctx
|
||||||
)
|
)
|
||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
|
||||||
|
|
||||||
# Use cross-layer pipeline if active (initialized in model_runner)
|
# Use cross-layer pipeline if active (initialized in model_runner)
|
||||||
if offload_engine.is_pipeline_active():
|
if offload_engine.is_pipeline_active():
|
||||||
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
||||||
|
|||||||
114
test_report_sparse_policy_refactor.md
Normal file
114
test_report_sparse_policy_refactor.md
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
# SparsePolicy 重构测试报告
|
||||||
|
|
||||||
|
## 任务概述
|
||||||
|
|
||||||
|
根据 task_plan.md 的要求,对 nanovllm 的 SparsePolicy 架构进行重构(v4 版本),将 chunked prefill attention 计算逻辑从 attention.py 完全迁移到 SparsePolicy。
|
||||||
|
|
||||||
|
## 修改范围
|
||||||
|
|
||||||
|
仅针对 FullPolicy,不涉及 QuestPolicy 或 XAttentionBSAPolicy,不修改 decode 阶段逻辑。
|
||||||
|
|
||||||
|
## 完成的修改
|
||||||
|
|
||||||
|
### 1. policy.py (SparsePolicy 基类)
|
||||||
|
|
||||||
|
- 添加 TYPE_CHECKING imports: `OffloadEngine`, `KVCacheManager`, `Sequence`
|
||||||
|
- 修改 `select_blocks` 签名:添加 `offload_engine` 参数
|
||||||
|
- 添加 `compute_chunked_attention` 抽象方法,参数包括:
|
||||||
|
- `q, k, v`: 张量
|
||||||
|
- `layer_id`: 层索引
|
||||||
|
- `softmax_scale`: softmax 缩放因子
|
||||||
|
- `offload_engine`: OffloadEngine 实例
|
||||||
|
- `kvcache_manager`: KVCacheManager 实例
|
||||||
|
- `current_chunk_idx`: 当前 chunk 索引
|
||||||
|
- `seq`: Sequence 对象
|
||||||
|
- `num_tokens`: 当前 chunk 的 token 数
|
||||||
|
|
||||||
|
### 2. full_policy.py (FullAttentionPolicy)
|
||||||
|
|
||||||
|
- 更新 TYPE_CHECKING imports
|
||||||
|
- `select_blocks` 方法签名添加 `offload_engine` 参数
|
||||||
|
- 重命名 `compute_prefill_attention` → `compute_chunked_attention`
|
||||||
|
- 添加 `kvcache_manager` 参数,替换所有 `seq.kvcache_manager` 引用
|
||||||
|
- 添加 debug 日志输出
|
||||||
|
|
||||||
|
### 3. attention.py
|
||||||
|
|
||||||
|
- 简化 `_chunked_prefill_attention` 方法:
|
||||||
|
- 删除所有 `flash_attn_*` 调用
|
||||||
|
- 删除所有 `merge_attention_outputs` 调用
|
||||||
|
- 仅保留委托调用 `sparse_policy.compute_chunked_attention()`
|
||||||
|
- 删除冗余方法:`_sync_load_previous_chunks`, `_ring_buffer_pipeline_load`
|
||||||
|
- decode 路径的 `select_blocks` 调用添加 `offload_engine` 参数
|
||||||
|
|
||||||
|
## 验收标准检查
|
||||||
|
|
||||||
|
| 标准 | 状态 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| test_needle.py --enable-offload 通过 | ✅ | 测试输出 PASSED |
|
||||||
|
| attention.py chunked prefill path 无 flash_attn_* 调用 | ✅ | `_chunked_prefill_attention` 方法(169-230行)内无直接 flash_attn 调用 |
|
||||||
|
| attention.py chunked prefill path 无 merge_attention_outputs 调用 | ✅ | 同上 |
|
||||||
|
| 所有 KV 通信通过 offload_engine 方法 | ✅ | 全部通过 `offload_engine.load_to_slot_layer`, `get_kv_for_slot`, `get_prefill_buffer_slice` |
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
```
|
||||||
|
============================================================
|
||||||
|
Needle-in-Haystack Test
|
||||||
|
============================================================
|
||||||
|
Model: /home/zijie/models/Llama-3.1-8B-Instruct
|
||||||
|
Max model len: 131072
|
||||||
|
Input length: 8192
|
||||||
|
Block size: 1024
|
||||||
|
Needle position: 50%
|
||||||
|
Needle value: 7492
|
||||||
|
CPU offload: True
|
||||||
|
Sparse policy: FULL
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
[NeedleTest] Target: 8192, Actual: 8213 tokens (diff=21)
|
||||||
|
Expected: 7492
|
||||||
|
Output: 7492<|eot_id|>...
|
||||||
|
Status: PASSED
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
test_needle: PASSED
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能指标
|
||||||
|
|
||||||
|
- Prefill: 3527 tok/s
|
||||||
|
- Decode: 11 tok/s
|
||||||
|
- TTFT: 2329.29 ms
|
||||||
|
- TPOT: 655.38 ms
|
||||||
|
|
||||||
|
## 架构变更总结
|
||||||
|
|
||||||
|
**重构前**:
|
||||||
|
```
|
||||||
|
attention.py::_chunked_prefill_attention()
|
||||||
|
├── 获取 cpu_block_table
|
||||||
|
├── 调用 sparse_policy.select_blocks()
|
||||||
|
├── 直接调用 flash_attn_with_lse + merge_attention_outputs
|
||||||
|
└── 返回结果
|
||||||
|
```
|
||||||
|
|
||||||
|
**重构后**:
|
||||||
|
```
|
||||||
|
attention.py::_chunked_prefill_attention()
|
||||||
|
├── 获取 context 信息
|
||||||
|
├── 调用 sparse_policy.compute_chunked_attention() # 委托全部计算
|
||||||
|
└── 返回结果
|
||||||
|
|
||||||
|
sparse_policy.compute_chunked_attention() # 在 FullPolicy 中
|
||||||
|
├── 获取 cpu_block_table
|
||||||
|
├── 调用 self.select_blocks()
|
||||||
|
├── 加载并计算历史 KV attention
|
||||||
|
├── 计算当前 chunk attention (causal)
|
||||||
|
├── 合并所有结果
|
||||||
|
└── 返回最终输出
|
||||||
|
```
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
SparsePolicy 架构 v4 重构成功完成。所有验收标准均已满足,测试通过。
|
||||||
Reference in New Issue
Block a user