[WIP] Before modify to FlashInfer.
This commit is contained in:
104
tests/test_flashinfer_merge.py
Normal file
104
tests/test_flashinfer_merge.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
Test FlashInfer chunked attention with CPU offload.
|
||||||
|
|
||||||
|
Uses single_prefill_with_kv_cache + merge_state for chunked KV processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Core Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size):
|
||||||
|
"""
|
||||||
|
Chunked causal attention with KV on CPU.
|
||||||
|
|
||||||
|
q: [seq_q, num_heads, head_dim] on GPU
|
||||||
|
k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU
|
||||||
|
"""
|
||||||
|
seq_q = q.shape[0]
|
||||||
|
seq_kv = k_cpu.shape[0]
|
||||||
|
final_outputs = []
|
||||||
|
|
||||||
|
for q_start in range(0, seq_q, q_chunk_size):
|
||||||
|
q_end = min(q_start + q_chunk_size, seq_q)
|
||||||
|
q_chunk = q[q_start:q_end]
|
||||||
|
|
||||||
|
merged_output = None
|
||||||
|
merged_lse = None
|
||||||
|
|
||||||
|
for kv_start in range(0, seq_kv, kv_chunk_size):
|
||||||
|
kv_end = min(kv_start + kv_chunk_size, seq_kv)
|
||||||
|
|
||||||
|
if kv_start >= q_end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
|
||||||
|
v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
|
||||||
|
|
||||||
|
causal = not (kv_end <= q_start)
|
||||||
|
partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache(
|
||||||
|
q_chunk, k_chunk, v_chunk,
|
||||||
|
causal=causal,
|
||||||
|
return_lse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if merged_output is None:
|
||||||
|
merged_output, merged_lse = partial_out, partial_lse
|
||||||
|
else:
|
||||||
|
merged_output, merged_lse = flashinfer.merge_state(
|
||||||
|
merged_output, merged_lse,
|
||||||
|
partial_out, partial_lse,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_outputs.append(merged_output)
|
||||||
|
|
||||||
|
return torch.cat(final_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test Script
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing FlashInfer chunked attention with CPU offload")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
num_heads = 32
|
||||||
|
num_kv_heads = 8
|
||||||
|
head_dim = 128
|
||||||
|
|
||||||
|
test_configs = [
|
||||||
|
(32768, 8192, 8192), # 32K tokens
|
||||||
|
(65536, 8192, 8192), # 64K tokens
|
||||||
|
(131072, 16384, 16384), # 128K tokens
|
||||||
|
# (262144, 16384, 16384), # 256K tokens (slow)
|
||||||
|
# (524288, 16384, 16384), # 512K tokens (slow)
|
||||||
|
]
|
||||||
|
|
||||||
|
for seq_len, q_chunk, kv_chunk in test_configs:
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda')
|
||||||
|
k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
|
||||||
|
v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
|
||||||
|
|
||||||
|
# Chunked result
|
||||||
|
chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk)
|
||||||
|
|
||||||
|
# Reference
|
||||||
|
k_gpu = k_cpu.to('cuda')
|
||||||
|
v_gpu = v_cpu.to('cuda')
|
||||||
|
ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True)
|
||||||
|
|
||||||
|
max_diff = (ref_out - chunked_out).abs().max().item()
|
||||||
|
mean_diff = (ref_out - chunked_out).abs().mean().item()
|
||||||
|
|
||||||
|
num_chunks = (seq_len + q_chunk - 1) // q_chunk
|
||||||
|
assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}"
|
||||||
|
print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||||
|
|
||||||
|
print("\ntest_flashinfer_merge: PASSED")
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
"""
|
|
||||||
Test if slicing maintains pinned memory property.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("Test: Pinned Memory Property with Slicing")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Create a pinned tensor with shape similar to k_cache_cpu
|
|
||||||
# [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]
|
|
||||||
tensor = torch.zeros(8, 16, 1024, 8, 64, dtype=torch.float16, device="cpu", pin_memory=True)
|
|
||||||
|
|
||||||
print(f"\n1. Original tensor:")
|
|
||||||
print(f" - Shape: {tensor.shape}")
|
|
||||||
print(f" - is_pinned(): {tensor.is_pinned()}")
|
|
||||||
print(f" - is_contiguous(): {tensor.is_contiguous()}")
|
|
||||||
|
|
||||||
# Test slicing operation (what we do in offload_slot_to_cpu)
|
|
||||||
slice_view = tensor[:, 0] # Same as k_cache_cpu[:, cpu_block_id]
|
|
||||||
|
|
||||||
print(f"\n2. Sliced tensor [:, 0]:")
|
|
||||||
print(f" - Shape: {slice_view.shape}")
|
|
||||||
print(f" - is_pinned(): {slice_view.is_pinned()}")
|
|
||||||
print(f" - is_contiguous(): {slice_view.is_contiguous()}")
|
|
||||||
|
|
||||||
# Test if contiguous() helps
|
|
||||||
contiguous_slice = tensor[:, 0].contiguous()
|
|
||||||
|
|
||||||
print(f"\n3. Contiguous slice [:, 0].contiguous():")
|
|
||||||
print(f" - Shape: {contiguous_slice.shape}")
|
|
||||||
print(f" - is_pinned(): {contiguous_slice.is_pinned()}")
|
|
||||||
print(f" - is_contiguous(): {contiguous_slice.is_contiguous()}")
|
|
||||||
|
|
||||||
# Test copy behavior
|
|
||||||
gpu_tensor = torch.zeros(8, 4, 1024, 8, 64, dtype=torch.float16, device="cuda")
|
|
||||||
gpu_slice = gpu_tensor[:, 0]
|
|
||||||
|
|
||||||
print(f"\n4. GPU tensor slice:")
|
|
||||||
print(f" - Shape: {gpu_slice.shape}")
|
|
||||||
print(f" - is_contiguous(): {gpu_slice.is_contiguous()}")
|
|
||||||
|
|
||||||
# Simulate the problematic copy operation
|
|
||||||
print(f"\n5. Testing copy operations:")
|
|
||||||
|
|
||||||
# Method 1: Direct slice copy (current approach - SLOW)
|
|
||||||
slice_dst = tensor[:, 1]
|
|
||||||
print(f" Method 1 (slice view): dst.is_pinned()={slice_dst.is_pinned()}")
|
|
||||||
|
|
||||||
# Method 2: Use contiguous destination
|
|
||||||
contiguous_dst = tensor[:, 2].contiguous()
|
|
||||||
print(f" Method 2 (contiguous): dst.is_pinned()={contiguous_dst.is_pinned()}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Conclusion:")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
if not slice_view.is_pinned():
|
|
||||||
print("❌ Slicing LOSES pinned memory property!")
|
|
||||||
print(" This causes Device-to-Pageable transfers (SLOW)")
|
|
||||||
else:
|
|
||||||
print("✓ Slicing maintains pinned memory property")
|
|
||||||
|
|
||||||
if contiguous_slice.is_pinned():
|
|
||||||
print("✓ .contiguous() maintains pinned memory property")
|
|
||||||
else:
|
|
||||||
print("❌ .contiguous() also loses pinned memory property")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
"""
|
|
||||||
Test D2H transfer performance with pinned vs non-contiguous memory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import time
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("Test: D2H Transfer Performance (for nsys profiling)")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Setup
|
|
||||||
num_layers = 8
|
|
||||||
num_blocks = 16
|
|
||||||
block_size = 1024
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 64
|
|
||||||
|
|
||||||
# Allocate CPU cache (pinned)
|
|
||||||
k_cache_cpu = torch.zeros(
|
|
||||||
num_layers, num_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device="cpu", pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allocate GPU cache
|
|
||||||
k_cache_gpu = torch.randn(
|
|
||||||
num_layers, 4, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
print("\nWarmup...")
|
|
||||||
for _ in range(10):
|
|
||||||
k_cache_cpu[:, 0].copy_(k_cache_gpu[:, 0], non_blocking=True)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
print(f"\nTensor info:")
|
|
||||||
print(f" k_cache_cpu.is_pinned(): {k_cache_cpu.is_pinned()}")
|
|
||||||
print(f" k_cache_cpu.is_contiguous(): {k_cache_cpu.is_contiguous()}")
|
|
||||||
print(f" k_cache_cpu[:, 0].is_pinned(): {k_cache_cpu[:, 0].is_pinned()}")
|
|
||||||
print(f" k_cache_cpu[:, 0].is_contiguous(): {k_cache_cpu[:, 0].is_contiguous()}")
|
|
||||||
|
|
||||||
# Test 1: Non-contiguous slice (current approach)
|
|
||||||
print(f"\n" + "=" * 60)
|
|
||||||
print("Test 1: Non-contiguous slice copy (current approach)")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
NUM_ITERS = 50 # Reduced for profiling
|
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push("Test1_NonContiguous")
|
|
||||||
times = []
|
|
||||||
for i in range(NUM_ITERS):
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H_NonContig_{i}")
|
|
||||||
start = time.perf_counter()
|
|
||||||
k_cache_cpu[:, i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
times.append(time.perf_counter() - start)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
avg_time = sum(times) / len(times)
|
|
||||||
print(f"Average time: {avg_time * 1000:.3f} ms")
|
|
||||||
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
|
|
||||||
|
|
||||||
# Test 2: Transpose to make dimension contiguous
|
|
||||||
print(f"\n" + "=" * 60)
|
|
||||||
print("Test 2: Transpose to contiguous dimension")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Reshape to [num_blocks, num_layers, block_size, num_kv_heads, head_dim]
|
|
||||||
k_cache_cpu_transposed = torch.zeros(
|
|
||||||
num_blocks, num_layers, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device="cpu", pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f" k_cache_cpu_transposed[0].is_pinned(): {k_cache_cpu_transposed[0].is_pinned()}")
|
|
||||||
print(f" k_cache_cpu_transposed[0].is_contiguous(): {k_cache_cpu_transposed[0].is_contiguous()}")
|
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push("Test2_Contiguous")
|
|
||||||
times = []
|
|
||||||
for i in range(NUM_ITERS):
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H_Contig_{i}")
|
|
||||||
start = time.perf_counter()
|
|
||||||
k_cache_cpu_transposed[i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
times.append(time.perf_counter() - start)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
avg_time = sum(times) / len(times)
|
|
||||||
print(f"Average time: {avg_time * 1000:.3f} ms")
|
|
||||||
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
|
|
||||||
|
|
||||||
# Test 3: Fully contiguous buffer
|
|
||||||
print(f"\n" + "=" * 60)
|
|
||||||
print("Test 3: Fully contiguous buffer")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
k_cache_cpu_flat = torch.zeros(
|
|
||||||
num_layers * block_size * num_kv_heads * head_dim,
|
|
||||||
dtype=torch.float16, device="cpu", pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f" k_cache_cpu_flat.is_pinned(): {k_cache_cpu_flat.is_pinned()}")
|
|
||||||
print(f" k_cache_cpu_flat.is_contiguous(): {k_cache_cpu_flat.is_contiguous()}")
|
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push("Test3_FlatContiguous")
|
|
||||||
times = []
|
|
||||||
for i in range(NUM_ITERS):
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H_Flat_{i}")
|
|
||||||
start = time.perf_counter()
|
|
||||||
k_cache_cpu_flat.copy_(k_cache_gpu[:, 0].flatten(), non_blocking=True)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
times.append(time.perf_counter() - start)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
avg_time = sum(times) / len(times)
|
|
||||||
print(f"Average time: {avg_time * 1000:.3f} ms")
|
|
||||||
print(f"Bandwidth: {k_cache_cpu_flat.numel() * 2 / avg_time / 1e9:.2f} GB/s")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("test_pinned_transfer: PASSED")
|
|
||||||
print("=" * 60)
|
|
||||||
@@ -1,286 +0,0 @@
|
|||||||
"""
|
|
||||||
Chunked Prefill + KV Cache Offload Simulation v2
|
|
||||||
|
|
||||||
改进:
|
|
||||||
1. 简化日志输出
|
|
||||||
2. 添加reduce时间
|
|
||||||
3. 计算必须等待KV load完成
|
|
||||||
"""
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, Future
|
|
||||||
|
|
||||||
# ============== 配置参数 ==============
|
|
||||||
NUM_CHUNKS = 8
|
|
||||||
GPU_SLOTS = 4
|
|
||||||
|
|
||||||
# 模拟时间 (秒)
|
|
||||||
TIME_COMPUTE_BLOCK = 0.10 # 计算一个attention block
|
|
||||||
TIME_REDUCE = 0.03 # 两个partial result做一次reduce
|
|
||||||
TIME_TRANSFER = 0.08 # 传输一个KV cache
|
|
||||||
TIME_PROJ = 0.02 # projection生成KV
|
|
||||||
|
|
||||||
# ============== 全局时间基准 ==============
|
|
||||||
START_TIME = None
|
|
||||||
|
|
||||||
def now() -> float:
|
|
||||||
"""返回相对开始的时间(ms)"""
|
|
||||||
return (time.time() - START_TIME) * 1000
|
|
||||||
|
|
||||||
def log_compute(msg: str):
|
|
||||||
"""计算队列日志(无缩进)"""
|
|
||||||
print(f"[{now():7.1f}ms] [COMPUTE] {msg}")
|
|
||||||
|
|
||||||
def log_transfer(msg: str):
|
|
||||||
"""传输队列日志(缩进)"""
|
|
||||||
print(f"[{now():7.1f}ms] [TRANSFER] {msg}")
|
|
||||||
|
|
||||||
def log_info(msg: str):
|
|
||||||
"""一般信息"""
|
|
||||||
print(f"[{now():7.1f}ms] {msg}")
|
|
||||||
|
|
||||||
# ============== GPU Slot管理 ==============
|
|
||||||
class GPUSlots:
|
|
||||||
def __init__(self, num_slots: int):
|
|
||||||
self.slots = [None] * num_slots # slot_id -> kv_idx
|
|
||||||
self.kv_to_slot = {} # kv_idx -> slot_id
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
# KV ready events: kv_idx -> Event
|
|
||||||
self.kv_ready = {}
|
|
||||||
|
|
||||||
def alloc(self, kv_idx: int) -> int:
|
|
||||||
with self.lock:
|
|
||||||
for sid, val in enumerate(self.slots):
|
|
||||||
if val is None:
|
|
||||||
self.slots[sid] = kv_idx
|
|
||||||
self.kv_to_slot[kv_idx] = sid
|
|
||||||
# 创建ready event
|
|
||||||
if kv_idx not in self.kv_ready:
|
|
||||||
self.kv_ready[kv_idx] = threading.Event()
|
|
||||||
return sid
|
|
||||||
raise RuntimeError(f"No free slot for KV{kv_idx}")
|
|
||||||
|
|
||||||
def free(self, slot_id: int):
|
|
||||||
with self.lock:
|
|
||||||
kv_idx = self.slots[slot_id]
|
|
||||||
if kv_idx is not None:
|
|
||||||
del self.kv_to_slot[kv_idx]
|
|
||||||
# 清除event
|
|
||||||
if kv_idx in self.kv_ready:
|
|
||||||
del self.kv_ready[kv_idx]
|
|
||||||
self.slots[slot_id] = None
|
|
||||||
|
|
||||||
def free_kv(self, kv_idx: int):
|
|
||||||
with self.lock:
|
|
||||||
if kv_idx in self.kv_to_slot:
|
|
||||||
sid = self.kv_to_slot[kv_idx]
|
|
||||||
self.slots[sid] = None
|
|
||||||
del self.kv_to_slot[kv_idx]
|
|
||||||
if kv_idx in self.kv_ready:
|
|
||||||
del self.kv_ready[kv_idx]
|
|
||||||
|
|
||||||
def mark_ready(self, kv_idx: int):
|
|
||||||
"""标记KV已就绪(load完成或proj完成)"""
|
|
||||||
with self.lock:
|
|
||||||
if kv_idx in self.kv_ready:
|
|
||||||
self.kv_ready[kv_idx].set()
|
|
||||||
|
|
||||||
def wait_ready(self, kv_idx: int):
|
|
||||||
"""等待KV就绪"""
|
|
||||||
with self.lock:
|
|
||||||
event = self.kv_ready.get(kv_idx)
|
|
||||||
if event:
|
|
||||||
event.wait()
|
|
||||||
|
|
||||||
def has_kv(self, kv_idx: int) -> bool:
|
|
||||||
with self.lock:
|
|
||||||
return kv_idx in self.kv_to_slot
|
|
||||||
|
|
||||||
def state(self) -> str:
|
|
||||||
with self.lock:
|
|
||||||
return "[" + "][".join(
|
|
||||||
f"KV{v}" if v is not None else "----"
|
|
||||||
for v in self.slots
|
|
||||||
) + "]"
|
|
||||||
|
|
||||||
# ============== 操作执行 ==============
|
|
||||||
class Executor:
|
|
||||||
def __init__(self, gpu: GPUSlots):
|
|
||||||
self.gpu = gpu
|
|
||||||
self.compute_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Compute")
|
|
||||||
self.transfer_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Transfer")
|
|
||||||
|
|
||||||
def proj_kv(self, q_idx: int) -> Future:
|
|
||||||
"""Projection生成KV,返回Future"""
|
|
||||||
def task():
|
|
||||||
log_compute(f"PROJ Q{q_idx}->KV{q_idx} START")
|
|
||||||
time.sleep(TIME_PROJ)
|
|
||||||
slot_id = self.gpu.alloc(q_idx)
|
|
||||||
self.gpu.mark_ready(q_idx) # proj完成,KV立即可用
|
|
||||||
log_compute(f"PROJ Q{q_idx}->KV{q_idx} END slot={slot_id} | {self.gpu.state()}")
|
|
||||||
return slot_id
|
|
||||||
return self.compute_pool.submit(task)
|
|
||||||
|
|
||||||
def compute_attn(self, q_idx: int, kv_indices: list) -> Future:
|
|
||||||
"""计算attention block,会等待所有KV就绪"""
|
|
||||||
def task():
|
|
||||||
# 等待所有需要的KV就绪
|
|
||||||
for kv_idx in kv_indices:
|
|
||||||
self.gpu.wait_ready(kv_idx)
|
|
||||||
|
|
||||||
kv_str = ",".join(map(str, kv_indices))
|
|
||||||
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] START")
|
|
||||||
time.sleep(TIME_COMPUTE_BLOCK * len(kv_indices))
|
|
||||||
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] END")
|
|
||||||
return (q_idx, kv_indices)
|
|
||||||
return self.compute_pool.submit(task)
|
|
||||||
|
|
||||||
def reduce(self, q_idx: int, num_partials: int) -> Future:
|
|
||||||
"""Online softmax reduce多个partial结果"""
|
|
||||||
def task():
|
|
||||||
if num_partials <= 1:
|
|
||||||
return
|
|
||||||
# n个partial需要n-1次两两reduce
|
|
||||||
num_reduces = num_partials - 1
|
|
||||||
log_compute(f"REDUCE Q{q_idx} ({num_partials} partials) START")
|
|
||||||
time.sleep(TIME_REDUCE * num_reduces)
|
|
||||||
log_compute(f"REDUCE Q{q_idx} END")
|
|
||||||
return self.compute_pool.submit(task)
|
|
||||||
|
|
||||||
def load_kv(self, kv_idx: int) -> Future:
|
|
||||||
"""从CPU load KV到GPU"""
|
|
||||||
def task():
|
|
||||||
if self.gpu.has_kv(kv_idx):
|
|
||||||
log_transfer(f"LOAD KV{kv_idx} SKIP (already on GPU)")
|
|
||||||
return kv_idx
|
|
||||||
|
|
||||||
slot_id = self.gpu.alloc(kv_idx)
|
|
||||||
log_transfer(f"LOAD KV{kv_idx} START -> slot{slot_id}")
|
|
||||||
time.sleep(TIME_TRANSFER)
|
|
||||||
self.gpu.mark_ready(kv_idx) # load完成,标记就绪
|
|
||||||
log_transfer(f"LOAD KV{kv_idx} END | {self.gpu.state()}")
|
|
||||||
return kv_idx
|
|
||||||
return self.transfer_pool.submit(task)
|
|
||||||
|
|
||||||
def offload_kv(self, kv_idx: int) -> Future:
|
|
||||||
"""从GPU offload KV到CPU"""
|
|
||||||
def task():
|
|
||||||
log_transfer(f"OFFLOAD KV{kv_idx} START")
|
|
||||||
time.sleep(TIME_TRANSFER)
|
|
||||||
self.gpu.free_kv(kv_idx)
|
|
||||||
log_transfer(f"OFFLOAD KV{kv_idx} END | {self.gpu.state()}")
|
|
||||||
return kv_idx
|
|
||||||
return self.transfer_pool.submit(task)
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
self.compute_pool.shutdown(wait=True)
|
|
||||||
self.transfer_pool.shutdown(wait=True)
|
|
||||||
|
|
||||||
# ============== 调度器 ==============
|
|
||||||
def schedule_query(exe: Executor, q_idx: int):
|
|
||||||
"""调度单个Query的处理"""
|
|
||||||
print(f"\n{'='*50}")
|
|
||||||
log_info(f"===== Query {q_idx} START =====")
|
|
||||||
|
|
||||||
hist_kv = list(range(q_idx)) # 历史KV: 0 ~ q_idx-1
|
|
||||||
num_partials = 0
|
|
||||||
|
|
||||||
# Phase 1: Projection生成当前KV
|
|
||||||
proj_fut = exe.proj_kv(q_idx)
|
|
||||||
proj_fut.result() # 等待完成
|
|
||||||
|
|
||||||
# Phase 2: 对角块计算 + 同时prefetch历史KV
|
|
||||||
# 启动对角块计算
|
|
||||||
diag_fut = exe.compute_attn(q_idx, [q_idx])
|
|
||||||
num_partials += 1
|
|
||||||
|
|
||||||
# 同时prefetch历史KV (最多3个slot可用)
|
|
||||||
prefetch_slots = min(len(hist_kv), GPU_SLOTS - 1)
|
|
||||||
prefetch_kv = hist_kv[:prefetch_slots]
|
|
||||||
prefetch_futs = [exe.load_kv(kv) for kv in prefetch_kv]
|
|
||||||
|
|
||||||
# 等待对角块完成
|
|
||||||
diag_fut.result()
|
|
||||||
|
|
||||||
# Phase 3: Offload当前KV释放slot
|
|
||||||
offload_fut = exe.offload_kv(q_idx)
|
|
||||||
|
|
||||||
# 等待prefetch完成,然后计算这批历史KV
|
|
||||||
for f in prefetch_futs:
|
|
||||||
f.result()
|
|
||||||
|
|
||||||
if prefetch_kv:
|
|
||||||
hist_fut = exe.compute_attn(q_idx, prefetch_kv)
|
|
||||||
num_partials += 1
|
|
||||||
else:
|
|
||||||
hist_fut = None
|
|
||||||
|
|
||||||
# 等待offload完成
|
|
||||||
offload_fut.result()
|
|
||||||
|
|
||||||
# Phase 4: 处理剩余历史KV
|
|
||||||
remaining_kv = hist_kv[prefetch_slots:]
|
|
||||||
computed_kv = prefetch_kv.copy()
|
|
||||||
|
|
||||||
while remaining_kv:
|
|
||||||
# 等待上一批计算完成
|
|
||||||
if hist_fut:
|
|
||||||
hist_fut.result()
|
|
||||||
|
|
||||||
# 释放已计算的KV
|
|
||||||
for kv in computed_kv:
|
|
||||||
exe.gpu.free_kv(kv)
|
|
||||||
|
|
||||||
# Load下一批
|
|
||||||
batch_size = min(len(remaining_kv), GPU_SLOTS)
|
|
||||||
batch_kv = remaining_kv[:batch_size]
|
|
||||||
remaining_kv = remaining_kv[batch_size:]
|
|
||||||
|
|
||||||
load_futs = [exe.load_kv(kv) for kv in batch_kv]
|
|
||||||
for f in load_futs:
|
|
||||||
f.result()
|
|
||||||
|
|
||||||
# 计算这批
|
|
||||||
hist_fut = exe.compute_attn(q_idx, batch_kv)
|
|
||||||
num_partials += 1
|
|
||||||
computed_kv = batch_kv
|
|
||||||
|
|
||||||
# 等待最后一批计算完成
|
|
||||||
if hist_fut:
|
|
||||||
hist_fut.result()
|
|
||||||
|
|
||||||
# 清理GPU
|
|
||||||
for kv in computed_kv:
|
|
||||||
exe.gpu.free_kv(kv)
|
|
||||||
|
|
||||||
# Phase 5: Reduce所有partial results
|
|
||||||
reduce_fut = exe.reduce(q_idx, num_partials)
|
|
||||||
reduce_fut.result()
|
|
||||||
|
|
||||||
log_info(f"===== Query {q_idx} END =====")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
global START_TIME
|
|
||||||
START_TIME = time.time()
|
|
||||||
|
|
||||||
print("Chunked Prefill + KV Cache Offload Simulation v2")
|
|
||||||
print(f"Config: {NUM_CHUNKS} chunks, {GPU_SLOTS} GPU slots")
|
|
||||||
print(f"Time: compute={TIME_COMPUTE_BLOCK}s, transfer={TIME_TRANSFER}s, reduce={TIME_REDUCE}s")
|
|
||||||
|
|
||||||
gpu = GPUSlots(GPU_SLOTS)
|
|
||||||
exe = Executor(gpu)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for q_idx in range(NUM_CHUNKS):
|
|
||||||
schedule_query(exe, q_idx)
|
|
||||||
|
|
||||||
print(f"\n{'='*50}")
|
|
||||||
log_info(f"ALL DONE! Total: {now():.1f}ms")
|
|
||||||
finally:
|
|
||||||
exe.shutdown()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user