♻️ refactor: remove cross-layer pipeline and rename compute_chunked_prefill

- Remove cross-layer pipeline from OffloadEngine (saves ~1GB GPU memory for long sequences)
  - Delete layer_k/v_buffer_a/b double buffers
  - Remove start_decode_pipeline, get_decode_layer_kv, end_decode_pipeline methods
  - Remove pipeline state tracking variables
- Simplify decode to use ring buffer pipeline only (more efficient for long sequences)
- Rename compute_chunked_attention → compute_chunked_prefill for clarity
- Add mandatory needle test requirements: --enable-offload --input-len 32768

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-20 02:10:40 +08:00
parent 6080bf7554
commit fa7601f4b8
9 changed files with 67 additions and 299 deletions

View File

@@ -644,12 +644,6 @@ class ModelRunner:
# Get decode start position for accumulated token tracking
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
# Get prefilled CPU blocks for pipeline initialization
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
# Start cross-layer pipeline (preloads Layer 0's data)
offload_engine.start_decode_pipeline(cpu_block_table)
# Set up context for chunked decode
set_context(
is_prefill=False,
@@ -666,9 +660,6 @@ class ModelRunner:
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# End cross-layer pipeline
offload_engine.end_decode_pipeline()
# Only offload when block is full (pos_in_block == block_size - 1)
# This avoids unnecessary offloading on every decode step
if pos_in_block == self.block_size - 1:

View File

@@ -141,40 +141,6 @@ class OffloadEngine:
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
# ========== Cross-layer pipeline buffers for decode ==========
# Double-buffered layer cache for pipelined decode:
# - Buffer A: Current layer's prefilled KV being computed
# - Buffer B: Next layer's prefilled KV being loaded
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
self.layer_k_buffer_a = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_buffer_a = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_k_buffer_b = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_buffer_b = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
# Pipeline state tracking
self._pipeline_active = False
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
self._pipeline_next_layer_event = torch.cuda.Event()
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
self._pipeline_num_blocks = 0
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
# ========== Per-layer prefill buffer for async offload ==========
# During chunked prefill, all layers share the same GPU slot. This means
# each layer must wait for offload to complete before the next layer can
@@ -666,122 +632,6 @@ class OffloadEngine:
raise
logger.warning(f"Debug hook error: {e}")
# ========== Cross-layer Pipeline Methods for Decode ==========
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
"""
Start cross-layer pipeline for decode.
Called at the beginning of a decode step to initialize the pipeline.
Preloads Layer 0's data into buffer A.
Args:
cpu_block_ids: List of CPU block IDs for prefilled blocks
"""
if not cpu_block_ids:
self._pipeline_active = False
return
self._pipeline_active = True
self._pipeline_cpu_blocks = cpu_block_ids
self._pipeline_num_blocks = len(cpu_block_ids)
self._pipeline_current_buffer = 0
# Preload Layer 0 into buffer A
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
"""
Get KV cache for a layer during decode.
If pipeline is active, returns data from the current buffer.
Also triggers preloading of the next layer (if not last layer).
Args:
layer_id: Current layer ID
num_blocks: Number of blocks to return
Returns:
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
"""
if not self._pipeline_active:
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
# Wait for current layer's data to be ready
self.compute_stream.wait_event(self._pipeline_next_layer_event)
# Get current buffer
if self._pipeline_current_buffer == 0:
k = self.layer_k_buffer_a[:num_blocks]
v = self.layer_v_buffer_a[:num_blocks]
else:
k = self.layer_k_buffer_b[:num_blocks]
v = self.layer_v_buffer_b[:num_blocks]
# Trigger preloading of next layer (if not last layer)
next_layer_id = layer_id + 1
if next_layer_id < self.num_layers:
# Use the other buffer for next layer
next_buffer_idx = 1 - self._pipeline_current_buffer
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
# Switch to next buffer for next layer
self._pipeline_current_buffer = next_buffer_idx
return k, v
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
"""
Async load a layer's prefilled blocks to the specified buffer.
Uses sgDMA for efficient strided transfer from CPU cache.
Args:
layer_id: Layer index to load
buffer_idx: 0 for buffer A, 1 for buffer B
"""
num_blocks = self._pipeline_num_blocks
cpu_block_ids = self._pipeline_cpu_blocks
# Select target buffer
if buffer_idx == 0:
k_buffer = self.layer_k_buffer_a
v_buffer = self.layer_v_buffer_a
else:
k_buffer = self.layer_k_buffer_b
v_buffer = self.layer_v_buffer_b
# Load all blocks for this layer using dedicated stream
with torch.cuda.stream(self._pipeline_layer_stream):
for i, cpu_block_id in enumerate(cpu_block_ids):
# Copy from CPU cache (has layer dimension) to GPU buffer
k_buffer[i].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
v_buffer[i].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# Record event when all transfers complete
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
def end_decode_pipeline(self) -> None:
"""
End the cross-layer pipeline.
Called at the end of a decode step to clean up pipeline state.
"""
if self._pipeline_active:
# Ensure all transfers complete before ending
self._pipeline_layer_stream.synchronize()
self._pipeline_active = False
self._pipeline_cpu_blocks = []
self._pipeline_num_blocks = 0
def is_pipeline_active(self) -> bool:
"""Check if decode pipeline is currently active."""
return self._pipeline_active
# ========== Per-layer Prefill Buffer Methods ==========
# These methods enable async offload during chunked prefill by using
# per-layer buffers instead of shared GPU slots.

View File

@@ -46,7 +46,7 @@ class FullAttentionPolicy(SparsePolicy):
"""Return all blocks - no sparsity."""
return available_blocks
def compute_chunked_attention(
def compute_chunked_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
@@ -86,7 +86,7 @@ class FullAttentionPolicy(SparsePolicy):
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, "
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
@@ -256,19 +256,12 @@ class FullAttentionPolicy(SparsePolicy):
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
# Use cross-layer pipeline if active (initialized in model_runner)
if offload_engine.is_pipeline_active():
o_acc, lse_acc = self._decode_with_layer_pipeline(
q_batched, cpu_block_table, offload_engine,
block_size, last_block_valid_tokens, layer_id, softmax_scale
)
else:
# Fallback to original ring buffer pipeline
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens, layer_id, softmax_scale
)
# Use ring buffer pipeline for loading prefilled blocks
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens, layer_id, softmax_scale
)
# Now attend to accumulated decode tokens from per-layer decode buffer
# Compute decode position information internally
@@ -386,62 +379,5 @@ class FullAttentionPolicy(SparsePolicy):
return o_acc, lse_acc
def _decode_with_layer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine: "OffloadEngine",
block_size: int,
last_block_valid_tokens: int,
layer_id: int,
softmax_scale: float,
):
"""
Decode using cross-layer pipeline for optimized H2D transfer.
Uses pre-loaded layer buffers instead of loading blocks one by one.
The pipeline loads the next layer's data while the current layer
computes, achieving transfer/compute overlap.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
compute_stream = offload_engine.compute_stream
# Get KV from pre-loaded layer buffer (triggers next layer loading)
prev_k, prev_v = offload_engine.get_decode_layer_kv(layer_id, num_blocks)
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
total_tokens = num_blocks * block_size
# Handle partial last block
if last_block_valid_tokens < block_size:
# Only use valid tokens from last block
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
# Flatten and truncate
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
else:
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
prev_k_batched = prev_k_flat.unsqueeze(0)
prev_v_batched = prev_v_flat.unsqueeze(0)
# Compute attention on all prefilled blocks at once
with torch.cuda.stream(compute_stream):
o_acc, lse_acc = flash_attn_with_lse(
q_batched, prev_k_batched, prev_v_batched,
softmax_scale=softmax_scale,
causal=False,
)
return o_acc, lse_acc
def __repr__(self) -> str:
return "FullAttentionPolicy()"

View File

@@ -192,7 +192,7 @@ class SparsePolicy(ABC):
pass
@abstractmethod
def compute_chunked_attention(
def compute_chunked_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,

View File

@@ -174,7 +174,7 @@ class Attention(nn.Module):
Compute attention with per-layer prefill buffer for async offload.
Simplified design:
- All computation logic is delegated to sparse_policy.compute_chunked_attention()
- All computation logic is delegated to sparse_policy.compute_chunked_prefill()
- This method only handles async offload after computation
The policy handles:
@@ -198,11 +198,11 @@ class Attention(nn.Module):
raise RuntimeError("sparse_policy is required for chunked prefill")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_attention, "
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
# Delegate all computation to policy (no flash_attn or merge calls here!)
final_o = sparse_policy.compute_chunked_attention(
final_o = sparse_policy.compute_chunked_prefill(
q, k, v,
self.layer_id,
self.scale,