[WIP] changed to layerwise offload.

This commit is contained in:
Zijie Tian
2026-01-08 00:28:27 +08:00
parent 6575099a06
commit ecd9ae0271

View File

@@ -398,16 +398,16 @@ class ModelRunner:
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
#> Check if Chunked Offload mode should be used (all blocks on CPU)
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill)
if use_chunked_offload:
#> Check if Layer-wise Offload mode should be used (CPU offload enabled)
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'offload_engine'):
use_layerwise_offload = self._should_use_layerwise_offload(seqs, is_prefill)
if use_layerwise_offload:
if is_prefill:
return self.run_chunked_offload_prefill(seqs)
return self.run_layerwise_offload_prefill(seqs)
else:
return self.run_chunked_offload_decode(seqs)
return self.run_layerwise_offload_decode(seqs)
#> Following Code will not use Chunked Offload mode
#> Following Code will not use Layer-wise Offload mode
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
@@ -415,236 +415,378 @@ class ModelRunner:
reset_context()
return token_ids
def _should_use_chunked_offload(self, seqs: list[Sequence], is_prefill: bool) -> bool:
def _should_use_layerwise_offload(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""
Check if three-region mode should be used.
Check if layer-wise offload mode should be used.
Use three-region when:
- CPU offload is enabled
- There are blocks on CPU (either allocated there or offloaded)
- Sequence exceeds GPU Compute region capacity
Use layer-wise offload when:
- CPU offload is enabled (offload_engine exists)
- Sequence has blocks allocated (not warmup)
"""
if not hasattr(self.kvcache_manager, 'offload_engine'):
return False
for seq in seqs:
if not seq.block_table:
continue # Skip warmup sequences
# Check if any blocks are on CPU
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
if cpu_blocks:
# Has CPU blocks - use three-region
return True
# Check if sequence needs more blocks than GPU Compute region can hold
compute_size = self.kvcache_manager.offload_engine.num_compute_blocks
if seq.num_blocks > compute_size:
# Needs chunked processing
if seq.block_table:
# Has blocks - use layer-wise offload
return True
return False
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with unified ring buffer (CPU is primary storage).
# ========== Layer-wise Offload Methods ==========
Flow:
1. All blocks are allocated to CPU (primary storage)
2. Each chunk writes KV to ring buffer slot[chunk_idx % N]
3. After each chunk, offload from ring buffer slot to CPU
4. All N-1 other slots are used to load previous chunks for attention
@torch.inference_mode()
def run_layerwise_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
Run prefill with layer-wise processing and CPU offload.
Key design:
- Process one layer at a time (not one chunk at a time)
- Each layer: full forward pass → offload KV to CPU
- Full KV stays on GPU during each layer's computation
- After layer completes, KV is offloaded to CPU
This enables future sparse attention methods (like MInference)
that need full KV context per layer for pattern estimation.
"""
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
# Each chunk uses 1 ring buffer slot = 1 block
tokens_per_chunk = self.block_size
num_layers = len(self.model.model.layers)
total_tokens = len(seq)
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
logger.debug(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
f"total_chunks={num_chunks}")
chunk_idx = 0
logits = None
processed_tokens = 0
logger.debug(f"[Layer-wise Prefill] Starting: {total_tokens} tokens, {num_layers} layers")
# Get CPU block table for offload targets
# Get CPU block IDs for offload targets
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
while processed_tokens < total_tokens:
chunk_start = processed_tokens
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
# Get ring buffer slot for this chunk
write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
# CPU block index for this chunk
block_idx = chunk_idx
logger.debug(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
f"write_slot={write_slot}")
# Prepare inputs
input_ids, positions = self._prepare_chunked_offload_chunk(
seq, chunk_start, chunk_end, write_slot, block_idx, chunk_idx
input_ids = torch.tensor(seq[:], dtype=torch.int64, device="cuda")
positions = torch.arange(total_tokens, dtype=torch.int64, device="cuda")
# Step 1: Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
# Step 2: Layer-by-layer processing
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# 2a. Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# 2b. Self-attention (full sequence)
# QKV projection
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k, v = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(total_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v = v.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms (Qwen3 specific)
if not layer.self_attn.qkv_bias:
num_tokens = q.shape[0]
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(num_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = layer.self_attn.k_norm(k.reshape(-1, layer.self_attn.head_dim))
k = k.view(num_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Full attention using FlashAttention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens = torch.tensor([0, total_tokens], dtype=torch.int32, device="cuda")
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
if input_ids.numel() == 0:
break
# O projection
attn_output = attn_output.view(total_tokens, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
#> Run model forward
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
# 2c. Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Mark block as prefilled
if block_idx < len(seq.block_table):
logical_id = seq.block_table[block_idx]
# 2d. Offload KV to CPU (synchronous for correctness)
# Use synchronous copy to ensure data is fully copied before moving to next layer
self._offload_layer_kv_to_cpu_sync(layer_id, k, v, cpu_block_ids, total_tokens)
# Mark all blocks as prefilled
for logical_id in logical_ids:
self.kvcache_manager.prefilled_blocks.add(logical_id)
# NOTE: Per-layer async offloading is now done in attention.forward
# Each layer offloads from its own prefill buffer - no waiting required!
# The sparse policy hook is called in offload_prefill_buffer_async.
# Sync offload completes within loop, no explicit wait needed
processed_tokens = chunk_end
chunk_idx += 1
# Step 3: Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Wait for all async prefill offloads to complete
offload_engine.wait_all_prefill_offloads()
# Step 4: Compute logits for last token
logits = self.model.compute_logits(hidden_states[-1:])
logger.debug(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks")
# Sample from last logits
# For chunked prefill, ParallelLMHead automatically selects last position's logits
# Step 5: Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
else:
token_ids = [0] if self.rank == 0 else None
logger.debug(f"[Layer-wise Prefill] Complete: {num_layers} layers processed")
return token_ids
def _prepare_chunked_offload_chunk(
def _offload_layer_kv_to_cpu(
self,
seq: Sequence,
chunk_start: int,
chunk_end: int,
write_slot: int,
block_idx: int,
chunk_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare inputs for a chunked offload prefill chunk (ring buffer design)."""
# Input tokens for this chunk
input_ids = seq[chunk_start:chunk_end]
positions = list(range(chunk_start, chunk_end))
layer_id: int,
k: torch.Tensor,
v: torch.Tensor,
cpu_block_ids: list[int],
total_tokens: int,
):
"""
Offload a layer's KV cache to CPU in blocks (async version).
# Create slot mapping pointing to the single write_slot
slot_mapping = []
for pos in range(chunk_start, chunk_end):
pos_in_block = pos % self.block_size
slot = write_slot * self.block_size + pos_in_block
slot_mapping.append(slot)
Args:
layer_id: Layer index
k: Key tensor [seq_len, kv_heads, head_dim]
v: Value tensor [seq_len, kv_heads, head_dim]
cpu_block_ids: List of CPU block IDs to offload to
total_tokens: Total number of tokens
"""
offload_engine = self.kvcache_manager.offload_engine
block_size = offload_engine.block_size
stream = offload_engine.prefill_offload_streams[layer_id]
# Convert to tensors
num_tokens = chunk_end - chunk_start
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
with torch.cuda.stream(stream):
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
# Set up context for chunked prefill
seqlen = num_tokens
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=seqlen,
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
current_chunk_idx=chunk_idx, # Pass chunk index for ring buffer pipeline
# Copy K and V to CPU cache
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(
k[start:end], non_blocking=True
)
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(
v[start:end], non_blocking=True
)
return input_ids, positions
# Record completion event
offload_engine.prefill_offload_events[layer_id].record(stream)
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
def _offload_layer_kv_to_cpu_sync(
self,
layer_id: int,
k: torch.Tensor,
v: torch.Tensor,
cpu_block_ids: list[int],
total_tokens: int,
):
"""
Run decode with cross-layer pipeline (CPU is primary storage).
Offload a layer's KV cache to CPU in blocks (synchronous version).
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
Optimized with cross-layer pipeline: Layer N's data is loaded while
Layer N-1 computes, achieving transfer/compute overlap.
Key: decode_slot is dedicated to writing new KV, never used for loading.
Optimization: Cross-layer pipeline reduces effective latency by overlapping
H2D transfers with attention computation across layers.
This version uses synchronous copy to ensure correctness.
It's slower than async but guarantees data integrity.
"""
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
offload_engine = self.kvcache_manager.offload_engine
block_size = offload_engine.block_size
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
# Synchronous copy to CPU
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
@torch.inference_mode()
def run_layerwise_offload_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with layer-wise KV loading from CPU.
Key design:
- For each layer: load all prefilled KV from CPU
- Compute attention with loaded KV + new token's KV
- Store new token's KV for offload when block is full
"""
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
num_layers = len(self.model.model.layers)
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, device="cuda")
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, device="cuda")
# Use Decode region (slot 0) to write new KV
decode_slot = offload_engine.decode_slot # = 0
pos_in_block = (len(seq) - 1) % self.block_size
slot = decode_slot * self.block_size + pos_in_block
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Get decode start position for accumulated token tracking
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
# Get prefilled CPU blocks for pipeline initialization
# Get prefilled CPU blocks
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
num_prefill_blocks = len(cpu_block_table)
total_prefill_tokens = self.kvcache_manager.get_prefill_len(seq)
# Start cross-layer pipeline (preloads Layer 0's data)
offload_engine.start_decode_pipeline(cpu_block_table)
# Calculate valid tokens in last prefill block
last_block_valid_tokens = total_prefill_tokens % self.block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = self.block_size
# Set up context for chunked decode
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention path
kvcache_manager=self.kvcache_manager,
chunked_seq=seq,
decode_pos_in_block=pos_in_block,
decode_start_pos_in_block=decode_start_pos,
# Current decode position info
pos_in_block = (len(seq) - 1) % self.block_size
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
num_decode_tokens = pos_in_block - decode_start_pos + 1
# Step 1: Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
# Allocate buffers for new decode token's KV (per layer)
# These will be accumulated and offloaded when block is full
decode_k_cache = []
decode_v_cache = []
# Step 2: Layer-by-layer processing
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# 2a. Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# 2b. QKV projection for new token
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k_new, v_new = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms
if not layer.self_attn.qkv_bias:
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim))
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k_new = layer.self_attn.rotary_emb(positions, q, k_new)
# Store new KV for later offload
decode_k_cache.append(k_new.clone())
decode_v_cache.append(v_new.clone())
# 2c. Load prefilled KV from CPU
k_prefill_list = []
v_prefill_list = []
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Determine valid tokens in this block
if block_idx == num_prefill_blocks - 1:
valid_tokens = last_block_valid_tokens
else:
valid_tokens = self.block_size
k_block = offload_engine.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens].to("cuda", non_blocking=True)
v_block = offload_engine.v_cache_cpu[layer_id, cpu_block_id, :valid_tokens].to("cuda", non_blocking=True)
k_prefill_list.append(k_block)
v_prefill_list.append(v_block)
# Concatenate prefilled KV
if k_prefill_list:
k_prefill = torch.cat(k_prefill_list, dim=0) # [prefill_tokens, kv_heads, head_dim]
v_prefill = torch.cat(v_prefill_list, dim=0)
else:
k_prefill = torch.empty(0, layer.self_attn.num_kv_heads, layer.self_attn.head_dim, device="cuda")
v_prefill = torch.empty(0, layer.self_attn.num_kv_heads, layer.self_attn.head_dim, device="cuda")
# 2d. Get accumulated decode KV from decode buffer (if any previous decode tokens)
if num_decode_tokens > 1:
# Load previous decode tokens for this layer from decode buffer
k_decode_prev = offload_engine.decode_k_buffer[layer_id, decode_start_pos:pos_in_block]
v_decode_prev = offload_engine.decode_v_buffer[layer_id, decode_start_pos:pos_in_block]
k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0)
v_full = torch.cat([v_prefill, v_decode_prev, v_new], dim=0)
else:
k_full = torch.cat([k_prefill, k_new], dim=0)
v_full = torch.cat([v_prefill, v_new], dim=0)
# Store new KV to decode buffer for future decode steps
offload_engine.decode_k_buffer[layer_id, pos_in_block].copy_(k_new.squeeze(0))
offload_engine.decode_v_buffer[layer_id, pos_in_block].copy_(v_new.squeeze(0))
# 2e. Compute attention
# For decode: query is at the last position, should attend to ALL previous keys
# Use causal=False because the single query token is conceptually at position N
# and should attend to all K tokens at positions 0 to N-1
from flash_attn.flash_attn_interface import flash_attn_varlen_func
total_kv_tokens = k_full.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
cu_seqlens_k = torch.tensor([0, total_kv_tokens], dtype=torch.int32, device="cuda")
attn_output = flash_attn_varlen_func(
q, k_full, v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=total_kv_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=False,
)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# O projection
attn_output = attn_output.view(1, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# End cross-layer pipeline
offload_engine.end_decode_pipeline()
# 2f. Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Only offload when block is full (pos_in_block == block_size - 1)
# This avoids unnecessary offloading on every decode step
# Step 3: Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Step 4: Compute logits
logits = self.model.compute_logits(hidden_states)
# Step 5: Handle block-full offload
if pos_in_block == self.block_size - 1:
# Block is full, offload decode buffer to CPU
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0:
# TODO: In new GPU cache architecture (no layer dimension),
# decode offload should be done per-layer in attention.forward.
# For now, offload all layers sequentially.
for layer_id in range(offload_engine.num_layers):
offload_engine.offload_decode_slot_layer(layer_id, last_cpu_block)
offload_engine.wait_all_offload_done()
# Reset decode start position for next block
for layer_id in range(num_layers):
offload_engine.k_cache_cpu[layer_id, last_cpu_block].copy_(
offload_engine.decode_k_buffer[layer_id], non_blocking=True
)
offload_engine.v_cache_cpu[layer_id, last_cpu_block].copy_(
offload_engine.decode_v_buffer[layer_id], non_blocking=True
)
torch.cuda.synchronize()
# Mark as prefilled for future decode steps
logical_id = seq.block_table[-1]
self.kvcache_manager.prefilled_blocks.add(logical_id)
# Reset decode start position
self.kvcache_manager.reset_decode_start_pos(seq)
# Sample
# Step 6: Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None