[refactor] Refactor current gpu and cpu block allocation strategy.

This commit is contained in:
Zijie Tian
2025-12-10 21:23:31 +08:00
parent 0a247ccb1b
commit 190df5f70d
7 changed files with 906 additions and 162 deletions

View File

@@ -97,51 +97,89 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute attention with chunked KV from CPU cache.
Compute attention with Ping-Pong dual buffer for chunked prefill.
For chunked prefill:
1. Load previous KV from CPU for this layer
2. Compute attention against previous KV (no causal mask)
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
2. Compute attention against previous KV chunks (no causal mask)
3. Compute attention against current chunk's KV (causal)
4. Merge results using online softmax
4. Merge all results using online softmax
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q, k, v shape: [total_tokens, num_heads, head_dim]
total_tokens = q.shape[0]
# Reshape for flash attention: [batch, seq, heads, dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
accumulated_o = None
accumulated_lse = None
o_acc = None
lse_acc = None
# Load previous KV from CPU for this layer
if context.offload_engine is not None and self.layer_id >= 0:
# Get the kvcache_manager from context
kvcache_manager = context.offload_engine
# Load previous KV from CPU using Ping-Pong
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
# For each sequence in the chunk, load previous KV
# Currently assuming single sequence
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer(
context.chunked_seq,
self.layer_id,
)
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks already written in previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if prev_k is not None and prev_v is not None:
# Compute attention against previous KV (no causal mask)
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
current_buffer = "ping"
# Prefetch first chunk to Ping buffer
first_chunk_end = min(ping_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Prefetch next chunk to OTHER buffer
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
else:
offload_engine.load_to_ping(next_chunk_ids)
# Wait for current buffer and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Compute attention against this chunk (no causal mask)
prev_o, prev_lse = flash_attn_with_lse(
q_batched,
prev_k,
prev_v,
softmax_scale=self.scale,
causal=False, # No causal mask for previous context
causal=False,
)
accumulated_o = prev_o
accumulated_lse = prev_lse
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Switch buffer
current_buffer = "pong" if current_buffer == "ping" else "ping"
# Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse(
@@ -149,17 +187,14 @@ class Attention(nn.Module):
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True, # Causal mask for current chunk
causal=True,
)
# Merge with accumulated
if accumulated_o is None:
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(
accumulated_o, accumulated_lse,
current_o, current_lse,
)
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -172,12 +207,13 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with KV spread across CPU and GPU.
Compute decode attention with Ping-Pong dual buffer.
Uses chunked attention similar to chunked prefill:
1. Process blocks on GPU first (if any)
2. Load CPU blocks in chunks to GPU slots (per-layer)
3. Compute attention for each chunk, merge with online softmax
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
1. Load first chunk to Ping buffer
2. While computing on current buffer, prefetch next chunk to other buffer
3. Alternate between Ping and Pong buffers
4. Merge attention outputs using online softmax (LSE)
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -185,62 +221,73 @@ class Attention(nn.Module):
# Need: [batch, seqlen, heads, dim]
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for Ping-Pong operations
offload_engine = kvcache_manager.offload_engine
# Calculate chunk info
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
o_acc = None
lse_acc = None
current_buffer = "ping"
# Step 1: Process blocks already on GPU (if any)
gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq)
if gpu_slots:
k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots)
o_gpu, lse_gpu = flash_attn_with_lse(
q_batched, k_gpu, v_gpu,
# Prefetch first chunk to Ping buffer (loads all layers at once)
first_chunk_end = min(ping_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
else:
offload_engine.load_to_ping(next_chunk_ids)
# Wait for current buffer to be ready and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
o_acc, lse_acc = o_gpu, lse_gpu
# Step 2: Process CPU blocks in chunks
# Get chunk info from kvcache_manager
cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
if num_chunks > 0:
# Use num_gpu_slots - 1 to avoid the reserved slot (used for write block)
chunk_size = kvcache_manager.num_gpu_slots - 1
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_ids))
chunk_cpu_ids = cpu_block_ids[start:end]
# Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER
# (slot num_gpu_slots-1 is reserved for write block)
gpu_slots_for_chunk = list(range(len(chunk_cpu_ids)))
kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots(
self.layer_id,
chunk_cpu_ids,
gpu_slots_for_chunk,
)
# Get KV for this chunk
k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots(
self.layer_id, gpu_slots_for_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Switch buffer for next iteration
current_buffer = "pong" if current_buffer == "ping" else "ping"
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")