Compare commits

3 Commits

Author SHA1 Message Date
Zijie Tian
d35dd76e09 🗑️ chore: clean up tests directory to essential files only
Keep only core test files:
- test_ruler.py - main RULER benchmark
- test_xattn_estimate_alignment.py - XAttn kernel validation
- utils.py - shared utilities

Remove 8 files (recoverable from git history):
- bench_estimate_block_size.py
- modeling_qwen3.py
- test_chunk_attention_graph_reuse.py
- test_cudagraph_memory.py
- test_gpuonly_density_alignment.py
- test_hierarchical_estimate.py
- test_quest_policy.py
- test_sequential.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:13:50 +08:00
Zijie Tian
2b61c5ab57 🗑️ chore: remove test_needle* files
Remove needle tests (validation now covered by test_ruler.py):
- test_needle.py - basic needle-in-haystack test
- test_needle_ref.py - HuggingFace reference implementation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:11:28 +08:00
Zijie Tian
a709551072 🗑️ chore: remove redundant XAttention test files
Remove 6 obsolete test files:
- test_xattn_bsa.py - XAttn+BSA integration (covered by test_ruler)
- test_xattn_chunked.py - duplicate of test_xattn_estimate_chunked
- test_xattn_estimate_chunked.py - chunked prefill validation
- test_xattn_kernels.py - Triton kernel unit tests
- test_xattn_kv_chunking_batch.py - batch KV chunking validation
- test_chunk_attention_graph.py - superseded by graph_reuse version

Retained: test_xattn_estimate_alignment.py (critical kernel validation)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:11:21 +08:00
16 changed files with 0 additions and 4306 deletions

View File

