🗑️ chore: clean up tests directory to essential files only

Keep only core test files:
- test_ruler.py - main RULER benchmark
- test_xattn_estimate_alignment.py - XAttn kernel validation
- utils.py - shared utilities

Remove 8 files (recoverable from git history):
- bench_estimate_block_size.py
- modeling_qwen3.py
- test_chunk_attention_graph_reuse.py
- test_cudagraph_memory.py
- test_gpuonly_density_alignment.py
- test_hierarchical_estimate.py
- test_quest_policy.py
- test_sequential.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-02-05 03:13:50 +08:00
parent 2b61c5ab57
commit d35dd76e09
8 changed files with 0 additions and 2510 deletions

View File

@@ -1,314 +0,0 @@
"""
Benchmark: block_size impact on XAttention estimate phase performance.
This script tests how different block_size values affect the performance of:
1. flat_group_gemm_fuse_reshape (estimate GEMM)
2. softmax_fuse_block_sum (estimate softmax + block aggregation)
Key insight: The current select_blocks uses global kvcache_block_size for estimation,
which may not be optimal for the Triton kernels.
"""
import sys
import os
import torch
import time
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# Configuration
# ============================================================
# Test configurations
BLOCK_SIZES = [64, 128, 256, 512] # BSA optimal is 128
STRIDE = 8
NUM_WARMUP = 3
NUM_RUNS = 10
# Model dimensions (Llama-3.1-8B-Instruct)
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
# Context lengths to test
CONTEXT_LENGTHS = [16384, 32768, 65536] # 16K, 32K, 64K
# ============================================================
# Benchmark Functions
# ============================================================
def benchmark_flat_group_gemm(Q, K, stride, block_size, num_warmup=3, num_runs=10):
"""
Benchmark flat_group_gemm_fuse_reshape kernel.
Args:
Q: [batch, heads, q_len, head_dim]
K: [batch, heads, k_len, head_dim]
stride: Stride for reshape
block_size: Block size (affects alignment requirements)
Returns:
(avg_time_ms, output_tensor)
"""
q_len = Q.shape[2]
k_len = K.shape[2]
# Compute reshaped dimensions
reshaped_q_len = q_len // stride
reshaped_k_len = k_len // stride
reshaped_block_size = block_size // stride
# Warmup
for _ in range(num_warmup):
_ = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=reshaped_q_len,
is_causal=False,
)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_runs):
output = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=reshaped_q_len,
is_causal=False,
)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_runs * 1000
return avg_time_ms, output
def benchmark_softmax_fuse_block_sum(attn_weights, reshaped_block_size, num_warmup=3, num_runs=10):
"""
Benchmark softmax_fuse_block_sum kernel.
Args:
attn_weights: [batch, heads, q_len, k_len] attention weights
reshaped_block_size: Block size in reshaped space
Returns:
avg_time_ms
"""
batch_size, num_heads, q_len, k_len = attn_weights.shape
head_dim = HEAD_DIM
stride = STRIDE
norm = 1.0
# segment_size must divide k_len and be >= reshaped_block_size
segment_size = min(4096, reshaped_block_size)
# Ensure k_len is divisible by segment_size
if k_len % segment_size != 0:
# Pad k_len
pad_size = segment_size - (k_len % segment_size)
attn_weights = torch.nn.functional.pad(attn_weights, (0, pad_size), value=0)
k_len = attn_weights.shape[3]
# Scale factor
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Warmup
for _ in range(num_warmup):
_ = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
segment_size,
chunk_start=0,
chunk_end=q_len,
real_q_len=q_len,
scale=scale,
is_causal=False,
)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_runs):
output = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
segment_size,
chunk_start=0,
chunk_end=q_len,
real_q_len=q_len,
scale=scale,
is_causal=False,
)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_runs * 1000
return avg_time_ms
def run_estimate_benchmark(q_len, k_len, block_size, stride=STRIDE):
"""
Run full estimate benchmark for given configuration.
Args:
q_len: Query length
k_len: Key length (usually same as q_len for current chunk scenario)
block_size: Block size to test
stride: Stride for reshape
Returns:
dict with timing results
"""
# Create random Q and K tensors
# Shape: [batch, heads, seq_len, head_dim]
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
K = torch.randn(1, NUM_HEADS, k_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
reshaped_block_size = block_size // stride
reshaped_q_len = q_len // stride
reshaped_k_len = k_len // stride
# Benchmark GEMM
gemm_time, attn_weights = benchmark_flat_group_gemm(
Q, K, stride, block_size,
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
)
# Benchmark softmax + block sum
softmax_time = benchmark_softmax_fuse_block_sum(
attn_weights, reshaped_block_size,
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
)
# Clean up
del Q, K, attn_weights
torch.cuda.empty_cache()
return {
"q_len": q_len,
"k_len": k_len,
"block_size": block_size,
"reshaped_block_size": reshaped_block_size,
"gemm_time_ms": gemm_time,
"softmax_time_ms": softmax_time,
"total_time_ms": gemm_time + softmax_time,
}
# ============================================================
# Main Benchmark
# ============================================================
def main():
import argparse
parser = argparse.ArgumentParser(description="Benchmark block_size impact on estimate phase")
parser.add_argument("--gpu", type=int, default=0, help="GPU to use")
parser.add_argument("--ctx-len", type=int, default=None,
help="Single context length to test (default: test multiple)")
args = parser.parse_args()
# Set GPU
torch.cuda.set_device(args.gpu)
device_name = torch.cuda.get_device_name(args.gpu)
print(f"Using GPU {args.gpu}: {device_name}")
print()
# Determine context lengths to test
if args.ctx_len:
context_lengths = [args.ctx_len]
else:
context_lengths = CONTEXT_LENGTHS
print("=" * 80)
print("Benchmark: block_size impact on XAttention estimate phase")
print("=" * 80)
print(f"Configuration:")
print(f" NUM_HEADS: {NUM_HEADS}")
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
print(f" HEAD_DIM: {HEAD_DIM}")
print(f" STRIDE: {STRIDE}")
print(f" BLOCK_SIZES: {BLOCK_SIZES}")
print(f" NUM_WARMUP: {NUM_WARMUP}")
print(f" NUM_RUNS: {NUM_RUNS}")
print()
all_results = []
for ctx_len in context_lengths:
print(f"\n{'='*80}")
print(f"Context Length: {ctx_len // 1024}K ({ctx_len} tokens)")
print(f"{'='*80}")
# Pad to alignment
alignment = STRIDE * 128 # Triton BLOCK_M requirement
padded_len = ((ctx_len + alignment - 1) // alignment) * alignment
print(f"Padded to: {padded_len} tokens (alignment={alignment})")
print()
results = []
for block_size in BLOCK_SIZES:
print(f"Testing block_size={block_size} (reshaped={block_size // STRIDE})...", end=" ")
try:
result = run_estimate_benchmark(padded_len, padded_len, block_size)
results.append(result)
print(f"GEMM={result['gemm_time_ms']:.2f}ms, "
f"Softmax={result['softmax_time_ms']:.2f}ms, "
f"Total={result['total_time_ms']:.2f}ms")
except Exception as e:
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
if results:
all_results.extend(results)
# Print summary table for this context length
print(f"\n--- Summary for {ctx_len // 1024}K context ---")
print(f"{'block_size':>12} {'reshaped':>10} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12} {'Speedup':>10}")
print("-" * 74)
baseline_total = results[0]["total_time_ms"]
for r in results:
speedup = baseline_total / r["total_time_ms"]
print(f"{r['block_size']:>12} {r['reshaped_block_size']:>10} "
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
f"{r['total_time_ms']:>12.2f} {speedup:>9.2f}x")
# Final summary across all context lengths
if len(context_lengths) > 1:
print(f"\n{'='*80}")
print("OVERALL SUMMARY")
print(f"{'='*80}")
print(f"{'ctx_len':>10} {'block_size':>12} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12}")
print("-" * 64)
for r in all_results:
print(f"{r['q_len']//1024:>9}K {r['block_size']:>12} "
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
f"{r['total_time_ms']:>12.2f}")
# Find optimal block_size for softmax
print(f"\n{'='*80}")
print("ANALYSIS: Optimal block_size for softmax_fuse_block_sum")
print(f"{'='*80}")
for ctx_len in context_lengths:
ctx_results = [r for r in all_results if r["q_len"] == ((ctx_len + STRIDE * 128 - 1) // (STRIDE * 128)) * (STRIDE * 128)]
if ctx_results:
best = min(ctx_results, key=lambda x: x["softmax_time_ms"])
worst = max(ctx_results, key=lambda x: x["softmax_time_ms"])
improvement = worst["softmax_time_ms"] / best["softmax_time_ms"]
print(f"Context {ctx_len // 1024}K:")
print(f" Best: block_size={best['block_size']} ({best['softmax_time_ms']:.2f}ms)")
print(f" Worst: block_size={worst['block_size']} ({worst['softmax_time_ms']:.2f}ms)")
print(f" Potential improvement: {improvement:.2f}x")
print("\nbench_estimate_block_size: DONE")
if __name__ == "__main__":
main()

View File

@@ -1,757 +0,0 @@
"""
Custom Qwen3 implementation using only torch and transformers.
This file provides a clean reference implementation for understanding the model computation graph.
Computation Graph:
==================
Input: token_ids [batch, seq_len]
┌─────────────┐
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
└─────────────┘
hidden_states [batch, seq_len, hidden_size]
┌─────────────────────────────────────────────────────────┐
│ Decoder Layer (x N) │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Self Attention Block │ │
│ │ │ │
│ │ input_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3Attention │ │ │
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
│ │ │ V = v_proj(x) → reshape │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ attn_output = attention(Q, K, V) │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ output = o_proj(attn_output) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + attn_output │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ MLP Block │ │
│ │ │ │
│ │ post_attention_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3MLP │ │ │
│ │ │ gate = gate_proj(x) │ │ │
│ │ │ up = up_proj(x) │ │ │
│ │ │ output = down_proj(silu(gate) * up) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + mlp_output │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
┌─────────────┐
│ norm │ final RMSNorm
└─────────────┘
┌─────────────┐
│ lm_head │ [hidden_size, vocab_size]
└─────────────┘
logits [batch, seq_len, vocab_size]
"""
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Qwen3RMSNorm(nn.Module):
"""RMSNorm implementation."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
class Qwen3RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
position_ids: Position indices [batch, seq_len]
Returns:
cos, sin: [batch, seq_len, head_dim]
"""
# inv_freq: [dim/2]
# position_ids: [batch, seq_len]
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
# freqs: [batch, dim/2, seq_len]
freqs = inv_freq_expanded @ position_ids_expanded
# freqs: [batch, seq_len, dim/2]
freqs = freqs.transpose(1, 2)
# Duplicate for full head_dim: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(x.dtype)
sin = emb.sin().to(x.dtype)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to Q and K.
Args:
q: [batch, num_heads, seq_len, head_dim]
k: [batch, num_kv_heads, seq_len, head_dim]
cos: [batch, seq_len, head_dim]
sin: [batch, seq_len, head_dim]
Returns:
q_embed, k_embed with same shapes as inputs
"""
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3Attention(nn.Module):
"""
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
apply_rotary_pos_emb(Q, K)
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
attention_bias: bool = False,
rms_norm_eps: float = 1e-6,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = head_dim
self.num_kv_groups = num_attention_heads // num_key_value_heads
self.layer_idx = layer_idx
# Scaling factor
self.scaling = head_dim ** -0.5
# QKV projections
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
# QK normalization (Qwen3 specific)
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
# Rotary embeddings
self.rotary_emb = Qwen3RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
past_key_value: (k_cache, v_cache) from previous steps
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V tensors for debugging
Returns:
output: [batch, seq_len, hidden_size]
past_key_value: Updated cache (if use_cache=True)
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
"""
batch_size, seq_len, _ = hidden_states.shape
# === QKV Projections ===
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
# Reshape to [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# === QK Normalization (Qwen3 specific) ===
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to [batch, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# === Rotary Position Embeddings ===
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# === KV Cache Update ===
if past_key_value is not None:
k_cache, v_cache = past_key_value
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_past_key_value = (k, v) if use_cache else None
# === Grouped Query Attention (expand KV heads if needed) ===
if self.num_kv_groups > 1:
# Repeat KV for each query group
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
# === Attention Computation (using SDPA for memory efficiency) ===
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
# is_causal only works when q_len == kv_len (prefill), not during decode
q_len, kv_len = q.shape[2], k.shape[2]
is_causal = (q_len == kv_len) and (q_len > 1)
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=is_causal,
scale=self.scaling,
) # [batch, num_heads, seq_len, head_dim]
# === Output Projection ===
# Transpose back and reshape
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
output = self.o_proj(attn_output)
# Optional QKV output for debugging
qkv_dict = None
if output_qkv:
qkv_dict = {
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
}
return output, new_past_key_value, qkv_dict
class Qwen3MLP(nn.Module):
"""
Qwen3 MLP with SwiGLU activation.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
└──► up_proj ──► up [batch, seq_len, intermediate_size]
silu(gate) * up
down_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(F.silu(gate) * up)
class Qwen3DecoderLayer(nn.Module):
"""Single Qwen3 Decoder Layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
layer_idx: int = 0,
):
super().__init__()
self.layer_idx = layer_idx
# Pre-attention LayerNorm
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Self-attention
self.self_attn = Qwen3Attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
attention_bias=attention_bias,
rms_norm_eps=rms_norm_eps,
layer_idx=layer_idx,
)
# Post-attention LayerNorm
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: Causal attention mask
past_key_value: KV cache for this layer
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V for debugging
Returns:
hidden_states: [batch, seq_len, hidden_size]
past_key_value: Updated cache
qkv_dict: QKV tensors (if output_qkv=True)
"""
# === Self Attention Block ===
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, new_past_key_value, qkv_dict = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_qkv=output_qkv,
)
hidden_states = residual + attn_output
# === MLP Block ===
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, new_past_key_value, qkv_dict
class Qwen3Model(nn.Module):
"""Qwen3 Transformer Model (without LM head)."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers
# Token embeddings
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# Decoder layers
self.layers = nn.ModuleList([
Qwen3DecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
layer_idx=i,
)
for i in range(num_hidden_layers)
])
# Final LayerNorm
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
position_ids: [batch, seq_len]
attention_mask: [batch, seq_len] or pre-computed 4D mask
past_key_values: List of (k, v) tuples for each layer
use_cache: Whether to return new cache
output_qkv_layers: List of layer indices to output QKV for
Returns:
hidden_states: [batch, seq_len, hidden_size]
new_past_key_values: Updated cache
qkv_outputs: {layer_idx: qkv_dict}
"""
batch_size, seq_len = input_ids.shape
# Embedding
hidden_states = self.embed_tokens(input_ids)
# Position IDs
if position_ids is None:
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Attention mask (create causal mask if not provided)
if attention_mask is None or attention_mask.dim() == 2:
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
causal_mask = torch.triu(
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
diagonal=kv_seq_len - seq_len + 1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
# Initialize cache list
new_past_key_values = [] if use_cache else None
qkv_outputs = {} if output_qkv_layers else None
# Decoder layers
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values else None
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
hidden_states, new_kv, qkv_dict = layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_kv,
use_cache=use_cache,
output_qkv=output_qkv,
)
if use_cache:
new_past_key_values.append(new_kv)
if qkv_dict is not None:
qkv_outputs[i] = qkv_dict
# Final norm
hidden_states = self.norm(hidden_states)
return hidden_states, new_past_key_values, qkv_outputs
class Qwen3ForCausalLM(nn.Module):
"""Qwen3 Model with Language Modeling head."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
tie_word_embeddings: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
# Transformer model
self.model = Qwen3Model(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
)
# LM head
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
... (same as Qwen3Model)
Returns:
logits: [batch, seq_len, vocab_size]
past_key_values: Updated KV cache
qkv_outputs: QKV tensors for specified layers
"""
hidden_states, new_past_key_values, qkv_outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_qkv_layers=output_qkv_layers,
)
logits = self.lm_head(hidden_states)
return logits, new_past_key_values, qkv_outputs
@classmethod
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
"""
Load weights from a pretrained Qwen3 model.
Args:
model_path: Path to model directory containing config.json and model weights
dtype: Data type for model weights
Returns:
Initialized Qwen3ForCausalLM model
"""
import json
import os
from safetensors.torch import load_file
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
# Create model
model = cls(
vocab_size=config["vocab_size"],
hidden_size=config["hidden_size"],
intermediate_size=config["intermediate_size"],
num_hidden_layers=config["num_hidden_layers"],
num_attention_heads=config["num_attention_heads"],
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
max_position_embeddings=config.get("max_position_embeddings", 32768),
rope_theta=config.get("rope_theta", 10000.0),
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
attention_bias=config.get("attention_bias", False),
mlp_bias=config.get("mlp_bias", False),
tie_word_embeddings=config.get("tie_word_embeddings", True),
)
# Load weights
weight_files = sorted([
f for f in os.listdir(model_path)
if f.endswith(".safetensors")
])
state_dict = {}
for wf in weight_files:
state_dict.update(load_file(os.path.join(model_path, wf)))
# Load into model
model.load_state_dict(state_dict, strict=False)
# Tie lm_head weights to embed_tokens if configured
if model.tie_word_embeddings:
model.lm_head.weight = model.model.embed_tokens.weight
model = model.to(dtype)
return model
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 32,
temperature: float = 1.0,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""Simple autoregressive generation."""
device = input_ids.device
batch_size, seq_len = input_ids.shape
past_key_values = None
generated = input_ids.clone()
for _ in range(max_new_tokens):
if past_key_values is None:
current_input = generated
else:
current_input = generated[:, -1:]
logits, past_key_values, _ = self(
input_ids=current_input,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = logits[:, -1, :]
if temperature > 0 and do_sample:
next_token_logits = next_token_logits / temperature
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def print_computation_graph():
"""Print the computation graph for reference."""
print(__doc__)
if __name__ == "__main__":
print_computation_graph()

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env python3
"""
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
Key insight: LLM layers have identical computation structure.
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
Usage:
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
"""
from dataclasses import dataclass
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@dataclass
class ReusableChunkGraph:
"""A single graph that can be reused with copy_() updates."""
graph: torch.cuda.CUDAGraph
static_q: torch.Tensor
static_k: torch.Tensor
static_v: torch.Tensor
static_output: torch.Tensor
static_lse: torch.Tensor
def capture_reusable_graph(
chunk_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
scale: float,
device: torch.device,
dtype: torch.dtype,
causal: bool,
) -> ReusableChunkGraph:
"""Capture ONE graph to be reused for all chunk pairs."""
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_q.normal_()
static_k.normal_()
static_v.normal_()
# Warmup
with torch.inference_mode():
for _ in range(3):
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
return ReusableChunkGraph(
graph=graph,
static_q=static_q,
static_k=static_k,
static_v=static_v,
static_output=static_output,
static_lse=static_lse,
)
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Replay graph after updating static tensors with copy_()."""
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
return graph.static_output.clone(), graph.static_lse.clone()
def main():
device = torch.device("cuda")
dtype = torch.bfloat16
chunk_size = 64
num_chunks = 4
num_layers = 3 # Simulate multiple layers
num_heads = 8
num_kv_heads = 8
head_dim = 64
scale = 1.0 / (head_dim ** 0.5)
seq_len = chunk_size * num_chunks
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
# Capture only 2 graphs
graph_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
)
graph_non_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
)
print("2 graphs captured (causal + non-causal)")
all_pass = True
for layer_id in range(num_layers):
# Different Q/K/V for each layer (simulating different layer outputs)
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
# Reference: full causal attention
with torch.inference_mode():
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
# Chunked with graph reuse
chunked_output = torch.zeros_like(full_output)
for q_idx in range(num_chunks):
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
acc_out, acc_lse = None, None
for k_idx in range(q_idx + 1):
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
# Reuse graph with copy_()
graph = graph_causal if k_idx == q_idx else graph_non_causal
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
if acc_out is None:
acc_out, acc_lse = out, lse
else:
with torch.inference_mode():
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
torch.cuda.synchronize()
# Compare
max_diff = (full_output - chunked_output).abs().max().item()
status = "" if max_diff < 1e-2 else ""
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
if max_diff >= 1e-2:
all_pass = False
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
if __name__ == "__main__":
main()

