From 61edb8a34456b6a59da3e2a538454d85c6043e34 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Fri, 12 Dec 2025 02:27:40 +0800 Subject: [PATCH] [feat] Finished offload. Still need optimize performance. --- bench_offload.py | 10 +-- nanovllm/layers/attention.py | 108 +++++++++++++++++++------------- tests/test_chunked_attention.py | 2 +- 3 files changed, 72 insertions(+), 48 deletions(-) diff --git a/bench_offload.py b/bench_offload.py index 33d46cf..e4a1771 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -4,10 +4,10 @@ from random import randint, seed from nanovllm import LLM, SamplingParams -def bench_decode(llm, num_seqs, max_input_len, max_output_len): +def bench_decode(llm, num_seqs, input_len, max_output_len): """Benchmark decode performance (original test)""" seed(0) - prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] + prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, input_len))] for _ in range(num_seqs)] sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)] t = time.time() @@ -54,13 +54,13 @@ def main(): # bench_prefill(llm, num_seqs=1, input_len=1024) # bench_prefill(llm, num_seqs=1, input_len=2048) # bench_prefill(llm, num_seqs=1, input_len=4096) - bench_prefill(llm, num_seqs=1, input_len=8192) + bench_prefill(llm, num_seqs=1, input_len=64 * 1024) print("=" * 60) print("Decode Benchmark (CPU Offload)") print("=" * 60) - bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=128) - # bench_decode(llm, num_seqs=1, max_input_len=2048, max_output_len=128) + bench_decode(llm, num_seqs=1, input_len=64 * 1024, max_output_len=128) + # bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128) if __name__ == "__main__": diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 9ba1b3d..1df50b9 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -133,23 +133,22 @@ class Attention(nn.Module): if cpu_block_table: offload_engine = kvcache_manager.offload_engine - # Use Prefetch region to load previous KV (won't conflict with current Compute region) - prefetch_size = offload_engine.num_prefetch_blocks - num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size + # For prefill: ONLY use Prefetch region to avoid conflict with + # current chunk's KV being written to Compute region slots + # Use synchronous per-layer loading (async would conflict with writes) + chunk_size = offload_engine.num_prefetch_blocks + num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size for chunk_idx in range(num_chunks): - start = chunk_idx * prefetch_size - end = min(start + prefetch_size, len(cpu_block_table)) + start = chunk_idx * chunk_size + end = min(start + chunk_size, len(cpu_block_table)) num_blocks_in_chunk = end - start chunk_ids = cpu_block_table[start:end] - # Load this chunk to Prefetch region (per-layer loading) - # Each layer loads only its own KV, avoiding the bug where layer 0 - # loads all layers and overwrites data before other layers can read it + # Load to Prefetch region (per-layer, sync) offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids) - - # Wait for this layer's Prefetch region and get KV offload_engine.wait_prefetch_layer(self.layer_id) + prev_k, prev_v = offload_engine.get_kv_for_prefetch( self.layer_id, num_blocks_in_chunk ) @@ -195,24 +194,27 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute decode attention with three-region GPU buffer. + Compute decode attention with async double-buffering using Compute and Prefetch regions. - All KV is stored on CPU. Uses Compute region buffer on GPU: - 1. Load chunk to Compute region - 2. Compute attention - 3. Repeat for all chunks - 4. Finally, attend to Decode region (slot 0) which contains the new token's KV - 5. Merge all attention outputs using online softmax (LSE) + Pipeline design: + - Compute region: holds current chunk being computed + - Prefetch region: async loads next chunk while current is computing + - After computation, swap roles of the two regions - Key: new token's KV is in Decode region (slot 0), won't be overwritten by Compute region loading. + Timeline: + ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ + │Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ... + └─────────────┘ └─────────────┘ └─────────────┘ + ↘ ↘ ↘ + ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ + │ Compute C0 │ │ Compute C1 │ │ Compute C2 │ + └─────────────┘ └─────────────┘ └─────────────┘ """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) - # Need: [batch, seqlen, heads, dim] q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] - # Note: context.offload_engine is actually HybridKVCacheManager kvcache_manager = context.offload_engine seq = context.chunked_seq @@ -223,32 +225,56 @@ class Attention(nn.Module): if not cpu_block_table: raise RuntimeError("Chunked decode attention failed: no CPU blocks available") - # Get the actual offload_engine for three-region operations offload_engine = kvcache_manager.offload_engine - # Calculate chunk info using Compute region - compute_size = offload_engine.num_compute_blocks - num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size + # Use prefetch_size as chunk size for double buffering + # This ensures both Compute and Prefetch regions can hold a full chunk + chunk_size = offload_engine.num_prefetch_blocks + num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size o_acc = None lse_acc = None + # Double buffering state: True = use Compute region, False = use Prefetch region + use_compute = True + + # Pre-load first chunk to Compute region (async) + first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))] + offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids) + for chunk_idx in range(num_chunks): - start = chunk_idx * compute_size - end = min(start + compute_size, len(cpu_block_table)) + start = chunk_idx * chunk_size + end = min(start + chunk_size, len(cpu_block_table)) num_blocks_in_chunk = end - start - chunk_ids = cpu_block_table[start:end] - # Load this chunk to Compute region (per-layer loading) - # Each layer loads only its own KV, avoiding the bug where layer 0 - # loads all layers and overwrites data before other layers can read it - offload_engine.load_to_compute_layer(self.layer_id, chunk_ids) + # Wait for current buffer to be ready + if use_compute: + offload_engine.wait_compute_layer(self.layer_id) + else: + offload_engine.wait_prefetch_layer(self.layer_id) - # Wait for this layer's Compute region to be ready and get KV - offload_engine.wait_compute_layer(self.layer_id) - k_chunk, v_chunk = offload_engine.get_kv_for_compute( - self.layer_id, num_blocks_in_chunk - ) + # Trigger async prefetch of next chunk to the OTHER buffer + # This overlaps transfer with current chunk's computation + if chunk_idx + 1 < num_chunks: + next_start = end + next_end = min(next_start + chunk_size, len(cpu_block_table)) + next_chunk_ids = cpu_block_table[next_start:next_end] + if use_compute: + # Current in Compute, prefetch next to Prefetch region + offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids) + else: + # Current in Prefetch, prefetch next to Compute region + offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids) + + # Get KV from current buffer + if use_compute: + k_chunk, v_chunk = offload_engine.get_kv_for_compute( + self.layer_id, num_blocks_in_chunk + ) + else: + k_chunk, v_chunk = offload_engine.get_kv_for_prefetch( + self.layer_id, num_blocks_in_chunk + ) # Compute attention for this chunk o_chunk, lse_chunk = flash_attn_with_lse( @@ -263,18 +289,18 @@ class Attention(nn.Module): else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) + # Swap buffers for next iteration + use_compute = not use_compute + # Now attend to Decode region (contains accumulated decode tokens) - # When batching offloads, decode slot accumulates multiple tokens - # from decode_start_pos_in_block to decode_pos_in_block (inclusive) pos_in_block = context.decode_pos_in_block start_pos = context.decode_start_pos_in_block num_accumulated = pos_in_block - start_pos + 1 if num_accumulated > 0: - # Get accumulated KV in decode slot [start_pos : pos_in_block+1] decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] - decode_k = decode_k.unsqueeze(0) # [1, num_tokens, heads, dim] + decode_k = decode_k.unsqueeze(0) decode_v = decode_v.unsqueeze(0) decode_o, decode_lse = flash_attn_with_lse( @@ -283,7 +309,6 @@ class Attention(nn.Module): causal=False, ) - # Merge with accumulated if o_acc is None: o_acc = decode_o else: @@ -292,5 +317,4 @@ class Attention(nn.Module): if o_acc is None: raise RuntimeError("Chunked decode attention failed: no KV available") - # Output shape: [batch, 1, heads, dim] (same as normal decode) return o_acc diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py index ccd6685..b2be4ff 100644 --- a/tests/test_chunked_attention.py +++ b/tests/test_chunked_attention.py @@ -61,7 +61,7 @@ Attention mechanisms allow models to focus on relevant parts of the input. fact_idx += 1 # Add the question at the end - prompt_parts.append("\n\nQuestion: Based on the information above, what is the capital of France and when was the Eiffel Tower built? Please answer briefly.\n\nAnswer:") + prompt_parts.append("\n\nQuestion: Based on the information above, what is the speed of light?\n\nAnswer:") return "".join(prompt_parts)