Files
nano-vllm/tests/test_flashinfer_merge.py
2025-12-30 01:11:13 +08:00

105 lines
3.2 KiB
Python

"""
Test FlashInfer chunked attention with CPU offload.
Uses single_prefill_with_kv_cache + merge_state for chunked KV processing.
"""
import torch
import flashinfer
# ============================================================
# Core Functions
# ============================================================
def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size):
"""
Chunked causal attention with KV on CPU.
q: [seq_q, num_heads, head_dim] on GPU
k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU
"""
seq_q = q.shape[0]
seq_kv = k_cpu.shape[0]
final_outputs = []
for q_start in range(0, seq_q, q_chunk_size):
q_end = min(q_start + q_chunk_size, seq_q)
q_chunk = q[q_start:q_end]
merged_output = None
merged_lse = None
for kv_start in range(0, seq_kv, kv_chunk_size):
kv_end = min(kv_start + kv_chunk_size, seq_kv)
if kv_start >= q_end:
continue
k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
causal = not (kv_end <= q_start)
partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache(
q_chunk, k_chunk, v_chunk,
causal=causal,
return_lse=True,
)
if merged_output is None:
merged_output, merged_lse = partial_out, partial_lse
else:
merged_output, merged_lse = flashinfer.merge_state(
merged_output, merged_lse,
partial_out, partial_lse,
)
final_outputs.append(merged_output)
return torch.cat(final_outputs, dim=0)
# ============================================================
# Main Test Script
# ============================================================
print("=" * 60)
print("Testing FlashInfer chunked attention with CPU offload")
print("=" * 60)
num_heads = 32
num_kv_heads = 8
head_dim = 128
test_configs = [
(32768, 8192, 8192), # 32K tokens
(65536, 8192, 8192), # 64K tokens
(131072, 16384, 16384), # 128K tokens
# (262144, 16384, 16384), # 256K tokens (slow)
# (524288, 16384, 16384), # 512K tokens (slow)
]
for seq_len, q_chunk, kv_chunk in test_configs:
torch.manual_seed(42)
q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda')
k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
# Chunked result
chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk)
# Reference
k_gpu = k_cpu.to('cuda')
v_gpu = v_cpu.to('cuda')
ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True)
max_diff = (ref_out - chunked_out).abs().max().item()
mean_diff = (ref_out - chunked_out).abs().mean().item()
num_chunks = (seq_len + q_chunk - 1) // q_chunk
assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}"
print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print("\ntest_flashinfer_merge: PASSED")