View File

@@ -1,357 +0,0 @@
#!/usr/bin/env python3
"""
CUDA Graph Memory Analysis Test
This script analyzes the memory overhead of CUDA Graph at each stage:
1. Model loading
2. StaticCache allocation
3. Warmup runs
4. Graph capture
5. Graph replay
Usage:
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048
"""
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache
def get_memory_mb():
"""Get current allocated memory in MB."""
return torch.cuda.memory_allocated() / 1024**2
def get_memory_gb():
"""Get current allocated memory in GB."""
return torch.cuda.memory_allocated() / 1024**3
def get_peak_memory_gb():
"""Get peak allocated memory in GB."""
return torch.cuda.max_memory_allocated() / 1024**3
def print_separator(title=None):
"""Print a separator line."""
if title:
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}")
else:
print("-" * 70)
def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1):
"""
Test memory usage at each stage of CUDA Graph setup.
Args:
model_path: Path to the model
max_cache_len: Maximum cache length for StaticCache
batch_size: Batch size for inference
"""
print_separator("CUDA Graph Memory Analysis")
print(f"Model: {model_path}")
print(f"Max cache length: {max_cache_len}")
print(f"Batch size: {batch_size}")
results = {}
# Stage 0: Initial
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
results["initial"] = get_memory_mb()
# Stage 1: Load model
print_separator("Stage 1: Model Loading")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
results["after_model"] = get_memory_mb()
model_size = results["after_model"] - results["initial"]
print(f" Memory: {results['after_model']:.0f} MB")
print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)")
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Stage 2: Allocate StaticCache
print_separator("Stage 2: StaticCache Allocation")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
static_cache = StaticCache(
config=config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
)
results["after_cache"] = get_memory_mb()
cache_size = results["after_cache"] - before
print(f" Memory: {results['after_cache']:.0f} MB")
print(f" StaticCache size: {cache_size:.0f} MB")
# Calculate theoretical cache size
num_layers = config.num_hidden_layers
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
head_dim = config.hidden_size // config.num_attention_heads
dtype_size = 2 # bfloat16
theoretical_cache = (
num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size
) / (1024**2)
print(f" Theoretical: {theoretical_cache:.0f} MB")
print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)")
# Stage 3: Prepare static tensors
print_separator("Stage 3: Static Tensor Allocation")
before = get_memory_mb()
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
results["after_tensors"] = get_memory_mb()
tensor_size = results["after_tensors"] - before
print(f" Memory: {results['after_tensors']:.0f} MB")
print(f" Static tensors: {tensor_size:.2f} MB (negligible)")
# Stage 4: Warmup runs
print_separator("Stage 4: Warmup Runs (3 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for i in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
results["after_warmup"] = get_memory_mb()
results["warmup_peak"] = get_peak_memory_gb() * 1024
warmup_size = results["after_warmup"] - before
print(f" Memory: {results['after_warmup']:.0f} MB")
print(f" Peak: {results['warmup_peak']:.0f} MB")
print(f" Warmup overhead: {warmup_size:.0f} MB")
# Stage 5: CUDA Graph capture
print_separator("Stage 5: CUDA Graph Capture")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
static_logits = outputs.logits
torch.cuda.synchronize()
results["after_capture"] = get_memory_mb()
results["capture_peak"] = get_peak_memory_gb() * 1024
capture_size = results["after_capture"] - before
print(f" Memory: {results['after_capture']:.0f} MB")
print(f" Peak: {results['capture_peak']:.0f} MB")
print(f" Graph capture overhead: {capture_size:.0f} MB")
# Stage 6: Graph replay
print_separator("Stage 6: Graph Replay (10 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for _ in range(10):
static_input_ids.fill_(1)
static_cache_position.fill_(0)
graph.replay()
torch.cuda.synchronize()
results["after_replay"] = get_memory_mb()
results["replay_peak"] = get_peak_memory_gb() * 1024
replay_change = results["after_replay"] - before
print(f" Memory: {results['after_replay']:.0f} MB")
print(f" Peak: {results['replay_peak']:.0f} MB")
print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)")
# Summary
print_separator("SUMMARY")
total_overhead = results["after_capture"] - results["after_model"]
print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}")
print("-" * 50)
print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}")
print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}")
print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}")
print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}")
print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}")
print("-" * 50)
print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}")
print_separator("KEY FINDINGS")
print(f" 1. Model size: {model_size/1024:.2f} GB")
print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)")
print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)")
print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)")
print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB")
return results
def test_cache_length_scaling(model_path: str, cache_lengths: list):
"""
Test how memory scales with different cache lengths.
Args:
model_path: Path to the model
cache_lengths: List of cache lengths to test
"""
print_separator("Cache Length Scaling Test")
print(f"Model: {model_path}")
print(f"Cache lengths: {cache_lengths}")
# Load model once
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
model_mem = get_memory_mb()
results = []
for cache_len in cache_lengths:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Create cache and capture graph
static_cache = StaticCache(
config=config,
max_batch_size=1,
max_cache_len=cache_len,
device=device,
dtype=dtype,
)
static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
with torch.inference_mode():
# Warmup
for _ in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
total_mem = get_memory_mb()
overhead = total_mem - model_mem
results.append((cache_len, total_mem, overhead))
del static_cache, graph
torch.cuda.empty_cache()
# Print results
print()
print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}")
print("-" * 60)
for cache_len, total, overhead in results:
per_1k = overhead / (cache_len / 1000)
print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}")
return results
def main():
parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis")
parser.add_argument(
"--model",
type=str,
default="~/models/Qwen3-4B-Instruct-2507",
help="Model path",
)
parser.add_argument(
"--max-cache-len",
type=int,
default=1024,
help="Maximum cache length",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size",
)
parser.add_argument(
"--test-scaling",
action="store_true",
help="Test cache length scaling",
)
args = parser.parse_args()
model_path = os.path.expanduser(args.model)
if not torch.cuda.is_available():
print("CUDA is not available!")
return
print(f"Device: cuda:{torch.cuda.current_device()}")
print(f"GPU: {torch.cuda.get_device_name()}")
if args.test_scaling:
cache_lengths = [256, 512, 1024, 2048, 4096]
test_cache_length_scaling(model_path, cache_lengths)
else:
test_memory_stages(model_path, args.max_cache_len, args.batch_size)
print("\ntest_cudagraph_memory: PASSED")
if __name__ == "__main__":
main()

