[claudesquad] update from 'perf_opt-1' on 07 Jan 26 05:58 CST
This commit is contained in:
@@ -479,17 +479,15 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention using ring buffer pipeline (same as prefill).
|
||||
Compute decode attention using cross-layer pipeline.
|
||||
|
||||
Uses the same loading mechanism as _chunked_prefill_attention:
|
||||
- Load one block at a time from CPU to GPU slot
|
||||
- Compute attention for each block
|
||||
- Merge results using online softmax
|
||||
- Finally merge with decode buffer (accumulated decode tokens)
|
||||
Optimization: Uses double-buffered layer cache to overlap H2D transfer
|
||||
with computation across layers:
|
||||
- Layer N computes while Layer N+1's data is being loaded
|
||||
- Each layer only waits for its own data, not all layers' data
|
||||
|
||||
This approach is simpler and proven correct (prefill tests pass).
|
||||
The only difference from prefill is the additional decode buffer
|
||||
that stores new tokens generated during decode.
|
||||
This reduces effective latency from O(num_layers * transfer_time) to
|
||||
O(transfer_time + num_layers * compute_time) when transfer < compute.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -533,13 +531,20 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
load_slots = offload_engine.decode_load_slots # Available slots for loading
|
||||
|
||||
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
|
||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
block_size, last_block_valid_tokens
|
||||
)
|
||||
# Use cross-layer pipeline if active (initialized in model_runner)
|
||||
if offload_engine.is_pipeline_active():
|
||||
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
||||
q_batched, cpu_block_table, offload_engine,
|
||||
block_size, last_block_valid_tokens
|
||||
)
|
||||
else:
|
||||
# Fallback to original ring buffer pipeline
|
||||
load_slots = offload_engine.decode_load_slots
|
||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
block_size, last_block_valid_tokens
|
||||
)
|
||||
|
||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||
pos_in_block = context.decode_pos_in_block
|
||||
@@ -652,3 +657,62 @@ class Attention(nn.Module):
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
def _decode_with_layer_pipeline(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
offload_engine,
|
||||
block_size: int,
|
||||
last_block_valid_tokens: int,
|
||||
):
|
||||
"""
|
||||
Decode using cross-layer pipeline for optimized H2D transfer.
|
||||
|
||||
This method uses pre-loaded layer buffers instead of loading
|
||||
blocks one by one. The pipeline loads the next layer's data
|
||||
while the current layer computes, achieving transfer/compute overlap.
|
||||
|
||||
The key insight is that each layer needs the SAME blocks but from
|
||||
different layers of CPU cache. By double-buffering and pipelining
|
||||
across layers, we reduce total latency.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
return None, None
|
||||
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
||||
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
|
||||
|
||||
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
total_tokens = num_blocks * block_size
|
||||
|
||||
# Handle partial last block
|
||||
if last_block_valid_tokens < block_size:
|
||||
# Only use valid tokens from last block
|
||||
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
||||
# Flatten and truncate
|
||||
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
||||
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
||||
else:
|
||||
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
||||
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
||||
|
||||
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
||||
prev_k_batched = prev_k_flat.unsqueeze(0)
|
||||
prev_v_batched = prev_v_flat.unsqueeze(0)
|
||||
|
||||
# Compute attention on all prefilled blocks at once
|
||||
with torch.cuda.stream(compute_stream):
|
||||
o_acc, lse_acc = flash_attn_with_lse(
|
||||
q_batched, prev_k_batched, prev_v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
Reference in New Issue
Block a user