@@ -1,314 +0,0 @@
"""
Benchmark: block_size impact on XAttention estimate phase performance.
This script tests how different block_size values affect the performance of:
1. flat_group_gemm_fuse_reshape (estimate GEMM)
2. softmax_fuse_block_sum (estimate softmax + block aggregation)
Key insight: The current select_blocks uses global kvcache_block_size for estimation,
which may not be optimal for the Triton kernels.
"""
import sys
import os
import torch
import time
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# Configuration
# ============================================================
# Test configurations
BLOCK_SIZES = [64, 128, 256, 512] # BSA optimal is 128
STRIDE = 8
NUM_WARMUP = 3
NUM_RUNS = 10
# Model dimensions (Llama-3.1-8B-Instruct)
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
# Context lengths to test
CONTEXT_LENGTHS = [16384, 32768, 65536] # 16K, 32K, 64K
# ============================================================
# Benchmark Functions
# ============================================================
def benchmark_flat_group_gemm(Q, K, stride, block_size, num_warmup=3, num_runs=10):
"""
Benchmark flat_group_gemm_fuse_reshape kernel.
Args:
Q: [batch, heads, q_len, head_dim]
K: [batch, heads, k_len, head_dim]
stride: Stride for reshape
block_size: Block size (affects alignment requirements)
Returns:
(avg_time_ms, output_tensor)
"""
q_len = Q.shape[2]
k_len = K.shape[2]
# Compute reshaped dimensions
reshaped_q_len = q_len // stride
reshaped_k_len = k_len // stride
reshaped_block_size = block_size // stride
# Warmup
for _ in range(num_warmup):
_ = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=reshaped_q_len,
is_causal=False,
)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_runs):
output = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=reshaped_q_len,
is_causal=False,
)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_runs * 1000
return avg_time_ms, output
def benchmark_softmax_fuse_block_sum(attn_weights, reshaped_block_size, num_warmup=3, num_runs=10):
"""
Benchmark softmax_fuse_block_sum kernel.
Args:
attn_weights: [batch, heads, q_len, k_len] attention weights
reshaped_block_size: Block size in reshaped space
Returns:
avg_time_ms
"""
batch_size, num_heads, q_len, k_len = attn_weights.shape
head_dim = HEAD_DIM
stride = STRIDE
norm = 1.0
# segment_size must divide k_len and be >= reshaped_block_size
segment_size = min(4096, reshaped_block_size)
# Ensure k_len is divisible by segment_size
if k_len % segment_size != 0:
# Pad k_len
pad_size = segment_size - (k_len % segment_size)
attn_weights = torch.nn.functional.pad(attn_weights, (0, pad_size), value=0)
k_len = attn_weights.shape[3]
# Scale factor
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Warmup
for _ in range(num_warmup):
_ = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
segment_size,
chunk_start=0,
chunk_end=q_len,
real_q_len=q_len,
scale=scale,
is_causal=False,
)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_runs):
output = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
segment_size,
chunk_start=0,
chunk_end=q_len,
real_q_len=q_len,
scale=scale,
is_causal=False,
)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_runs * 1000
return avg_time_ms
def run_estimate_benchmark(q_len, k_len, block_size, stride=STRIDE):
"""
Run full estimate benchmark for given configuration.
Args:
q_len: Query length
k_len: Key length (usually same as q_len for current chunk scenario)
block_size: Block size to test
stride: Stride for reshape
Returns:
dict with timing results
"""
# Create random Q and K tensors
# Shape: [batch, heads, seq_len, head_dim]
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
K = torch.randn(1, NUM_HEADS, k_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
reshaped_block_size = block_size // stride
reshaped_q_len = q_len // stride
reshaped_k_len = k_len // stride
# Benchmark GEMM
gemm_time, attn_weights = benchmark_flat_group_gemm(
Q, K, stride, block_size,
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
)
# Benchmark softmax + block sum
softmax_time = benchmark_softmax_fuse_block_sum(
attn_weights, reshaped_block_size,
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
)
# Clean up
del Q, K, attn_weights
torch.cuda.empty_cache()
return {
"q_len": q_len,
"k_len": k_len,
"block_size": block_size,
"reshaped_block_size": reshaped_block_size,
"gemm_time_ms": gemm_time,
"softmax_time_ms": softmax_time,
"total_time_ms": gemm_time + softmax_time,
}
# ============================================================
# Main Benchmark
# ============================================================
def main():
import argparse
parser = argparse.ArgumentParser(description="Benchmark block_size impact on estimate phase")
parser.add_argument("--gpu", type=int, default=0, help="GPU to use")
parser.add_argument("--ctx-len", type=int, default=None,
help="Single context length to test (default: test multiple)")
args = parser.parse_args()
# Set GPU
torch.cuda.set_device(args.gpu)
device_name = torch.cuda.get_device_name(args.gpu)
print(f"Using GPU {args.gpu}: {device_name}")
print()
# Determine context lengths to test
if args.ctx_len:
context_lengths = [args.ctx_len]
else:
context_lengths = CONTEXT_LENGTHS
print("=" * 80)
print("Benchmark: block_size impact on XAttention estimate phase")
print("=" * 80)
print(f"Configuration:")
print(f" NUM_HEADS: {NUM_HEADS}")
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
print(f" HEAD_DIM: {HEAD_DIM}")
print(f" STRIDE: {STRIDE}")
print(f" BLOCK_SIZES: {BLOCK_SIZES}")
print(f" NUM_WARMUP: {NUM_WARMUP}")
print(f" NUM_RUNS: {NUM_RUNS}")
print()
all_results = []
for ctx_len in context_lengths:
print(f"\n{'='*80}")
print(f"Context Length: {ctx_len // 1024}K ({ctx_len} tokens)")
print(f"{'='*80}")
# Pad to alignment
alignment = STRIDE * 128 # Triton BLOCK_M requirement
padded_len = ((ctx_len + alignment - 1) // alignment) * alignment
print(f"Padded to: {padded_len} tokens (alignment={alignment})")
print()
results = []
for block_size in BLOCK_SIZES:
print(f"Testing block_size={block_size} (reshaped={block_size // STRIDE})...", end=" ")
try:
result = run_estimate_benchmark(padded_len, padded_len, block_size)
results.append(result)
print(f"GEMM={result['gemm_time_ms']:.2f}ms, "
f"Softmax={result['softmax_time_ms']:.2f}ms, "
f"Total={result['total_time_ms']:.2f}ms")
except Exception as e:
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
if results:
all_results.extend(results)
# Print summary table for this context length
print(f"\n--- Summary for {ctx_len // 1024}K context ---")
print(f"{'block_size':>12} {'reshaped':>10} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12} {'Speedup':>10}")
print("-" * 74)
baseline_total = results[0]["total_time_ms"]
for r in results:
speedup = baseline_total / r["total_time_ms"]
print(f"{r['block_size']:>12} {r['reshaped_block_size']:>10} "
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
f"{r['total_time_ms']:>12.2f} {speedup:>9.2f}x")
# Final summary across all context lengths
if len(context_lengths) > 1:
print(f"\n{'='*80}")
print("OVERALL SUMMARY")
print(f"{'='*80}")
print(f"{'ctx_len':>10} {'block_size':>12} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12}")
print("-" * 64)
for r in all_results:
print(f"{r['q_len']//1024:>9}K {r['block_size']:>12} "
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
f"{r['total_time_ms']:>12.2f}")
# Find optimal block_size for softmax
print(f"\n{'='*80}")
print("ANALYSIS: Optimal block_size for softmax_fuse_block_sum")
print(f"{'='*80}")
for ctx_len in context_lengths:
ctx_results = [r for r in all_results if r["q_len"] == ((ctx_len + STRIDE * 128 - 1) // (STRIDE * 128)) * (STRIDE * 128)]
if ctx_results:
best = min(ctx_results, key=lambda x: x["softmax_time_ms"])
worst = max(ctx_results, key=lambda x: x["softmax_time_ms"])
improvement = worst["softmax_time_ms"] / best["softmax_time_ms"]
print(f"Context {ctx_len // 1024}K:")
print(f" Best: block_size={best['block_size']} ({best['softmax_time_ms']:.2f}ms)")
print(f" Worst: block_size={worst['block_size']} ({worst['softmax_time_ms']:.2f}ms)")
print(f" Potential improvement: {improvement:.2f}x")
print("\nbench_estimate_block_size: DONE")
if __name__ == "__main__":
main()

View File

@@ -1,757 +0,0 @@
"""
Custom Qwen3 implementation using only torch and transformers.
This file provides a clean reference implementation for understanding the model computation graph.
Computation Graph:
==================
Input: token_ids [batch, seq_len]
┌─────────────┐
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
└─────────────┘
hidden_states [batch, seq_len, hidden_size]
┌─────────────────────────────────────────────────────────┐
│ Decoder Layer (x N) │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Self Attention Block │ │
│ │ │ │
│ │ input_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3Attention │ │ │
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
│ │ │ V = v_proj(x) → reshape │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ attn_output = attention(Q, K, V) │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ output = o_proj(attn_output) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + attn_output │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ MLP Block │ │
│ │ │ │
│ │ post_attention_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3MLP │ │ │
│ │ │ gate = gate_proj(x) │ │ │
│ │ │ up = up_proj(x) │ │ │
│ │ │ output = down_proj(silu(gate) * up) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + mlp_output │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
┌─────────────┐
│ norm │ final RMSNorm
└─────────────┘
┌─────────────┐
│ lm_head │ [hidden_size, vocab_size]
└─────────────┘
logits [batch, seq_len, vocab_size]
"""
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Qwen3RMSNorm(nn.Module):
"""RMSNorm implementation."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
class Qwen3RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
position_ids: Position indices [batch, seq_len]
Returns:
cos, sin: [batch, seq_len, head_dim]
"""
# inv_freq: [dim/2]
# position_ids: [batch, seq_len]
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
# freqs: [batch, dim/2, seq_len]
freqs = inv_freq_expanded @ position_ids_expanded
# freqs: [batch, seq_len, dim/2]
freqs = freqs.transpose(1, 2)
# Duplicate for full head_dim: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(x.dtype)
sin = emb.sin().to(x.dtype)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to Q and K.
Args:
q: [batch, num_heads, seq_len, head_dim]
k: [batch, num_kv_heads, seq_len, head_dim]
cos: [batch, seq_len, head_dim]
sin: [batch, seq_len, head_dim]
Returns:
q_embed, k_embed with same shapes as inputs
"""
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3Attention(nn.Module):
"""
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
apply_rotary_pos_emb(Q, K)
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
attention_bias: bool = False,
rms_norm_eps: float = 1e-6,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = head_dim
self.num_kv_groups = num_attention_heads // num_key_value_heads
self.layer_idx = layer_idx
# Scaling factor
self.scaling = head_dim ** -0.5
# QKV projections
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
# QK normalization (Qwen3 specific)
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
# Rotary embeddings
self.rotary_emb = Qwen3RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
past_key_value: (k_cache, v_cache) from previous steps
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V tensors for debugging
Returns:
output: [batch, seq_len, hidden_size]
past_key_value: Updated cache (if use_cache=True)
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
"""
batch_size, seq_len, _ = hidden_states.shape
# === QKV Projections ===
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
# Reshape to [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# === QK Normalization (Qwen3 specific) ===
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to [batch, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# === Rotary Position Embeddings ===
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# === KV Cache Update ===
if past_key_value is not None:
k_cache, v_cache = past_key_value
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_past_key_value = (k, v) if use_cache else None
# === Grouped Query Attention (expand KV heads if needed) ===
if self.num_kv_groups > 1:
# Repeat KV for each query group
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
# === Attention Computation (using SDPA for memory efficiency) ===
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
# is_causal only works when q_len == kv_len (prefill), not during decode
q_len, kv_len = q.shape[2], k.shape[2]
is_causal = (q_len == kv_len) and (q_len > 1)
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=is_causal,
scale=self.scaling,
) # [batch, num_heads, seq_len, head_dim]
# === Output Projection ===
# Transpose back and reshape
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
output = self.o_proj(attn_output)
# Optional QKV output for debugging
qkv_dict = None
if output_qkv:
qkv_dict = {
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
}
return output, new_past_key_value, qkv_dict
class Qwen3MLP(nn.Module):
"""
Qwen3 MLP with SwiGLU activation.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
└──► up_proj ──► up [batch, seq_len, intermediate_size]
silu(gate) * up
down_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(F.silu(gate) * up)
class Qwen3DecoderLayer(nn.Module):
"""Single Qwen3 Decoder Layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
layer_idx: int = 0,
):
super().__init__()
self.layer_idx = layer_idx
# Pre-attention LayerNorm
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Self-attention
self.self_attn = Qwen3Attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
attention_bias=attention_bias,
rms_norm_eps=rms_norm_eps,
layer_idx=layer_idx,
)
# Post-attention LayerNorm
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: Causal attention mask
past_key_value: KV cache for this layer
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V for debugging
Returns:
hidden_states: [batch, seq_len, hidden_size]
past_key_value: Updated cache
qkv_dict: QKV tensors (if output_qkv=True)
"""
# === Self Attention Block ===
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, new_past_key_value, qkv_dict = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_qkv=output_qkv,
)
hidden_states = residual + attn_output
# === MLP Block ===
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, new_past_key_value, qkv_dict
class Qwen3Model(nn.Module):
"""Qwen3 Transformer Model (without LM head)."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers
# Token embeddings
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# Decoder layers
self.layers = nn.ModuleList([
Qwen3DecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
layer_idx=i,
)
for i in range(num_hidden_layers)
])
# Final LayerNorm
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
position_ids: [batch, seq_len]
attention_mask: [batch, seq_len] or pre-computed 4D mask
past_key_values: List of (k, v) tuples for each layer
use_cache: Whether to return new cache
output_qkv_layers: List of layer indices to output QKV for
Returns:
hidden_states: [batch, seq_len, hidden_size]
new_past_key_values: Updated cache
qkv_outputs: {layer_idx: qkv_dict}
"""
batch_size, seq_len = input_ids.shape
# Embedding
hidden_states = self.embed_tokens(input_ids)
# Position IDs
if position_ids is None:
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Attention mask (create causal mask if not provided)
if attention_mask is None or attention_mask.dim() == 2:
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
causal_mask = torch.triu(
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
diagonal=kv_seq_len - seq_len + 1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
# Initialize cache list
new_past_key_values = [] if use_cache else None
qkv_outputs = {} if output_qkv_layers else None
# Decoder layers
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values else None
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
hidden_states, new_kv, qkv_dict = layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_kv,
use_cache=use_cache,
output_qkv=output_qkv,
)
if use_cache:
new_past_key_values.append(new_kv)
if qkv_dict is not None:
qkv_outputs[i] = qkv_dict
# Final norm
hidden_states = self.norm(hidden_states)
return hidden_states, new_past_key_values, qkv_outputs
class Qwen3ForCausalLM(nn.Module):
"""Qwen3 Model with Language Modeling head."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
tie_word_embeddings: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
# Transformer model
self.model = Qwen3Model(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
)
# LM head
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
... (same as Qwen3Model)
Returns:
logits: [batch, seq_len, vocab_size]
past_key_values: Updated KV cache
qkv_outputs: QKV tensors for specified layers
"""
hidden_states, new_past_key_values, qkv_outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_qkv_layers=output_qkv_layers,
)
logits = self.lm_head(hidden_states)
return logits, new_past_key_values, qkv_outputs
@classmethod
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
"""
Load weights from a pretrained Qwen3 model.
Args:
model_path: Path to model directory containing config.json and model weights
dtype: Data type for model weights
Returns:
Initialized Qwen3ForCausalLM model
"""
import json
import os
from safetensors.torch import load_file
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
# Create model
model = cls(
vocab_size=config["vocab_size"],
hidden_size=config["hidden_size"],
intermediate_size=config["intermediate_size"],
num_hidden_layers=config["num_hidden_layers"],
num_attention_heads=config["num_attention_heads"],
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
max_position_embeddings=config.get("max_position_embeddings", 32768),
rope_theta=config.get("rope_theta", 10000.0),
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
attention_bias=config.get("attention_bias", False),
mlp_bias=config.get("mlp_bias", False),
tie_word_embeddings=config.get("tie_word_embeddings", True),
)
# Load weights
weight_files = sorted([
f for f in os.listdir(model_path)
if f.endswith(".safetensors")
])
state_dict = {}
for wf in weight_files:
state_dict.update(load_file(os.path.join(model_path, wf)))
# Load into model
model.load_state_dict(state_dict, strict=False)
# Tie lm_head weights to embed_tokens if configured
if model.tie_word_embeddings:
model.lm_head.weight = model.model.embed_tokens.weight
model = model.to(dtype)
return model
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 32,
temperature: float = 1.0,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""Simple autoregressive generation."""
device = input_ids.device
batch_size, seq_len = input_ids.shape
past_key_values = None
generated = input_ids.clone()
for _ in range(max_new_tokens):
if past_key_values is None:
current_input = generated
else:
current_input = generated[:, -1:]
logits, past_key_values, _ = self(
input_ids=current_input,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = logits[:, -1, :]
if temperature > 0 and do_sample:
next_token_logits = next_token_logits / temperature
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def print_computation_graph():
"""Print the computation graph for reference."""
print(__doc__)
if __name__ == "__main__":
print_computation_graph()

View File

@@ -1,151 +0,0 @@
#!/usr/bin/env python3
"""
Test: Pre-allocated chunk pair graphs for block sparse attention.
Each (Q_chunk, K_chunk) pair has its own captured CUDA graph.
Zero copy_() during replay - all data pre-filled.
Usage:
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph.py
"""
from dataclasses import dataclass
from typing import List, Optional
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@dataclass
class ChunkAttentionGraph:
"""Container for a captured chunk attention graph."""
graph: torch.cuda.CUDAGraph
static_q: torch.Tensor
static_k: torch.Tensor
static_v: torch.Tensor
static_output: torch.Tensor
static_lse: torch.Tensor
causal: bool
def capture_chunk_attention_graph(
chunk_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
scale: float,
device: torch.device,
dtype: torch.dtype,
causal: bool = False,
) -> ChunkAttentionGraph:
"""Capture a CUDA graph for single chunk attention."""
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_q.normal_()
static_k.normal_()
static_v.normal_()
# Warmup
with torch.inference_mode():
for _ in range(3):
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
return ChunkAttentionGraph(
graph=graph,
static_q=static_q,
static_k=static_k,
static_v=static_v,
static_output=static_output,
static_lse=static_lse,
causal=causal,
)
def main():
device = torch.device("cuda")
dtype = torch.bfloat16
chunk_size = 64
num_chunks = 4
num_heads = 8
num_kv_heads = 8
head_dim = 64
scale = 1.0 / (head_dim ** 0.5)
seq_len = chunk_size * num_chunks
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}")
print(f"Total graphs: {num_chunks * (num_chunks + 1) // 2}")
# Test data
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
# Reference
with torch.inference_mode():
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
# Capture all graphs
graphs: List[List[Optional[ChunkAttentionGraph]]] = [[None] * num_chunks for _ in range(num_chunks)]
for q_idx in range(num_chunks):
for k_idx in range(q_idx + 1):
graphs[q_idx][k_idx] = capture_chunk_attention_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype,
causal=(k_idx == q_idx)
)
print("All graphs captured")
# Pre-fill static tensors
for q_idx in range(num_chunks):
for k_idx in range(q_idx + 1):
g = graphs[q_idx][k_idx]
g.static_q.copy_(full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size])
g.static_k.copy_(full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
g.static_v.copy_(full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
print("Static tensors pre-filled")
# Replay and merge
chunked_output = torch.zeros_like(full_output)
for q_idx in range(num_chunks):
acc_out, acc_lse = None, None
for k_idx in range(q_idx + 1):
g = graphs[q_idx][k_idx]
g.graph.replay()
out, lse = g.static_output.clone(), g.static_lse.clone()
if acc_out is None:
acc_out, acc_lse = out, lse
else:
with torch.inference_mode():
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
torch.cuda.synchronize()
# Compare
all_pass = True
for q_idx in range(num_chunks):
s, e = q_idx * chunk_size, (q_idx + 1) * chunk_size
diff = (full_output[:, s:e] - chunked_output[:, s:e]).abs().max().item()
status = "" if diff < 1e-2 else ""
print(f"Q[{q_idx}]: max_diff={diff:.2e} {status}")
if diff >= 1e-2:
all_pass = False
print("✅ PASSED" if all_pass else "❌ FAILED")
if __name__ == "__main__":
main()

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env python3
"""
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
Key insight: LLM layers have identical computation structure.
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
Usage:
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
"""
from dataclasses import dataclass
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@dataclass
class ReusableChunkGraph:
"""A single graph that can be reused with copy_() updates."""
graph: torch.cuda.CUDAGraph
static_q: torch.Tensor
static_k: torch.Tensor
static_v: torch.Tensor
static_output: torch.Tensor
static_lse: torch.Tensor
def capture_reusable_graph(
chunk_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
scale: float,
device: torch.device,
dtype: torch.dtype,
causal: bool,
) -> ReusableChunkGraph:
"""Capture ONE graph to be reused for all chunk pairs."""
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_q.normal_()
static_k.normal_()
static_v.normal_()
# Warmup
with torch.inference_mode():
for _ in range(3):
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
return ReusableChunkGraph(
graph=graph,
static_q=static_q,
static_k=static_k,
static_v=static_v,
static_output=static_output,
static_lse=static_lse,
)
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Replay graph after updating static tensors with copy_()."""
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
return graph.static_output.clone(), graph.static_lse.clone()
def main():
device = torch.device("cuda")
dtype = torch.bfloat16
chunk_size = 64
num_chunks = 4
num_layers = 3 # Simulate multiple layers
num_heads = 8
num_kv_heads = 8
head_dim = 64
scale = 1.0 / (head_dim ** 0.5)
seq_len = chunk_size * num_chunks
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
# Capture only 2 graphs
graph_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
)
graph_non_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
)
print("2 graphs captured (causal + non-causal)")
all_pass = True
for layer_id in range(num_layers):
# Different Q/K/V for each layer (simulating different layer outputs)
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
# Reference: full causal attention
with torch.inference_mode():
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
# Chunked with graph reuse
chunked_output = torch.zeros_like(full_output)
for q_idx in range(num_chunks):
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
acc_out, acc_lse = None, None
for k_idx in range(q_idx + 1):
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
# Reuse graph with copy_()
graph = graph_causal if k_idx == q_idx else graph_non_causal
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
if acc_out is None:
acc_out, acc_lse = out, lse
else:
with torch.inference_mode():
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
torch.cuda.synchronize()
# Compare
max_diff = (full_output - chunked_output).abs().max().item()
status = "" if max_diff < 1e-2 else ""
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
if max_diff >= 1e-2:
all_pass = False
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
if __name__ == "__main__":
main()

