105 lines
3.2 KiB
Python
105 lines
3.2 KiB
Python
"""
|
|
Test FlashInfer chunked attention with CPU offload.
|
|
|
|
Uses single_prefill_with_kv_cache + merge_state for chunked KV processing.
|
|
"""
|
|
|
|
import torch
|
|
import flashinfer
|
|
|
|
|
|
# ============================================================
|
|
# Core Functions
|
|
# ============================================================
|
|
|
|
def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size):
|
|
"""
|
|
Chunked causal attention with KV on CPU.
|
|
|
|
q: [seq_q, num_heads, head_dim] on GPU
|
|
k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU
|
|
"""
|
|
seq_q = q.shape[0]
|
|
seq_kv = k_cpu.shape[0]
|
|
final_outputs = []
|
|
|
|
for q_start in range(0, seq_q, q_chunk_size):
|
|
q_end = min(q_start + q_chunk_size, seq_q)
|
|
q_chunk = q[q_start:q_end]
|
|
|
|
merged_output = None
|
|
merged_lse = None
|
|
|
|
for kv_start in range(0, seq_kv, kv_chunk_size):
|
|
kv_end = min(kv_start + kv_chunk_size, seq_kv)
|
|
|
|
if kv_start >= q_end:
|
|
continue
|
|
|
|
k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
|
|
v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
|
|
|
|
causal = not (kv_end <= q_start)
|
|
partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache(
|
|
q_chunk, k_chunk, v_chunk,
|
|
causal=causal,
|
|
return_lse=True,
|
|
)
|
|
|
|
if merged_output is None:
|
|
merged_output, merged_lse = partial_out, partial_lse
|
|
else:
|
|
merged_output, merged_lse = flashinfer.merge_state(
|
|
merged_output, merged_lse,
|
|
partial_out, partial_lse,
|
|
)
|
|
|
|
final_outputs.append(merged_output)
|
|
|
|
return torch.cat(final_outputs, dim=0)
|
|
|
|
|
|
# ============================================================
|
|
# Main Test Script
|
|
# ============================================================
|
|
|
|
print("=" * 60)
|
|
print("Testing FlashInfer chunked attention with CPU offload")
|
|
print("=" * 60)
|
|
|
|
num_heads = 32
|
|
num_kv_heads = 8
|
|
head_dim = 128
|
|
|
|
test_configs = [
|
|
(32768, 8192, 8192), # 32K tokens
|
|
(65536, 8192, 8192), # 64K tokens
|
|
(131072, 16384, 16384), # 128K tokens
|
|
# (262144, 16384, 16384), # 256K tokens (slow)
|
|
# (524288, 16384, 16384), # 512K tokens (slow)
|
|
]
|
|
|
|
for seq_len, q_chunk, kv_chunk in test_configs:
|
|
torch.manual_seed(42)
|
|
|
|
q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda')
|
|
k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
|
|
v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
|
|
|
|
# Chunked result
|
|
chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk)
|
|
|
|
# Reference
|
|
k_gpu = k_cpu.to('cuda')
|
|
v_gpu = v_cpu.to('cuda')
|
|
ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True)
|
|
|
|
max_diff = (ref_out - chunked_out).abs().max().item()
|
|
mean_diff = (ref_out - chunked_out).abs().mean().item()
|
|
|
|
num_chunks = (seq_len + q_chunk - 1) // q_chunk
|
|
assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}"
|
|
print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
|
|
|
print("\ntest_flashinfer_merge: PASSED")
|