View File

@@ -1,149 +0,0 @@
"""
Test: GPU-only density alignment verification
验证 xattn_bsa.py 中 GPU-only 路径的 density 计算是否与独立调用 xattn_estimate 一致。
流程:
1. 运行 GPU-only 推理,保存 Q, K, mask, attn_sums
2. 加载保存的数据,独立调用 xattn_estimate
3. 比较两者的 mask 和 density
"""
import torch
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.ops.xattn import xattn_estimate
# ============================================================
# 参数配置
# ============================================================
DATA_PATH = "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt"
# ============================================================
# 加载保存的数据
# ============================================================
print(f"Loading data from {DATA_PATH}")
data = torch.load(DATA_PATH, weights_only=False)
Q = data["Q"].cuda() # [1, num_heads, q_len, head_dim]
K = data["K"].cuda() # [1, num_heads, k_len, head_dim]
chunk_size = data["chunk_size"]
block_size = data["block_size"]
stride = data["stride"]
threshold = data["threshold"]
mask_saved = data["mask"] # [1, num_heads, valid_q_blocks, valid_k_blocks]
attn_sums_saved = data["attn_sums"]
q_len = data["q_len"]
k_len = data["k_len"]
valid_q_blocks = data["valid_q_blocks"]
valid_k_blocks = data["valid_k_blocks"]
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"q_len: {q_len}, k_len: {k_len}")
print(f"chunk_size: {chunk_size}, block_size: {block_size}, stride: {stride}, threshold: {threshold}")
print(f"valid_q_blocks: {valid_q_blocks}, valid_k_blocks: {valid_k_blocks}")
print(f"mask_saved shape: {mask_saved.shape}")
# ============================================================
# 独立调用 xattn_estimate
# ============================================================
print("\nCalling xattn_estimate independently...")
attn_sums_ext, mask_ext = xattn_estimate(
Q, K,
chunk_size=chunk_size,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
causal=True,
)
# Trim to valid blocks
mask_ext_valid = mask_ext[:, :, :valid_q_blocks, :valid_k_blocks]
attn_sums_ext_valid = attn_sums_ext[:, :, :valid_q_blocks, :valid_k_blocks]
print(f"mask_ext shape: {mask_ext.shape}")
print(f"mask_ext_valid shape: {mask_ext_valid.shape}")
# ============================================================
# 比较 attn_sums
# ============================================================
print("\n" + "=" * 60)
print("Comparing attn_sums")
print("=" * 60)
attn_sums_saved_gpu = attn_sums_saved.cuda()
attn_diff = (attn_sums_ext_valid - attn_sums_saved_gpu).abs()
print(f"attn_sums max diff: {attn_diff.max().item():.6e}")
print(f"attn_sums mean diff: {attn_diff.mean().item():.6e}")
# Check if attn_sums match
attn_match = attn_diff.max().item() < 1e-4
print(f"attn_sums match: {attn_match}")
# ============================================================
# 比较 mask
# ============================================================
print("\n" + "=" * 60)
print("Comparing mask")
print("=" * 60)
mask_saved_gpu = mask_saved.cuda()
mask_match = (mask_ext_valid == mask_saved_gpu).all().item()
print(f"mask exact match: {mask_match}")
if not mask_match:
diff_count = (mask_ext_valid != mask_saved_gpu).sum().item()
total_count = mask_ext_valid.numel()
print(f"mask diff count: {diff_count} / {total_count} ({diff_count/total_count*100:.2f}%)")
# ============================================================
# 计算 density
# ============================================================
print("\n" + "=" * 60)
print("Comparing density")
print("=" * 60)
# 计算 causal mask
q_offset_blocks = valid_k_blocks - valid_q_blocks
indices = torch.arange(valid_k_blocks, device=mask_ext_valid.device).unsqueeze(0)
q_indices = torch.arange(valid_q_blocks, device=mask_ext_valid.device).unsqueeze(1)
causal_mask = indices <= (q_indices + q_offset_blocks)
# Density from saved mask
total_saved = causal_mask.sum().item() * mask_saved_gpu.shape[0] * mask_saved_gpu.shape[1]
selected_saved = (mask_saved_gpu & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_saved = selected_saved / total_saved
# Density from external xattn_estimate
total_ext = causal_mask.sum().item() * mask_ext_valid.shape[0] * mask_ext_valid.shape[1]
selected_ext = (mask_ext_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_ext = selected_ext / total_ext
print(f"Saved density: {density_saved:.6f} (selected={selected_saved}, total={total_saved})")
print(f"External density: {density_ext:.6f} (selected={selected_ext}, total={total_ext})")
print(f"Density diff: {abs(density_saved - density_ext):.6f}")
# ============================================================
# 结论
# ============================================================
print("\n" + "=" * 60)
print("RESULT")
print("=" * 60)
if attn_match and mask_match:
print("✅ PASSED: GPU-only density matches external xattn_estimate")
else:
print("❌ FAILED: Mismatch detected")
if not attn_match:
print(" - attn_sums mismatch")
if not mask_match:
print(" - mask mismatch")

View File

@@ -1,442 +0,0 @@
"""
Test: Hierarchical Block Sum Estimation for XAttention
Verify that hierarchical estimation (small estimate_block_size + aggregation)
produces equivalent results to direct estimation (large block_size), while
being significantly faster.
Key changes validated:
1. Hierarchical block sum: estimate_block_size=1024 → aggregate to cpu_block_size=4096
2. Selection strategy: score + threshold (NOT mask + majority voting)
This test uses pure torch + xattn kernels, independent of nanovllm framework.
"""
import sys
import os
import torch
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# Configuration
# ============================================================
# Model dimensions (Llama-3.1-8B-Instruct style)
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
STRIDE = 8
# Block sizes
CPU_BLOCK_SIZE = 4096 # External CPU block size (fixed, for overlap)
ESTIMATE_BLOCK_SIZE = 1024 # Internal estimate block size (optimized)
# Selection parameters
THRESHOLD = 0.95 # Cumulative attention threshold
# ============================================================
# Hierarchical Estimation Implementation
# ============================================================
def compute_attention_scores(Q, K_blocks, stride):
"""
Compute attention scores for Q against multiple K blocks.
Args:
Q: [1, num_heads, q_len, head_dim]
K_blocks: List of K tensors, each [1, num_heads, block_size, head_dim]
stride: Stride for reshape
Returns:
attn_scores: [1, num_heads, q_reshaped, total_k_reshaped]
"""
q_len = Q.shape[2]
q_reshaped = q_len // stride
attn_chunks = []
for K_block in K_blocks:
# flat_group_gemm_fuse_reshape
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_block, stride,
chunk_start=0,
chunk_end=q_reshaped,
is_causal=False,
)
attn_chunks.append(attn_chunk)
# Concatenate along K dimension
attn_scores = torch.cat(attn_chunks, dim=-1)
return attn_scores
def hierarchical_block_sum(
attn_scores,
estimate_block_size,
cpu_block_size,
stride,
head_dim,
):
"""
Compute hierarchical block sums: fine-grained → aggregated to CPU block level.
Args:
attn_scores: [batch, heads, q_reshaped, k_reshaped]
estimate_block_size: Small block size for efficient softmax (e.g., 1024)
cpu_block_size: External CPU block size (e.g., 4096)
stride: Stride used in reshape
head_dim: Head dimension for scale computation
Returns:
cpu_block_scores: [batch, heads, num_cpu_blocks] - attention score per CPU block
"""
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
# Compute reshaped block sizes
reshaped_est_bs = estimate_block_size // stride # 1024/8 = 128
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
# Scale factor
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Segment size for softmax kernel
segment_size = min(4096, reshaped_est_bs)
# Step 1: Fine-grained softmax + block sum
block_sums_fine = softmax_fuse_block_sum(
attn_scores,
reshaped_est_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped,
real_q_len=q_reshaped,
scale=scale,
is_causal=False,
)
# block_sums_fine: [batch, heads, q_est_blocks, k_est_blocks]
q_est_blocks = block_sums_fine.shape[2]
k_est_blocks = block_sums_fine.shape[3]
# Step 2: Aggregate to CPU block level
# ratio = cpu_block_size / estimate_block_size = 4
ratio = cpu_block_size // estimate_block_size
num_cpu_blocks = k_est_blocks // ratio
# Reshape and sum along K dimension
# [batch, heads, q_est, k_est] → [batch, heads, q_est, num_cpu, ratio]
block_sums_coarse = block_sums_fine.view(
batch_size, num_heads, q_est_blocks, num_cpu_blocks, ratio
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
# Step 3: Sum over Q dimension (total attention from Q chunk to each K block)
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
return cpu_block_scores, block_sums_fine
def direct_block_sum(
attn_scores,
cpu_block_size,
stride,
head_dim,
):
"""
Compute block sums directly with CPU block size (baseline for comparison).
Args:
attn_scores: [batch, heads, q_reshaped, k_reshaped]
cpu_block_size: Block size (e.g., 4096)
stride: Stride used in reshape
head_dim: Head dimension for scale computation
Returns:
cpu_block_scores: [batch, heads, num_cpu_blocks]
"""
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
segment_size = min(4096, reshaped_cpu_bs)
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_cpu_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped,
real_q_len=q_reshaped,
scale=scale,
is_causal=False,
)
# block_sums: [batch, heads, q_cpu_blocks, k_cpu_blocks]
# Sum over Q dimension
cpu_block_scores = block_sums.sum(dim=2) # [batch, heads, num_cpu_blocks]
return cpu_block_scores
def select_blocks_by_score(
cpu_block_scores,
threshold=0.95,
always_include_first=True,
always_include_last=True,
):
"""
Select CPU blocks based on score + threshold.
⚠️ IMPORTANT: This replaces the original mask + majority voting strategy.
This change should be documented in the final implementation.
Args:
cpu_block_scores: [batch, heads, num_cpu_blocks]
threshold: Cumulative attention threshold (e.g., 0.95)
always_include_first: Always include first block (sink)
always_include_last: Always include last block (safety)
Returns:
selected_block_ids: List of selected block indices
density: Fraction of blocks selected
"""
# Average scores across heads
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
num_blocks = scores_per_block.shape[0]
# Normalize to get attention distribution
total_score = scores_per_block.sum()
score_ratio = scores_per_block / total_score
# Sort by score (descending)
sorted_indices = torch.argsort(score_ratio, descending=True)
# Select blocks until cumulative threshold is reached
cumsum = 0.0
selected = set()
for idx in sorted_indices.tolist():
selected.add(idx)
cumsum += score_ratio[idx].item()
if cumsum >= threshold:
break
# Always include first and last blocks
if always_include_first:
selected.add(0)
if always_include_last:
selected.add(num_blocks - 1)
selected_block_ids = sorted(list(selected))
density = len(selected_block_ids) / num_blocks
return selected_block_ids, density
# ============================================================
# Test Cases
# ============================================================
def test_equivalence():
"""
Test that hierarchical estimation produces equivalent scores to direct estimation.
"""
print("=" * 60)
print("Test 1: Hierarchical vs Direct - Equivalence")
print("=" * 60)
# Create random Q and multiple K blocks
q_len = CPU_BLOCK_SIZE # 4096
num_k_blocks = 4
# Q: [1, num_heads, q_len, head_dim]
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# K blocks: each [1, num_heads, cpu_block_size, head_dim]
K_blocks = [
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
for _ in range(num_k_blocks)
]
# Compute attention scores
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
print(f"attn_scores shape: {attn_scores.shape}")
# Method 1: Hierarchical (fast)
scores_hier, _ = hierarchical_block_sum(
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
print(f"scores_hier shape: {scores_hier.shape}")
# Method 2: Direct (slow)
scores_direct = direct_block_sum(
attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
print(f"scores_direct shape: {scores_direct.shape}")
# Compare
diff = (scores_hier - scores_direct).abs().max().item()
print(f"\nMax difference: {diff:.6f}")
# Per-block comparison
print("\nPer-block scores comparison:")
for i in range(num_k_blocks):
h_val = scores_hier[0, 0, i].item()
d_val = scores_direct[0, 0, i].item()
print(f" Block {i}: hierarchical={h_val:.4f}, direct={d_val:.4f}, diff={abs(h_val-d_val):.6f}")
passed = diff < 0.01
print(f"\nTest 1: {'PASSED' if passed else 'FAILED'}")
return passed
def test_selection():
"""
Test the score + threshold selection strategy.
"""
print("\n" + "=" * 60)
print("Test 2: Score + Threshold Selection")
print("=" * 60)
# Create Q and K blocks with varying importance
q_len = CPU_BLOCK_SIZE
num_k_blocks = 8
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# Create K blocks - make some more important than others
K_blocks = []
for i in range(num_k_blocks):
# First and middle blocks are more important (higher values)
importance = 2.0 if i in [0, 3, 4] else 1.0
K = torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
K = K * importance
K_blocks.append(K)
# Compute scores
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
scores, _ = hierarchical_block_sum(
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
# Print scores per block
print("\nCPU block scores (head 0):")
for i in range(num_k_blocks):
print(f" Block {i}: {scores[0, 0, i].item():.4f}")
# Select blocks with different thresholds
for thresh in [0.9, 0.95, 0.99]:
selected, density = select_blocks_by_score(scores, threshold=thresh)
print(f"\nThreshold {thresh}: selected {len(selected)}/{num_k_blocks} blocks ({density:.1%})")
print(f" Selected: {selected}")
print("\nTest 2: PASSED (visual inspection)")
return True
def test_performance():
"""
Benchmark hierarchical vs direct estimation performance.
"""
print("\n" + "=" * 60)
print("Test 3: Performance Benchmark")
print("=" * 60)
import time
NUM_WARMUP = 3
NUM_RUNS = 10
# Larger test case
q_len = CPU_BLOCK_SIZE
num_k_blocks = 16 # 64K context
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
K_blocks = [
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
for _ in range(num_k_blocks)
]
# Compute attention scores (shared)
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
print(f"attn_scores shape: {attn_scores.shape}")
print(f"Context: {num_k_blocks * CPU_BLOCK_SIZE // 1024}K tokens")
# Warmup and benchmark hierarchical
for _ in range(NUM_WARMUP):
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(NUM_RUNS):
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
hier_time = (time.perf_counter() - start) / NUM_RUNS * 1000
# Warmup and benchmark direct
for _ in range(NUM_WARMUP):
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(NUM_RUNS):
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
direct_time = (time.perf_counter() - start) / NUM_RUNS * 1000
speedup = direct_time / hier_time
print(f"\nResults:")
print(f" Hierarchical (bs=1024): {hier_time:.2f} ms")
print(f" Direct (bs=4096): {direct_time:.2f} ms")
print(f" Speedup: {speedup:.2f}x")
passed = speedup > 5.0 # Expect at least 5x speedup
print(f"\nTest 3: {'PASSED' if passed else 'FAILED'} (speedup > 5x expected)")
return passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Hierarchical Block Sum Estimation Test")
print("=" * 60)
print(f"\nConfiguration:")
print(f" NUM_HEADS: {NUM_HEADS}")
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
print(f" HEAD_DIM: {HEAD_DIM}")
print(f" STRIDE: {STRIDE}")
print(f" CPU_BLOCK_SIZE: {CPU_BLOCK_SIZE}")
print(f" ESTIMATE_BLOCK_SIZE: {ESTIMATE_BLOCK_SIZE}")
print(f" THRESHOLD: {THRESHOLD}")
print()
results = []
results.append(("Equivalence", test_equivalence()))
results.append(("Selection", test_selection()))
results.append(("Performance", test_performance()))
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for name, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" {name}: {status}")
all_passed = all(p for _, p in results)
print("=" * 60)
if all_passed:
print("test_hierarchical_estimate: ALL PASSED")
sys.exit(0)
else:
print("test_hierarchical_estimate: SOME FAILED")
sys.exit(1)

View File

@@ -1,136 +0,0 @@
"""
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
Demonstrates the key limitation: scores are AVERAGED across heads,
so blocks strongly needed by one head but not others may be dropped.
This is the expected Quest behavior - not a bug.
"""
import torch
from nanovllm.kvcache.sparse import (
create_sparse_policy,
SparsePolicyType,
PolicyContext,
)
# ============================================================
# Test: Per-Head Score Averaging in GQA
# ============================================================
# Determine device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running test on device: {device}")
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
# topk=2 to make selection competitive
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
quest.initialize(
num_layers=1,
num_kv_heads=2,
head_dim=4,
num_cpu_blocks=6,
dtype=torch.float32,
device=device, # Metadata stored on GPU
)
metadata = quest.metadata
def set_key(block_id, head_id, values):
"""Set both key_min and key_max to same values for deterministic scoring."""
# Values need to be on the same device as metadata
tensor = torch.tensor(values, device=device)
metadata.key_min[block_id, 0, head_id, :] = tensor
metadata.key_max[block_id, 0, head_id, :] = tensor
# ============================================================
# Design: Different heads want different blocks
# ============================================================
#
# Query = [1,1,1,1] for all heads, so score = sum(key values)
#
# Block | Head 0 | Head 1 | Average | Result
# ------|--------|--------|---------|--------
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 2: Both heads want equally (highest average)
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 3: Both heads want moderately
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
# Block 4: Head 0 strongly wants, Head 1 neutral
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
# Block 5: Head 1 strongly wants, Head 0 neutral
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# ============================================================
# Run selection
# ============================================================
# Query on same device as metadata
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=0,
query=query,
is_prefill=False,
block_size=1024,
total_kv_len=6144,
)
available = list(range(6))
selected = quest.select_blocks(available, ctx)
# ============================================================
# Verify: Averaging behavior
# ============================================================
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
# Key insight: blocks 0 and 1 have score +4 for ONE head,
# but they cancel out due to averaging with the other head's -4
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
# But +2 < +3 (block 3), so they don't make the top-2
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
# Verify metadata is on correct device
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
print(f"✓ Metadata stored on {device.type.upper()}")
print("\ntest_quest_policy: PASSED")

View File

@@ -1,199 +0,0 @@
"""
Sequential inference test for LLM.
Tests: After completing one prompt, the system can correctly handle
a second prompt with a clean state (first prompt's KV cache deallocated).
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from utils import generate_needle_prompt, check_needle_answer
def run_sequential_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
enable_cpu_offload: bool = False,
verbose: bool = True,
) -> bool:
"""
Run sequential inference test with two different prompts.
Each prompt has a different needle value. Both must be retrieved correctly.
"""
if verbose:
print(f"\n{'='*60}")
print(f"Sequential Inference Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"CPU offload: {enable_cpu_offload}")
print(f"{'='*60}\n")
# Initialize LLM once
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm = LLM(model_path, **llm_kwargs)
sampling_params = SamplingParams(
temperature=0.6,
max_tokens=32,
)
# ============================================================
# Test 1: First prompt with needle value "1234"
# ============================================================
needle_value_1 = "1234"
if verbose:
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
prompt_1, expected_1 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_1,
)
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
output_text_1 = outputs_1[0]["text"]
passed_1 = check_needle_answer(output_text_1, expected_1)
if verbose:
print(f" Expected: {expected_1}")
print(f" Output: {output_text_1[:100]}...")
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
# ============================================================
# Test 2: Second prompt with needle value "5678"
# ============================================================
needle_value_2 = "5678"
if verbose:
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
prompt_2, expected_2 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_2,
)
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
output_text_2 = outputs_2[0]["text"]
passed_2 = check_needle_answer(output_text_2, expected_2)
if verbose:
print(f" Expected: {expected_2}")
print(f" Output: {output_text_2[:100]}...")
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
# ============================================================
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
# ============================================================
needle_value_3 = "9999"
if verbose:
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
prompt_3, expected_3 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_3,
)
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
output_text_3 = outputs_3[0]["text"]
passed_3 = check_needle_answer(output_text_3, expected_3)
if verbose:
print(f" Expected: {expected_3}")
print(f" Output: {output_text_3[:100]}...")
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
# ============================================================
# Summary
# ============================================================
all_passed = passed_1 and passed_2 and passed_3
if verbose:
print(f"\n{'='*60}")
print(f"Summary")
print(f"{'='*60}")
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
print(f"{'='*60}\n")
return all_passed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential inference test")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=36 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload"
)
args = parser.parse_args()
passed = run_sequential_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
enable_cpu_offload=args.enable_offload,
verbose=True,
)
if passed:
print("test_sequential: PASSED")
else:
print("test_sequential: FAILED")
exit(1)