View File

@@ -1,357 +0,0 @@
#!/usr/bin/env python3
"""
CUDA Graph Memory Analysis Test
This script analyzes the memory overhead of CUDA Graph at each stage:
1. Model loading
2. StaticCache allocation
3. Warmup runs
4. Graph capture
5. Graph replay
Usage:
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048
"""
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache
def get_memory_mb():
"""Get current allocated memory in MB."""
return torch.cuda.memory_allocated() / 1024**2
def get_memory_gb():
"""Get current allocated memory in GB."""
return torch.cuda.memory_allocated() / 1024**3
def get_peak_memory_gb():
"""Get peak allocated memory in GB."""
return torch.cuda.max_memory_allocated() / 1024**3
def print_separator(title=None):
"""Print a separator line."""
if title:
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}")
else:
print("-" * 70)
def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1):
"""
Test memory usage at each stage of CUDA Graph setup.
Args:
model_path: Path to the model
max_cache_len: Maximum cache length for StaticCache
batch_size: Batch size for inference
"""
print_separator("CUDA Graph Memory Analysis")
print(f"Model: {model_path}")
print(f"Max cache length: {max_cache_len}")
print(f"Batch size: {batch_size}")
results = {}
# Stage 0: Initial
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
results["initial"] = get_memory_mb()
# Stage 1: Load model
print_separator("Stage 1: Model Loading")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
results["after_model"] = get_memory_mb()
model_size = results["after_model"] - results["initial"]
print(f" Memory: {results['after_model']:.0f} MB")
print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)")
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Stage 2: Allocate StaticCache
print_separator("Stage 2: StaticCache Allocation")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
static_cache = StaticCache(
config=config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
)
results["after_cache"] = get_memory_mb()
cache_size = results["after_cache"] - before
print(f" Memory: {results['after_cache']:.0f} MB")
print(f" StaticCache size: {cache_size:.0f} MB")
# Calculate theoretical cache size
num_layers = config.num_hidden_layers
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
head_dim = config.hidden_size // config.num_attention_heads
dtype_size = 2 # bfloat16
theoretical_cache = (
num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size
) / (1024**2)
print(f" Theoretical: {theoretical_cache:.0f} MB")
print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)")
# Stage 3: Prepare static tensors
print_separator("Stage 3: Static Tensor Allocation")
before = get_memory_mb()
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
results["after_tensors"] = get_memory_mb()
tensor_size = results["after_tensors"] - before
print(f" Memory: {results['after_tensors']:.0f} MB")
print(f" Static tensors: {tensor_size:.2f} MB (negligible)")
# Stage 4: Warmup runs
print_separator("Stage 4: Warmup Runs (3 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for i in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
results["after_warmup"] = get_memory_mb()
results["warmup_peak"] = get_peak_memory_gb() * 1024
warmup_size = results["after_warmup"] - before
print(f" Memory: {results['after_warmup']:.0f} MB")
print(f" Peak: {results['warmup_peak']:.0f} MB")
print(f" Warmup overhead: {warmup_size:.0f} MB")
# Stage 5: CUDA Graph capture
print_separator("Stage 5: CUDA Graph Capture")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
static_logits = outputs.logits
torch.cuda.synchronize()
results["after_capture"] = get_memory_mb()
results["capture_peak"] = get_peak_memory_gb() * 1024
capture_size = results["after_capture"] - before
print(f" Memory: {results['after_capture']:.0f} MB")
print(f" Peak: {results['capture_peak']:.0f} MB")
print(f" Graph capture overhead: {capture_size:.0f} MB")
# Stage 6: Graph replay
print_separator("Stage 6: Graph Replay (10 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for _ in range(10):
static_input_ids.fill_(1)
static_cache_position.fill_(0)
graph.replay()
torch.cuda.synchronize()
results["after_replay"] = get_memory_mb()
results["replay_peak"] = get_peak_memory_gb() * 1024
replay_change = results["after_replay"] - before
print(f" Memory: {results['after_replay']:.0f} MB")
print(f" Peak: {results['replay_peak']:.0f} MB")
print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)")
# Summary
print_separator("SUMMARY")
total_overhead = results["after_capture"] - results["after_model"]
print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}")
print("-" * 50)
print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}")
print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}")
print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}")
print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}")
print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}")
print("-" * 50)
print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}")
print_separator("KEY FINDINGS")
print(f" 1. Model size: {model_size/1024:.2f} GB")
print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)")
print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)")
print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)")
print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB")
return results
def test_cache_length_scaling(model_path: str, cache_lengths: list):
"""
Test how memory scales with different cache lengths.
Args:
model_path: Path to the model
cache_lengths: List of cache lengths to test
"""
print_separator("Cache Length Scaling Test")
print(f"Model: {model_path}")
print(f"Cache lengths: {cache_lengths}")
# Load model once
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
model_mem = get_memory_mb()
results = []
for cache_len in cache_lengths:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Create cache and capture graph
static_cache = StaticCache(
config=config,
max_batch_size=1,
max_cache_len=cache_len,
device=device,
dtype=dtype,
)
static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
with torch.inference_mode():
# Warmup
for _ in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
total_mem = get_memory_mb()
overhead = total_mem - model_mem
results.append((cache_len, total_mem, overhead))
del static_cache, graph
torch.cuda.empty_cache()
# Print results
print()
print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}")
print("-" * 60)
for cache_len, total, overhead in results:
per_1k = overhead / (cache_len / 1000)
print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}")
return results
def main():
parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis")
parser.add_argument(
"--model",
type=str,
default="~/models/Qwen3-4B-Instruct-2507",
help="Model path",
)
parser.add_argument(
"--max-cache-len",
type=int,
default=1024,
help="Maximum cache length",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size",
)
parser.add_argument(
"--test-scaling",
action="store_true",
help="Test cache length scaling",
)
args = parser.parse_args()
model_path = os.path.expanduser(args.model)
if not torch.cuda.is_available():
print("CUDA is not available!")
return
print(f"Device: cuda:{torch.cuda.current_device()}")
print(f"GPU: {torch.cuda.get_device_name()}")
if args.test_scaling:
cache_lengths = [256, 512, 1024, 2048, 4096]
test_cache_length_scaling(model_path, cache_lengths)
else:
test_memory_stages(model_path, args.max_cache_len, args.batch_size)
print("\ntest_cudagraph_memory: PASSED")
if __name__ == "__main__":
main()

