diff --git a/tests/test_flashinfer_merge.py b/tests/test_flashinfer_merge.py new file mode 100644 index 0000000..7aa57a6 --- /dev/null +++ b/tests/test_flashinfer_merge.py @@ -0,0 +1,104 @@ +""" +Test FlashInfer chunked attention with CPU offload. + +Uses single_prefill_with_kv_cache + merge_state for chunked KV processing. +""" + +import torch +import flashinfer + + +# ============================================================ +# Core Functions +# ============================================================ + +def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size): + """ + Chunked causal attention with KV on CPU. + + q: [seq_q, num_heads, head_dim] on GPU + k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU + """ + seq_q = q.shape[0] + seq_kv = k_cpu.shape[0] + final_outputs = [] + + for q_start in range(0, seq_q, q_chunk_size): + q_end = min(q_start + q_chunk_size, seq_q) + q_chunk = q[q_start:q_end] + + merged_output = None + merged_lse = None + + for kv_start in range(0, seq_kv, kv_chunk_size): + kv_end = min(kv_start + kv_chunk_size, seq_kv) + + if kv_start >= q_end: + continue + + k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True) + v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True) + + causal = not (kv_end <= q_start) + partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache( + q_chunk, k_chunk, v_chunk, + causal=causal, + return_lse=True, + ) + + if merged_output is None: + merged_output, merged_lse = partial_out, partial_lse + else: + merged_output, merged_lse = flashinfer.merge_state( + merged_output, merged_lse, + partial_out, partial_lse, + ) + + final_outputs.append(merged_output) + + return torch.cat(final_outputs, dim=0) + + +# ============================================================ +# Main Test Script +# ============================================================ + +print("=" * 60) +print("Testing FlashInfer chunked attention with CPU offload") +print("=" * 60) + +num_heads = 32 +num_kv_heads = 8 +head_dim = 128 + +test_configs = [ + (32768, 8192, 8192), # 32K tokens + (65536, 8192, 8192), # 64K tokens + (131072, 16384, 16384), # 128K tokens + # (262144, 16384, 16384), # 256K tokens (slow) + # (524288, 16384, 16384), # 512K tokens (slow) +] + +for seq_len, q_chunk, kv_chunk in test_configs: + torch.manual_seed(42) + + q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda') + k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu') + v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu') + + # Chunked result + chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk) + + # Reference + k_gpu = k_cpu.to('cuda') + v_gpu = v_cpu.to('cuda') + ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True) + + max_diff = (ref_out - chunked_out).abs().max().item() + mean_diff = (ref_out - chunked_out).abs().mean().item() + + num_chunks = (seq_len + q_chunk - 1) // q_chunk + assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}" + print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + +print("\ntest_flashinfer_merge: PASSED") diff --git a/tests/test_pinned_memory_slice.py b/tests/test_pinned_memory_slice.py deleted file mode 100644 index d948008..0000000 --- a/tests/test_pinned_memory_slice.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Test if slicing maintains pinned memory property. -""" - -import torch - -print("=" * 60) -print("Test: Pinned Memory Property with Slicing") -print("=" * 60) - -# Create a pinned tensor with shape similar to k_cache_cpu -# [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim] -tensor = torch.zeros(8, 16, 1024, 8, 64, dtype=torch.float16, device="cpu", pin_memory=True) - -print(f"\n1. Original tensor:") -print(f" - Shape: {tensor.shape}") -print(f" - is_pinned(): {tensor.is_pinned()}") -print(f" - is_contiguous(): {tensor.is_contiguous()}") - -# Test slicing operation (what we do in offload_slot_to_cpu) -slice_view = tensor[:, 0] # Same as k_cache_cpu[:, cpu_block_id] - -print(f"\n2. Sliced tensor [:, 0]:") -print(f" - Shape: {slice_view.shape}") -print(f" - is_pinned(): {slice_view.is_pinned()}") -print(f" - is_contiguous(): {slice_view.is_contiguous()}") - -# Test if contiguous() helps -contiguous_slice = tensor[:, 0].contiguous() - -print(f"\n3. Contiguous slice [:, 0].contiguous():") -print(f" - Shape: {contiguous_slice.shape}") -print(f" - is_pinned(): {contiguous_slice.is_pinned()}") -print(f" - is_contiguous(): {contiguous_slice.is_contiguous()}") - -# Test copy behavior -gpu_tensor = torch.zeros(8, 4, 1024, 8, 64, dtype=torch.float16, device="cuda") -gpu_slice = gpu_tensor[:, 0] - -print(f"\n4. GPU tensor slice:") -print(f" - Shape: {gpu_slice.shape}") -print(f" - is_contiguous(): {gpu_slice.is_contiguous()}") - -# Simulate the problematic copy operation -print(f"\n5. Testing copy operations:") - -# Method 1: Direct slice copy (current approach - SLOW) -slice_dst = tensor[:, 1] -print(f" Method 1 (slice view): dst.is_pinned()={slice_dst.is_pinned()}") - -# Method 2: Use contiguous destination -contiguous_dst = tensor[:, 2].contiguous() -print(f" Method 2 (contiguous): dst.is_pinned()={contiguous_dst.is_pinned()}") - -print("\n" + "=" * 60) -print("Conclusion:") -print("=" * 60) - -if not slice_view.is_pinned(): - print("❌ Slicing LOSES pinned memory property!") - print(" This causes Device-to-Pageable transfers (SLOW)") -else: - print("✓ Slicing maintains pinned memory property") - -if contiguous_slice.is_pinned(): - print("✓ .contiguous() maintains pinned memory property") -else: - print("❌ .contiguous() also loses pinned memory property") - -print("\n" + "=" * 60) diff --git a/tests/test_pinned_transfer.py b/tests/test_pinned_transfer.py deleted file mode 100644 index 937d423..0000000 --- a/tests/test_pinned_transfer.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -Test D2H transfer performance with pinned vs non-contiguous memory. -""" - -import torch -import time - -print("=" * 60) -print("Test: D2H Transfer Performance (for nsys profiling)") -print("=" * 60) - -# Setup -num_layers = 8 -num_blocks = 16 -block_size = 1024 -num_kv_heads = 8 -head_dim = 64 - -# Allocate CPU cache (pinned) -k_cache_cpu = torch.zeros( - num_layers, num_blocks, block_size, num_kv_heads, head_dim, - dtype=torch.float16, device="cpu", pin_memory=True -) - -# Allocate GPU cache -k_cache_gpu = torch.randn( - num_layers, 4, block_size, num_kv_heads, head_dim, - dtype=torch.float16, device="cuda" -) - -# Warmup -print("\nWarmup...") -for _ in range(10): - k_cache_cpu[:, 0].copy_(k_cache_gpu[:, 0], non_blocking=True) - torch.cuda.synchronize() - -print(f"\nTensor info:") -print(f" k_cache_cpu.is_pinned(): {k_cache_cpu.is_pinned()}") -print(f" k_cache_cpu.is_contiguous(): {k_cache_cpu.is_contiguous()}") -print(f" k_cache_cpu[:, 0].is_pinned(): {k_cache_cpu[:, 0].is_pinned()}") -print(f" k_cache_cpu[:, 0].is_contiguous(): {k_cache_cpu[:, 0].is_contiguous()}") - -# Test 1: Non-contiguous slice (current approach) -print(f"\n" + "=" * 60) -print("Test 1: Non-contiguous slice copy (current approach)") -print("=" * 60) - -NUM_ITERS = 50 # Reduced for profiling - -torch.cuda.nvtx.range_push("Test1_NonContiguous") -times = [] -for i in range(NUM_ITERS): - torch.cuda.nvtx.range_push(f"D2H_NonContig_{i}") - start = time.perf_counter() - k_cache_cpu[:, i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True) - torch.cuda.synchronize() - times.append(time.perf_counter() - start) - torch.cuda.nvtx.range_pop() -torch.cuda.nvtx.range_pop() - -avg_time = sum(times) / len(times) -print(f"Average time: {avg_time * 1000:.3f} ms") -print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s") - -# Test 2: Transpose to make dimension contiguous -print(f"\n" + "=" * 60) -print("Test 2: Transpose to contiguous dimension") -print("=" * 60) - -# Reshape to [num_blocks, num_layers, block_size, num_kv_heads, head_dim] -k_cache_cpu_transposed = torch.zeros( - num_blocks, num_layers, block_size, num_kv_heads, head_dim, - dtype=torch.float16, device="cpu", pin_memory=True -) - -print(f" k_cache_cpu_transposed[0].is_pinned(): {k_cache_cpu_transposed[0].is_pinned()}") -print(f" k_cache_cpu_transposed[0].is_contiguous(): {k_cache_cpu_transposed[0].is_contiguous()}") - -torch.cuda.nvtx.range_push("Test2_Contiguous") -times = [] -for i in range(NUM_ITERS): - torch.cuda.nvtx.range_push(f"D2H_Contig_{i}") - start = time.perf_counter() - k_cache_cpu_transposed[i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True) - torch.cuda.synchronize() - times.append(time.perf_counter() - start) - torch.cuda.nvtx.range_pop() -torch.cuda.nvtx.range_pop() - -avg_time = sum(times) / len(times) -print(f"Average time: {avg_time * 1000:.3f} ms") -print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s") - -# Test 3: Fully contiguous buffer -print(f"\n" + "=" * 60) -print("Test 3: Fully contiguous buffer") -print("=" * 60) - -k_cache_cpu_flat = torch.zeros( - num_layers * block_size * num_kv_heads * head_dim, - dtype=torch.float16, device="cpu", pin_memory=True -) - -print(f" k_cache_cpu_flat.is_pinned(): {k_cache_cpu_flat.is_pinned()}") -print(f" k_cache_cpu_flat.is_contiguous(): {k_cache_cpu_flat.is_contiguous()}") - -torch.cuda.nvtx.range_push("Test3_FlatContiguous") -times = [] -for i in range(NUM_ITERS): - torch.cuda.nvtx.range_push(f"D2H_Flat_{i}") - start = time.perf_counter() - k_cache_cpu_flat.copy_(k_cache_gpu[:, 0].flatten(), non_blocking=True) - torch.cuda.synchronize() - times.append(time.perf_counter() - start) - torch.cuda.nvtx.range_pop() -torch.cuda.nvtx.range_pop() - -avg_time = sum(times) / len(times) -print(f"Average time: {avg_time * 1000:.3f} ms") -print(f"Bandwidth: {k_cache_cpu_flat.numel() * 2 / avg_time / 1e9:.2f} GB/s") - -print("\n" + "=" * 60) -print("test_pinned_transfer: PASSED") -print("=" * 60) diff --git a/tests/test_sim.py b/tests/test_sim.py deleted file mode 100644 index a113ddc..0000000 --- a/tests/test_sim.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -Chunked Prefill + KV Cache Offload Simulation v2 - -改进: -1. 简化日志输出 -2. 添加reduce时间 -3. 计算必须等待KV load完成 -""" - -import threading -import time -from dataclasses import dataclass -from typing import Optional -from concurrent.futures import ThreadPoolExecutor, Future - -# ============== 配置参数 ============== -NUM_CHUNKS = 8 -GPU_SLOTS = 4 - -# 模拟时间 (秒) -TIME_COMPUTE_BLOCK = 0.10 # 计算一个attention block -TIME_REDUCE = 0.03 # 两个partial result做一次reduce -TIME_TRANSFER = 0.08 # 传输一个KV cache -TIME_PROJ = 0.02 # projection生成KV - -# ============== 全局时间基准 ============== -START_TIME = None - -def now() -> float: - """返回相对开始的时间(ms)""" - return (time.time() - START_TIME) * 1000 - -def log_compute(msg: str): - """计算队列日志(无缩进)""" - print(f"[{now():7.1f}ms] [COMPUTE] {msg}") - -def log_transfer(msg: str): - """传输队列日志(缩进)""" - print(f"[{now():7.1f}ms] [TRANSFER] {msg}") - -def log_info(msg: str): - """一般信息""" - print(f"[{now():7.1f}ms] {msg}") - -# ============== GPU Slot管理 ============== -class GPUSlots: - def __init__(self, num_slots: int): - self.slots = [None] * num_slots # slot_id -> kv_idx - self.kv_to_slot = {} # kv_idx -> slot_id - self.lock = threading.Lock() - # KV ready events: kv_idx -> Event - self.kv_ready = {} - - def alloc(self, kv_idx: int) -> int: - with self.lock: - for sid, val in enumerate(self.slots): - if val is None: - self.slots[sid] = kv_idx - self.kv_to_slot[kv_idx] = sid - # 创建ready event - if kv_idx not in self.kv_ready: - self.kv_ready[kv_idx] = threading.Event() - return sid - raise RuntimeError(f"No free slot for KV{kv_idx}") - - def free(self, slot_id: int): - with self.lock: - kv_idx = self.slots[slot_id] - if kv_idx is not None: - del self.kv_to_slot[kv_idx] - # 清除event - if kv_idx in self.kv_ready: - del self.kv_ready[kv_idx] - self.slots[slot_id] = None - - def free_kv(self, kv_idx: int): - with self.lock: - if kv_idx in self.kv_to_slot: - sid = self.kv_to_slot[kv_idx] - self.slots[sid] = None - del self.kv_to_slot[kv_idx] - if kv_idx in self.kv_ready: - del self.kv_ready[kv_idx] - - def mark_ready(self, kv_idx: int): - """标记KV已就绪(load完成或proj完成)""" - with self.lock: - if kv_idx in self.kv_ready: - self.kv_ready[kv_idx].set() - - def wait_ready(self, kv_idx: int): - """等待KV就绪""" - with self.lock: - event = self.kv_ready.get(kv_idx) - if event: - event.wait() - - def has_kv(self, kv_idx: int) -> bool: - with self.lock: - return kv_idx in self.kv_to_slot - - def state(self) -> str: - with self.lock: - return "[" + "][".join( - f"KV{v}" if v is not None else "----" - for v in self.slots - ) + "]" - -# ============== 操作执行 ============== -class Executor: - def __init__(self, gpu: GPUSlots): - self.gpu = gpu - self.compute_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Compute") - self.transfer_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Transfer") - - def proj_kv(self, q_idx: int) -> Future: - """Projection生成KV,返回Future""" - def task(): - log_compute(f"PROJ Q{q_idx}->KV{q_idx} START") - time.sleep(TIME_PROJ) - slot_id = self.gpu.alloc(q_idx) - self.gpu.mark_ready(q_idx) # proj完成,KV立即可用 - log_compute(f"PROJ Q{q_idx}->KV{q_idx} END slot={slot_id} | {self.gpu.state()}") - return slot_id - return self.compute_pool.submit(task) - - def compute_attn(self, q_idx: int, kv_indices: list) -> Future: - """计算attention block,会等待所有KV就绪""" - def task(): - # 等待所有需要的KV就绪 - for kv_idx in kv_indices: - self.gpu.wait_ready(kv_idx) - - kv_str = ",".join(map(str, kv_indices)) - log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] START") - time.sleep(TIME_COMPUTE_BLOCK * len(kv_indices)) - log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] END") - return (q_idx, kv_indices) - return self.compute_pool.submit(task) - - def reduce(self, q_idx: int, num_partials: int) -> Future: - """Online softmax reduce多个partial结果""" - def task(): - if num_partials <= 1: - return - # n个partial需要n-1次两两reduce - num_reduces = num_partials - 1 - log_compute(f"REDUCE Q{q_idx} ({num_partials} partials) START") - time.sleep(TIME_REDUCE * num_reduces) - log_compute(f"REDUCE Q{q_idx} END") - return self.compute_pool.submit(task) - - def load_kv(self, kv_idx: int) -> Future: - """从CPU load KV到GPU""" - def task(): - if self.gpu.has_kv(kv_idx): - log_transfer(f"LOAD KV{kv_idx} SKIP (already on GPU)") - return kv_idx - - slot_id = self.gpu.alloc(kv_idx) - log_transfer(f"LOAD KV{kv_idx} START -> slot{slot_id}") - time.sleep(TIME_TRANSFER) - self.gpu.mark_ready(kv_idx) # load完成,标记就绪 - log_transfer(f"LOAD KV{kv_idx} END | {self.gpu.state()}") - return kv_idx - return self.transfer_pool.submit(task) - - def offload_kv(self, kv_idx: int) -> Future: - """从GPU offload KV到CPU""" - def task(): - log_transfer(f"OFFLOAD KV{kv_idx} START") - time.sleep(TIME_TRANSFER) - self.gpu.free_kv(kv_idx) - log_transfer(f"OFFLOAD KV{kv_idx} END | {self.gpu.state()}") - return kv_idx - return self.transfer_pool.submit(task) - - def shutdown(self): - self.compute_pool.shutdown(wait=True) - self.transfer_pool.shutdown(wait=True) - -# ============== 调度器 ============== -def schedule_query(exe: Executor, q_idx: int): - """调度单个Query的处理""" - print(f"\n{'='*50}") - log_info(f"===== Query {q_idx} START =====") - - hist_kv = list(range(q_idx)) # 历史KV: 0 ~ q_idx-1 - num_partials = 0 - - # Phase 1: Projection生成当前KV - proj_fut = exe.proj_kv(q_idx) - proj_fut.result() # 等待完成 - - # Phase 2: 对角块计算 + 同时prefetch历史KV - # 启动对角块计算 - diag_fut = exe.compute_attn(q_idx, [q_idx]) - num_partials += 1 - - # 同时prefetch历史KV (最多3个slot可用) - prefetch_slots = min(len(hist_kv), GPU_SLOTS - 1) - prefetch_kv = hist_kv[:prefetch_slots] - prefetch_futs = [exe.load_kv(kv) for kv in prefetch_kv] - - # 等待对角块完成 - diag_fut.result() - - # Phase 3: Offload当前KV释放slot - offload_fut = exe.offload_kv(q_idx) - - # 等待prefetch完成,然后计算这批历史KV - for f in prefetch_futs: - f.result() - - if prefetch_kv: - hist_fut = exe.compute_attn(q_idx, prefetch_kv) - num_partials += 1 - else: - hist_fut = None - - # 等待offload完成 - offload_fut.result() - - # Phase 4: 处理剩余历史KV - remaining_kv = hist_kv[prefetch_slots:] - computed_kv = prefetch_kv.copy() - - while remaining_kv: - # 等待上一批计算完成 - if hist_fut: - hist_fut.result() - - # 释放已计算的KV - for kv in computed_kv: - exe.gpu.free_kv(kv) - - # Load下一批 - batch_size = min(len(remaining_kv), GPU_SLOTS) - batch_kv = remaining_kv[:batch_size] - remaining_kv = remaining_kv[batch_size:] - - load_futs = [exe.load_kv(kv) for kv in batch_kv] - for f in load_futs: - f.result() - - # 计算这批 - hist_fut = exe.compute_attn(q_idx, batch_kv) - num_partials += 1 - computed_kv = batch_kv - - # 等待最后一批计算完成 - if hist_fut: - hist_fut.result() - - # 清理GPU - for kv in computed_kv: - exe.gpu.free_kv(kv) - - # Phase 5: Reduce所有partial results - reduce_fut = exe.reduce(q_idx, num_partials) - reduce_fut.result() - - log_info(f"===== Query {q_idx} END =====") - -def main(): - global START_TIME - START_TIME = time.time() - - print("Chunked Prefill + KV Cache Offload Simulation v2") - print(f"Config: {NUM_CHUNKS} chunks, {GPU_SLOTS} GPU slots") - print(f"Time: compute={TIME_COMPUTE_BLOCK}s, transfer={TIME_TRANSFER}s, reduce={TIME_REDUCE}s") - - gpu = GPUSlots(GPU_SLOTS) - exe = Executor(gpu) - - try: - for q_idx in range(NUM_CHUNKS): - schedule_query(exe, q_idx) - - print(f"\n{'='*50}") - log_info(f"ALL DONE! Total: {now():.1f}ms") - finally: - exe.shutdown() - -if __name__ == "__main__": - main() \ No newline at end of file