[WIP] remove num_prefetch_blocks varible.
This commit is contained in:
@@ -237,7 +237,6 @@ Warmup uses a reasonable sequence length (`block_size * 2`) instead of `max_mode
|
|||||||
| `max_num_seqs` | 512 | Max concurrent sequences |
|
| `max_num_seqs` | 512 | Max concurrent sequences |
|
||||||
| `gpu_memory_utilization` | 0.9 | GPU memory fraction for KV cache |
|
| `gpu_memory_utilization` | 0.9 | GPU memory fraction for KV cache |
|
||||||
| `enforce_eager` | False | Disable CUDA graphs if True |
|
| `enforce_eager` | False | Disable CUDA graphs if True |
|
||||||
| `num_prefetch_blocks` | 2 | Ring buffer pipeline depth (deprecated, uses num_gpu_blocks) |
|
|
||||||
|
|
||||||
## Benchmarking
|
## Benchmarking
|
||||||
|
|
||||||
|
|||||||
@@ -109,7 +109,6 @@ def main():
|
|||||||
max_num_batched_tokens=max_len,
|
max_num_batched_tokens=max_len,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
num_gpu_blocks=8, # Small GPU buffer for offload testing
|
num_gpu_blocks=8, # Small GPU buffer for offload testing
|
||||||
num_prefetch_blocks=4,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.no_sparse:
|
if not args.no_sparse:
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ class Config:
|
|||||||
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
||||||
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
||||||
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
||||||
num_prefetch_blocks: int = 2 # Number of prefetch blocks for three-region GPU buffer design
|
|
||||||
|
|
||||||
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
||||||
num_gpu_kvcache_blocks: int = -1
|
num_gpu_kvcache_blocks: int = -1
|
||||||
|
|||||||
@@ -58,14 +58,12 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
from nanovllm.kvcache.policies import get_policy
|
from nanovllm.kvcache.policies import get_policy
|
||||||
|
|
||||||
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||||
num_prefetch_blocks = getattr(config, 'num_prefetch_blocks', 2)
|
|
||||||
|
|
||||||
return HybridKVCacheManager(
|
return HybridKVCacheManager(
|
||||||
num_gpu_slots=num_gpu_blocks,
|
num_gpu_slots=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
block_size=config.kvcache_block_size,
|
block_size=config.kvcache_block_size,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
num_prefetch_blocks=num_prefetch_blocks,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
policy: Optional[EvictionPolicy] = None,
|
policy: Optional[EvictionPolicy] = None,
|
||||||
num_prefetch_blocks: int = 2,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||||
@@ -99,13 +98,11 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||||
block_size: Tokens per block
|
block_size: Tokens per block
|
||||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||||
num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots)
|
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self.num_gpu_slots = num_gpu_slots
|
self.num_gpu_slots = num_gpu_slots
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||||
self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated)
|
|
||||||
|
|
||||||
# Eviction policy
|
# Eviction policy
|
||||||
self.policy = policy or LRUPolicy()
|
self.policy = policy or LRUPolicy()
|
||||||
@@ -170,7 +167,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
num_prefetch_blocks=self.num_prefetch_blocks,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ class OffloadEngine:
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
num_streams: int = 4,
|
num_streams: int = 4,
|
||||||
num_prefetch_blocks: int = 2,
|
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
@@ -82,8 +81,6 @@ class OffloadEngine:
|
|||||||
self.decode_load_slots = list(range(1, num_gpu_blocks))
|
self.decode_load_slots = list(range(1, num_gpu_blocks))
|
||||||
self.num_decode_load_slots = len(self.decode_load_slots)
|
self.num_decode_load_slots = len(self.decode_load_slots)
|
||||||
|
|
||||||
# Keep num_prefetch_blocks for compatibility (used as chunk size for loading)
|
|
||||||
self.num_prefetch_blocks = num_prefetch_blocks
|
|
||||||
self.num_gpu_slots = num_gpu_blocks # alias
|
self.num_gpu_slots = num_gpu_blocks # alias
|
||||||
|
|
||||||
logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total")
|
logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total")
|
||||||
|
|||||||
@@ -378,9 +378,9 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
# Use prefetch_size as chunk size for double buffering
|
# Chunk size = capacity of each double buffer region (compute/prefetch)
|
||||||
# This ensures both Compute and Prefetch regions can hold a full chunk
|
# Each region uses half of decode_load_slots
|
||||||
chunk_size = offload_engine.num_prefetch_blocks
|
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
|
||||||
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
||||||
|
|
||||||
o_acc = None
|
o_acc = None
|
||||||
|
|||||||
169
tests/test_chunked_attention.py
Normal file
169
tests/test_chunked_attention.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
Test script for chunked attention correctness.
|
||||||
|
|
||||||
|
Validates that chunked prefill using flash_attn_with_lse + merge_attention_outputs
|
||||||
|
produces the same result as full flash_attn_varlen_func.
|
||||||
|
|
||||||
|
Scenario: Simulating chunked prefill where we process query chunk by chunk.
|
||||||
|
For each query chunk i:
|
||||||
|
- KV contains all tokens from chunk 0 to chunk i
|
||||||
|
- Previous KV chunks (0 to i-1): full attention (no causal mask)
|
||||||
|
- Current KV chunk (i): causal attention (diagonal block)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_func
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Utility Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def compute_chunked_prefill_for_chunk(
|
||||||
|
q_chunk: torch.Tensor,
|
||||||
|
kv_chunks: list,
|
||||||
|
current_chunk_idx: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute attention for a single query chunk against all KV chunks up to current.
|
||||||
|
|
||||||
|
This simulates chunked prefill for query chunk `current_chunk_idx`:
|
||||||
|
- KV chunks 0 to current_chunk_idx-1: full attention (all previous tokens visible)
|
||||||
|
- KV chunk current_chunk_idx: causal attention (diagonal block)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q_chunk: [batch, chunk_size, nheads, headdim] - current query chunk
|
||||||
|
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
|
||||||
|
current_chunk_idx: Index of the current chunk being processed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: [batch, chunk_size, nheads, headdim]
|
||||||
|
"""
|
||||||
|
accumulated_o = None
|
||||||
|
accumulated_lse = None
|
||||||
|
|
||||||
|
for i in range(current_chunk_idx + 1):
|
||||||
|
k_chunk, v_chunk = kv_chunks[i]
|
||||||
|
|
||||||
|
# Previous chunks: no causal mask (all tokens visible)
|
||||||
|
# Current chunk (diagonal): causal mask
|
||||||
|
is_diagonal = (i == current_chunk_idx)
|
||||||
|
|
||||||
|
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||||
|
q_chunk, k_chunk, v_chunk, causal=is_diagonal
|
||||||
|
)
|
||||||
|
|
||||||
|
if accumulated_o is None:
|
||||||
|
accumulated_o = chunk_o
|
||||||
|
accumulated_lse = chunk_lse
|
||||||
|
else:
|
||||||
|
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||||
|
accumulated_o, accumulated_lse,
|
||||||
|
chunk_o, chunk_lse
|
||||||
|
)
|
||||||
|
|
||||||
|
return accumulated_o
|
||||||
|
|
||||||
|
|
||||||
|
def compute_reference_causal(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute reference causal attention using flash_attn_func.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q, k, v: [batch, seqlen, nheads, headdim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: [batch, seqlen, nheads, headdim]
|
||||||
|
"""
|
||||||
|
return flash_attn_func(q, k, v, causal=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test Script
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Test configurations: (batch, num_chunks, chunk_size, nheads, headdim)
|
||||||
|
TEST_CASES = [
|
||||||
|
(1, 4, 256, 8, 128),
|
||||||
|
(1, 4, 512, 8, 128),
|
||||||
|
(1, 8, 512, 8, 128),
|
||||||
|
(1, 4, 1024, 8, 128),
|
||||||
|
(1, 4, 1024, 32, 128), # More heads
|
||||||
|
(1, 8, 256, 8, 64), # Smaller head dim
|
||||||
|
]
|
||||||
|
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Test: Chunked Prefill Attention vs Reference (flash_attn_func causal)")
|
||||||
|
print("=" * 80)
|
||||||
|
print("Simulating chunked prefill: Q chunk attends to all KV chunks up to current")
|
||||||
|
print(" - Previous KV chunks: full attention (no causal mask)")
|
||||||
|
print(" - Current KV chunk (diagonal): causal attention")
|
||||||
|
print()
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
|
||||||
|
for dtype in DTYPES:
|
||||||
|
print(f"--- dtype: {dtype} ---")
|
||||||
|
|
||||||
|
for batch, num_chunks, chunk_size, nheads, headdim in TEST_CASES:
|
||||||
|
seqlen = num_chunks * chunk_size
|
||||||
|
|
||||||
|
# Generate full Q, K, V
|
||||||
|
q_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
k_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
v_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
# Reference: full causal attention
|
||||||
|
out_ref = compute_reference_causal(q_full, k_full, v_full)
|
||||||
|
|
||||||
|
# Split into chunks
|
||||||
|
q_chunks = [q_full[:, i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
|
||||||
|
kv_chunks = [
|
||||||
|
(k_full[:, i*chunk_size:(i+1)*chunk_size],
|
||||||
|
v_full[:, i*chunk_size:(i+1)*chunk_size])
|
||||||
|
for i in range(num_chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Compute chunked prefill for each query chunk
|
||||||
|
out_chunks = []
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
chunk_out = compute_chunked_prefill_for_chunk(
|
||||||
|
q_chunks[chunk_idx],
|
||||||
|
kv_chunks,
|
||||||
|
chunk_idx,
|
||||||
|
)
|
||||||
|
out_chunks.append(chunk_out)
|
||||||
|
|
||||||
|
# Concatenate chunked outputs
|
||||||
|
out_chunked = torch.cat(out_chunks, dim=1)
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
diff = (out_ref - out_chunked).abs()
|
||||||
|
max_diff = diff.max().item()
|
||||||
|
mean_diff = diff.mean().item()
|
||||||
|
|
||||||
|
# Tolerance: fp16/bf16 have limited precision
|
||||||
|
tol = 1e-2
|
||||||
|
passed = max_diff < tol
|
||||||
|
all_passed = all_passed and passed
|
||||||
|
|
||||||
|
status = "PASS" if passed else "FAIL"
|
||||||
|
print(
|
||||||
|
f"[{status}] seqlen={seqlen:5d} chunks={num_chunks} "
|
||||||
|
f"chunk_size={chunk_size:4d} heads={nheads:2d} dim={headdim:3d} "
|
||||||
|
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
assert all_passed, "Some tests failed!"
|
||||||
|
print("test_chunked_attention: PASSED")
|
||||||
@@ -5,17 +5,20 @@ Demonstrates: LLM initialization, prefill execution with CPU offload enabled.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||||
|
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Configuration
|
# Configuration
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||||
MAX_MODEL_LEN = 8192
|
MAX_MODEL_LEN = 32 * 1024
|
||||||
NUM_GPU_BLOCKS = 4
|
NUM_GPU_BLOCKS = 2
|
||||||
INPUT_LEN = 4096
|
INPUT_LEN = 16 * 1024
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Main Test Script
|
# Main Test Script
|
||||||
@@ -28,6 +31,7 @@ llm = LLM(
|
|||||||
max_model_len=MAX_MODEL_LEN,
|
max_model_len=MAX_MODEL_LEN,
|
||||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
|
kvcache_block_size=1024,
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
286
tests/test_sim.py
Normal file
286
tests/test_sim.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""
|
||||||
|
Chunked Prefill + KV Cache Offload Simulation v2
|
||||||
|
|
||||||
|
改进:
|
||||||
|
1. 简化日志输出
|
||||||
|
2. 添加reduce时间
|
||||||
|
3. 计算必须等待KV load完成
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, Future
|
||||||
|
|
||||||
|
# ============== 配置参数 ==============
|
||||||
|
NUM_CHUNKS = 8
|
||||||
|
GPU_SLOTS = 4
|
||||||
|
|
||||||
|
# 模拟时间 (秒)
|
||||||
|
TIME_COMPUTE_BLOCK = 0.10 # 计算一个attention block
|
||||||
|
TIME_REDUCE = 0.03 # 两个partial result做一次reduce
|
||||||
|
TIME_TRANSFER = 0.08 # 传输一个KV cache
|
||||||
|
TIME_PROJ = 0.02 # projection生成KV
|
||||||
|
|
||||||
|
# ============== 全局时间基准 ==============
|
||||||
|
START_TIME = None
|
||||||
|
|
||||||
|
def now() -> float:
|
||||||
|
"""返回相对开始的时间(ms)"""
|
||||||
|
return (time.time() - START_TIME) * 1000
|
||||||
|
|
||||||
|
def log_compute(msg: str):
|
||||||
|
"""计算队列日志(无缩进)"""
|
||||||
|
print(f"[{now():7.1f}ms] [COMPUTE] {msg}")
|
||||||
|
|
||||||
|
def log_transfer(msg: str):
|
||||||
|
"""传输队列日志(缩进)"""
|
||||||
|
print(f"[{now():7.1f}ms] [TRANSFER] {msg}")
|
||||||
|
|
||||||
|
def log_info(msg: str):
|
||||||
|
"""一般信息"""
|
||||||
|
print(f"[{now():7.1f}ms] {msg}")
|
||||||
|
|
||||||
|
# ============== GPU Slot管理 ==============
|
||||||
|
class GPUSlots:
|
||||||
|
def __init__(self, num_slots: int):
|
||||||
|
self.slots = [None] * num_slots # slot_id -> kv_idx
|
||||||
|
self.kv_to_slot = {} # kv_idx -> slot_id
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
# KV ready events: kv_idx -> Event
|
||||||
|
self.kv_ready = {}
|
||||||
|
|
||||||
|
def alloc(self, kv_idx: int) -> int:
|
||||||
|
with self.lock:
|
||||||
|
for sid, val in enumerate(self.slots):
|
||||||
|
if val is None:
|
||||||
|
self.slots[sid] = kv_idx
|
||||||
|
self.kv_to_slot[kv_idx] = sid
|
||||||
|
# 创建ready event
|
||||||
|
if kv_idx not in self.kv_ready:
|
||||||
|
self.kv_ready[kv_idx] = threading.Event()
|
||||||
|
return sid
|
||||||
|
raise RuntimeError(f"No free slot for KV{kv_idx}")
|
||||||
|
|
||||||
|
def free(self, slot_id: int):
|
||||||
|
with self.lock:
|
||||||
|
kv_idx = self.slots[slot_id]
|
||||||
|
if kv_idx is not None:
|
||||||
|
del self.kv_to_slot[kv_idx]
|
||||||
|
# 清除event
|
||||||
|
if kv_idx in self.kv_ready:
|
||||||
|
del self.kv_ready[kv_idx]
|
||||||
|
self.slots[slot_id] = None
|
||||||
|
|
||||||
|
def free_kv(self, kv_idx: int):
|
||||||
|
with self.lock:
|
||||||
|
if kv_idx in self.kv_to_slot:
|
||||||
|
sid = self.kv_to_slot[kv_idx]
|
||||||
|
self.slots[sid] = None
|
||||||
|
del self.kv_to_slot[kv_idx]
|
||||||
|
if kv_idx in self.kv_ready:
|
||||||
|
del self.kv_ready[kv_idx]
|
||||||
|
|
||||||
|
def mark_ready(self, kv_idx: int):
|
||||||
|
"""标记KV已就绪(load完成或proj完成)"""
|
||||||
|
with self.lock:
|
||||||
|
if kv_idx in self.kv_ready:
|
||||||
|
self.kv_ready[kv_idx].set()
|
||||||
|
|
||||||
|
def wait_ready(self, kv_idx: int):
|
||||||
|
"""等待KV就绪"""
|
||||||
|
with self.lock:
|
||||||
|
event = self.kv_ready.get(kv_idx)
|
||||||
|
if event:
|
||||||
|
event.wait()
|
||||||
|
|
||||||
|
def has_kv(self, kv_idx: int) -> bool:
|
||||||
|
with self.lock:
|
||||||
|
return kv_idx in self.kv_to_slot
|
||||||
|
|
||||||
|
def state(self) -> str:
|
||||||
|
with self.lock:
|
||||||
|
return "[" + "][".join(
|
||||||
|
f"KV{v}" if v is not None else "----"
|
||||||
|
for v in self.slots
|
||||||
|
) + "]"
|
||||||
|
|
||||||
|
# ============== 操作执行 ==============
|
||||||
|
class Executor:
|
||||||
|
def __init__(self, gpu: GPUSlots):
|
||||||
|
self.gpu = gpu
|
||||||
|
self.compute_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Compute")
|
||||||
|
self.transfer_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Transfer")
|
||||||
|
|
||||||
|
def proj_kv(self, q_idx: int) -> Future:
|
||||||
|
"""Projection生成KV,返回Future"""
|
||||||
|
def task():
|
||||||
|
log_compute(f"PROJ Q{q_idx}->KV{q_idx} START")
|
||||||
|
time.sleep(TIME_PROJ)
|
||||||
|
slot_id = self.gpu.alloc(q_idx)
|
||||||
|
self.gpu.mark_ready(q_idx) # proj完成,KV立即可用
|
||||||
|
log_compute(f"PROJ Q{q_idx}->KV{q_idx} END slot={slot_id} | {self.gpu.state()}")
|
||||||
|
return slot_id
|
||||||
|
return self.compute_pool.submit(task)
|
||||||
|
|
||||||
|
def compute_attn(self, q_idx: int, kv_indices: list) -> Future:
|
||||||
|
"""计算attention block,会等待所有KV就绪"""
|
||||||
|
def task():
|
||||||
|
# 等待所有需要的KV就绪
|
||||||
|
for kv_idx in kv_indices:
|
||||||
|
self.gpu.wait_ready(kv_idx)
|
||||||
|
|
||||||
|
kv_str = ",".join(map(str, kv_indices))
|
||||||
|
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] START")
|
||||||
|
time.sleep(TIME_COMPUTE_BLOCK * len(kv_indices))
|
||||||
|
log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] END")
|
||||||
|
return (q_idx, kv_indices)
|
||||||
|
return self.compute_pool.submit(task)
|
||||||
|
|
||||||
|
def reduce(self, q_idx: int, num_partials: int) -> Future:
|
||||||
|
"""Online softmax reduce多个partial结果"""
|
||||||
|
def task():
|
||||||
|
if num_partials <= 1:
|
||||||
|
return
|
||||||
|
# n个partial需要n-1次两两reduce
|
||||||
|
num_reduces = num_partials - 1
|
||||||
|
log_compute(f"REDUCE Q{q_idx} ({num_partials} partials) START")
|
||||||
|
time.sleep(TIME_REDUCE * num_reduces)
|
||||||
|
log_compute(f"REDUCE Q{q_idx} END")
|
||||||
|
return self.compute_pool.submit(task)
|
||||||
|
|
||||||
|
def load_kv(self, kv_idx: int) -> Future:
|
||||||
|
"""从CPU load KV到GPU"""
|
||||||
|
def task():
|
||||||
|
if self.gpu.has_kv(kv_idx):
|
||||||
|
log_transfer(f"LOAD KV{kv_idx} SKIP (already on GPU)")
|
||||||
|
return kv_idx
|
||||||
|
|
||||||
|
slot_id = self.gpu.alloc(kv_idx)
|
||||||
|
log_transfer(f"LOAD KV{kv_idx} START -> slot{slot_id}")
|
||||||
|
time.sleep(TIME_TRANSFER)
|
||||||
|
self.gpu.mark_ready(kv_idx) # load完成,标记就绪
|
||||||
|
log_transfer(f"LOAD KV{kv_idx} END | {self.gpu.state()}")
|
||||||
|
return kv_idx
|
||||||
|
return self.transfer_pool.submit(task)
|
||||||
|
|
||||||
|
def offload_kv(self, kv_idx: int) -> Future:
|
||||||
|
"""从GPU offload KV到CPU"""
|
||||||
|
def task():
|
||||||
|
log_transfer(f"OFFLOAD KV{kv_idx} START")
|
||||||
|
time.sleep(TIME_TRANSFER)
|
||||||
|
self.gpu.free_kv(kv_idx)
|
||||||
|
log_transfer(f"OFFLOAD KV{kv_idx} END | {self.gpu.state()}")
|
||||||
|
return kv_idx
|
||||||
|
return self.transfer_pool.submit(task)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.compute_pool.shutdown(wait=True)
|
||||||
|
self.transfer_pool.shutdown(wait=True)
|
||||||
|
|
||||||
|
# ============== 调度器 ==============
|
||||||
|
def schedule_query(exe: Executor, q_idx: int):
|
||||||
|
"""调度单个Query的处理"""
|
||||||
|
print(f"\n{'='*50}")
|
||||||
|
log_info(f"===== Query {q_idx} START =====")
|
||||||
|
|
||||||
|
hist_kv = list(range(q_idx)) # 历史KV: 0 ~ q_idx-1
|
||||||
|
num_partials = 0
|
||||||
|
|
||||||
|
# Phase 1: Projection生成当前KV
|
||||||
|
proj_fut = exe.proj_kv(q_idx)
|
||||||
|
proj_fut.result() # 等待完成
|
||||||
|
|
||||||
|
# Phase 2: 对角块计算 + 同时prefetch历史KV
|
||||||
|
# 启动对角块计算
|
||||||
|
diag_fut = exe.compute_attn(q_idx, [q_idx])
|
||||||
|
num_partials += 1
|
||||||
|
|
||||||
|
# 同时prefetch历史KV (最多3个slot可用)
|
||||||
|
prefetch_slots = min(len(hist_kv), GPU_SLOTS - 1)
|
||||||
|
prefetch_kv = hist_kv[:prefetch_slots]
|
||||||
|
prefetch_futs = [exe.load_kv(kv) for kv in prefetch_kv]
|
||||||
|
|
||||||
|
# 等待对角块完成
|
||||||
|
diag_fut.result()
|
||||||
|
|
||||||
|
# Phase 3: Offload当前KV释放slot
|
||||||
|
offload_fut = exe.offload_kv(q_idx)
|
||||||
|
|
||||||
|
# 等待prefetch完成,然后计算这批历史KV
|
||||||
|
for f in prefetch_futs:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
if prefetch_kv:
|
||||||
|
hist_fut = exe.compute_attn(q_idx, prefetch_kv)
|
||||||
|
num_partials += 1
|
||||||
|
else:
|
||||||
|
hist_fut = None
|
||||||
|
|
||||||
|
# 等待offload完成
|
||||||
|
offload_fut.result()
|
||||||
|
|
||||||
|
# Phase 4: 处理剩余历史KV
|
||||||
|
remaining_kv = hist_kv[prefetch_slots:]
|
||||||
|
computed_kv = prefetch_kv.copy()
|
||||||
|
|
||||||
|
while remaining_kv:
|
||||||
|
# 等待上一批计算完成
|
||||||
|
if hist_fut:
|
||||||
|
hist_fut.result()
|
||||||
|
|
||||||
|
# 释放已计算的KV
|
||||||
|
for kv in computed_kv:
|
||||||
|
exe.gpu.free_kv(kv)
|
||||||
|
|
||||||
|
# Load下一批
|
||||||
|
batch_size = min(len(remaining_kv), GPU_SLOTS)
|
||||||
|
batch_kv = remaining_kv[:batch_size]
|
||||||
|
remaining_kv = remaining_kv[batch_size:]
|
||||||
|
|
||||||
|
load_futs = [exe.load_kv(kv) for kv in batch_kv]
|
||||||
|
for f in load_futs:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
# 计算这批
|
||||||
|
hist_fut = exe.compute_attn(q_idx, batch_kv)
|
||||||
|
num_partials += 1
|
||||||
|
computed_kv = batch_kv
|
||||||
|
|
||||||
|
# 等待最后一批计算完成
|
||||||
|
if hist_fut:
|
||||||
|
hist_fut.result()
|
||||||
|
|
||||||
|
# 清理GPU
|
||||||
|
for kv in computed_kv:
|
||||||
|
exe.gpu.free_kv(kv)
|
||||||
|
|
||||||
|
# Phase 5: Reduce所有partial results
|
||||||
|
reduce_fut = exe.reduce(q_idx, num_partials)
|
||||||
|
reduce_fut.result()
|
||||||
|
|
||||||
|
log_info(f"===== Query {q_idx} END =====")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global START_TIME
|
||||||
|
START_TIME = time.time()
|
||||||
|
|
||||||
|
print("Chunked Prefill + KV Cache Offload Simulation v2")
|
||||||
|
print(f"Config: {NUM_CHUNKS} chunks, {GPU_SLOTS} GPU slots")
|
||||||
|
print(f"Time: compute={TIME_COMPUTE_BLOCK}s, transfer={TIME_TRANSFER}s, reduce={TIME_REDUCE}s")
|
||||||
|
|
||||||
|
gpu = GPUSlots(GPU_SLOTS)
|
||||||
|
exe = Executor(gpu)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for q_idx in range(NUM_CHUNKS):
|
||||||
|
schedule_query(exe, q_idx)
|
||||||
|
|
||||||
|
print(f"\n{'='*50}")
|
||||||
|
log_info(f"ALL DONE! Total: {now():.1f}ms")
|
||||||
|
finally:
|
||||||
|
exe.shutdown()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user