View File

@@ -1,149 +0,0 @@
"""
Test: GPU-only density alignment verification
验证 xattn_bsa.py 中 GPU-only 路径的 density 计算是否与独立调用 xattn_estimate 一致。
流程:
1. 运行 GPU-only 推理,保存 Q, K, mask, attn_sums
2. 加载保存的数据,独立调用 xattn_estimate
3. 比较两者的 mask 和 density
"""
import torch
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.ops.xattn import xattn_estimate
# ============================================================
# 参数配置
# ============================================================
DATA_PATH = "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt"
# ============================================================
# 加载保存的数据
# ============================================================
print(f"Loading data from {DATA_PATH}")
data = torch.load(DATA_PATH, weights_only=False)
Q = data["Q"].cuda() # [1, num_heads, q_len, head_dim]
K = data["K"].cuda() # [1, num_heads, k_len, head_dim]
chunk_size = data["chunk_size"]
block_size = data["block_size"]
stride = data["stride"]
threshold = data["threshold"]
mask_saved = data["mask"] # [1, num_heads, valid_q_blocks, valid_k_blocks]
attn_sums_saved = data["attn_sums"]
q_len = data["q_len"]
k_len = data["k_len"]
valid_q_blocks = data["valid_q_blocks"]
valid_k_blocks = data["valid_k_blocks"]
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"q_len: {q_len}, k_len: {k_len}")
print(f"chunk_size: {chunk_size}, block_size: {block_size}, stride: {stride}, threshold: {threshold}")
print(f"valid_q_blocks: {valid_q_blocks}, valid_k_blocks: {valid_k_blocks}")
print(f"mask_saved shape: {mask_saved.shape}")
# ============================================================
# 独立调用 xattn_estimate
# ============================================================
print("\nCalling xattn_estimate independently...")
attn_sums_ext, mask_ext = xattn_estimate(
Q, K,
chunk_size=chunk_size,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
causal=True,
)
# Trim to valid blocks
mask_ext_valid = mask_ext[:, :, :valid_q_blocks, :valid_k_blocks]
attn_sums_ext_valid = attn_sums_ext[:, :, :valid_q_blocks, :valid_k_blocks]
print(f"mask_ext shape: {mask_ext.shape}")
print(f"mask_ext_valid shape: {mask_ext_valid.shape}")
# ============================================================
# 比较 attn_sums
# ============================================================
print("\n" + "=" * 60)
print("Comparing attn_sums")
print("=" * 60)
attn_sums_saved_gpu = attn_sums_saved.cuda()
attn_diff = (attn_sums_ext_valid - attn_sums_saved_gpu).abs()
print(f"attn_sums max diff: {attn_diff.max().item():.6e}")
print(f"attn_sums mean diff: {attn_diff.mean().item():.6e}")
# Check if attn_sums match
attn_match = attn_diff.max().item() < 1e-4
print(f"attn_sums match: {attn_match}")
# ============================================================
# 比较 mask
# ============================================================
print("\n" + "=" * 60)
print("Comparing mask")
print("=" * 60)
mask_saved_gpu = mask_saved.cuda()
mask_match = (mask_ext_valid == mask_saved_gpu).all().item()
print(f"mask exact match: {mask_match}")
if not mask_match:
diff_count = (mask_ext_valid != mask_saved_gpu).sum().item()
total_count = mask_ext_valid.numel()
print(f"mask diff count: {diff_count} / {total_count} ({diff_count/total_count*100:.2f}%)")
# ============================================================
# 计算 density
# ============================================================
print("\n" + "=" * 60)
print("Comparing density")
print("=" * 60)
# 计算 causal mask
q_offset_blocks = valid_k_blocks - valid_q_blocks
indices = torch.arange(valid_k_blocks, device=mask_ext_valid.device).unsqueeze(0)
q_indices = torch.arange(valid_q_blocks, device=mask_ext_valid.device).unsqueeze(1)
causal_mask = indices <= (q_indices + q_offset_blocks)
# Density from saved mask
total_saved = causal_mask.sum().item() * mask_saved_gpu.shape[0] * mask_saved_gpu.shape[1]
selected_saved = (mask_saved_gpu & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_saved = selected_saved / total_saved
# Density from external xattn_estimate
total_ext = causal_mask.sum().item() * mask_ext_valid.shape[0] * mask_ext_valid.shape[1]
selected_ext = (mask_ext_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_ext = selected_ext / total_ext
print(f"Saved density: {density_saved:.6f} (selected={selected_saved}, total={total_saved})")
print(f"External density: {density_ext:.6f} (selected={selected_ext}, total={total_ext})")
print(f"Density diff: {abs(density_saved - density_ext):.6f}")
# ============================================================
# 结论
# ============================================================
print("\n" + "=" * 60)
print("RESULT")
print("=" * 60)
if attn_match and mask_match:
print("✅ PASSED: GPU-only density matches external xattn_estimate")
else:
print("❌ FAILED: Mismatch detected")
if not attn_match:
print(" - attn_sums mismatch")
if not mask_match:
print(" - mask mismatch")

View File

@@ -1,442 +0,0 @@
"""
Test: Hierarchical Block Sum Estimation for XAttention
Verify that hierarchical estimation (small estimate_block_size + aggregation)
produces equivalent results to direct estimation (large block_size), while
being significantly faster.
Key changes validated:
1. Hierarchical block sum: estimate_block_size=1024 → aggregate to cpu_block_size=4096
2. Selection strategy: score + threshold (NOT mask + majority voting)
This test uses pure torch + xattn kernels, independent of nanovllm framework.
"""
import sys
import os
import torch
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# Configuration
# ============================================================
# Model dimensions (Llama-3.1-8B-Instruct style)
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
STRIDE = 8
# Block sizes
CPU_BLOCK_SIZE = 4096 # External CPU block size (fixed, for overlap)
ESTIMATE_BLOCK_SIZE = 1024 # Internal estimate block size (optimized)
# Selection parameters
THRESHOLD = 0.95 # Cumulative attention threshold
# ============================================================
# Hierarchical Estimation Implementation
# ============================================================
def compute_attention_scores(Q, K_blocks, stride):
"""
Compute attention scores for Q against multiple K blocks.
Args:
Q: [1, num_heads, q_len, head_dim]
K_blocks: List of K tensors, each [1, num_heads, block_size, head_dim]
stride: Stride for reshape
Returns:
attn_scores: [1, num_heads, q_reshaped, total_k_reshaped]
"""
q_len = Q.shape[2]
q_reshaped = q_len // stride
attn_chunks = []
for K_block in K_blocks:
# flat_group_gemm_fuse_reshape
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_block, stride,
chunk_start=0,
chunk_end=q_reshaped,
is_causal=False,
)
attn_chunks.append(attn_chunk)
# Concatenate along K dimension
attn_scores = torch.cat(attn_chunks, dim=-1)
return attn_scores
def hierarchical_block_sum(
attn_scores,
estimate_block_size,
cpu_block_size,
stride,
head_dim,
):
"""
Compute hierarchical block sums: fine-grained → aggregated to CPU block level.
Args:
attn_scores: [batch, heads, q_reshaped, k_reshaped]
estimate_block_size: Small block size for efficient softmax (e.g., 1024)
cpu_block_size: External CPU block size (e.g., 4096)
stride: Stride used in reshape
head_dim: Head dimension for scale computation
Returns:
cpu_block_scores: [batch, heads, num_cpu_blocks] - attention score per CPU block
"""
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
# Compute reshaped block sizes
reshaped_est_bs = estimate_block_size // stride # 1024/8 = 128
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
# Scale factor
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Segment size for softmax kernel
segment_size = min(4096, reshaped_est_bs)
# Step 1: Fine-grained softmax + block sum
block_sums_fine = softmax_fuse_block_sum(
attn_scores,
reshaped_est_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped,
real_q_len=q_reshaped,
scale=scale,
is_causal=False,
)
# block_sums_fine: [batch, heads, q_est_blocks, k_est_blocks]
q_est_blocks = block_sums_fine.shape[2]
k_est_blocks = block_sums_fine.shape[3]
# Step 2: Aggregate to CPU block level
# ratio = cpu_block_size / estimate_block_size = 4
ratio = cpu_block_size // estimate_block_size
num_cpu_blocks = k_est_blocks // ratio
# Reshape and sum along K dimension
# [batch, heads, q_est, k_est] → [batch, heads, q_est, num_cpu, ratio]
block_sums_coarse = block_sums_fine.view(
batch_size, num_heads, q_est_blocks, num_cpu_blocks, ratio
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
# Step 3: Sum over Q dimension (total attention from Q chunk to each K block)
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
return cpu_block_scores, block_sums_fine
def direct_block_sum(
attn_scores,
cpu_block_size,
stride,
head_dim,
):
"""
Compute block sums directly with CPU block size (baseline for comparison).
Args:
attn_scores: [batch, heads, q_reshaped, k_reshaped]
cpu_block_size: Block size (e.g., 4096)
stride: Stride used in reshape
head_dim: Head dimension for scale computation
Returns:
cpu_block_scores: [batch, heads, num_cpu_blocks]
"""
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
segment_size = min(4096, reshaped_cpu_bs)
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_cpu_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped,
real_q_len=q_reshaped,
scale=scale,
is_causal=False,
)
# block_sums: [batch, heads, q_cpu_blocks, k_cpu_blocks]
# Sum over Q dimension
cpu_block_scores = block_sums.sum(dim=2) # [batch, heads, num_cpu_blocks]
return cpu_block_scores
def select_blocks_by_score(
cpu_block_scores,
threshold=0.95,
always_include_first=True,
always_include_last=True,
):
"""
Select CPU blocks based on score + threshold.
⚠️ IMPORTANT: This replaces the original mask + majority voting strategy.
This change should be documented in the final implementation.
Args:
cpu_block_scores: [batch, heads, num_cpu_blocks]
threshold: Cumulative attention threshold (e.g., 0.95)
always_include_first: Always include first block (sink)
always_include_last: Always include last block (safety)
Returns:
selected_block_ids: List of selected block indices
density: Fraction of blocks selected
"""
# Average scores across heads
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
num_blocks = scores_per_block.shape[0]
# Normalize to get attention distribution
total_score = scores_per_block.sum()
score_ratio = scores_per_block / total_score
# Sort by score (descending)
sorted_indices = torch.argsort(score_ratio, descending=True)
# Select blocks until cumulative threshold is reached
cumsum = 0.0
selected = set()
for idx in sorted_indices.tolist():
selected.add(idx)
cumsum += score_ratio[idx].item()
if cumsum >= threshold:
break
# Always include first and last blocks
if always_include_first:
selected.add(0)
if always_include_last:
selected.add(num_blocks - 1)
selected_block_ids = sorted(list(selected))
density = len(selected_block_ids) / num_blocks
return selected_block_ids, density
# ============================================================
# Test Cases
# ============================================================
def test_equivalence():
"""
Test that hierarchical estimation produces equivalent scores to direct estimation.
"""
print("=" * 60)
print("Test 1: Hierarchical vs Direct - Equivalence")
print("=" * 60)
# Create random Q and multiple K blocks
q_len = CPU_BLOCK_SIZE # 4096
num_k_blocks = 4
# Q: [1, num_heads, q_len, head_dim]
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# K blocks: each [1, num_heads, cpu_block_size, head_dim]
K_blocks = [
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
for _ in range(num_k_blocks)
]
# Compute attention scores
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
print(f"attn_scores shape: {attn_scores.shape}")
# Method 1: Hierarchical (fast)
scores_hier, _ = hierarchical_block_sum(
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
print(f"scores_hier shape: {scores_hier.shape}")
# Method 2: Direct (slow)
scores_direct = direct_block_sum(
attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
print(f"scores_direct shape: {scores_direct.shape}")
# Compare
diff = (scores_hier - scores_direct).abs().max().item()
print(f"\nMax difference: {diff:.6f}")
# Per-block comparison
print("\nPer-block scores comparison:")
for i in range(num_k_blocks):
h_val = scores_hier[0, 0, i].item()
d_val = scores_direct[0, 0, i].item()
print(f" Block {i}: hierarchical={h_val:.4f}, direct={d_val:.4f}, diff={abs(h_val-d_val):.6f}")
passed = diff < 0.01
print(f"\nTest 1: {'PASSED' if passed else 'FAILED'}")
return passed
def test_selection():
"""
Test the score + threshold selection strategy.
"""
print("\n" + "=" * 60)
print("Test 2: Score + Threshold Selection")
print("=" * 60)
# Create Q and K blocks with varying importance
q_len = CPU_BLOCK_SIZE
num_k_blocks = 8
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# Create K blocks - make some more important than others
K_blocks = []
for i in range(num_k_blocks):
# First and middle blocks are more important (higher values)
importance = 2.0 if i in [0, 3, 4] else 1.0
K = torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
K = K * importance
K_blocks.append(K)
# Compute scores
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
scores, _ = hierarchical_block_sum(
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
# Print scores per block
print("\nCPU block scores (head 0):")
for i in range(num_k_blocks):
print(f" Block {i}: {scores[0, 0, i].item():.4f}")
# Select blocks with different thresholds
for thresh in [0.9, 0.95, 0.99]:
selected, density = select_blocks_by_score(scores, threshold=thresh)
print(f"\nThreshold {thresh}: selected {len(selected)}/{num_k_blocks} blocks ({density:.1%})")
print(f" Selected: {selected}")
print("\nTest 2: PASSED (visual inspection)")
return True
def test_performance():
"""
Benchmark hierarchical vs direct estimation performance.
"""
print("\n" + "=" * 60)
print("Test 3: Performance Benchmark")
print("=" * 60)
import time
NUM_WARMUP = 3
NUM_RUNS = 10
# Larger test case
q_len = CPU_BLOCK_SIZE
num_k_blocks = 16 # 64K context
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
K_blocks = [
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
for _ in range(num_k_blocks)
]
# Compute attention scores (shared)
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
print(f"attn_scores shape: {attn_scores.shape}")
print(f"Context: {num_k_blocks * CPU_BLOCK_SIZE // 1024}K tokens")
# Warmup and benchmark hierarchical
for _ in range(NUM_WARMUP):
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(NUM_RUNS):
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
hier_time = (time.perf_counter() - start) / NUM_RUNS * 1000
# Warmup and benchmark direct
for _ in range(NUM_WARMUP):
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(NUM_RUNS):
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
direct_time = (time.perf_counter() - start) / NUM_RUNS * 1000
speedup = direct_time / hier_time
print(f"\nResults:")
print(f" Hierarchical (bs=1024): {hier_time:.2f} ms")
print(f" Direct (bs=4096): {direct_time:.2f} ms")
print(f" Speedup: {speedup:.2f}x")
passed = speedup > 5.0 # Expect at least 5x speedup
print(f"\nTest 3: {'PASSED' if passed else 'FAILED'} (speedup > 5x expected)")
return passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Hierarchical Block Sum Estimation Test")
print("=" * 60)
print(f"\nConfiguration:")
print(f" NUM_HEADS: {NUM_HEADS}")
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
print(f" HEAD_DIM: {HEAD_DIM}")
print(f" STRIDE: {STRIDE}")
print(f" CPU_BLOCK_SIZE: {CPU_BLOCK_SIZE}")
print(f" ESTIMATE_BLOCK_SIZE: {ESTIMATE_BLOCK_SIZE}")
print(f" THRESHOLD: {THRESHOLD}")
print()
results = []
results.append(("Equivalence", test_equivalence()))
results.append(("Selection", test_selection()))
results.append(("Performance", test_performance()))
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for name, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" {name}: {status}")
all_passed = all(p for _, p in results)
print("=" * 60)
if all_passed:
print("test_hierarchical_estimate: ALL PASSED")
sys.exit(0)
else:
print("test_hierarchical_estimate: SOME FAILED")
sys.exit(1)

View File

@@ -1,254 +0,0 @@
"""
Needle-in-a-haystack test for LLM.
Tests: Long context retrieval capability with configurable sequence length.
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
enable_quest: bool = False,
enable_xattn_bsa: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
sparse_samples: int = 128,
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test.
Args:
model_path: Path to model
max_model_len: Maximum model context length
input_len: Target input sequence length
num_gpu_blocks: Number of GPU blocks for offload
block_size: KV cache block size
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_xattn_bsa: Enable XAttention BSA sparse attention (prefill-only)
sparse_topk: Top-K blocks for Quest
sparse_threshold: Threshold for sparse selection (Quest/XAttention BSA)
sparse_samples: Samples per chunk for XAttention BSA estimation
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
# Determine sparse policy
if enable_xattn_bsa:
sparse_policy = SparsePolicyType.XATTN_BSA
elif enable_quest:
sparse_policy = SparsePolicyType.QUEST
else:
sparse_policy = SparsePolicyType.FULL
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload:
print(f"Sparse policy: {sparse_policy.name}")
if sparse_policy == SparsePolicyType.QUEST:
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
elif sparse_policy == SparsePolicyType.XATTN_BSA:
print(f" XAttention BSA: threshold={sparse_threshold}, samples={sparse_samples}")
print(f"{'='*60}\n")
# 1. Initialize LLM
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy
if sparse_policy == SparsePolicyType.QUEST:
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
elif sparse_policy == SparsePolicyType.XATTN_BSA:
llm_kwargs["sparse_threshold"] = float(sparse_threshold) / 10.0 # Convert to 0.0-1.0 range
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt
prompt, expected = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Generate output
sampling_params = SamplingParams(
temperature=0.6, # Moderate temperature
max_tokens=max_new_tokens,
)
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
# 4. Check result
output_text = outputs[0]["text"]
output_token_ids = outputs[0]["token_ids"]
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=128 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload (has known bug for long sequences)"
)
parser.add_argument(
"--enable-quest",
action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)"
)
parser.add_argument(
"--enable-xattn-bsa",
action="store_true",
help="Enable XAttention BSA sparse attention (prefill-only)"
)
parser.add_argument(
"--sparse-topk",
type=int,
default=8,
help="Top-K blocks for Quest sparse attention"
)
parser.add_argument(
"--sparse-threshold",
type=int,
default=4,
help="Apply sparse only when blocks > threshold (Quest) or attention threshold 0-9 (XAttention BSA)"
)
parser.add_argument(
"--sparse-samples",
type=int,
default=128,
help="Samples per chunk for XAttention BSA estimation"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
enable_xattn_bsa=args.enable_xattn_bsa,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
sparse_samples=args.sparse_samples,
verbose=True,
)
if passed:
print("test_needle: PASSED")
else:
print("test_needle: FAILED")
exit(1)

View File

@@ -1,176 +0,0 @@
"""
Needle-in-a-haystack reference test using pure torch + transformers.
This is a reference implementation for comparison with nanovllm.
Uses standard HuggingFace inference (no custom KV cache, no offload).
"""
import os
import argparse
import torch
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
input_len: int,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
dtype: str = "auto",
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test using standard transformers inference.
Args:
model_path: Path to model
input_len: Target input sequence length
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
dtype: Model dtype ("auto", "float16", "bfloat16")
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Reference Test (torch + transformers)")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Input length: {input_len}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"Dtype: {dtype}")
print(f"{'='*60}\n")
# 1. Load tokenizer
print("[1/4] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# 2. Generate needle prompt
print("[2/4] Generating needle prompt...")
prompt, expected = generate_needle_prompt(
tokenizer=tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Load model
print("[3/4] Loading model...")
torch_dtype = {
"auto": torch.float16, # default to float16 for custom model
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}.get(dtype, torch.float16)
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# 4. Generate output
print("[4/4] Running inference...")
device = next(model.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
print(f" Input shape: {input_ids.shape}")
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=0.6,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
# Decode only the new tokens
new_token_ids = output_ids[0, input_ids.shape[1]:]
output_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
# 5. Check result
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(new_token_ids)}): {new_token_ids[:20].tolist()}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Needle-in-haystack reference test (torch + transformers)"
)
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "float16", "bfloat16"],
help="Model dtype"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
input_len=args.input_len,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
dtype=args.dtype,
verbose=True,
)
if passed:
print("test_needle_ref: PASSED")
else:
print("test_needle_ref: FAILED")
exit(1)

View File

@@ -1,136 +0,0 @@
"""
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
Demonstrates the key limitation: scores are AVERAGED across heads,
so blocks strongly needed by one head but not others may be dropped.
This is the expected Quest behavior - not a bug.
"""
import torch
from nanovllm.kvcache.sparse import (
create_sparse_policy,
SparsePolicyType,
PolicyContext,
)
# ============================================================
# Test: Per-Head Score Averaging in GQA
# ============================================================
# Determine device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running test on device: {device}")
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
# topk=2 to make selection competitive
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
quest.initialize(
num_layers=1,
num_kv_heads=2,
head_dim=4,
num_cpu_blocks=6,
dtype=torch.float32,
device=device, # Metadata stored on GPU
)
metadata = quest.metadata
def set_key(block_id, head_id, values):
"""Set both key_min and key_max to same values for deterministic scoring."""
# Values need to be on the same device as metadata
tensor = torch.tensor(values, device=device)
metadata.key_min[block_id, 0, head_id, :] = tensor
metadata.key_max[block_id, 0, head_id, :] = tensor
# ============================================================
# Design: Different heads want different blocks
# ============================================================
#
# Query = [1,1,1,1] for all heads, so score = sum(key values)
#
# Block | Head 0 | Head 1 | Average | Result
# ------|--------|--------|---------|--------
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 2: Both heads want equally (highest average)
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 3: Both heads want moderately
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
# Block 4: Head 0 strongly wants, Head 1 neutral
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
# Block 5: Head 1 strongly wants, Head 0 neutral
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# ============================================================
# Run selection
# ============================================================
# Query on same device as metadata
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=0,
query=query,
is_prefill=False,
block_size=1024,
total_kv_len=6144,
)
available = list(range(6))
selected = quest.select_blocks(available, ctx)
# ============================================================
# Verify: Averaging behavior
# ============================================================
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
# Key insight: blocks 0 and 1 have score +4 for ONE head,
# but they cancel out due to averaging with the other head's -4
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
# But +2 < +3 (block 3), so they don't make the top-2
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
# Verify metadata is on correct device
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
print(f"✓ Metadata stored on {device.type.upper()}")
print("\ntest_quest_policy: PASSED")

View File

@@ -1,199 +0,0 @@
"""
Sequential inference test for LLM.
Tests: After completing one prompt, the system can correctly handle
a second prompt with a clean state (first prompt's KV cache deallocated).
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from utils import generate_needle_prompt, check_needle_answer
def run_sequential_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
enable_cpu_offload: bool = False,
verbose: bool = True,
) -> bool:
"""
Run sequential inference test with two different prompts.
Each prompt has a different needle value. Both must be retrieved correctly.
"""
if verbose:
print(f"\n{'='*60}")
print(f"Sequential Inference Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"CPU offload: {enable_cpu_offload}")
print(f"{'='*60}\n")
# Initialize LLM once
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm = LLM(model_path, **llm_kwargs)
sampling_params = SamplingParams(
temperature=0.6,
max_tokens=32,
)
# ============================================================
# Test 1: First prompt with needle value "1234"
# ============================================================
needle_value_1 = "1234"
if verbose:
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
prompt_1, expected_1 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_1,
)
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
output_text_1 = outputs_1[0]["text"]
passed_1 = check_needle_answer(output_text_1, expected_1)
if verbose:
print(f" Expected: {expected_1}")
print(f" Output: {output_text_1[:100]}...")
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
# ============================================================
# Test 2: Second prompt with needle value "5678"
# ============================================================
needle_value_2 = "5678"
if verbose:
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
prompt_2, expected_2 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_2,
)
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
output_text_2 = outputs_2[0]["text"]
passed_2 = check_needle_answer(output_text_2, expected_2)
if verbose:
print(f" Expected: {expected_2}")
print(f" Output: {output_text_2[:100]}...")
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
# ============================================================
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
# ============================================================
needle_value_3 = "9999"
if verbose:
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
prompt_3, expected_3 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_3,
)
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
output_text_3 = outputs_3[0]["text"]
passed_3 = check_needle_answer(output_text_3, expected_3)
if verbose:
print(f" Expected: {expected_3}")
print(f" Output: {output_text_3[:100]}...")
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
# ============================================================
# Summary
# ============================================================
all_passed = passed_1 and passed_2 and passed_3
if verbose:
print(f"\n{'='*60}")
print(f"Summary")
print(f"{'='*60}")
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
print(f"{'='*60}\n")
return all_passed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential inference test")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=36 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload"
)
args = parser.parse_args()
passed = run_sequential_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
enable_cpu_offload=args.enable_offload,
verbose=True,
)
if passed:
print("test_sequential: PASSED")
else:
print("test_sequential: FAILED")
exit(1)

