♻️ refactor: migrate chunked decode attention to SparsePolicy
Move decode attention computation from attention.py to SparsePolicy: - Add compute_chunked_decode abstract method to SparsePolicy base class - Implement compute_chunked_decode in FullAttentionPolicy with: - Ring buffer pipeline (_decode_ring_buffer_pipeline) - Cross-layer pipeline (_decode_with_layer_pipeline) - Decode buffer handling - Simplify _chunked_decode_attention to only validate and delegate - Remove _decode_ring_buffer_pipeline and _decode_with_layer_pipeline from attention.py - Add supports_decode check for policy validation This completes the SparsePolicy v5 refactoring where both prefill and decode paths now delegate all computation to the sparse policy. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -192,5 +192,256 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
|
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
|
||||||
return final_o.squeeze(0)
|
return final_o.squeeze(0)
|
||||||
|
|
||||||
|
def compute_chunked_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
kvcache_manager: "KVCacheManager",
|
||||||
|
seq: "Sequence",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute full attention for chunked decode.
|
||||||
|
|
||||||
|
This method handles the complete chunked decode flow:
|
||||||
|
1. Get prefilled CPU blocks
|
||||||
|
2. Apply select_blocks for block filtering
|
||||||
|
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||||
|
4. Read accumulated decode tokens from decode buffer
|
||||||
|
5. Merge all results
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [batch_size, num_heads, head_dim]
|
||||||
|
layer_id: Current layer index
|
||||||
|
softmax_scale: Softmax scaling factor
|
||||||
|
offload_engine: OffloadEngine for loading blocks
|
||||||
|
kvcache_manager: KVCacheManager for block management
|
||||||
|
seq: Sequence object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [batch_size, 1, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||||
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||||
|
|
||||||
|
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||||
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
if layer_id == 0:
|
||||||
|
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||||
|
if not cpu_block_table:
|
||||||
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||||
|
|
||||||
|
# Calculate valid tokens in the last CPU block
|
||||||
|
# CRITICAL: Use original prefill length, not current seq length!
|
||||||
|
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||||
|
block_size = kvcache_manager.block_size
|
||||||
|
num_prefill_blocks = len(cpu_block_table)
|
||||||
|
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||||
|
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||||
|
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||||
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|
||||||
|
# Apply sparse policy (self) for block filtering
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=layer_id,
|
||||||
|
query=q_batched,
|
||||||
|
is_prefill=False,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||||
|
|
||||||
|
# Use cross-layer pipeline if active (initialized in model_runner)
|
||||||
|
if offload_engine.is_pipeline_active():
|
||||||
|
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
||||||
|
q_batched, cpu_block_table, offload_engine,
|
||||||
|
block_size, last_block_valid_tokens, layer_id, softmax_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to original ring buffer pipeline
|
||||||
|
load_slots = offload_engine.decode_load_slots
|
||||||
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
|
block_size, last_block_valid_tokens, layer_id, softmax_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||||
|
# Compute decode position information internally
|
||||||
|
seq_len = len(seq)
|
||||||
|
decode_pos_in_block = (seq_len - 1) % block_size
|
||||||
|
decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||||
|
decode_start_pos_in_block = decode_start_pos % block_size
|
||||||
|
num_accumulated = decode_pos_in_block - decode_start_pos_in_block + 1
|
||||||
|
|
||||||
|
# Sync compute_stream with default stream before reading decode_buffer
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
if num_accumulated > 0:
|
||||||
|
# Read from per-layer decode buffer
|
||||||
|
decode_k = offload_engine.decode_k_buffer[layer_id, decode_start_pos_in_block:decode_pos_in_block+1]
|
||||||
|
decode_v = offload_engine.decode_v_buffer[layer_id, decode_start_pos_in_block:decode_pos_in_block+1]
|
||||||
|
decode_k = decode_k.unsqueeze(0)
|
||||||
|
decode_v = decode_v.unsqueeze(0)
|
||||||
|
|
||||||
|
decode_o, decode_lse = flash_attn_with_lse(
|
||||||
|
q_batched, decode_k, decode_v,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc = decode_o
|
||||||
|
else:
|
||||||
|
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||||
|
|
||||||
|
# Sync back to default stream before returning
|
||||||
|
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||||
|
|
||||||
|
return o_acc
|
||||||
|
|
||||||
|
def _decode_ring_buffer_pipeline(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
load_slots: list,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
block_size: int,
|
||||||
|
last_block_valid_tokens: int,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ring buffer pipeline for decode prefill loading.
|
||||||
|
|
||||||
|
Loads one block at a time, computes attention, and merges results.
|
||||||
|
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not load_slots:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
num_slots = len(load_slots)
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Phase 1: Pre-load up to num_slots blocks
|
||||||
|
num_preload = min(num_slots, num_blocks)
|
||||||
|
for i in range(num_preload):
|
||||||
|
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||||
|
|
||||||
|
# Phase 2: Process blocks with pipeline
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
current_slot = load_slots[block_idx % num_slots]
|
||||||
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
|
|
||||||
|
# Wait for current slot's transfer to complete
|
||||||
|
offload_engine.wait_slot_layer(current_slot)
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Get KV from slot
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||||
|
|
||||||
|
# Handle partial last block
|
||||||
|
is_last_block = (block_idx == num_blocks - 1)
|
||||||
|
if is_last_block and last_block_valid_tokens < block_size:
|
||||||
|
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
|
||||||
|
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k, prev_v,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record compute done for slot reuse
|
||||||
|
offload_engine.record_slot_compute_done(current_slot)
|
||||||
|
|
||||||
|
# Start loading next block (pipeline)
|
||||||
|
next_block_idx = block_idx + num_slots
|
||||||
|
if next_block_idx < num_blocks:
|
||||||
|
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
|
||||||
|
|
||||||
|
# Merge with accumulated
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
def _decode_with_layer_pipeline(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
block_size: int,
|
||||||
|
last_block_valid_tokens: int,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decode using cross-layer pipeline for optimized H2D transfer.
|
||||||
|
|
||||||
|
Uses pre-loaded layer buffers instead of loading blocks one by one.
|
||||||
|
The pipeline loads the next layer's data while the current layer
|
||||||
|
computes, achieving transfer/compute overlap.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
||||||
|
prev_k, prev_v = offload_engine.get_decode_layer_kv(layer_id, num_blocks)
|
||||||
|
|
||||||
|
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
||||||
|
total_tokens = num_blocks * block_size
|
||||||
|
|
||||||
|
# Handle partial last block
|
||||||
|
if last_block_valid_tokens < block_size:
|
||||||
|
# Only use valid tokens from last block
|
||||||
|
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
||||||
|
# Flatten and truncate
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
||||||
|
else:
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
||||||
|
|
||||||
|
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
||||||
|
prev_k_batched = prev_k_flat.unsqueeze(0)
|
||||||
|
prev_v_batched = prev_v_flat.unsqueeze(0)
|
||||||
|
|
||||||
|
# Compute attention on all prefilled blocks at once
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
o_acc, lse_acc = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k_batched, prev_v_batched,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "FullAttentionPolicy()"
|
return "FullAttentionPolicy()"
|
||||||
|
|||||||
@@ -233,5 +233,43 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_chunked_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
kvcache_manager: "KVCacheManager",
|
||||||
|
seq: "Sequence",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute chunked decode attention (complete flow).
|
||||||
|
|
||||||
|
This is the main entry point for decode attention computation.
|
||||||
|
It defines the complete decode flow:
|
||||||
|
1. Get prefilled blocks from CPU
|
||||||
|
2. Select blocks (call select_blocks)
|
||||||
|
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||||
|
4. Read accumulated decode tokens from decode buffer
|
||||||
|
5. Merge all results
|
||||||
|
|
||||||
|
The decode position information can be computed internally:
|
||||||
|
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||||
|
- decode_pos_in_block = (len(seq) - 1) % kvcache_manager.block_size
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [batch_size, num_heads, head_dim] query for decode token
|
||||||
|
layer_id: transformer layer index
|
||||||
|
softmax_scale: softmax scaling factor
|
||||||
|
offload_engine: OffloadEngine for loading blocks
|
||||||
|
kvcache_manager: KVCacheManager for block management
|
||||||
|
seq: Sequence object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size, 1, num_heads, head_dim] final attention output
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from torch import nn
|
|||||||
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
from nanovllm.kvcache.sparse.policy import PolicyContext
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -237,240 +236,41 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute decode attention using cross-layer pipeline.
|
Compute decode attention by delegating to sparse policy.
|
||||||
|
|
||||||
Optimization: Uses double-buffered layer cache to overlap H2D transfer
|
Simplified design:
|
||||||
with computation across layers:
|
- All computation logic is delegated to sparse_policy.compute_chunked_decode()
|
||||||
- Layer N computes while Layer N+1's data is being loaded
|
- This method only validates the policy and delegates
|
||||||
- Each layer only waits for its own data, not all layers' data
|
|
||||||
|
|
||||||
This reduces effective latency from O(num_layers * transfer_time) to
|
The policy handles:
|
||||||
O(transfer_time + num_layers * compute_time) when transfer < compute.
|
1. Loading prefilled blocks from CPU via pipeline
|
||||||
|
2. Computing attention against prefilled KV
|
||||||
|
3. Reading accumulated decode tokens from decode buffer
|
||||||
|
4. Merging all results
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
|
||||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
|
||||||
|
|
||||||
kvcache_manager = context.kvcache_manager
|
kvcache_manager = context.kvcache_manager
|
||||||
seq = context.chunked_seq
|
seq = context.chunked_seq
|
||||||
|
|
||||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
|
||||||
if self.layer_id == 0:
|
|
||||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
|
||||||
if not cpu_block_table:
|
|
||||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
|
||||||
|
|
||||||
# Calculate valid tokens in the last CPU block
|
|
||||||
# CRITICAL: Use original prefill length, not current seq length!
|
|
||||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
|
||||||
block_size = kvcache_manager.block_size
|
|
||||||
num_prefill_blocks = len(cpu_block_table)
|
|
||||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
|
||||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
|
||||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
|
||||||
last_block_valid_tokens = block_size # Last block was exactly full
|
|
||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
# Get sparse policy - required for chunked decode
|
||||||
sparse_policy = kvcache_manager.sparse_policy
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
if sparse_policy is not None:
|
if sparse_policy is None:
|
||||||
policy_ctx = PolicyContext(
|
raise RuntimeError("sparse_policy is required for chunked decode")
|
||||||
query_chunk_idx=0,
|
|
||||||
num_query_chunks=1,
|
|
||||||
layer_id=self.layer_id,
|
|
||||||
query=q_batched,
|
|
||||||
is_prefill=False,
|
|
||||||
block_size=kvcache_manager.block_size,
|
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
|
||||||
)
|
|
||||||
cpu_block_table = sparse_policy.select_blocks(
|
|
||||||
cpu_block_table, offload_engine, policy_ctx
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use cross-layer pipeline if active (initialized in model_runner)
|
# Check if policy supports decode phase
|
||||||
if offload_engine.is_pipeline_active():
|
if not sparse_policy.supports_decode:
|
||||||
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
raise RuntimeError(f"{sparse_policy} does not support decode phase")
|
||||||
q_batched, cpu_block_table, offload_engine,
|
|
||||||
block_size, last_block_valid_tokens
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fallback to original ring buffer pipeline
|
|
||||||
load_slots = offload_engine.decode_load_slots
|
|
||||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
|
||||||
block_size, last_block_valid_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
# [DEBUG] Verify execution path
|
||||||
pos_in_block = context.decode_pos_in_block
|
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||||||
start_pos = context.decode_start_pos_in_block
|
f"policy={sparse_policy}, layer={self.layer_id}")
|
||||||
num_accumulated = pos_in_block - start_pos + 1
|
|
||||||
|
|
||||||
# Sync compute_stream with default stream before reading decode_buffer
|
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||||
compute_stream = offload_engine.compute_stream
|
return sparse_policy.compute_chunked_decode(
|
||||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
q,
|
||||||
|
self.layer_id,
|
||||||
with torch.cuda.stream(compute_stream):
|
self.scale,
|
||||||
if num_accumulated > 0:
|
offload_engine,
|
||||||
# Read from per-layer decode buffer
|
kvcache_manager,
|
||||||
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
|
seq,
|
||||||
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
|
)
|
||||||
decode_k = decode_k.unsqueeze(0)
|
|
||||||
decode_v = decode_v.unsqueeze(0)
|
|
||||||
|
|
||||||
decode_o, decode_lse = flash_attn_with_lse(
|
|
||||||
q_batched, decode_k, decode_v,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc = decode_o
|
|
||||||
else:
|
|
||||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
|
||||||
|
|
||||||
if o_acc is None:
|
|
||||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
|
||||||
|
|
||||||
# Sync back to default stream before returning
|
|
||||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
|
||||||
|
|
||||||
return o_acc
|
|
||||||
|
|
||||||
def _decode_ring_buffer_pipeline(
|
|
||||||
self,
|
|
||||||
q_batched: torch.Tensor,
|
|
||||||
cpu_block_table: list,
|
|
||||||
load_slots: list,
|
|
||||||
offload_engine,
|
|
||||||
block_size: int,
|
|
||||||
last_block_valid_tokens: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
|
|
||||||
|
|
||||||
Loads one block at a time, computes attention, and merges results.
|
|
||||||
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
|
|
||||||
methods as prefill for proven correctness.
|
|
||||||
"""
|
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
num_blocks = len(cpu_block_table)
|
|
||||||
if num_blocks == 0:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
if not load_slots:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
o_acc, lse_acc = None, None
|
|
||||||
num_slots = len(load_slots)
|
|
||||||
compute_stream = offload_engine.compute_stream
|
|
||||||
|
|
||||||
# Phase 1: Pre-load up to num_slots blocks
|
|
||||||
num_preload = min(num_slots, num_blocks)
|
|
||||||
for i in range(num_preload):
|
|
||||||
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
|
||||||
|
|
||||||
# Phase 2: Process blocks with pipeline
|
|
||||||
for block_idx in range(num_blocks):
|
|
||||||
current_slot = load_slots[block_idx % num_slots]
|
|
||||||
cpu_block_id = cpu_block_table[block_idx]
|
|
||||||
|
|
||||||
# Wait for current slot's transfer to complete
|
|
||||||
offload_engine.wait_slot_layer(current_slot)
|
|
||||||
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
# Get KV from slot
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
|
||||||
|
|
||||||
# Handle partial last block
|
|
||||||
is_last_block = (block_idx == num_blocks - 1)
|
|
||||||
if is_last_block and last_block_valid_tokens < block_size:
|
|
||||||
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
|
|
||||||
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
|
|
||||||
|
|
||||||
# Compute attention
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
|
||||||
q_batched, prev_k, prev_v,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record compute done for slot reuse
|
|
||||||
offload_engine.record_slot_compute_done(current_slot)
|
|
||||||
|
|
||||||
# Start loading next block (pipeline)
|
|
||||||
next_block_idx = block_idx + num_slots
|
|
||||||
if next_block_idx < num_blocks:
|
|
||||||
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
|
||||||
|
|
||||||
# Merge with accumulated
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
||||||
|
|
||||||
return o_acc, lse_acc
|
|
||||||
|
|
||||||
def _decode_with_layer_pipeline(
|
|
||||||
self,
|
|
||||||
q_batched: torch.Tensor,
|
|
||||||
cpu_block_table: list,
|
|
||||||
offload_engine,
|
|
||||||
block_size: int,
|
|
||||||
last_block_valid_tokens: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Decode using cross-layer pipeline for optimized H2D transfer.
|
|
||||||
|
|
||||||
This method uses pre-loaded layer buffers instead of loading
|
|
||||||
blocks one by one. The pipeline loads the next layer's data
|
|
||||||
while the current layer computes, achieving transfer/compute overlap.
|
|
||||||
|
|
||||||
The key insight is that each layer needs the SAME blocks but from
|
|
||||||
different layers of CPU cache. By double-buffering and pipelining
|
|
||||||
across layers, we reduce total latency.
|
|
||||||
"""
|
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
num_blocks = len(cpu_block_table)
|
|
||||||
if num_blocks == 0:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
compute_stream = offload_engine.compute_stream
|
|
||||||
|
|
||||||
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
|
||||||
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
|
|
||||||
|
|
||||||
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
|
||||||
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
|
||||||
total_tokens = num_blocks * block_size
|
|
||||||
|
|
||||||
# Handle partial last block
|
|
||||||
if last_block_valid_tokens < block_size:
|
|
||||||
# Only use valid tokens from last block
|
|
||||||
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
|
||||||
# Flatten and truncate
|
|
||||||
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
|
||||||
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
|
||||||
else:
|
|
||||||
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
|
||||||
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
|
||||||
|
|
||||||
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
|
||||||
prev_k_batched = prev_k_flat.unsqueeze(0)
|
|
||||||
prev_v_batched = prev_v_flat.unsqueeze(0)
|
|
||||||
|
|
||||||
# Compute attention on all prefilled blocks at once
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
o_acc, lse_acc = flash_attn_with_lse(
|
|
||||||
q_batched, prev_k_batched, prev_v_batched,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o_acc, lse_acc
|
|
||||||
|
|||||||
Reference in New Issue
Block a user