[claudesquad] update from 'perf_opt-1' on 07 Jan 26 05:58 CST

This commit is contained in:
Zijie Tian
2026-01-07 05:58:23 +08:00
parent aa953ecb59
commit ccf27d3a74
4 changed files with 255 additions and 20 deletions

View File

@@ -46,6 +46,17 @@ python bench_offload.py
## Local Package Installation for Multi-Instance ## Local Package Installation for Multi-Instance
**CRITICAL**: After ANY code modification in the `nanovllm/` directory, you MUST reinstall the package before running tests or benchmarks:
```bash
pip install -e . --prefix=./.local --no-deps
```
Then run with PYTHONPATH:
```bash
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
```
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation: **IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
1. **Install to worktree-local directory**: 1. **Install to worktree-local directory**:

View File

@@ -590,14 +590,15 @@ class ModelRunner:
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]: def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
""" """
Run decode with ring buffer (CPU is primary storage). Run decode with cross-layer pipeline (CPU is primary storage).
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV. All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
Other slots (slots[1:]) are used to load previous KV chunks via pipeline. Optimized with cross-layer pipeline: Layer N's data is loaded while
New token's KV is written to decode_slot then offloaded to CPU only when block is full. Layer N-1 computes, achieving transfer/compute overlap.
Key: decode_slot is dedicated to writing new KV, never used for loading. Key: decode_slot is dedicated to writing new KV, never used for loading.
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens. Optimization: Cross-layer pipeline reduces effective latency by overlapping
H2D transfers with attention computation across layers.
""" """
assert len(seqs) == 1, "Ring buffer decode only supports single sequence" assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
seq = seqs[0] seq = seqs[0]
@@ -618,6 +619,12 @@ class ModelRunner:
# Get decode start position for accumulated token tracking # Get decode start position for accumulated token tracking
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq) decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
# Get prefilled CPU blocks for pipeline initialization
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
# Start cross-layer pipeline (preloads Layer 0's data)
offload_engine.start_decode_pipeline(cpu_block_table)
# Set up context for chunked decode # Set up context for chunked decode
set_context( set_context(
is_prefill=False, is_prefill=False,
@@ -634,6 +641,9 @@ class ModelRunner:
logits = self.run_model(input_ids, positions, is_prefill=False) logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context() reset_context()
# End cross-layer pipeline
offload_engine.end_decode_pipeline()
# Only offload when block is full (pos_in_block == block_size - 1) # Only offload when block is full (pos_in_block == block_size - 1)
# This avoids unnecessary offloading on every decode step # This avoids unnecessary offloading on every decode step
if pos_in_block == self.block_size - 1: if pos_in_block == self.block_size - 1:

View File

@@ -142,6 +142,40 @@ class OffloadEngine:
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024) decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB") logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
# ========== Cross-layer pipeline buffers for decode ==========
# Double-buffered layer cache for pipelined decode:
# - Buffer A: Current layer's prefilled KV being computed
# - Buffer B: Next layer's prefilled KV being loaded
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
self.layer_k_buffer_a = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_buffer_a = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_k_buffer_b = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_buffer_b = torch.zeros(
max_prefill_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
# Pipeline state tracking
self._pipeline_active = False
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
self._pipeline_next_layer_event = torch.cuda.Event()
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
self._pipeline_num_blocks = 0
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
# ========== Fixed-address CPU KV cache (pinned memory) ========== # ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.zeros( self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim, num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
@@ -1063,4 +1097,120 @@ class OffloadEngine:
# Allow pdb quit to propagate # Allow pdb quit to propagate
if e.__class__.__name__ == 'BdbQuit': if e.__class__.__name__ == 'BdbQuit':
raise raise
logger.warning(f"Debug hook error: {e}") logger.warning(f"Debug hook error: {e}")
# ========== Cross-layer Pipeline Methods for Decode ==========
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
"""
Start cross-layer pipeline for decode.
Called at the beginning of a decode step to initialize the pipeline.
Preloads Layer 0's data into buffer A.
Args:
cpu_block_ids: List of CPU block IDs for prefilled blocks
"""
if not cpu_block_ids:
self._pipeline_active = False
return
self._pipeline_active = True
self._pipeline_cpu_blocks = cpu_block_ids
self._pipeline_num_blocks = len(cpu_block_ids)
self._pipeline_current_buffer = 0
# Preload Layer 0 into buffer A
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
"""
Get KV cache for a layer during decode.
If pipeline is active, returns data from the current buffer.
Also triggers preloading of the next layer (if not last layer).
Args:
layer_id: Current layer ID
num_blocks: Number of blocks to return
Returns:
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
"""
if not self._pipeline_active:
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
# Wait for current layer's data to be ready
self.compute_stream.wait_event(self._pipeline_next_layer_event)
# Get current buffer
if self._pipeline_current_buffer == 0:
k = self.layer_k_buffer_a[:num_blocks]
v = self.layer_v_buffer_a[:num_blocks]
else:
k = self.layer_k_buffer_b[:num_blocks]
v = self.layer_v_buffer_b[:num_blocks]
# Trigger preloading of next layer (if not last layer)
next_layer_id = layer_id + 1
if next_layer_id < self.num_layers:
# Use the other buffer for next layer
next_buffer_idx = 1 - self._pipeline_current_buffer
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
# Switch to next buffer for next layer
self._pipeline_current_buffer = next_buffer_idx
return k, v
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
"""
Async load a layer's prefilled blocks to the specified buffer.
Uses sgDMA for efficient strided transfer from CPU cache.
Args:
layer_id: Layer index to load
buffer_idx: 0 for buffer A, 1 for buffer B
"""
num_blocks = self._pipeline_num_blocks
cpu_block_ids = self._pipeline_cpu_blocks
# Select target buffer
if buffer_idx == 0:
k_buffer = self.layer_k_buffer_a
v_buffer = self.layer_v_buffer_a
else:
k_buffer = self.layer_k_buffer_b
v_buffer = self.layer_v_buffer_b
# Load all blocks for this layer using dedicated stream
with torch.cuda.stream(self._pipeline_layer_stream):
for i, cpu_block_id in enumerate(cpu_block_ids):
# Copy from CPU cache (has layer dimension) to GPU buffer
k_buffer[i].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
v_buffer[i].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# Record event when all transfers complete
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
def end_decode_pipeline(self) -> None:
"""
End the cross-layer pipeline.
Called at the end of a decode step to clean up pipeline state.
"""
if self._pipeline_active:
# Ensure all transfers complete before ending
self._pipeline_layer_stream.synchronize()
self._pipeline_active = False
self._pipeline_cpu_blocks = []
self._pipeline_num_blocks = 0
def is_pipeline_active(self) -> bool:
"""Check if decode pipeline is currently active."""
return self._pipeline_active

View File

@@ -479,17 +479,15 @@ class Attention(nn.Module):
context, context,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute decode attention using ring buffer pipeline (same as prefill). Compute decode attention using cross-layer pipeline.
Uses the same loading mechanism as _chunked_prefill_attention: Optimization: Uses double-buffered layer cache to overlap H2D transfer
- Load one block at a time from CPU to GPU slot with computation across layers:
- Compute attention for each block - Layer N computes while Layer N+1's data is being loaded
- Merge results using online softmax - Each layer only waits for its own data, not all layers' data
- Finally merge with decode buffer (accumulated decode tokens)
This approach is simpler and proven correct (prefill tests pass). This reduces effective latency from O(num_layers * transfer_time) to
The only difference from prefill is the additional decode buffer O(transfer_time + num_layers * compute_time) when transfer < compute.
that stores new tokens generated during decode.
""" """
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -533,13 +531,20 @@ class Attention(nn.Module):
) )
offload_engine = kvcache_manager.offload_engine offload_engine = kvcache_manager.offload_engine
load_slots = offload_engine.decode_load_slots # Available slots for loading
# Use ring buffer pipeline (same as prefill) to load prefilled blocks # Use cross-layer pipeline if active (initialized in model_runner)
o_acc, lse_acc = self._decode_ring_buffer_pipeline( if offload_engine.is_pipeline_active():
q_batched, cpu_block_table, load_slots, offload_engine, o_acc, lse_acc = self._decode_with_layer_pipeline(
block_size, last_block_valid_tokens q_batched, cpu_block_table, offload_engine,
) block_size, last_block_valid_tokens
)
else:
# Fallback to original ring buffer pipeline
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Now attend to accumulated decode tokens from per-layer decode buffer # Now attend to accumulated decode tokens from per-layer decode buffer
pos_in_block = context.decode_pos_in_block pos_in_block = context.decode_pos_in_block
@@ -652,3 +657,62 @@ class Attention(nn.Module):
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc return o_acc, lse_acc
def _decode_with_layer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
block_size: int,
last_block_valid_tokens: int,
):
"""
Decode using cross-layer pipeline for optimized H2D transfer.
This method uses pre-loaded layer buffers instead of loading
blocks one by one. The pipeline loads the next layer's data
while the current layer computes, achieving transfer/compute overlap.
The key insight is that each layer needs the SAME blocks but from
different layers of CPU cache. By double-buffering and pipelining
across layers, we reduce total latency.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
compute_stream = offload_engine.compute_stream
# Get KV from pre-loaded layer buffer (triggers next layer loading)
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
total_tokens = num_blocks * block_size
# Handle partial last block
if last_block_valid_tokens < block_size:
# Only use valid tokens from last block
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
# Flatten and truncate
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
else:
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
prev_k_batched = prev_k_flat.unsqueeze(0)
prev_v_batched = prev_v_flat.unsqueeze(0)
# Compute attention on all prefilled blocks at once
with torch.cuda.stream(compute_stream):
o_acc, lse_acc = flash_attn_with_lse(
q_batched, prev_k_batched, prev_v_batched,
softmax_scale=self.scale,
causal=False,
)
return o_acc, lse_acc