- Update task_plan.md with 6-phase segmented graph implementation plan - Add findings.md documenting 7 key discoveries about current implementation - Add progress.md for tracking implementation progress - Add test_chunk_attention_graph_reuse.py validating 2-graph reuse strategy Key architecture decision: Split transformer layer into 3 segments: - PRE-ATTENTION GRAPH: norm → qkv_proj → rotary (1 graph, reused) - CHUNKED ATTENTION: H2D (eager) + flash_attn (2 graphs) + merge (eager) - POST-ATTENTION GRAPH: o_proj → norm → FFN (1 graph, reused) Total: 4 graphs serving all layers via copy_() tensor updates. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
157 lines
5.1 KiB
Python
157 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
|
||
|
||
Key insight: LLM layers have identical computation structure.
|
||
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
|
||
|
||
Usage:
|
||
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
|
||
"""
|
||
|
||
from dataclasses import dataclass
|
||
|
||
import torch
|
||
|
||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||
|
||
|
||
@dataclass
|
||
class ReusableChunkGraph:
|
||
"""A single graph that can be reused with copy_() updates."""
|
||
graph: torch.cuda.CUDAGraph
|
||
static_q: torch.Tensor
|
||
static_k: torch.Tensor
|
||
static_v: torch.Tensor
|
||
static_output: torch.Tensor
|
||
static_lse: torch.Tensor
|
||
|
||
|
||
def capture_reusable_graph(
|
||
chunk_size: int,
|
||
num_heads: int,
|
||
num_kv_heads: int,
|
||
head_dim: int,
|
||
scale: float,
|
||
device: torch.device,
|
||
dtype: torch.dtype,
|
||
causal: bool,
|
||
) -> ReusableChunkGraph:
|
||
"""Capture ONE graph to be reused for all chunk pairs."""
|
||
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
|
||
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||
|
||
static_q.normal_()
|
||
static_k.normal_()
|
||
static_v.normal_()
|
||
|
||
# Warmup
|
||
with torch.inference_mode():
|
||
for _ in range(3):
|
||
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||
torch.cuda.synchronize()
|
||
|
||
# Capture
|
||
graph = torch.cuda.CUDAGraph()
|
||
with torch.inference_mode():
|
||
with torch.cuda.graph(graph):
|
||
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||
|
||
torch.cuda.synchronize()
|
||
|
||
return ReusableChunkGraph(
|
||
graph=graph,
|
||
static_q=static_q,
|
||
static_k=static_k,
|
||
static_v=static_v,
|
||
static_output=static_output,
|
||
static_lse=static_lse,
|
||
)
|
||
|
||
|
||
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||
"""Replay graph after updating static tensors with copy_()."""
|
||
graph.static_q.copy_(q)
|
||
graph.static_k.copy_(k)
|
||
graph.static_v.copy_(v)
|
||
graph.graph.replay()
|
||
return graph.static_output.clone(), graph.static_lse.clone()
|
||
|
||
|
||
def main():
|
||
device = torch.device("cuda")
|
||
dtype = torch.bfloat16
|
||
|
||
chunk_size = 64
|
||
num_chunks = 4
|
||
num_layers = 3 # Simulate multiple layers
|
||
num_heads = 8
|
||
num_kv_heads = 8
|
||
head_dim = 64
|
||
scale = 1.0 / (head_dim ** 0.5)
|
||
seq_len = chunk_size * num_chunks
|
||
|
||
print(f"Device: {torch.cuda.get_device_name()}")
|
||
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
|
||
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
|
||
|
||
# Capture only 2 graphs
|
||
graph_causal = capture_reusable_graph(
|
||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
|
||
)
|
||
graph_non_causal = capture_reusable_graph(
|
||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
|
||
)
|
||
print("2 graphs captured (causal + non-causal)")
|
||
|
||
all_pass = True
|
||
|
||
for layer_id in range(num_layers):
|
||
# Different Q/K/V for each layer (simulating different layer outputs)
|
||
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
||
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||
|
||
# Reference: full causal attention
|
||
with torch.inference_mode():
|
||
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
|
||
|
||
# Chunked with graph reuse
|
||
chunked_output = torch.zeros_like(full_output)
|
||
|
||
for q_idx in range(num_chunks):
|
||
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
|
||
acc_out, acc_lse = None, None
|
||
|
||
for k_idx in range(q_idx + 1):
|
||
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
|
||
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
|
||
|
||
# Reuse graph with copy_()
|
||
graph = graph_causal if k_idx == q_idx else graph_non_causal
|
||
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
|
||
|
||
if acc_out is None:
|
||
acc_out, acc_lse = out, lse
|
||
else:
|
||
with torch.inference_mode():
|
||
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
|
||
|
||
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
|
||
|
||
torch.cuda.synchronize()
|
||
|
||
# Compare
|
||
max_diff = (full_output - chunked_output).abs().max().item()
|
||
status = "✅" if max_diff < 1e-2 else "❌"
|
||
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
|
||
if max_diff >= 1e-2:
|
||
all_pass = False
|
||
|
||
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|