[feat] Finished offload. Still need optimize performance.

This commit is contained in:
Zijie Tian
2025-12-12 02:27:40 +08:00
parent 9b8165af5a
commit 61edb8a344
3 changed files with 72 additions and 48 deletions

View File

@@ -4,10 +4,10 @@ from random import randint, seed
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
def bench_decode(llm, num_seqs, max_input_len, max_output_len): def bench_decode(llm, num_seqs, input_len, max_output_len):
"""Benchmark decode performance (original test)""" """Benchmark decode performance (original test)"""
seed(0) seed(0)
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, input_len))] for _ in range(num_seqs)]
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)] sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
t = time.time() t = time.time()
@@ -54,13 +54,13 @@ def main():
# bench_prefill(llm, num_seqs=1, input_len=1024) # bench_prefill(llm, num_seqs=1, input_len=1024)
# bench_prefill(llm, num_seqs=1, input_len=2048) # bench_prefill(llm, num_seqs=1, input_len=2048)
# bench_prefill(llm, num_seqs=1, input_len=4096) # bench_prefill(llm, num_seqs=1, input_len=4096)
bench_prefill(llm, num_seqs=1, input_len=8192) bench_prefill(llm, num_seqs=1, input_len=64 * 1024)
print("=" * 60) print("=" * 60)
print("Decode Benchmark (CPU Offload)") print("Decode Benchmark (CPU Offload)")
print("=" * 60) print("=" * 60)
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=128) bench_decode(llm, num_seqs=1, input_len=64 * 1024, max_output_len=128)
# bench_decode(llm, num_seqs=1, max_input_len=2048, max_output_len=128) # bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -133,23 +133,22 @@ class Attention(nn.Module):
if cpu_block_table: if cpu_block_table:
offload_engine = kvcache_manager.offload_engine offload_engine = kvcache_manager.offload_engine
# Use Prefetch region to load previous KV (won't conflict with current Compute region) # For prefill: ONLY use Prefetch region to avoid conflict with
prefetch_size = offload_engine.num_prefetch_blocks # current chunk's KV being written to Compute region slots
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size # Use synchronous per-layer loading (async would conflict with writes)
chunk_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
start = chunk_idx * prefetch_size start = chunk_idx * chunk_size
end = min(start + prefetch_size, len(cpu_block_table)) end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end] chunk_ids = cpu_block_table[start:end]
# Load this chunk to Prefetch region (per-layer loading) # Load to Prefetch region (per-layer, sync)
# Each layer loads only its own KV, avoiding the bug where layer 0
# loads all layers and overwrites data before other layers can read it
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids) offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
# Wait for this layer's Prefetch region and get KV
offload_engine.wait_prefetch_layer(self.layer_id) offload_engine.wait_prefetch_layer(self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_prefetch( prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk self.layer_id, num_blocks_in_chunk
) )
@@ -195,24 +194,27 @@ class Attention(nn.Module):
context, context,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute decode attention with three-region GPU buffer. Compute decode attention with async double-buffering using Compute and Prefetch regions.
All KV is stored on CPU. Uses Compute region buffer on GPU: Pipeline design:
1. Load chunk to Compute region - Compute region: holds current chunk being computed
2. Compute attention - Prefetch region: async loads next chunk while current is computing
3. Repeat for all chunks - After computation, swap roles of the two regions
4. Finally, attend to Decode region (slot 0) which contains the new token's KV
5. Merge all attention outputs using online softmax (LSE)
Key: new token's KV is in Decode region (slot 0), won't be overwritten by Compute region loading. Timeline:
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ...
└─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Compute C0 │ │ Compute C1 │ │ Compute C2 │
└─────────────┘ └─────────────┘ └─────────────┘
""" """
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
# Need: [batch, seqlen, heads, dim]
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine kvcache_manager = context.offload_engine
seq = context.chunked_seq seq = context.chunked_seq
@@ -223,32 +225,56 @@ class Attention(nn.Module):
if not cpu_block_table: if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available") raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for three-region operations
offload_engine = kvcache_manager.offload_engine offload_engine = kvcache_manager.offload_engine
# Calculate chunk info using Compute region # Use prefetch_size as chunk size for double buffering
compute_size = offload_engine.num_compute_blocks # This ensures both Compute and Prefetch regions can hold a full chunk
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size chunk_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
o_acc = None o_acc = None
lse_acc = None lse_acc = None
# Double buffering state: True = use Compute region, False = use Prefetch region
use_compute = True
# Pre-load first chunk to Compute region (async)
first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))]
offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids)
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
start = chunk_idx * compute_size start = chunk_idx * chunk_size
end = min(start + compute_size, len(cpu_block_table)) end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Load this chunk to Compute region (per-layer loading) # Wait for current buffer to be ready
# Each layer loads only its own KV, avoiding the bug where layer 0 if use_compute:
# loads all layers and overwrites data before other layers can read it
offload_engine.load_to_compute_layer(self.layer_id, chunk_ids)
# Wait for this layer's Compute region to be ready and get KV
offload_engine.wait_compute_layer(self.layer_id) offload_engine.wait_compute_layer(self.layer_id)
else:
offload_engine.wait_prefetch_layer(self.layer_id)
# Trigger async prefetch of next chunk to the OTHER buffer
# This overlaps transfer with current chunk's computation
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Get KV from current buffer
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute( k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk self.layer_id, num_blocks_in_chunk
) )
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk # Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse( o_chunk, lse_chunk = flash_attn_with_lse(
@@ -263,18 +289,18 @@ class Attention(nn.Module):
else: else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Swap buffers for next iteration
use_compute = not use_compute
# Now attend to Decode region (contains accumulated decode tokens) # Now attend to Decode region (contains accumulated decode tokens)
# When batching offloads, decode slot accumulates multiple tokens
# from decode_start_pos_in_block to decode_pos_in_block (inclusive)
pos_in_block = context.decode_pos_in_block pos_in_block = context.decode_pos_in_block
start_pos = context.decode_start_pos_in_block start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1 num_accumulated = pos_in_block - start_pos + 1
if num_accumulated > 0: if num_accumulated > 0:
# Get accumulated KV in decode slot [start_pos : pos_in_block+1]
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0) # [1, num_tokens, heads, dim] decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0) decode_v = decode_v.unsqueeze(0)
decode_o, decode_lse = flash_attn_with_lse( decode_o, decode_lse = flash_attn_with_lse(
@@ -283,7 +309,6 @@ class Attention(nn.Module):
causal=False, causal=False,
) )
# Merge with accumulated
if o_acc is None: if o_acc is None:
o_acc = decode_o o_acc = decode_o
else: else:
@@ -292,5 +317,4 @@ class Attention(nn.Module):
if o_acc is None: if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available") raise RuntimeError("Chunked decode attention failed: no KV available")
# Output shape: [batch, 1, heads, dim] (same as normal decode)
return o_acc return o_acc

View File

@@ -61,7 +61,7 @@ Attention mechanisms allow models to focus on relevant parts of the input.
fact_idx += 1 fact_idx += 1
# Add the question at the end # Add the question at the end
prompt_parts.append("\n\nQuestion: Based on the information above, what is the capital of France and when was the Eiffel Tower built? Please answer briefly.\n\nAnswer:") prompt_parts.append("\n\nQuestion: Based on the information above, what is the speed of light?\n\nAnswer:")
return "".join(prompt_parts) return "".join(prompt_parts)