[claudesquad] update from 'perf_opt-1' on 07 Jan 26 05:58 CST

This commit is contained in:
Zijie Tian
2026-01-07 05:58:23 +08:00
parent aa953ecb59
commit ccf27d3a74
4 changed files with 255 additions and 20 deletions

View File

@@ -479,17 +479,15 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention using ring buffer pipeline (same as prefill).
Compute decode attention using cross-layer pipeline.
Uses the same loading mechanism as _chunked_prefill_attention:
- Load one block at a time from CPU to GPU slot
- Compute attention for each block
- Merge results using online softmax
- Finally merge with decode buffer (accumulated decode tokens)
Optimization: Uses double-buffered layer cache to overlap H2D transfer
with computation across layers:
- Layer N computes while Layer N+1's data is being loaded
- Each layer only waits for its own data, not all layers' data
This approach is simpler and proven correct (prefill tests pass).
The only difference from prefill is the additional decode buffer
that stores new tokens generated during decode.
This reduces effective latency from O(num_layers * transfer_time) to
O(transfer_time + num_layers * compute_time) when transfer < compute.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -533,13 +531,20 @@ class Attention(nn.Module):
)
offload_engine = kvcache_manager.offload_engine
load_slots = offload_engine.decode_load_slots # Available slots for loading
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Use cross-layer pipeline if active (initialized in model_runner)
if offload_engine.is_pipeline_active():
o_acc, lse_acc = self._decode_with_layer_pipeline(
q_batched, cpu_block_table, offload_engine,
block_size, last_block_valid_tokens
)
else:
# Fallback to original ring buffer pipeline
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Now attend to accumulated decode tokens from per-layer decode buffer
pos_in_block = context.decode_pos_in_block
@@ -652,3 +657,62 @@ class Attention(nn.Module):
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _decode_with_layer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
block_size: int,
last_block_valid_tokens: int,
):
"""
Decode using cross-layer pipeline for optimized H2D transfer.
This method uses pre-loaded layer buffers instead of loading
blocks one by one. The pipeline loads the next layer's data
while the current layer computes, achieving transfer/compute overlap.
The key insight is that each layer needs the SAME blocks but from
different layers of CPU cache. By double-buffering and pipelining
across layers, we reduce total latency.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
compute_stream = offload_engine.compute_stream
# Get KV from pre-loaded layer buffer (triggers next layer loading)
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
total_tokens = num_blocks * block_size
# Handle partial last block
if last_block_valid_tokens < block_size:
# Only use valid tokens from last block
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
# Flatten and truncate
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
else:
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
prev_k_batched = prev_k_flat.unsqueeze(0)
prev_v_batched = prev_v_flat.unsqueeze(0)
# Compute attention on all prefilled blocks at once
with torch.cuda.stream(compute_stream):
o_acc, lse_acc = flash_attn_with_lse(
q_batched, prev_k_batched, prev_v_batched,
softmax_scale=self.scale,
causal=False,
)
return o_acc, lse_acc