View File

@@ -1,334 +0,0 @@
"""
Test XAttention + BSA with RULER benchmark data.
Tests XAttention sparse attention correctness using RULER NIAH task.
Attention methods:
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
- Decode: FlashAttention (always, since q_len=1)
Usage (in compass conda env with BSA available):
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
# Test with XAttention + BSA for prefill (default)
python tests/test_xattn_bsa.py --prefill-method xattn
# Test with FlashAttention for prefill (baseline)
python tests/test_xattn_bsa.py --prefill-method flash
# Test specific sample(s)
python tests/test_xattn_bsa.py --sample-id 0
python tests/test_xattn_bsa.py --sample-ids 0,1,2
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
and new `past_key_values` API).
"""
import argparse
import json
import sys
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from nanovllm.ops.xattn import xattn_estimate
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# ============================================================
# XAttention + BSA Functions
# ============================================================
def expand_kv_for_gqa(key_states, value_states, num_heads):
"""Expand KV for Grouped Query Attention."""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
"""Standard FlashAttention."""
from flash_attn import flash_attn_func
q = query_states.transpose(1, 2)
k = key_states.transpose(1, 2)
v = value_states.transpose(1, 2)
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
"""XAttention + BSA sparse attention."""
from block_sparse_attn import block_sparse_attn_func
batch_size, num_heads, q_len, head_dim = query_states.shape
k_len = key_states.shape[2]
_, mask = xattn_estimate(
query_states, key_states,
chunk_size=16384, block_size=128, threshold=threshold,
use_triton=True, causal=True,
)
q_block_num = (q_len + 127) // 128
k_block_num = (k_len + 127) // 128
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
__import__('pdb').set_trace()
output = block_sparse_attn_func(
q, k, v,
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
torch.ones(num_heads, dtype=torch.int32, device=q.device),
None,
mask[:, :, :q_block_num, :k_block_num].contiguous(),
q_len, k_len,
p_dropout=0.0, deterministic=True, is_causal=True,
)
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
DEBUG = False # Set to True to enable debugging
def create_patched_forward(prefill_method="xattn", threshold=0.9):
"""Create patched forward with configurable prefill method.
Args:
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
Note:
- Prefill (q_len > 1): Uses specified prefill_method
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
"""
call_count = [0] # Mutable to track calls across layers
def patched_forward(
self,
hidden_states,
position_embeddings=None,
attention_mask=None,
past_key_value=None, # Old API (transformers < 4.57)
past_key_values=None, # New API (transformers >= 4.57)
cache_position=None,
**kwargs
):
# Handle both old and new transformers API
kv_cache = past_key_values if past_key_values is not None else past_key_value
bsz, q_len, _ = hidden_states.size()
num_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
# Compute Q, K, V projections
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
# Apply rotary position embedding
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Handle KV cache
if kv_cache is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = kv_cache.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# Expand KV for GQA
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
# Debug output
if DEBUG and self.layer_idx == 0:
call_count[0] += 1
if call_count[0] <= 5:
phase = "prefill" if q_len > 1 else "decode"
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
print(f" kv_cache is None: {kv_cache is None}")
# Choose attention method:
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
# - Decode (q_len = 1): Always use FlashAttention
is_prefill = q_len > 1
if is_prefill and prefill_method == "xattn":
# Prefill with XAttention + BSA (sparse)
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
else:
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
# Note: For decode (q_len=1), causal=False since single query attends to all KV
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
return attn_output, None
return patched_forward
# ============================================================
# Data & Evaluation
# ============================================================
def load_samples(filepath, indices=None):
"""Load samples from JSONL file."""
samples = []
with open(filepath) as f:
for i, line in enumerate(f):
if indices is None or i in indices:
sample = json.loads(line)
sample["_idx"] = i
samples.append(sample)
return samples
def string_match_all(output_text, expected_list):
"""RULER metric: fraction of expected values found in output."""
output_lower = output_text.lower().replace('\n', ' ')
if not expected_list:
return 1.0
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
# ============================================================
# Test
# ============================================================
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
"""Test attention methods using RULER data.
Args:
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
"""
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
print("=" * 60)
print("RULER NIAH Attention Test")
print("=" * 60)
print(f"Data: {data_file}")
print(f"Samples: {sample_ids}")
print(f"Prefill method: {prefill_desc}")
print(f"Decode method: FlashAttention (always)")
if prefill_method == "xattn":
print(f"XAttention threshold: {threshold}")
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
if not samples:
print("No samples found!")
return False
print(f"Loaded {len(samples)} samples")
# Load model
print(f"\nLoading model: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="cuda",
attn_implementation="eager", # Will be patched
)
model.eval()
# Patch all layers
print(f"Patching attention layers...")
print(f" - Prefill: {prefill_desc}")
print(f" - Decode: FlashAttention")
for idx, layer in enumerate(model.model.layers):
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
layer.self_attn, type(layer.self_attn)
)
total_score = 0.0
results = []
for sample in samples:
idx = sample["_idx"]
prompt = sample["input"]
expected = sample["outputs"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
num_tokens = inputs["input_ids"].shape[1]
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
print(f"Expected: {expected}")
with torch.no_grad():
output = model.generate(
inputs["input_ids"],
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
score = string_match_all(output_text, expected)
total_score += score
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
print(f"Output: '{output_text[:100]}...'")
print(f"Result: {status} (score={score:.2f})")
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
avg_score = total_score / len(samples)
passed = sum(1 for r in results if r["passed"])
print(f"\n{'='*60}")
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
print(f"{'='*60}")
return avg_score >= 0.5
def main():
parser = argparse.ArgumentParser(
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
)
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
parser.add_argument("--max-new-tokens", type=int, default=50)
# Keep old option for backwards compatibility
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
args = parser.parse_args()
model_path = args.model.replace("~", "/home/zijie")
# Handle deprecated --no-xattn option
prefill_method = args.prefill_method
if args.no_xattn:
prefill_method = "flash"
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
if args.sample_id is not None:
sample_ids = [args.sample_id]
elif args.sample_ids:
sample_ids = [int(x) for x in args.sample_ids.split(",")]
else:
sample_ids = [0]
# Check BSA availability if using xattn
if prefill_method == "xattn":
try:
from block_sparse_attn import block_sparse_attn_func
print("✓ BSA (Block Sparse Attention) available")
except ImportError:
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
sys.exit(1)
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
print("\ntest_xattn_bsa: PASSED")
else:
print("\ntest_xattn_bsa: FAILED")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,259 +0,0 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
Uses real QKV data captured from model inference.
"""
import sys
import os
import torch
import warnings
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096
# Default QKV data directory (relative to project root)
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
# ============================================================
# Utility Functions
# ============================================================
def load_qkv(path):
"""Load saved QKV data."""
data = torch.load(path, map_location="cpu", weights_only=False)
print(f"Loaded: {path}")
print(f" Query shape: {data['query'].shape}")
print(f" Key shape: {data['key'].shape}")
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
return data
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return xattn_estimate_chunked(
query, key,
q_start_pos=q_start_pos,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
k_end = q_start_pos + q_chunk_end
k_chunk = key[:, :, :k_end, :]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_start_pos + q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_qkv(qkv_path):
"""Test a single QKV file."""
data = load_qkv(qkv_path)
query = data["query"].cuda().to(torch.bfloat16)
key = data["key"].cuda().to(torch.bfloat16)
seq_len = query.shape[2]
print(f"\nTesting with seq_len={seq_len}")
print("=" * 60)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
)
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
q_start_pos=0,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
help="Directory containing QKV files")
args = parser.parse_args()
# QKV files to test
qkv_files = [
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
]
available_files = [p for p in qkv_files if os.path.exists(p)]
if not available_files:
print(f"No QKV file found in {args.qkv_dir}.")
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
sys.exit(1)
print(f"Found {len(available_files)} QKV files to test")
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
print(f"Using Triton kernels")
all_passed = True
results = []
for qkv_path in available_files:
passed = test_single_qkv(qkv_path)
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
results.append((seq_len, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, passed in results:
status = "PASSED" if passed else "FAILED"
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
print("=" * 60)
if all_passed:
print("test_xattn_chunked: PASSED")
sys.exit(0)
else:
print("test_xattn_chunked: FAILED")
sys.exit(1)

View File

@@ -1,244 +0,0 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask
as standard estimation. This ensures the chunked version can be used in
chunked prefill scenarios without accuracy loss.
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
"""
import sys
import traceback
import torch
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
# Configuration for xattn_estimate_chunked consistency test.
# Key requirements for 100% match:
# 1. Use matching chunk_size for both standard and chunked versions
# 2. Use same random seed for reproducibility
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
# floating point precision in cumulative sum calculations.
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096 # External chunking size
# Test sequence lengths
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
# ============================================================
# Utility Functions
# ============================================================
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
return xattn_estimate_chunked(
query, key,
q_start_pos=0,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
# q_start_pos=0 means Q starts at position 0 in the full sequence
# K is [0, q_chunk_end) for causal attention
k_end = q_chunk_end
k_chunk = key[:, :, :k_end, :]
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
"""Test a single sequence length."""
print(f"\nTesting seq_len={seq_len}")
print("=" * 60)
# Generate random Q/K
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
causal=True,
)
density_std = mask_std.float().mean().item()
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
density_chunked = mask_chunked.float().mean().item()
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
print("XAttention Chunked vs Standard Test")
print("=" * 60)
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
print(f"External chunk_size={CHUNK_SIZE}")
print()
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available!")
sys.exit(1)
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print("✓ xattn_estimate imported")
print("✓ xattn_estimate_chunked imported")
# Run tests
all_passed = True
results = []
for seq_len in TEST_SEQ_LENS:
passed = test_single_seq_len(seq_len)
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
results.append((seq_len, chunks, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, chunks, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
print("=" * 60)
if all_passed:
print("ALL TESTS PASSED!")
sys.exit(0)
else:
print("SOME TESTS FAILED!")
sys.exit(1)

View File

@@ -1,132 +0,0 @@
"""
Test: XAttention Triton kernels
演示 XAttention 的两个核心 Triton kernel:
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
数据流:
Q [batch, heads, q_len, head_dim]
K [batch, heads, kv_len, head_dim]
↓ flat_group_gemm_fuse_reshape
attn_scores [batch, heads, q_len/stride, kv_len/stride]
↓ softmax_fuse_block_sum
block_sums [batch, heads, q_blocks, k_blocks]
"""
import torch
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# 参数配置
# ============================================================
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
q_len = 512
kv_len = 2048
head_dim = 128
stride = 4
block_size = 128 # softmax block size (in reshaped space)
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
# ============================================================
# 构造输入: 偶数位置=1, 奇数位置=2
# ============================================================
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(q_len):
if i % 2 == 0:
Q[0, 0, i, :] = 1 * (i // stride + 1)
else:
Q[0, 0, i, :] = 2 * (i // stride + 1)
for i in range(kv_len):
if i % 2 == 0:
K[0, 0, i, :] = 1
else:
K[0, 0, i, :] = 2
# ============================================================
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
# ============================================================
q_reshaped_len = q_len // stride # 128
kv_reshaped_len = kv_len // stride # 512
# 将 K 沿着长度维度分成多个 chunk
k_chunk_size = 512 # 每个 chunk 512 tokens
num_k_chunks = kv_len // k_chunk_size # 4 chunks
attn_scores_list = []
for k_chunk_idx in range(num_k_chunks):
k_start = k_chunk_idx * k_chunk_size
k_end = k_start + k_chunk_size
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=True
)
__import__('pdb').set_trace()
attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks 的结果
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
attn_scores = torch.cat(attn_scores_list, dim=-1)
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
# 验证: 反对角线求和
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
# 反对角线有 stride/2 对,再乘以 head_dim
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
actual_gemm = attn_scores[0, 0, 0, 0].item()
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
# ============================================================
# Step 2: softmax_fuse_block_sum
# ============================================================
scale = 1.4426950408889634 # log2(e) for exp2
block_sums = softmax_fuse_block_sum(
attn_scores,
block_size,
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False
)
# 验证 shape: [batch, heads, q_blocks, k_blocks]
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
# 验证: 每个 block 的 softmax 结果求和
# 所有 attn_scores 相同 → softmax 均匀分布
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
# 每个 Q block 有 block_size 行
# block_sum = block_size * (block_size / kv_reshaped_len)
expected_sum = block_size * block_size / kv_reshaped_len
actual_sum = block_sums[0, 0, 0, 0].item()
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
print("test_xattn_kernels: PASSED")

View File

@@ -1,246 +0,0 @@
"""
Test: 批量验证 xattn_estimate 与 KV chunking kernels 的一致性
测试 results/kvcache 下所有保存的 QKV 数据
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_kv_chunking_batch.py
"""
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
import os
import glob
import torch
import math
from nanovllm.ops.xattn import (
xattn_estimate,
flat_group_gemm_fuse_reshape,
softmax_compute_partial_stats,
softmax_normalize_and_block_sum,
merge_softmax_stats,
find_blocks_chunked,
)
# ============================================================
# 参数配置
# ============================================================
DATA_DIR = "/home/zijie/Code/nano-vllm/results/kvcache"
BSA_BLOCK_SIZE = 128
CHUNK_SIZE = 16384
device = "cuda"
def test_single_file(data_file: str) -> dict:
"""测试单个 kvcache 文件"""
data = torch.load(data_file, map_location="cpu")
Q = data["query"].to(device)
K = data["key"].to(device)
batch_size, num_heads, seq_len, head_dim = Q.shape
STRIDE = data["stride"]
THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"]
# ========== xattn_estimate API ==========
attn_sums_api, mask_api = xattn_estimate(
Q, K,
block_size=BSA_BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
causal=True,
)
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
total_api = causal_mask.sum().item() * batch_size * num_heads
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_api = selected_api / total_api
# ========== 三阶段 KV Chunking ==========
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
reshaped_chunk_size = CHUNK_SIZE // STRIDE
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
if k_num_to_pad > 0:
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
else:
K_padded = K
if q_num_to_pad > 0:
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
else:
Q_padded = Q
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
simple_mask_list = []
for q_chunk_idx in range(q_chunk_num):
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
q_end = q_start + reshaped_chunk_size * STRIDE
Q_chunk = Q_padded[:, :, q_start:q_end, :]
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
chunk_end = chunk_start + reshaped_chunk_size
m_chunks = []
l_chunks = []
attn_weights_chunks = []
for kv_chunk_idx in range(kv_chunk_num):
kv_start = kv_chunk_idx * CHUNK_SIZE
kv_end = kv_start + CHUNK_SIZE
K_chunk = K_padded[:, :, kv_start:kv_end, :]
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
attn_weights_kv = flat_group_gemm_fuse_reshape(
Q_chunk, K_chunk, STRIDE,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False,
)
attn_weights_chunks.append(attn_weights_kv)
m_partial, l_partial = softmax_compute_partial_stats(
attn_weights_kv,
reshaped_block_size,
min(4096, reshaped_block_size),
scale,
chunk_start=chunk_start,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
m_chunks.append(m_partial)
l_chunks.append(l_partial)
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
attn_sum_per_kv = []
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
attn_sum_kv = softmax_normalize_and_block_sum(
attn_weights_kv,
m_global,
l_global,
reshaped_block_size,
min(4096, reshaped_block_size),
chunk_start=chunk_start,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
attn_sum_per_kv.append(attn_sum_kv)
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
simple_mask = find_blocks_chunked(
attn_sum_concat,
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
threshold=THRESHOLD,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True,
)
simple_mask_list.append(simple_mask)
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
False,
)
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_kv = selected_kv / total_api
mask_total = mask_api_valid.numel()
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
mask_diff_pct = 100 * mask_diff / mask_total
return {
"seq_len": seq_len,
"stride": STRIDE,
"threshold": THRESHOLD,
"kv_chunks": kv_chunk_num,
"density_api": density_api,
"density_kv": density_kv,
"density_diff": abs(density_api - density_kv),
"mask_diff_pct": mask_diff_pct,
"passed": abs(density_api - density_kv) < 1e-6 and mask_diff_pct < 0.01,
}
def main():
files = sorted(glob.glob(os.path.join(DATA_DIR, "qkv_*.pt")))
print("=" * 80)
print("XAttention KV Chunking Alignment Test")
print("=" * 80)
print()
results = []
for f in files:
fname = os.path.basename(f)
print(f"Testing {fname}...", end=" ", flush=True)
try:
r = test_single_file(f)
results.append(r)
status = "✓ PASS" if r["passed"] else "✗ FAIL"
print(f"{status} (seq_len={r['seq_len']}, kv_chunks={r['kv_chunks']})")
except Exception as e:
print(f"✗ ERROR: {e}")
results.append({"file": fname, "error": str(e)})
print()
print("=" * 80)
print("Results Summary")
print("=" * 80)
print()
print("| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |")
print("|---------|--------|-----------|-----------|-------------|------------|------|-----------|--------|")
all_passed = True
for r in results:
if "error" in r:
print(f"| ERROR | - | - | - | - | - | - | - | {r['error'][:20]} |")
all_passed = False
else:
status = "PASS" if r["passed"] else "FAIL"
if not r["passed"]:
all_passed = False
print(f"| {r['seq_len']:>7} | {r['stride']:>6} | {r['threshold']:.2f} | {r['kv_chunks']:>9} | "
f"{r['density_api']:.6f} | {r['density_kv']:.6f} | {r['density_diff']:.6f} | "
f"{r['mask_diff_pct']:.4f}% | {status} |")
print()
if all_passed:
print("test_xattn_kv_chunking_batch: ALL PASSED")
else:
print("test_xattn_kv_chunking_batch: SOME FAILED")
if __name__ == "__main__":
main()