572 lines
25 KiB
Python
572 lines
25 KiB
Python
import logging
|
|
import torch
|
|
import torch.cuda.nvtx
|
|
from torch import nn
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
|
from nanovllm.utils.context import get_context
|
|
from nanovllm.kvcache.sparse.policy import PolicyContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@triton.jit
|
|
def store_kvcache_kernel(
|
|
key_ptr,
|
|
key_stride,
|
|
value_ptr,
|
|
value_stride,
|
|
k_cache_ptr,
|
|
v_cache_ptr,
|
|
slot_mapping_ptr,
|
|
D: tl.constexpr,
|
|
):
|
|
idx = tl.program_id(0)
|
|
slot = tl.load(slot_mapping_ptr + idx)
|
|
if slot == -1: return
|
|
key_offsets = idx * key_stride + tl.arange(0, D)
|
|
value_offsets = idx * value_stride + tl.arange(0, D)
|
|
key = tl.load(key_ptr + key_offsets)
|
|
value = tl.load(value_ptr + value_offsets)
|
|
cache_offsets = slot * D + tl.arange(0, D)
|
|
tl.store(k_cache_ptr + cache_offsets, key)
|
|
tl.store(v_cache_ptr + cache_offsets, value)
|
|
|
|
|
|
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
|
N, num_heads, head_dim = key.shape
|
|
D = num_heads * head_dim
|
|
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
|
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
|
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
|
assert slot_mapping.numel() == N
|
|
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads,
|
|
head_dim,
|
|
scale,
|
|
num_kv_heads,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
self.scale = scale
|
|
self.num_kv_heads = num_kv_heads
|
|
self.k_cache = self.v_cache = torch.tensor([])
|
|
# Layer ID set by model_runner after model creation
|
|
self.layer_id: int = -1
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
|
context = get_context()
|
|
k_cache, v_cache = self.k_cache, self.v_cache
|
|
if k_cache.numel() and v_cache.numel():
|
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
|
|
|
if context.is_prefill:
|
|
if context.is_chunked_prefill:
|
|
# Chunked prefill: merge attention from previous KV
|
|
o = self._chunked_prefill_attention(q, k, v, context)
|
|
elif context.block_tables is not None: # prefix cache
|
|
k, v = k_cache, v_cache
|
|
o = flash_attn_varlen_func(q, k, v,
|
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
|
else:
|
|
o = flash_attn_varlen_func(q, k, v,
|
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
|
else: # decode
|
|
if context.is_chunked_prefill:
|
|
# Chunked decode: need to load all KV from CPU+GPU
|
|
# Store current decode token to per-layer decode buffer
|
|
# This is needed because GPU cache has no layer dimension,
|
|
# so all layers would overwrite each other in decode_slot.
|
|
kvcache_manager = context.kvcache_manager
|
|
offload_engine = kvcache_manager.offload_engine
|
|
pos_in_block = context.decode_pos_in_block
|
|
# k, v shape: [1, kv_heads, head_dim]
|
|
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
|
|
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
|
|
o = self._chunked_decode_attention(q, k, v, context)
|
|
else:
|
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
|
softmax_scale=self.scale, causal=True)
|
|
return o
|
|
|
|
def _chunked_prefill_attention(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
context,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute attention with unified ring buffer for chunked prefill.
|
|
|
|
Ring buffer design:
|
|
- Current chunk's KV is written to ring_slot[chunk_idx % N]
|
|
- Previous chunks' KV are loaded from CPU using N-1 available slots
|
|
- Pipeline: pre-fill slots, then process with overlapped load/compute
|
|
|
|
For each layer:
|
|
1. Current chunk's KV is in k_batched, v_batched (just written by model)
|
|
2. Load previous chunks from CPU using available slots (pipeline)
|
|
3. Compute attention against previous KV (no causal mask)
|
|
4. Compute attention against current KV (causal)
|
|
5. Merge all results using online softmax
|
|
"""
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
|
|
current_chunk_idx = context.current_chunk_idx
|
|
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
|
|
|
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
|
# Reshape for flash attention: [batch, seq, heads, dim]
|
|
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
|
k_batched = k.unsqueeze(0)
|
|
v_batched = v.unsqueeze(0)
|
|
|
|
o_acc = None
|
|
lse_acc = None
|
|
|
|
kvcache_manager = context.kvcache_manager
|
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
|
|
|
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
|
# Get prefilled CPU blocks (blocks from previous chunks)
|
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
|
|
|
# Apply sparse policy if enabled
|
|
if cpu_block_table and kvcache_manager.sparse_policy is not None:
|
|
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
|
policy_ctx = PolicyContext(
|
|
query_chunk_idx=current_chunk_idx,
|
|
num_query_chunks=num_chunks,
|
|
layer_id=self.layer_id,
|
|
query=None, # Prefill typically doesn't use query for selection
|
|
is_prefill=True,
|
|
block_size=kvcache_manager.block_size,
|
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
|
)
|
|
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
|
cpu_block_table, policy_ctx
|
|
)
|
|
|
|
if cpu_block_table:
|
|
offload_engine = kvcache_manager.offload_engine
|
|
|
|
# Get write slot for current chunk and available load slots
|
|
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
pipeline_depth = len(load_slots)
|
|
|
|
if pipeline_depth == 0:
|
|
# Only 1 slot total, cannot pipeline - use sync loading
|
|
o_acc, lse_acc = self._sync_load_previous_chunks(
|
|
q_batched, cpu_block_table, offload_engine
|
|
)
|
|
else:
|
|
# Use ring buffer pipeline
|
|
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
|
current_chunk_idx
|
|
)
|
|
|
|
|
|
# Compute attention against current chunk's KV (with causal mask)
|
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
|
current_o, current_lse = flash_attn_with_lse(
|
|
q_batched,
|
|
k_batched,
|
|
v_batched,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
)
|
|
torch.cuda.nvtx.range_pop()
|
|
|
|
# Merge with accumulated
|
|
if o_acc is None:
|
|
final_o = current_o
|
|
else:
|
|
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
|
|
# reading it on the default stream for the merge operation.
|
|
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
|
offload_engine = kvcache_manager.offload_engine
|
|
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
|
|
|
|
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
|
torch.cuda.nvtx.range_pop()
|
|
|
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
|
|
|
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
|
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
|
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
|
offload_engine = kvcache_manager.offload_engine
|
|
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
|
if seq is not None:
|
|
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
|
if current_chunk_idx < len(cpu_block_ids):
|
|
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
|
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
|
|
|
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
|
return final_o.squeeze(0)
|
|
|
|
def _sync_load_previous_chunks(
|
|
self,
|
|
q_batched: torch.Tensor,
|
|
cpu_block_table: list,
|
|
offload_engine,
|
|
):
|
|
"""Synchronous loading fallback when pipeline_depth=0."""
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
|
|
o_acc, lse_acc = None, None
|
|
compute_stream = offload_engine.compute_stream
|
|
|
|
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
|
# Load to slot 0 (single slot)
|
|
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
|
offload_engine.wait_slot_layer(0)
|
|
|
|
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
|
with torch.cuda.stream(compute_stream):
|
|
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
|
|
|
|
prev_o, prev_lse = flash_attn_with_lse(
|
|
q_batched, prev_k, prev_v,
|
|
softmax_scale=self.scale,
|
|
causal=False,
|
|
)
|
|
|
|
if o_acc is None:
|
|
o_acc, lse_acc = prev_o, prev_lse
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
|
|
return o_acc, lse_acc
|
|
|
|
def _ring_buffer_pipeline_load(
|
|
self,
|
|
q_batched: torch.Tensor,
|
|
cpu_block_table: list,
|
|
load_slots: list,
|
|
offload_engine,
|
|
current_chunk_idx: int = -1,
|
|
):
|
|
"""
|
|
Ring buffer async pipeline loading with double buffering.
|
|
|
|
Uses compute_done events to ensure safe buffer reuse:
|
|
- Before loading to slot X, wait for previous compute on slot X to finish
|
|
- Before computing on slot X, wait for load to slot X to finish
|
|
|
|
Timeline with 2 slots (A, B):
|
|
┌──────────────┐
|
|
│ Load B0→A │
|
|
└──────────────┘
|
|
┌──────────────┐ ┌──────────────┐
|
|
│ Load B1→B │ │ Load B2→A │ ...
|
|
└──────────────┘ └──────────────┘
|
|
↘ ↘
|
|
┌──────────────┐ ┌──────────────┐
|
|
│ Compute(A) │ │ Compute(B) │ ...
|
|
└──────────────┘ └──────────────┘
|
|
|
|
The load_to_slot_layer internally waits for compute_done[slot] before
|
|
starting the transfer, ensuring no data race.
|
|
"""
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
|
|
num_blocks = len(cpu_block_table)
|
|
if num_blocks == 0:
|
|
return None, None
|
|
|
|
pipeline_depth = len(load_slots)
|
|
if pipeline_depth == 0:
|
|
return None, None
|
|
|
|
o_acc, lse_acc = None, None
|
|
|
|
if pipeline_depth == 1:
|
|
# Only 1 slot available, cannot pipeline - use synchronous mode
|
|
# IMPORTANT: Must use compute_stream to match synchronization in
|
|
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
|
|
slot = load_slots[0]
|
|
compute_stream = offload_engine.compute_stream
|
|
for block_idx in range(num_blocks):
|
|
cpu_block_id = cpu_block_table[block_idx]
|
|
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
|
offload_engine.wait_slot_layer(slot)
|
|
|
|
with torch.cuda.stream(compute_stream):
|
|
# Debug: call hooks on compute_stream (synchronized with transfer)
|
|
if offload_engine.debug_mode:
|
|
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
|
|
|
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
|
prev_o, prev_lse = flash_attn_with_lse(
|
|
q_batched, prev_k, prev_v,
|
|
softmax_scale=self.scale,
|
|
causal=False,
|
|
)
|
|
# Record compute done so next load can safely reuse this slot
|
|
offload_engine.record_slot_compute_done(slot)
|
|
if o_acc is None:
|
|
o_acc, lse_acc = prev_o, prev_lse
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
return o_acc, lse_acc
|
|
|
|
# N-way pipeline: use ALL available slots for maximum overlap
|
|
# Pipeline depth = num_slots - 1 (num_slots blocks in flight)
|
|
num_slots = len(load_slots)
|
|
|
|
# Phase 1: Pre-load up to num_slots blocks to fill the pipeline
|
|
# This starts all transfers in parallel, utilizing full PCIe bandwidth
|
|
num_preload = min(num_slots, num_blocks)
|
|
for i in range(num_preload):
|
|
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
|
|
|
# Phase 2: Main loop - compute and immediately reuse slot for next transfer
|
|
# Use dedicated compute_stream (not default stream) to enable overlap with transfers
|
|
compute_stream = offload_engine.compute_stream
|
|
|
|
for block_idx in range(num_blocks):
|
|
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
|
|
|
|
# Cycle through slots: slot[block_idx % num_slots]
|
|
current_slot = load_slots[block_idx % num_slots]
|
|
cpu_block_id = cpu_block_table[block_idx]
|
|
|
|
# Wait for current slot's transfer to complete (on compute_stream)
|
|
offload_engine.wait_slot_layer(current_slot)
|
|
|
|
# Compute attention on current slot's data
|
|
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
|
with torch.cuda.stream(compute_stream):
|
|
# Debug: call hooks on compute_stream (synchronized with transfer)
|
|
if offload_engine.debug_mode:
|
|
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
|
|
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
|
prev_o, prev_lse = flash_attn_with_lse(
|
|
q_batched, prev_k, prev_v,
|
|
softmax_scale=self.scale,
|
|
causal=False,
|
|
)
|
|
torch.cuda.nvtx.range_pop()
|
|
|
|
# Record compute done - this allows the next transfer to safely overwrite this slot
|
|
offload_engine.record_slot_compute_done(current_slot)
|
|
|
|
# Immediately start loading the NEXT block into this slot (if more blocks remain)
|
|
# Key insight: reuse current_slot immediately after compute is done!
|
|
next_block_idx = block_idx + num_slots
|
|
if next_block_idx < num_blocks:
|
|
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
|
|
|
# Merge with accumulated (also on compute_stream for consistency)
|
|
with torch.cuda.stream(compute_stream):
|
|
if o_acc is None:
|
|
o_acc, lse_acc = prev_o, prev_lse
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
|
|
torch.cuda.nvtx.range_pop() # PipelineBlock
|
|
|
|
return o_acc, lse_acc
|
|
|
|
def _chunked_decode_attention(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
context,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute decode attention using ring buffer pipeline (same as prefill).
|
|
|
|
Uses the same loading mechanism as _chunked_prefill_attention:
|
|
- Load one block at a time from CPU to GPU slot
|
|
- Compute attention for each block
|
|
- Merge results using online softmax
|
|
- Finally merge with decode buffer (accumulated decode tokens)
|
|
|
|
This approach is simpler and proven correct (prefill tests pass).
|
|
The only difference from prefill is the additional decode buffer
|
|
that stores new tokens generated during decode.
|
|
"""
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
|
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
|
|
|
kvcache_manager = context.kvcache_manager
|
|
seq = context.chunked_seq
|
|
|
|
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
|
if self.layer_id == 0:
|
|
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
|
if not cpu_block_table:
|
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
|
|
|
# Calculate valid tokens in the last block
|
|
# Note: For chunked prefill, each block is exactly block_size tokens
|
|
# The cpu_block_table only contains full prefill blocks
|
|
block_size = kvcache_manager.block_size
|
|
num_prefill_blocks = len(cpu_block_table)
|
|
# All prefill blocks are full (block_size tokens each)
|
|
last_block_valid_tokens = block_size
|
|
|
|
# Apply sparse policy if enabled
|
|
if kvcache_manager.sparse_policy is not None:
|
|
policy_ctx = PolicyContext(
|
|
query_chunk_idx=0,
|
|
num_query_chunks=1,
|
|
layer_id=self.layer_id,
|
|
query=q_batched,
|
|
is_prefill=False,
|
|
block_size=kvcache_manager.block_size,
|
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
|
)
|
|
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
|
cpu_block_table, policy_ctx
|
|
)
|
|
|
|
offload_engine = kvcache_manager.offload_engine
|
|
load_slots = offload_engine.decode_load_slots # Available slots for loading
|
|
|
|
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
|
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
|
block_size, last_block_valid_tokens
|
|
)
|
|
|
|
# Now attend to accumulated decode tokens from per-layer decode buffer
|
|
pos_in_block = context.decode_pos_in_block
|
|
start_pos = context.decode_start_pos_in_block
|
|
num_accumulated = pos_in_block - start_pos + 1
|
|
|
|
# Sync compute_stream with default stream before reading decode_buffer
|
|
compute_stream = offload_engine.compute_stream
|
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
|
|
|
with torch.cuda.stream(compute_stream):
|
|
if num_accumulated > 0:
|
|
# Read from per-layer decode buffer
|
|
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
|
|
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
|
|
decode_k = decode_k.unsqueeze(0)
|
|
decode_v = decode_v.unsqueeze(0)
|
|
|
|
decode_o, decode_lse = flash_attn_with_lse(
|
|
q_batched, decode_k, decode_v,
|
|
softmax_scale=self.scale,
|
|
causal=False,
|
|
)
|
|
|
|
if o_acc is None:
|
|
o_acc = decode_o
|
|
else:
|
|
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
|
|
|
if o_acc is None:
|
|
raise RuntimeError("Chunked decode attention failed: no KV available")
|
|
|
|
# Sync back to default stream before returning
|
|
torch.cuda.default_stream().wait_stream(compute_stream)
|
|
|
|
return o_acc
|
|
|
|
def _decode_ring_buffer_pipeline(
|
|
self,
|
|
q_batched: torch.Tensor,
|
|
cpu_block_table: list,
|
|
load_slots: list,
|
|
offload_engine,
|
|
block_size: int,
|
|
last_block_valid_tokens: int,
|
|
):
|
|
"""
|
|
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
|
|
|
|
Loads one block at a time, computes attention, and merges results.
|
|
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
|
|
methods as prefill for proven correctness.
|
|
"""
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
|
|
num_blocks = len(cpu_block_table)
|
|
if num_blocks == 0:
|
|
return None, None
|
|
|
|
if not load_slots:
|
|
return None, None
|
|
|
|
o_acc, lse_acc = None, None
|
|
num_slots = len(load_slots)
|
|
compute_stream = offload_engine.compute_stream
|
|
|
|
# Phase 1: Pre-load up to num_slots blocks
|
|
num_preload = min(num_slots, num_blocks)
|
|
for i in range(num_preload):
|
|
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
|
|
|
# Phase 2: Process blocks with pipeline
|
|
for block_idx in range(num_blocks):
|
|
current_slot = load_slots[block_idx % num_slots]
|
|
cpu_block_id = cpu_block_table[block_idx]
|
|
|
|
# Wait for current slot's transfer to complete
|
|
offload_engine.wait_slot_layer(current_slot)
|
|
|
|
with torch.cuda.stream(compute_stream):
|
|
# Get KV from slot
|
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
|
|
|
# Handle partial last block
|
|
is_last_block = (block_idx == num_blocks - 1)
|
|
if is_last_block and last_block_valid_tokens < block_size:
|
|
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
|
|
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
|
|
|
|
# Compute attention
|
|
prev_o, prev_lse = flash_attn_with_lse(
|
|
q_batched, prev_k, prev_v,
|
|
softmax_scale=self.scale,
|
|
causal=False,
|
|
)
|
|
|
|
# Record compute done for slot reuse
|
|
offload_engine.record_slot_compute_done(current_slot)
|
|
|
|
# Start loading next block (pipeline)
|
|
next_block_idx = block_idx + num_slots
|
|
if next_block_idx < num_blocks:
|
|
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
|
|
|
# Merge with accumulated
|
|
with torch.cuda.stream(compute_stream):
|
|
if o_acc is None:
|
|
o_acc, lse_acc = prev_o, prev_lse
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
|
|
return o_acc, lse_acc
|