[refactor] Refactor current gpu and cpu block allocation strategy.
This commit is contained in:
@@ -97,51 +97,89 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with chunked KV from CPU cache.
|
||||
Compute attention with Ping-Pong dual buffer for chunked prefill.
|
||||
|
||||
For chunked prefill:
|
||||
1. Load previous KV from CPU for this layer
|
||||
2. Compute attention against previous KV (no causal mask)
|
||||
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
|
||||
2. Compute attention against previous KV chunks (no causal mask)
|
||||
3. Compute attention against current chunk's KV (causal)
|
||||
4. Merge results using online softmax
|
||||
4. Merge all results using online softmax
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
||||
total_tokens = q.shape[0]
|
||||
|
||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
# Load previous KV from CPU for this layer
|
||||
if context.offload_engine is not None and self.layer_id >= 0:
|
||||
# Get the kvcache_manager from context
|
||||
kvcache_manager = context.offload_engine
|
||||
# Load previous KV from CPU using Ping-Pong
|
||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||
kvcache_manager = context.offload_engine
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
|
||||
# For each sequence in the chunk, load previous KV
|
||||
# Currently assuming single sequence
|
||||
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
|
||||
prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer(
|
||||
context.chunked_seq,
|
||||
self.layer_id,
|
||||
)
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
# Get prefilled CPU blocks (blocks already written in previous chunks)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
if prev_k is not None and prev_v is not None:
|
||||
# Compute attention against previous KV (no causal mask)
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
ping_size = offload_engine.ping_size
|
||||
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
|
||||
current_buffer = "ping"
|
||||
|
||||
# Prefetch first chunk to Ping buffer
|
||||
first_chunk_end = min(ping_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
offload_engine.load_to_ping(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * ping_size
|
||||
end = min(start + ping_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
|
||||
# Prefetch next chunk to OTHER buffer
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
next_start = end
|
||||
next_end = min(next_start + ping_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if current_buffer == "ping":
|
||||
offload_engine.load_to_pong(next_chunk_ids)
|
||||
else:
|
||||
offload_engine.load_to_ping(next_chunk_ids)
|
||||
|
||||
# Wait for current buffer and get KV
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
offload_engine.wait_pong()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention against this chunk (no causal mask)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
prev_k,
|
||||
prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # No causal mask for previous context
|
||||
causal=False,
|
||||
)
|
||||
accumulated_o = prev_o
|
||||
accumulated_lse = prev_lse
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
# Switch buffer
|
||||
current_buffer = "pong" if current_buffer == "ping" else "ping"
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
@@ -149,17 +187,14 @@ class Attention(nn.Module):
|
||||
k_batched,
|
||||
v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=True, # Causal mask for current chunk
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if accumulated_o is None:
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
final_o, _ = merge_attention_outputs(
|
||||
accumulated_o, accumulated_lse,
|
||||
current_o, current_lse,
|
||||
)
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
@@ -172,12 +207,13 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with KV spread across CPU and GPU.
|
||||
Compute decode attention with Ping-Pong dual buffer.
|
||||
|
||||
Uses chunked attention similar to chunked prefill:
|
||||
1. Process blocks on GPU first (if any)
|
||||
2. Load CPU blocks in chunks to GPU slots (per-layer)
|
||||
3. Compute attention for each chunk, merge with online softmax
|
||||
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
|
||||
1. Load first chunk to Ping buffer
|
||||
2. While computing on current buffer, prefetch next chunk to other buffer
|
||||
3. Alternate between Ping and Pong buffers
|
||||
4. Merge attention outputs using online softmax (LSE)
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -185,62 +221,73 @@ class Attention(nn.Module):
|
||||
# Need: [batch, seqlen, heads, dim]
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||
kvcache_manager = context.offload_engine
|
||||
seq = context.chunked_seq
|
||||
|
||||
# Get all CPU blocks for this sequence
|
||||
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||
|
||||
# Get the actual offload_engine for Ping-Pong operations
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Calculate chunk info
|
||||
ping_size = offload_engine.ping_size
|
||||
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
|
||||
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
current_buffer = "ping"
|
||||
|
||||
# Step 1: Process blocks already on GPU (if any)
|
||||
gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq)
|
||||
if gpu_slots:
|
||||
k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots)
|
||||
o_gpu, lse_gpu = flash_attn_with_lse(
|
||||
q_batched, k_gpu, v_gpu,
|
||||
# Prefetch first chunk to Ping buffer (loads all layers at once)
|
||||
first_chunk_end = min(ping_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
offload_engine.load_to_ping(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * ping_size
|
||||
end = min(start + ping_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
|
||||
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
next_start = end
|
||||
next_end = min(next_start + ping_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if current_buffer == "ping":
|
||||
offload_engine.load_to_pong(next_chunk_ids)
|
||||
else:
|
||||
offload_engine.load_to_ping(next_chunk_ids)
|
||||
|
||||
# Wait for current buffer to be ready and get KV
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
offload_engine.wait_pong()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
o_acc, lse_acc = o_gpu, lse_gpu
|
||||
|
||||
# Step 2: Process CPU blocks in chunks
|
||||
# Get chunk info from kvcache_manager
|
||||
cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq)
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o_chunk, lse_chunk
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
|
||||
if num_chunks > 0:
|
||||
# Use num_gpu_slots - 1 to avoid the reserved slot (used for write block)
|
||||
chunk_size = kvcache_manager.num_gpu_slots - 1
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * chunk_size
|
||||
end = min(start + chunk_size, len(cpu_block_ids))
|
||||
chunk_cpu_ids = cpu_block_ids[start:end]
|
||||
|
||||
# Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER
|
||||
# (slot num_gpu_slots-1 is reserved for write block)
|
||||
gpu_slots_for_chunk = list(range(len(chunk_cpu_ids)))
|
||||
kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots(
|
||||
self.layer_id,
|
||||
chunk_cpu_ids,
|
||||
gpu_slots_for_chunk,
|
||||
)
|
||||
|
||||
# Get KV for this chunk
|
||||
k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots(
|
||||
self.layer_id, gpu_slots_for_chunk
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o_chunk, lse_chunk
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
# Switch buffer for next iteration
|
||||
current_buffer = "pong" if current_buffer == "ping" else "ping"
|
||||
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
Reference in New Issue
Block a user