[claudesquad] update from 'perf_opt-1' on 07 Jan 26 05:58 CST

This commit is contained in:
Zijie Tian
2026-01-07 05:58:23 +08:00
parent aa953ecb59
commit ccf27d3a74
4 changed files with 255 additions and 20 deletions

View File

@@ -46,6 +46,17 @@ python bench_offload.py
## Local Package Installation for Multi-Instance
**CRITICAL**: After ANY code modification in the `nanovllm/` directory, you MUST reinstall the package before running tests or benchmarks:
```bash
pip install -e . --prefix=./.local --no-deps
```
Then run with PYTHONPATH:
```bash
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
```
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
1. **Install to worktree-local directory**:

View File

@@ -590,14 +590,15 @@ class ModelRunner:
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with ring buffer (CPU is primary storage).
Run decode with cross-layer pipeline (CPU is primary storage).
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
Optimized with cross-layer pipeline: Layer N's data is loaded while
Layer N-1 computes, achieving transfer/compute overlap.
Key: decode_slot is dedicated to writing new KV, never used for loading.
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
Optimization: Cross-layer pipeline reduces effective latency by overlapping
H2D transfers with attention computation across layers.
"""
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
seq = seqs[0]
@@ -618,6 +619,12 @@ class ModelRunner:
# Get decode start position for accumulated token tracking
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
# Get prefilled CPU blocks for pipeline initialization
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
# Start cross-layer pipeline (preloads Layer 0's data)
offload_engine.start_decode_pipeline(cpu_block_table)
# Set up context for chunked decode
set_context(
is_prefill=False,
@@ -634,6 +641,9 @@ class ModelRunner:
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# End cross-layer pipeline
offload_engine.end_decode_pipeline()
# Only offload when block is full (pos_in_block == block_size - 1)
# This avoids unnecessary offloading on every decode step
if pos_in_block == self.block_size - 1:

View File

@@ -142,6 +142,40 @@ class OffloadEngine:
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
# ========== Cross-layer pipeline buffers for decode ==========
# Double-buffered layer cache for pipelined decode:
# - Buffer A: Current layer's prefilled KV being computed
# - Buffer B: Next layer's prefilled KV being loaded
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
self.layer_k_buffer_a = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_buffer_a = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_k_buffer_b = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_buffer_b = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
# Pipeline state tracking
self._pipeline_active = False
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
self._pipeline_next_layer_event = torch.cuda.Event()
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
self._pipeline_num_blocks = 0
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
@@ -1063,4 +1097,120 @@ class OffloadEngine:
# Allow pdb quit to propagate
if e.__class__.__name__ == 'BdbQuit':
raise
logger.warning(f"Debug hook error: {e}")
logger.warning(f"Debug hook error: {e}")
# ========== Cross-layer Pipeline Methods for Decode ==========
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
"""
Start cross-layer pipeline for decode.
Called at the beginning of a decode step to initialize the pipeline.
Preloads Layer 0's data into buffer A.
Args:
cpu_block_ids: List of CPU block IDs for prefilled blocks
"""
if not cpu_block_ids:
self._pipeline_active = False
return
self._pipeline_active = True
self._pipeline_cpu_blocks = cpu_block_ids
self._pipeline_num_blocks = len(cpu_block_ids)
self._pipeline_current_buffer = 0
# Preload Layer 0 into buffer A
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
"""
Get KV cache for a layer during decode.
If pipeline is active, returns data from the current buffer.
Also triggers preloading of the next layer (if not last layer).
Args:
layer_id: Current layer ID
num_blocks: Number of blocks to return
Returns:
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
"""
if not self._pipeline_active:
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
# Wait for current layer's data to be ready
self.compute_stream.wait_event(self._pipeline_next_layer_event)
# Get current buffer
if self._pipeline_current_buffer == 0:
k = self.layer_k_buffer_a[:num_blocks]
v = self.layer_v_buffer_a[:num_blocks]
else:
k = self.layer_k_buffer_b[:num_blocks]
v = self.layer_v_buffer_b[:num_blocks]
# Trigger preloading of next layer (if not last layer)
next_layer_id = layer_id + 1
if next_layer_id < self.num_layers:
# Use the other buffer for next layer
next_buffer_idx = 1 - self._pipeline_current_buffer
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
# Switch to next buffer for next layer
self._pipeline_current_buffer = next_buffer_idx
return k, v
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
"""
Async load a layer's prefilled blocks to the specified buffer.
Uses sgDMA for efficient strided transfer from CPU cache.
Args:
layer_id: Layer index to load
buffer_idx: 0 for buffer A, 1 for buffer B
"""
num_blocks = self._pipeline_num_blocks
cpu_block_ids = self._pipeline_cpu_blocks
# Select target buffer
if buffer_idx == 0:
k_buffer = self.layer_k_buffer_a
v_buffer = self.layer_v_buffer_a
else:
k_buffer = self.layer_k_buffer_b
v_buffer = self.layer_v_buffer_b
# Load all blocks for this layer using dedicated stream
with torch.cuda.stream(self._pipeline_layer_stream):
for i, cpu_block_id in enumerate(cpu_block_ids):
# Copy from CPU cache (has layer dimension) to GPU buffer
k_buffer[i].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
v_buffer[i].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# Record event when all transfers complete
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
def end_decode_pipeline(self) -> None:
"""
End the cross-layer pipeline.
Called at the end of a decode step to clean up pipeline state.
"""
if self._pipeline_active:
# Ensure all transfers complete before ending
self._pipeline_layer_stream.synchronize()
self._pipeline_active = False
self._pipeline_cpu_blocks = []
self._pipeline_num_blocks = 0
def is_pipeline_active(self) -> bool:
"""Check if decode pipeline is currently active."""
return self._pipeline_active

View File

@@ -479,17 +479,15 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention using ring buffer pipeline (same as prefill).
Compute decode attention using cross-layer pipeline.
Uses the same loading mechanism as _chunked_prefill_attention:
- Load one block at a time from CPU to GPU slot
- Compute attention for each block
- Merge results using online softmax
- Finally merge with decode buffer (accumulated decode tokens)
Optimization: Uses double-buffered layer cache to overlap H2D transfer
with computation across layers:
- Layer N computes while Layer N+1's data is being loaded
- Each layer only waits for its own data, not all layers' data
This approach is simpler and proven correct (prefill tests pass).
The only difference from prefill is the additional decode buffer
that stores new tokens generated during decode.
This reduces effective latency from O(num_layers * transfer_time) to
O(transfer_time + num_layers * compute_time) when transfer < compute.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -533,13 +531,20 @@ class Attention(nn.Module):
)
offload_engine = kvcache_manager.offload_engine
load_slots = offload_engine.decode_load_slots # Available slots for loading
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Use cross-layer pipeline if active (initialized in model_runner)
if offload_engine.is_pipeline_active():
o_acc, lse_acc = self._decode_with_layer_pipeline(
q_batched, cpu_block_table, offload_engine,
block_size, last_block_valid_tokens
)
else:
# Fallback to original ring buffer pipeline
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Now attend to accumulated decode tokens from per-layer decode buffer
pos_in_block = context.decode_pos_in_block
@@ -652,3 +657,62 @@ class Attention(nn.Module):
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _decode_with_layer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
block_size: int,
last_block_valid_tokens: int,
):
"""
Decode using cross-layer pipeline for optimized H2D transfer.
This method uses pre-loaded layer buffers instead of loading
blocks one by one. The pipeline loads the next layer's data
while the current layer computes, achieving transfer/compute overlap.
The key insight is that each layer needs the SAME blocks but from
different layers of CPU cache. By double-buffering and pipelining
across layers, we reduce total latency.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
compute_stream = offload_engine.compute_stream
# Get KV from pre-loaded layer buffer (triggers next layer loading)
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
total_tokens = num_blocks * block_size
# Handle partial last block
if last_block_valid_tokens < block_size:
# Only use valid tokens from last block
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
# Flatten and truncate
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
else:
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
prev_k_batched = prev_k_flat.unsqueeze(0)
prev_v_batched = prev_v_flat.unsqueeze(0)
# Compute attention on all prefilled blocks at once
with torch.cuda.stream(compute_stream):
o_acc, lse_acc = flash_attn_with_lse(
q_batched, prev_k_batched, prev_v_batched,
softmax_scale=self.scale,
causal=False,
)
return o_acc, lse_acc