[claudesquad] update from 'perf_opt-1' on 07 Jan 26 05:58 CST
This commit is contained in:
11
CLAUDE.md
11
CLAUDE.md
@@ -46,6 +46,17 @@ python bench_offload.py
|
|||||||
|
|
||||||
## Local Package Installation for Multi-Instance
|
## Local Package Installation for Multi-Instance
|
||||||
|
|
||||||
|
**CRITICAL**: After ANY code modification in the `nanovllm/` directory, you MUST reinstall the package before running tests or benchmarks:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e . --prefix=./.local --no-deps
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run with PYTHONPATH:
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
|
||||||
|
```
|
||||||
|
|
||||||
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
|
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
|
||||||
|
|
||||||
1. **Install to worktree-local directory**:
|
1. **Install to worktree-local directory**:
|
||||||
|
|||||||
@@ -590,14 +590,15 @@ class ModelRunner:
|
|||||||
|
|
||||||
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Run decode with ring buffer (CPU is primary storage).
|
Run decode with cross-layer pipeline (CPU is primary storage).
|
||||||
|
|
||||||
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
||||||
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
|
Optimized with cross-layer pipeline: Layer N's data is loaded while
|
||||||
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
|
Layer N-1 computes, achieving transfer/compute overlap.
|
||||||
|
|
||||||
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
||||||
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
|
Optimization: Cross-layer pipeline reduces effective latency by overlapping
|
||||||
|
H2D transfers with attention computation across layers.
|
||||||
"""
|
"""
|
||||||
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
@@ -618,6 +619,12 @@ class ModelRunner:
|
|||||||
# Get decode start position for accumulated token tracking
|
# Get decode start position for accumulated token tracking
|
||||||
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
||||||
|
|
||||||
|
# Get prefilled CPU blocks for pipeline initialization
|
||||||
|
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Start cross-layer pipeline (preloads Layer 0's data)
|
||||||
|
offload_engine.start_decode_pipeline(cpu_block_table)
|
||||||
|
|
||||||
# Set up context for chunked decode
|
# Set up context for chunked decode
|
||||||
set_context(
|
set_context(
|
||||||
is_prefill=False,
|
is_prefill=False,
|
||||||
@@ -634,6 +641,9 @@ class ModelRunner:
|
|||||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
|
# End cross-layer pipeline
|
||||||
|
offload_engine.end_decode_pipeline()
|
||||||
|
|
||||||
# Only offload when block is full (pos_in_block == block_size - 1)
|
# Only offload when block is full (pos_in_block == block_size - 1)
|
||||||
# This avoids unnecessary offloading on every decode step
|
# This avoids unnecessary offloading on every decode step
|
||||||
if pos_in_block == self.block_size - 1:
|
if pos_in_block == self.block_size - 1:
|
||||||
|
|||||||
@@ -142,6 +142,40 @@ class OffloadEngine:
|
|||||||
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||||
|
|
||||||
|
# ========== Cross-layer pipeline buffers for decode ==========
|
||||||
|
# Double-buffered layer cache for pipelined decode:
|
||||||
|
# - Buffer A: Current layer's prefilled KV being computed
|
||||||
|
# - Buffer B: Next layer's prefilled KV being loaded
|
||||||
|
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
|
||||||
|
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
|
||||||
|
self.layer_k_buffer_a = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_v_buffer_a = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_k_buffer_b = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_v_buffer_b = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
|
||||||
|
|
||||||
|
# Pipeline state tracking
|
||||||
|
self._pipeline_active = False
|
||||||
|
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
|
||||||
|
self._pipeline_next_layer_event = torch.cuda.Event()
|
||||||
|
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
|
||||||
|
self._pipeline_num_blocks = 0
|
||||||
|
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
|
||||||
|
|
||||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||||
self.k_cache_cpu = torch.zeros(
|
self.k_cache_cpu = torch.zeros(
|
||||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||||
@@ -1064,3 +1098,119 @@ class OffloadEngine:
|
|||||||
if e.__class__.__name__ == 'BdbQuit':
|
if e.__class__.__name__ == 'BdbQuit':
|
||||||
raise
|
raise
|
||||||
logger.warning(f"Debug hook error: {e}")
|
logger.warning(f"Debug hook error: {e}")
|
||||||
|
|
||||||
|
# ========== Cross-layer Pipeline Methods for Decode ==========
|
||||||
|
|
||||||
|
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
|
||||||
|
"""
|
||||||
|
Start cross-layer pipeline for decode.
|
||||||
|
|
||||||
|
Called at the beginning of a decode step to initialize the pipeline.
|
||||||
|
Preloads Layer 0's data into buffer A.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_block_ids: List of CPU block IDs for prefilled blocks
|
||||||
|
"""
|
||||||
|
if not cpu_block_ids:
|
||||||
|
self._pipeline_active = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self._pipeline_active = True
|
||||||
|
self._pipeline_cpu_blocks = cpu_block_ids
|
||||||
|
self._pipeline_num_blocks = len(cpu_block_ids)
|
||||||
|
self._pipeline_current_buffer = 0
|
||||||
|
|
||||||
|
# Preload Layer 0 into buffer A
|
||||||
|
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
|
||||||
|
|
||||||
|
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get KV cache for a layer during decode.
|
||||||
|
|
||||||
|
If pipeline is active, returns data from the current buffer.
|
||||||
|
Also triggers preloading of the next layer (if not last layer).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Current layer ID
|
||||||
|
num_blocks: Number of blocks to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
if not self._pipeline_active:
|
||||||
|
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
|
||||||
|
|
||||||
|
# Wait for current layer's data to be ready
|
||||||
|
self.compute_stream.wait_event(self._pipeline_next_layer_event)
|
||||||
|
|
||||||
|
# Get current buffer
|
||||||
|
if self._pipeline_current_buffer == 0:
|
||||||
|
k = self.layer_k_buffer_a[:num_blocks]
|
||||||
|
v = self.layer_v_buffer_a[:num_blocks]
|
||||||
|
else:
|
||||||
|
k = self.layer_k_buffer_b[:num_blocks]
|
||||||
|
v = self.layer_v_buffer_b[:num_blocks]
|
||||||
|
|
||||||
|
# Trigger preloading of next layer (if not last layer)
|
||||||
|
next_layer_id = layer_id + 1
|
||||||
|
if next_layer_id < self.num_layers:
|
||||||
|
# Use the other buffer for next layer
|
||||||
|
next_buffer_idx = 1 - self._pipeline_current_buffer
|
||||||
|
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
|
||||||
|
# Switch to next buffer for next layer
|
||||||
|
self._pipeline_current_buffer = next_buffer_idx
|
||||||
|
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Async load a layer's prefilled blocks to the specified buffer.
|
||||||
|
|
||||||
|
Uses sgDMA for efficient strided transfer from CPU cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index to load
|
||||||
|
buffer_idx: 0 for buffer A, 1 for buffer B
|
||||||
|
"""
|
||||||
|
num_blocks = self._pipeline_num_blocks
|
||||||
|
cpu_block_ids = self._pipeline_cpu_blocks
|
||||||
|
|
||||||
|
# Select target buffer
|
||||||
|
if buffer_idx == 0:
|
||||||
|
k_buffer = self.layer_k_buffer_a
|
||||||
|
v_buffer = self.layer_v_buffer_a
|
||||||
|
else:
|
||||||
|
k_buffer = self.layer_k_buffer_b
|
||||||
|
v_buffer = self.layer_v_buffer_b
|
||||||
|
|
||||||
|
# Load all blocks for this layer using dedicated stream
|
||||||
|
with torch.cuda.stream(self._pipeline_layer_stream):
|
||||||
|
for i, cpu_block_id in enumerate(cpu_block_ids):
|
||||||
|
# Copy from CPU cache (has layer dimension) to GPU buffer
|
||||||
|
k_buffer[i].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
v_buffer[i].copy_(
|
||||||
|
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
# Record event when all transfers complete
|
||||||
|
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
|
||||||
|
|
||||||
|
def end_decode_pipeline(self) -> None:
|
||||||
|
"""
|
||||||
|
End the cross-layer pipeline.
|
||||||
|
|
||||||
|
Called at the end of a decode step to clean up pipeline state.
|
||||||
|
"""
|
||||||
|
if self._pipeline_active:
|
||||||
|
# Ensure all transfers complete before ending
|
||||||
|
self._pipeline_layer_stream.synchronize()
|
||||||
|
self._pipeline_active = False
|
||||||
|
self._pipeline_cpu_blocks = []
|
||||||
|
self._pipeline_num_blocks = 0
|
||||||
|
|
||||||
|
def is_pipeline_active(self) -> bool:
|
||||||
|
"""Check if decode pipeline is currently active."""
|
||||||
|
return self._pipeline_active
|
||||||
@@ -479,17 +479,15 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute decode attention using ring buffer pipeline (same as prefill).
|
Compute decode attention using cross-layer pipeline.
|
||||||
|
|
||||||
Uses the same loading mechanism as _chunked_prefill_attention:
|
Optimization: Uses double-buffered layer cache to overlap H2D transfer
|
||||||
- Load one block at a time from CPU to GPU slot
|
with computation across layers:
|
||||||
- Compute attention for each block
|
- Layer N computes while Layer N+1's data is being loaded
|
||||||
- Merge results using online softmax
|
- Each layer only waits for its own data, not all layers' data
|
||||||
- Finally merge with decode buffer (accumulated decode tokens)
|
|
||||||
|
|
||||||
This approach is simpler and proven correct (prefill tests pass).
|
This reduces effective latency from O(num_layers * transfer_time) to
|
||||||
The only difference from prefill is the additional decode buffer
|
O(transfer_time + num_layers * compute_time) when transfer < compute.
|
||||||
that stores new tokens generated during decode.
|
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
@@ -533,9 +531,16 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
load_slots = offload_engine.decode_load_slots # Available slots for loading
|
|
||||||
|
|
||||||
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
|
# Use cross-layer pipeline if active (initialized in model_runner)
|
||||||
|
if offload_engine.is_pipeline_active():
|
||||||
|
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
||||||
|
q_batched, cpu_block_table, offload_engine,
|
||||||
|
block_size, last_block_valid_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to original ring buffer pipeline
|
||||||
|
load_slots = offload_engine.decode_load_slots
|
||||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
block_size, last_block_valid_tokens
|
block_size, last_block_valid_tokens
|
||||||
@@ -652,3 +657,62 @@ class Attention(nn.Module):
|
|||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
return o_acc, lse_acc
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
def _decode_with_layer_pipeline(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
offload_engine,
|
||||||
|
block_size: int,
|
||||||
|
last_block_valid_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decode using cross-layer pipeline for optimized H2D transfer.
|
||||||
|
|
||||||
|
This method uses pre-loaded layer buffers instead of loading
|
||||||
|
blocks one by one. The pipeline loads the next layer's data
|
||||||
|
while the current layer computes, achieving transfer/compute overlap.
|
||||||
|
|
||||||
|
The key insight is that each layer needs the SAME blocks but from
|
||||||
|
different layers of CPU cache. By double-buffering and pipelining
|
||||||
|
across layers, we reduce total latency.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
||||||
|
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
|
||||||
|
|
||||||
|
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
||||||
|
total_tokens = num_blocks * block_size
|
||||||
|
|
||||||
|
# Handle partial last block
|
||||||
|
if last_block_valid_tokens < block_size:
|
||||||
|
# Only use valid tokens from last block
|
||||||
|
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
||||||
|
# Flatten and truncate
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
||||||
|
else:
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
||||||
|
|
||||||
|
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
||||||
|
prev_k_batched = prev_k_flat.unsqueeze(0)
|
||||||
|
prev_v_batched = prev_v_flat.unsqueeze(0)
|
||||||
|
|
||||||
|
# Compute attention on all prefilled blocks at once
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
o_acc, lse_acc = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k_batched, prev_v_batched,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|||||||
Reference in New Issue
Block a user