Files
nano-vllm/tests/test_sim.py
2025-12-24 18:22:26 +08:00

286 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Chunked Prefill + KV Cache Offload Simulation v2
改进:
1. 简化日志输出
2. 添加reduce时间
3. 计算必须等待KV load完成
"""
import threading
import time
from dataclasses import dataclass
from typing import Optional
from concurrent.futures import ThreadPoolExecutor, Future
# ============== 配置参数 ==============
NUM_CHUNKS = 8
GPU_SLOTS = 4
# 模拟时间 (秒)
TIME_COMPUTE_BLOCK = 0.10 # 计算一个attention block
TIME_REDUCE = 0.03 # 两个partial result做一次reduce
TIME_TRANSFER = 0.08 # 传输一个KV cache
TIME_PROJ = 0.02 # projection生成KV
# ============== 全局时间基准 ==============
START_TIME = None
def now() -> float:
"""返回相对开始的时间(ms)"""
return (time.time() - START_TIME) * 1000
def log_compute(msg: str):
"""计算队列日志(无缩进)"""
print(f"[{now():7.1f}ms] [COMPUTE] {msg}")
def log_transfer(msg: str):
"""传输队列日志(缩进)"""
print(f"[{now():7.1f}ms] [TRANSFER] {msg}")
def log_info(msg: str):
"""一般信息"""
print(f"[{now():7.1f}ms] {msg}")
# ============== GPU Slot管理 ==============
class GPUSlots:
def __init__(self, num_slots: int):
self.slots = [None] * num_slots # slot_id -> kv_idx
self.kv_to_slot = {} # kv_idx -> slot_id
self.lock = threading.Lock()
# KV ready events: kv_idx -> Event
self.kv_ready = {}
def alloc(self, kv_idx: int) -> int:
with self.lock:
for sid, val in enumerate(self.slots):
if val is None:
self.slots[sid] = kv_idx
self.kv_to_slot[kv_idx] = sid
# 创建ready event
if kv_idx not in self.kv_ready:
self.kv_ready[kv_idx] = threading.Event()
return sid
raise RuntimeError(f"No free slot for KV{kv_idx}")
def free(self, slot_id: int):
with self.lock:
kv_idx = self.slots[slot_id]
if kv_idx is not None:
del self.kv_to_slot[kv_idx]
# 清除event
if kv_idx in self.kv_ready:
del self.kv_ready[kv_idx]
self.slots[slot_id] = None
def free_kv(self, kv_idx: int):
with self.lock:
if kv_idx in self.kv_to_slot:
sid = self.kv_to_slot[kv_idx]
self.slots[sid] = None
del self.kv_to_slot[kv_idx]
if kv_idx in self.kv_ready:
del self.kv_ready[kv_idx]
def mark_ready(self, kv_idx: int):
"""标记KV已就绪load完成或proj完成"""
with self.lock:
if kv_idx in self.kv_ready:
self.kv_ready[kv_idx].set()
def wait_ready(self, kv_idx: int):
"""等待KV就绪"""
with self.lock:
event = self.kv_ready.get(kv_idx)
if event:
event.wait()
def has_kv(self, kv_idx: int) -> bool:
with self.lock:
return kv_idx in self.kv_to_slot
def state(self) -> str:
with self.lock:
return "[" + "][".join(
f"KV{v}" if v is not None else "----"
for v in self.slots
) + "]"
# ============== 操作执行 ==============
class Executor:
def __init__(self, gpu: GPUSlots):
self.gpu = gpu
self.compute_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Compute")
self.transfer_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Transfer")
def proj_kv(self, q_idx: int) -> Future:
"""Projection生成KV返回Future"""
def task():
log_compute(f"PROJ Q{q_idx}->KV{q_idx} START")
time.sleep(TIME_PROJ)
slot_id = self.gpu.alloc(q_idx)
self.gpu.mark_ready(q_idx) # proj完成KV立即可用
log_compute(f"PROJ Q{q_idx}->KV{q_idx} END slot={slot_id} | {self.gpu.state()}")
return slot_id
return self.compute_pool.submit(task)
def compute_attn(self, q_idx: int, kv_indices: list) -> Future:
"""计算attention block会等待所有KV就绪"""
def task():
# 等待所有需要的KV就绪
for kv_idx in kv_indices:
self.gpu.wait_ready(kv_idx)
kv_str = ",".join(map(str, kv_indices))
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] START")
time.sleep(TIME_COMPUTE_BLOCK * len(kv_indices))
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] END")
return (q_idx, kv_indices)
return self.compute_pool.submit(task)
def reduce(self, q_idx: int, num_partials: int) -> Future:
"""Online softmax reduce多个partial结果"""
def task():
if num_partials <= 1:
return
# n个partial需要n-1次两两reduce
num_reduces = num_partials - 1
log_compute(f"REDUCE Q{q_idx} ({num_partials} partials) START")
time.sleep(TIME_REDUCE * num_reduces)
log_compute(f"REDUCE Q{q_idx} END")
return self.compute_pool.submit(task)
def load_kv(self, kv_idx: int) -> Future:
"""从CPU load KV到GPU"""
def task():
if self.gpu.has_kv(kv_idx):
log_transfer(f"LOAD KV{kv_idx} SKIP (already on GPU)")
return kv_idx
slot_id = self.gpu.alloc(kv_idx)
log_transfer(f"LOAD KV{kv_idx} START -> slot{slot_id}")
time.sleep(TIME_TRANSFER)
self.gpu.mark_ready(kv_idx) # load完成标记就绪
log_transfer(f"LOAD KV{kv_idx} END | {self.gpu.state()}")
return kv_idx
return self.transfer_pool.submit(task)
def offload_kv(self, kv_idx: int) -> Future:
"""从GPU offload KV到CPU"""
def task():
log_transfer(f"OFFLOAD KV{kv_idx} START")
time.sleep(TIME_TRANSFER)
self.gpu.free_kv(kv_idx)
log_transfer(f"OFFLOAD KV{kv_idx} END | {self.gpu.state()}")
return kv_idx
return self.transfer_pool.submit(task)
def shutdown(self):
self.compute_pool.shutdown(wait=True)
self.transfer_pool.shutdown(wait=True)
# ============== 调度器 ==============
def schedule_query(exe: Executor, q_idx: int):
"""调度单个Query的处理"""
print(f"\n{'='*50}")
log_info(f"===== Query {q_idx} START =====")
hist_kv = list(range(q_idx)) # 历史KV: 0 ~ q_idx-1
num_partials = 0
# Phase 1: Projection生成当前KV
proj_fut = exe.proj_kv(q_idx)
proj_fut.result() # 等待完成
# Phase 2: 对角块计算 + 同时prefetch历史KV
# 启动对角块计算
diag_fut = exe.compute_attn(q_idx, [q_idx])
num_partials += 1
# 同时prefetch历史KV (最多3个slot可用)
prefetch_slots = min(len(hist_kv), GPU_SLOTS - 1)
prefetch_kv = hist_kv[:prefetch_slots]
prefetch_futs = [exe.load_kv(kv) for kv in prefetch_kv]
# 等待对角块完成
diag_fut.result()
# Phase 3: Offload当前KV释放slot
offload_fut = exe.offload_kv(q_idx)
# 等待prefetch完成然后计算这批历史KV
for f in prefetch_futs:
f.result()
if prefetch_kv:
hist_fut = exe.compute_attn(q_idx, prefetch_kv)
num_partials += 1
else:
hist_fut = None
# 等待offload完成
offload_fut.result()
# Phase 4: 处理剩余历史KV
remaining_kv = hist_kv[prefetch_slots:]
computed_kv = prefetch_kv.copy()
while remaining_kv:
# 等待上一批计算完成
if hist_fut:
hist_fut.result()
# 释放已计算的KV
for kv in computed_kv:
exe.gpu.free_kv(kv)
# Load下一批
batch_size = min(len(remaining_kv), GPU_SLOTS)
batch_kv = remaining_kv[:batch_size]
remaining_kv = remaining_kv[batch_size:]
load_futs = [exe.load_kv(kv) for kv in batch_kv]
for f in load_futs:
f.result()
# 计算这批
hist_fut = exe.compute_attn(q_idx, batch_kv)
num_partials += 1
computed_kv = batch_kv
# 等待最后一批计算完成
if hist_fut:
hist_fut.result()
# 清理GPU
for kv in computed_kv:
exe.gpu.free_kv(kv)
# Phase 5: Reduce所有partial results
reduce_fut = exe.reduce(q_idx, num_partials)
reduce_fut.result()
log_info(f"===== Query {q_idx} END =====")
def main():
global START_TIME
START_TIME = time.time()
print("Chunked Prefill + KV Cache Offload Simulation v2")
print(f"Config: {NUM_CHUNKS} chunks, {GPU_SLOTS} GPU slots")
print(f"Time: compute={TIME_COMPUTE_BLOCK}s, transfer={TIME_TRANSFER}s, reduce={TIME_REDUCE}s")
gpu = GPUSlots(GPU_SLOTS)
exe = Executor(gpu)
try:
for q_idx in range(NUM_CHUNKS):
schedule_query(exe, q_idx)
print(f"\n{'='*50}")
log_info(f"ALL DONE! Total: {now():.1f}ms")
finally:
exe.shutdown()
if __name__ == "__main__":
main()