[feat] Finished offload. Still need optimize performance.

This commit is contained in:
Zijie Tian
2025-12-12 02:27:40 +08:00
parent 9b8165af5a
commit 61edb8a344
3 changed files with 72 additions and 48 deletions

View File

@@ -4,10 +4,10 @@ from random import randint, seed
from nanovllm import LLM, SamplingParams
def bench_decode(llm, num_seqs, max_input_len, max_output_len):
def bench_decode(llm, num_seqs, input_len, max_output_len):
"""Benchmark decode performance (original test)"""
seed(0)
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, input_len))] for _ in range(num_seqs)]
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
t = time.time()
@@ -54,13 +54,13 @@ def main():
# bench_prefill(llm, num_seqs=1, input_len=1024)
# bench_prefill(llm, num_seqs=1, input_len=2048)
# bench_prefill(llm, num_seqs=1, input_len=4096)
bench_prefill(llm, num_seqs=1, input_len=8192)
bench_prefill(llm, num_seqs=1, input_len=64 * 1024)
print("=" * 60)
print("Decode Benchmark (CPU Offload)")
print("=" * 60)
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=128)
# bench_decode(llm, num_seqs=1, max_input_len=2048, max_output_len=128)
bench_decode(llm, num_seqs=1, input_len=64 * 1024, max_output_len=128)
# bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128)
if __name__ == "__main__":

View File

@@ -133,23 +133,22 @@ class Attention(nn.Module):
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
# Use Prefetch region to load previous KV (won't conflict with current Compute region)
prefetch_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
# For prefill: ONLY use Prefetch region to avoid conflict with
# current chunk's KV being written to Compute region slots
# Use synchronous per-layer loading (async would conflict with writes)
chunk_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
for chunk_idx in range(num_chunks):
start = chunk_idx * prefetch_size
end = min(start + prefetch_size, len(cpu_block_table))
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Load this chunk to Prefetch region (per-layer loading)
# Each layer loads only its own KV, avoiding the bug where layer 0
# loads all layers and overwrites data before other layers can read it
# Load to Prefetch region (per-layer, sync)
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
# Wait for this layer's Prefetch region and get KV
offload_engine.wait_prefetch_layer(self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
@@ -195,24 +194,27 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with three-region GPU buffer.
Compute decode attention with async double-buffering using Compute and Prefetch regions.
All KV is stored on CPU. Uses Compute region buffer on GPU:
1. Load chunk to Compute region
2. Compute attention
3. Repeat for all chunks
4. Finally, attend to Decode region (slot 0) which contains the new token's KV
5. Merge all attention outputs using online softmax (LSE)
Pipeline design:
- Compute region: holds current chunk being computed
- Prefetch region: async loads next chunk while current is computing
- After computation, swap roles of the two regions
Key: new token's KV is in Decode region (slot 0), won't be overwritten by Compute region loading.
Timeline:
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ...
└─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Compute C0 │ │ Compute C1 │ │ Compute C2 │
└─────────────┘ └─────────────┘ └─────────────┘
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
# Need: [batch, seqlen, heads, dim]
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq
@@ -223,32 +225,56 @@ class Attention(nn.Module):
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for three-region operations
offload_engine = kvcache_manager.offload_engine
# Calculate chunk info using Compute region
compute_size = offload_engine.num_compute_blocks
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
# Use prefetch_size as chunk size for double buffering
# This ensures both Compute and Prefetch regions can hold a full chunk
chunk_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
o_acc = None
lse_acc = None
# Double buffering state: True = use Compute region, False = use Prefetch region
use_compute = True
# Pre-load first chunk to Compute region (async)
first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))]
offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * compute_size
end = min(start + compute_size, len(cpu_block_table))
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Load this chunk to Compute region (per-layer loading)
# Each layer loads only its own KV, avoiding the bug where layer 0
# loads all layers and overwrites data before other layers can read it
offload_engine.load_to_compute_layer(self.layer_id, chunk_ids)
# Wait for this layer's Compute region to be ready and get KV
# Wait for current buffer to be ready
if use_compute:
offload_engine.wait_compute_layer(self.layer_id)
else:
offload_engine.wait_prefetch_layer(self.layer_id)
# Trigger async prefetch of next chunk to the OTHER buffer
# This overlaps transfer with current chunk's computation
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Get KV from current buffer
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
@@ -263,18 +289,18 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Swap buffers for next iteration
use_compute = not use_compute
# Now attend to Decode region (contains accumulated decode tokens)
# When batching offloads, decode slot accumulates multiple tokens
# from decode_start_pos_in_block to decode_pos_in_block (inclusive)
pos_in_block = context.decode_pos_in_block
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
if num_accumulated > 0:
# Get accumulated KV in decode slot [start_pos : pos_in_block+1]
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0) # [1, num_tokens, heads, dim]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
decode_o, decode_lse = flash_attn_with_lse(
@@ -283,7 +309,6 @@ class Attention(nn.Module):
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc = decode_o
else:
@@ -292,5 +317,4 @@ class Attention(nn.Module):
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Output shape: [batch, 1, heads, dim] (same as normal decode)
return o_acc

View File

@@ -61,7 +61,7 @@ Attention mechanisms allow models to focus on relevant parts of the input.
fact_idx += 1
# Add the question at the end
prompt_parts.append("\n\nQuestion: Based on the information above, what is the capital of France and when was the Eiffel Tower built? Please answer briefly.\n\nAnswer:")
prompt_parts.append("\n\nQuestion: Based on the information above, what is the speed of light?\n\nAnswer:")
return "".join(prompt_parts)