[feat] Finished offload. Still need optimize performance.
This commit is contained in:
@@ -4,10 +4,10 @@ from random import randint, seed
|
|||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, max_input_len, max_output_len):
|
def bench_decode(llm, num_seqs, input_len, max_output_len):
|
||||||
"""Benchmark decode performance (original test)"""
|
"""Benchmark decode performance (original test)"""
|
||||||
seed(0)
|
seed(0)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, input_len))] for _ in range(num_seqs)]
|
||||||
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
||||||
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
@@ -54,13 +54,13 @@ def main():
|
|||||||
# bench_prefill(llm, num_seqs=1, input_len=1024)
|
# bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=4096)
|
# bench_prefill(llm, num_seqs=1, input_len=4096)
|
||||||
bench_prefill(llm, num_seqs=1, input_len=8192)
|
bench_prefill(llm, num_seqs=1, input_len=64 * 1024)
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Decode Benchmark (CPU Offload)")
|
print("Decode Benchmark (CPU Offload)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=128)
|
bench_decode(llm, num_seqs=1, input_len=64 * 1024, max_output_len=128)
|
||||||
# bench_decode(llm, num_seqs=1, max_input_len=2048, max_output_len=128)
|
# bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -133,23 +133,22 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
# Use Prefetch region to load previous KV (won't conflict with current Compute region)
|
# For prefill: ONLY use Prefetch region to avoid conflict with
|
||||||
prefetch_size = offload_engine.num_prefetch_blocks
|
# current chunk's KV being written to Compute region slots
|
||||||
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
# Use synchronous per-layer loading (async would conflict with writes)
|
||||||
|
chunk_size = offload_engine.num_prefetch_blocks
|
||||||
|
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
||||||
|
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
start = chunk_idx * prefetch_size
|
start = chunk_idx * chunk_size
|
||||||
end = min(start + prefetch_size, len(cpu_block_table))
|
end = min(start + chunk_size, len(cpu_block_table))
|
||||||
num_blocks_in_chunk = end - start
|
num_blocks_in_chunk = end - start
|
||||||
chunk_ids = cpu_block_table[start:end]
|
chunk_ids = cpu_block_table[start:end]
|
||||||
|
|
||||||
# Load this chunk to Prefetch region (per-layer loading)
|
# Load to Prefetch region (per-layer, sync)
|
||||||
# Each layer loads only its own KV, avoiding the bug where layer 0
|
|
||||||
# loads all layers and overwrites data before other layers can read it
|
|
||||||
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
|
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
|
||||||
|
|
||||||
# Wait for this layer's Prefetch region and get KV
|
|
||||||
offload_engine.wait_prefetch_layer(self.layer_id)
|
offload_engine.wait_prefetch_layer(self.layer_id)
|
||||||
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||||||
self.layer_id, num_blocks_in_chunk
|
self.layer_id, num_blocks_in_chunk
|
||||||
)
|
)
|
||||||
@@ -195,24 +194,27 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute decode attention with three-region GPU buffer.
|
Compute decode attention with async double-buffering using Compute and Prefetch regions.
|
||||||
|
|
||||||
All KV is stored on CPU. Uses Compute region buffer on GPU:
|
Pipeline design:
|
||||||
1. Load chunk to Compute region
|
- Compute region: holds current chunk being computed
|
||||||
2. Compute attention
|
- Prefetch region: async loads next chunk while current is computing
|
||||||
3. Repeat for all chunks
|
- After computation, swap roles of the two regions
|
||||||
4. Finally, attend to Decode region (slot 0) which contains the new token's KV
|
|
||||||
5. Merge all attention outputs using online softmax (LSE)
|
|
||||||
|
|
||||||
Key: new token's KV is in Decode region (slot 0), won't be overwritten by Compute region loading.
|
Timeline:
|
||||||
|
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||||
|
│Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ...
|
||||||
|
└─────────────┘ └─────────────┘ └─────────────┘
|
||||||
|
↘ ↘ ↘
|
||||||
|
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||||
|
│ Compute C0 │ │ Compute C1 │ │ Compute C2 │
|
||||||
|
└─────────────┘ └─────────────┘ └─────────────┘
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||||
# Need: [batch, seqlen, heads, dim]
|
|
||||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||||
|
|
||||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
|
||||||
kvcache_manager = context.offload_engine
|
kvcache_manager = context.offload_engine
|
||||||
seq = context.chunked_seq
|
seq = context.chunked_seq
|
||||||
|
|
||||||
@@ -223,32 +225,56 @@ class Attention(nn.Module):
|
|||||||
if not cpu_block_table:
|
if not cpu_block_table:
|
||||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||||
|
|
||||||
# Get the actual offload_engine for three-region operations
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
# Calculate chunk info using Compute region
|
# Use prefetch_size as chunk size for double buffering
|
||||||
compute_size = offload_engine.num_compute_blocks
|
# This ensures both Compute and Prefetch regions can hold a full chunk
|
||||||
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
|
chunk_size = offload_engine.num_prefetch_blocks
|
||||||
|
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
||||||
|
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
|
|
||||||
|
# Double buffering state: True = use Compute region, False = use Prefetch region
|
||||||
|
use_compute = True
|
||||||
|
|
||||||
|
# Pre-load first chunk to Compute region (async)
|
||||||
|
first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))]
|
||||||
|
offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids)
|
||||||
|
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
start = chunk_idx * compute_size
|
start = chunk_idx * chunk_size
|
||||||
end = min(start + compute_size, len(cpu_block_table))
|
end = min(start + chunk_size, len(cpu_block_table))
|
||||||
num_blocks_in_chunk = end - start
|
num_blocks_in_chunk = end - start
|
||||||
chunk_ids = cpu_block_table[start:end]
|
|
||||||
|
|
||||||
# Load this chunk to Compute region (per-layer loading)
|
# Wait for current buffer to be ready
|
||||||
# Each layer loads only its own KV, avoiding the bug where layer 0
|
if use_compute:
|
||||||
# loads all layers and overwrites data before other layers can read it
|
|
||||||
offload_engine.load_to_compute_layer(self.layer_id, chunk_ids)
|
|
||||||
|
|
||||||
# Wait for this layer's Compute region to be ready and get KV
|
|
||||||
offload_engine.wait_compute_layer(self.layer_id)
|
offload_engine.wait_compute_layer(self.layer_id)
|
||||||
|
else:
|
||||||
|
offload_engine.wait_prefetch_layer(self.layer_id)
|
||||||
|
|
||||||
|
# Trigger async prefetch of next chunk to the OTHER buffer
|
||||||
|
# This overlaps transfer with current chunk's computation
|
||||||
|
if chunk_idx + 1 < num_chunks:
|
||||||
|
next_start = end
|
||||||
|
next_end = min(next_start + chunk_size, len(cpu_block_table))
|
||||||
|
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||||
|
if use_compute:
|
||||||
|
# Current in Compute, prefetch next to Prefetch region
|
||||||
|
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
|
||||||
|
else:
|
||||||
|
# Current in Prefetch, prefetch next to Compute region
|
||||||
|
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
||||||
|
|
||||||
|
# Get KV from current buffer
|
||||||
|
if use_compute:
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||||
self.layer_id, num_blocks_in_chunk
|
self.layer_id, num_blocks_in_chunk
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
|
||||||
|
self.layer_id, num_blocks_in_chunk
|
||||||
|
)
|
||||||
|
|
||||||
# Compute attention for this chunk
|
# Compute attention for this chunk
|
||||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||||
@@ -263,18 +289,18 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||||
|
|
||||||
|
# Swap buffers for next iteration
|
||||||
|
use_compute = not use_compute
|
||||||
|
|
||||||
# Now attend to Decode region (contains accumulated decode tokens)
|
# Now attend to Decode region (contains accumulated decode tokens)
|
||||||
# When batching offloads, decode slot accumulates multiple tokens
|
|
||||||
# from decode_start_pos_in_block to decode_pos_in_block (inclusive)
|
|
||||||
pos_in_block = context.decode_pos_in_block
|
pos_in_block = context.decode_pos_in_block
|
||||||
start_pos = context.decode_start_pos_in_block
|
start_pos = context.decode_start_pos_in_block
|
||||||
num_accumulated = pos_in_block - start_pos + 1
|
num_accumulated = pos_in_block - start_pos + 1
|
||||||
|
|
||||||
if num_accumulated > 0:
|
if num_accumulated > 0:
|
||||||
# Get accumulated KV in decode slot [start_pos : pos_in_block+1]
|
|
||||||
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||||
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||||
decode_k = decode_k.unsqueeze(0) # [1, num_tokens, heads, dim]
|
decode_k = decode_k.unsqueeze(0)
|
||||||
decode_v = decode_v.unsqueeze(0)
|
decode_v = decode_v.unsqueeze(0)
|
||||||
|
|
||||||
decode_o, decode_lse = flash_attn_with_lse(
|
decode_o, decode_lse = flash_attn_with_lse(
|
||||||
@@ -283,7 +309,6 @@ class Attention(nn.Module):
|
|||||||
causal=False,
|
causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge with accumulated
|
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
o_acc = decode_o
|
o_acc = decode_o
|
||||||
else:
|
else:
|
||||||
@@ -292,5 +317,4 @@ class Attention(nn.Module):
|
|||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||||
|
|
||||||
# Output shape: [batch, 1, heads, dim] (same as normal decode)
|
|
||||||
return o_acc
|
return o_acc
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ Attention mechanisms allow models to focus on relevant parts of the input.
|
|||||||
fact_idx += 1
|
fact_idx += 1
|
||||||
|
|
||||||
# Add the question at the end
|
# Add the question at the end
|
||||||
prompt_parts.append("\n\nQuestion: Based on the information above, what is the capital of France and when was the Eiffel Tower built? Please answer briefly.\n\nAnswer:")
|
prompt_parts.append("\n\nQuestion: Based on the information above, what is the speed of light?\n\nAnswer:")
|
||||||
|
|
||||||
return "".join(prompt_parts)
|
return "".join(prompt_parts)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user