diff --git a/tests/bench_estimate_block_size.py b/tests/bench_estimate_block_size.py deleted file mode 100644 index e77fb33..0000000 --- a/tests/bench_estimate_block_size.py +++ /dev/null @@ -1,314 +0,0 @@ -""" -Benchmark: block_size impact on XAttention estimate phase performance. - -This script tests how different block_size values affect the performance of: -1. flat_group_gemm_fuse_reshape (estimate GEMM) -2. softmax_fuse_block_sum (estimate softmax + block aggregation) - -Key insight: The current select_blocks uses global kvcache_block_size for estimation, -which may not be optimal for the Triton kernels. -""" - -import sys -import os -import torch -import time -import math - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum - -# ============================================================ -# Configuration -# ============================================================ - -# Test configurations -BLOCK_SIZES = [64, 128, 256, 512] # BSA optimal is 128 -STRIDE = 8 -NUM_WARMUP = 3 -NUM_RUNS = 10 - -# Model dimensions (Llama-3.1-8B-Instruct) -NUM_HEADS = 32 -NUM_KV_HEADS = 8 -HEAD_DIM = 128 - -# Context lengths to test -CONTEXT_LENGTHS = [16384, 32768, 65536] # 16K, 32K, 64K - -# ============================================================ -# Benchmark Functions -# ============================================================ - -def benchmark_flat_group_gemm(Q, K, stride, block_size, num_warmup=3, num_runs=10): - """ - Benchmark flat_group_gemm_fuse_reshape kernel. - - Args: - Q: [batch, heads, q_len, head_dim] - K: [batch, heads, k_len, head_dim] - stride: Stride for reshape - block_size: Block size (affects alignment requirements) - - Returns: - (avg_time_ms, output_tensor) - """ - q_len = Q.shape[2] - k_len = K.shape[2] - - # Compute reshaped dimensions - reshaped_q_len = q_len // stride - reshaped_k_len = k_len // stride - reshaped_block_size = block_size // stride - - # Warmup - for _ in range(num_warmup): - _ = flat_group_gemm_fuse_reshape( - Q, K, stride, - chunk_start=0, - chunk_end=reshaped_q_len, - is_causal=False, - ) - torch.cuda.synchronize() - - # Benchmark - start = time.perf_counter() - for _ in range(num_runs): - output = flat_group_gemm_fuse_reshape( - Q, K, stride, - chunk_start=0, - chunk_end=reshaped_q_len, - is_causal=False, - ) - torch.cuda.synchronize() - end = time.perf_counter() - - avg_time_ms = (end - start) / num_runs * 1000 - return avg_time_ms, output - - -def benchmark_softmax_fuse_block_sum(attn_weights, reshaped_block_size, num_warmup=3, num_runs=10): - """ - Benchmark softmax_fuse_block_sum kernel. - - Args: - attn_weights: [batch, heads, q_len, k_len] attention weights - reshaped_block_size: Block size in reshaped space - - Returns: - avg_time_ms - """ - batch_size, num_heads, q_len, k_len = attn_weights.shape - head_dim = HEAD_DIM - stride = STRIDE - norm = 1.0 - - # segment_size must divide k_len and be >= reshaped_block_size - segment_size = min(4096, reshaped_block_size) - - # Ensure k_len is divisible by segment_size - if k_len % segment_size != 0: - # Pad k_len - pad_size = segment_size - (k_len % segment_size) - attn_weights = torch.nn.functional.pad(attn_weights, (0, pad_size), value=0) - k_len = attn_weights.shape[3] - - # Scale factor - scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm - - # Warmup - for _ in range(num_warmup): - _ = softmax_fuse_block_sum( - attn_weights, - reshaped_block_size, - segment_size, - chunk_start=0, - chunk_end=q_len, - real_q_len=q_len, - scale=scale, - is_causal=False, - ) - torch.cuda.synchronize() - - # Benchmark - start = time.perf_counter() - for _ in range(num_runs): - output = softmax_fuse_block_sum( - attn_weights, - reshaped_block_size, - segment_size, - chunk_start=0, - chunk_end=q_len, - real_q_len=q_len, - scale=scale, - is_causal=False, - ) - torch.cuda.synchronize() - end = time.perf_counter() - - avg_time_ms = (end - start) / num_runs * 1000 - return avg_time_ms - - -def run_estimate_benchmark(q_len, k_len, block_size, stride=STRIDE): - """ - Run full estimate benchmark for given configuration. - - Args: - q_len: Query length - k_len: Key length (usually same as q_len for current chunk scenario) - block_size: Block size to test - stride: Stride for reshape - - Returns: - dict with timing results - """ - # Create random Q and K tensors - # Shape: [batch, heads, seq_len, head_dim] - Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda") - K = torch.randn(1, NUM_HEADS, k_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda") - - reshaped_block_size = block_size // stride - reshaped_q_len = q_len // stride - reshaped_k_len = k_len // stride - - # Benchmark GEMM - gemm_time, attn_weights = benchmark_flat_group_gemm( - Q, K, stride, block_size, - num_warmup=NUM_WARMUP, num_runs=NUM_RUNS - ) - - # Benchmark softmax + block sum - softmax_time = benchmark_softmax_fuse_block_sum( - attn_weights, reshaped_block_size, - num_warmup=NUM_WARMUP, num_runs=NUM_RUNS - ) - - # Clean up - del Q, K, attn_weights - torch.cuda.empty_cache() - - return { - "q_len": q_len, - "k_len": k_len, - "block_size": block_size, - "reshaped_block_size": reshaped_block_size, - "gemm_time_ms": gemm_time, - "softmax_time_ms": softmax_time, - "total_time_ms": gemm_time + softmax_time, - } - - -# ============================================================ -# Main Benchmark -# ============================================================ - -def main(): - import argparse - parser = argparse.ArgumentParser(description="Benchmark block_size impact on estimate phase") - parser.add_argument("--gpu", type=int, default=0, help="GPU to use") - parser.add_argument("--ctx-len", type=int, default=None, - help="Single context length to test (default: test multiple)") - args = parser.parse_args() - - # Set GPU - torch.cuda.set_device(args.gpu) - device_name = torch.cuda.get_device_name(args.gpu) - print(f"Using GPU {args.gpu}: {device_name}") - print() - - # Determine context lengths to test - if args.ctx_len: - context_lengths = [args.ctx_len] - else: - context_lengths = CONTEXT_LENGTHS - - print("=" * 80) - print("Benchmark: block_size impact on XAttention estimate phase") - print("=" * 80) - print(f"Configuration:") - print(f" NUM_HEADS: {NUM_HEADS}") - print(f" NUM_KV_HEADS: {NUM_KV_HEADS}") - print(f" HEAD_DIM: {HEAD_DIM}") - print(f" STRIDE: {STRIDE}") - print(f" BLOCK_SIZES: {BLOCK_SIZES}") - print(f" NUM_WARMUP: {NUM_WARMUP}") - print(f" NUM_RUNS: {NUM_RUNS}") - print() - - all_results = [] - - for ctx_len in context_lengths: - print(f"\n{'='*80}") - print(f"Context Length: {ctx_len // 1024}K ({ctx_len} tokens)") - print(f"{'='*80}") - - # Pad to alignment - alignment = STRIDE * 128 # Triton BLOCK_M requirement - padded_len = ((ctx_len + alignment - 1) // alignment) * alignment - print(f"Padded to: {padded_len} tokens (alignment={alignment})") - print() - - results = [] - for block_size in BLOCK_SIZES: - print(f"Testing block_size={block_size} (reshaped={block_size // STRIDE})...", end=" ") - try: - result = run_estimate_benchmark(padded_len, padded_len, block_size) - results.append(result) - print(f"GEMM={result['gemm_time_ms']:.2f}ms, " - f"Softmax={result['softmax_time_ms']:.2f}ms, " - f"Total={result['total_time_ms']:.2f}ms") - except Exception as e: - print(f"ERROR: {e}") - import traceback - traceback.print_exc() - - if results: - all_results.extend(results) - - # Print summary table for this context length - print(f"\n--- Summary for {ctx_len // 1024}K context ---") - print(f"{'block_size':>12} {'reshaped':>10} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12} {'Speedup':>10}") - print("-" * 74) - - baseline_total = results[0]["total_time_ms"] - for r in results: - speedup = baseline_total / r["total_time_ms"] - print(f"{r['block_size']:>12} {r['reshaped_block_size']:>10} " - f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} " - f"{r['total_time_ms']:>12.2f} {speedup:>9.2f}x") - - # Final summary across all context lengths - if len(context_lengths) > 1: - print(f"\n{'='*80}") - print("OVERALL SUMMARY") - print(f"{'='*80}") - print(f"{'ctx_len':>10} {'block_size':>12} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12}") - print("-" * 64) - for r in all_results: - print(f"{r['q_len']//1024:>9}K {r['block_size']:>12} " - f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} " - f"{r['total_time_ms']:>12.2f}") - - # Find optimal block_size for softmax - print(f"\n{'='*80}") - print("ANALYSIS: Optimal block_size for softmax_fuse_block_sum") - print(f"{'='*80}") - - for ctx_len in context_lengths: - ctx_results = [r for r in all_results if r["q_len"] == ((ctx_len + STRIDE * 128 - 1) // (STRIDE * 128)) * (STRIDE * 128)] - if ctx_results: - best = min(ctx_results, key=lambda x: x["softmax_time_ms"]) - worst = max(ctx_results, key=lambda x: x["softmax_time_ms"]) - improvement = worst["softmax_time_ms"] / best["softmax_time_ms"] - print(f"Context {ctx_len // 1024}K:") - print(f" Best: block_size={best['block_size']} ({best['softmax_time_ms']:.2f}ms)") - print(f" Worst: block_size={worst['block_size']} ({worst['softmax_time_ms']:.2f}ms)") - print(f" Potential improvement: {improvement:.2f}x") - - print("\nbench_estimate_block_size: DONE") - - -if __name__ == "__main__": - main() diff --git a/tests/modeling_qwen3.py b/tests/modeling_qwen3.py deleted file mode 100644 index 68e1bb3..0000000 --- a/tests/modeling_qwen3.py +++ /dev/null @@ -1,757 +0,0 @@ -""" -Custom Qwen3 implementation using only torch and transformers. -This file provides a clean reference implementation for understanding the model computation graph. - -Computation Graph: -================== - -Input: token_ids [batch, seq_len] - │ - ▼ - ┌─────────────┐ - │ Embedding │ embed_tokens: [vocab_size, hidden_size] - └─────────────┘ - │ - ▼ - hidden_states [batch, seq_len, hidden_size] - │ - ▼ - ┌─────────────────────────────────────────────────────────┐ - │ Decoder Layer (x N) │ - │ ┌───────────────────────────────────────────────────┐ │ - │ │ Self Attention Block │ │ - │ │ │ │ - │ │ input_layernorm (RMSNorm) │ │ - │ │ │ │ │ - │ │ ▼ │ │ - │ │ ┌─────────────────────────────────────────────┐ │ │ - │ │ │ Qwen3Attention │ │ │ - │ │ │ Q = q_proj(x) → q_norm → reshape │ │ │ - │ │ │ K = k_proj(x) → k_norm → reshape │ │ │ - │ │ │ V = v_proj(x) → reshape │ │ │ - │ │ │ │ │ │ │ - │ │ │ ▼ │ │ │ - │ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │ - │ │ │ │ │ │ │ - │ │ │ ▼ │ │ │ - │ │ │ attn_output = attention(Q, K, V) │ │ │ - │ │ │ │ │ │ │ - │ │ │ ▼ │ │ │ - │ │ │ output = o_proj(attn_output) │ │ │ - │ │ └─────────────────────────────────────────────┘ │ │ - │ │ │ │ │ - │ │ ▼ │ │ - │ │ hidden_states = residual + attn_output │ │ - │ └───────────────────────────────────────────────────┘ │ - │ │ │ - │ ▼ │ - │ ┌───────────────────────────────────────────────────┐ │ - │ │ MLP Block │ │ - │ │ │ │ - │ │ post_attention_layernorm (RMSNorm) │ │ - │ │ │ │ │ - │ │ ▼ │ │ - │ │ ┌─────────────────────────────────────────────┐ │ │ - │ │ │ Qwen3MLP │ │ │ - │ │ │ gate = gate_proj(x) │ │ │ - │ │ │ up = up_proj(x) │ │ │ - │ │ │ output = down_proj(silu(gate) * up) │ │ │ - │ │ └─────────────────────────────────────────────┘ │ │ - │ │ │ │ │ - │ │ ▼ │ │ - │ │ hidden_states = residual + mlp_output │ │ - │ └───────────────────────────────────────────────────┘ │ - └─────────────────────────────────────────────────────────┘ - │ - ▼ - ┌─────────────┐ - │ norm │ final RMSNorm - └─────────────┘ - │ - ▼ - ┌─────────────┐ - │ lm_head │ [hidden_size, vocab_size] - └─────────────┘ - │ - ▼ - logits [batch, seq_len, vocab_size] -""" - -import math -from typing import Optional, Tuple, List -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Qwen3RMSNorm(nn.Module): - """RMSNorm implementation.""" - - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.eps = eps - - def forward(self, x: torch.Tensor) -> torch.Tensor: - input_dtype = x.dtype - x = x.float() - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.eps) - return self.weight * x.to(input_dtype) - - -class Qwen3RotaryEmbedding(nn.Module): - """Rotary Position Embedding (RoPE).""" - - def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - - # Compute inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: Input tensor [batch, seq_len, num_heads, head_dim] or similar - position_ids: Position indices [batch, seq_len] - - Returns: - cos, sin: [batch, seq_len, head_dim] - """ - # inv_freq: [dim/2] - # position_ids: [batch, seq_len] - inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1] - position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len] - - # freqs: [batch, dim/2, seq_len] - freqs = inv_freq_expanded @ position_ids_expanded - # freqs: [batch, seq_len, dim/2] - freqs = freqs.transpose(1, 2) - - # Duplicate for full head_dim: [batch, seq_len, dim] - emb = torch.cat((freqs, freqs), dim=-1) - - cos = emb.cos().to(x.dtype) - sin = emb.sin().to(x.dtype) - - return cos, sin - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotate half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embeddings to Q and K. - - Args: - q: [batch, num_heads, seq_len, head_dim] - k: [batch, num_kv_heads, seq_len, head_dim] - cos: [batch, seq_len, head_dim] - sin: [batch, seq_len, head_dim] - - Returns: - q_embed, k_embed with same shapes as inputs - """ - # Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim] - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - return q_embed, k_embed - - -class Qwen3Attention(nn.Module): - """ - Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support. - - Data Flow: - --------- - hidden_states [batch, seq_len, hidden_size] - │ - ├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim] - ├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim] - └──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim] - │ - ▼ - apply_rotary_pos_emb(Q, K) - │ - ▼ - attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim] - │ - ▼ - reshape ──► o_proj ──► output [batch, seq_len, hidden_size] - """ - - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - num_key_value_heads: int, - head_dim: int, - max_position_embeddings: int = 32768, - rope_theta: float = 10000.0, - attention_bias: bool = False, - rms_norm_eps: float = 1e-6, - layer_idx: int = 0, - ): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_attention_heads - self.num_kv_heads = num_key_value_heads - self.head_dim = head_dim - self.num_kv_groups = num_attention_heads // num_key_value_heads - self.layer_idx = layer_idx - - # Scaling factor - self.scaling = head_dim ** -0.5 - - # QKV projections - self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias) - - # QK normalization (Qwen3 specific) - self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps) - self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps) - - # Rotary embeddings - self.rotary_emb = Qwen3RotaryEmbedding( - head_dim, - max_position_embeddings=max_position_embeddings, - base=rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_qkv: bool = False, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]: - """ - Args: - hidden_states: [batch, seq_len, hidden_size] - position_ids: [batch, seq_len] - attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask) - past_key_value: (k_cache, v_cache) from previous steps - use_cache: Whether to return updated cache - output_qkv: Whether to output Q, K, V tensors for debugging - - Returns: - output: [batch, seq_len, hidden_size] - past_key_value: Updated cache (if use_cache=True) - qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True) - """ - batch_size, seq_len, _ = hidden_states.shape - - # === QKV Projections === - q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim] - k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim] - v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim] - - # Reshape to [batch, seq_len, num_heads, head_dim] - q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) - k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) - v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) - - # === QK Normalization (Qwen3 specific) === - q = self.q_norm(q) - k = self.k_norm(k) - - # Transpose to [batch, num_heads, seq_len, head_dim] - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # === Rotary Position Embeddings === - cos, sin = self.rotary_emb(v, position_ids) - q, k = apply_rotary_pos_emb(q, k, cos, sin) - - # === KV Cache Update === - if past_key_value is not None: - k_cache, v_cache = past_key_value - k = torch.cat([k_cache, k], dim=2) - v = torch.cat([v_cache, v], dim=2) - - new_past_key_value = (k, v) if use_cache else None - - # === Grouped Query Attention (expand KV heads if needed) === - if self.num_kv_groups > 1: - # Repeat KV for each query group - k = k.repeat_interleave(self.num_kv_groups, dim=1) - v = v.repeat_interleave(self.num_kv_groups, dim=1) - - # === Attention Computation (using SDPA for memory efficiency) === - # Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend - # is_causal only works when q_len == kv_len (prefill), not during decode - q_len, kv_len = q.shape[2], k.shape[2] - is_causal = (q_len == kv_len) and (q_len > 1) - - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=None, - dropout_p=0.0, - is_causal=is_causal, - scale=self.scaling, - ) # [batch, num_heads, seq_len, head_dim] - - # === Output Projection === - # Transpose back and reshape - attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim] - attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size] - output = self.o_proj(attn_output) - - # Optional QKV output for debugging - qkv_dict = None - if output_qkv: - qkv_dict = { - "q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE) - "k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded) - "v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded) - } - - return output, new_past_key_value, qkv_dict - - -class Qwen3MLP(nn.Module): - """ - Qwen3 MLP with SwiGLU activation. - - Data Flow: - --------- - hidden_states [batch, seq_len, hidden_size] - │ - ├──► gate_proj ──► gate [batch, seq_len, intermediate_size] - │ - └──► up_proj ──► up [batch, seq_len, intermediate_size] - │ - ▼ - silu(gate) * up - │ - ▼ - down_proj ──► output [batch, seq_len, hidden_size] - """ - - def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - gate = self.gate_proj(x) - up = self.up_proj(x) - return self.down_proj(F.silu(gate) * up) - - -class Qwen3DecoderLayer(nn.Module): - """Single Qwen3 Decoder Layer.""" - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_attention_heads: int, - num_key_value_heads: int, - head_dim: int, - max_position_embeddings: int = 32768, - rope_theta: float = 10000.0, - rms_norm_eps: float = 1e-6, - attention_bias: bool = False, - mlp_bias: bool = False, - layer_idx: int = 0, - ): - super().__init__() - self.layer_idx = layer_idx - - # Pre-attention LayerNorm - self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) - - # Self-attention - self.self_attn = Qwen3Attention( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - max_position_embeddings=max_position_embeddings, - rope_theta=rope_theta, - attention_bias=attention_bias, - rms_norm_eps=rms_norm_eps, - layer_idx=layer_idx, - ) - - # Post-attention LayerNorm - self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) - - # MLP - self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_qkv: bool = False, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]: - """ - Args: - hidden_states: [batch, seq_len, hidden_size] - position_ids: [batch, seq_len] - attention_mask: Causal attention mask - past_key_value: KV cache for this layer - use_cache: Whether to return updated cache - output_qkv: Whether to output Q, K, V for debugging - - Returns: - hidden_states: [batch, seq_len, hidden_size] - past_key_value: Updated cache - qkv_dict: QKV tensors (if output_qkv=True) - """ - # === Self Attention Block === - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - attn_output, new_past_key_value, qkv_dict = self.self_attn( - hidden_states=hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_qkv=output_qkv, - ) - - hidden_states = residual + attn_output - - # === MLP Block === - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states, new_past_key_value, qkv_dict - - -class Qwen3Model(nn.Module): - """Qwen3 Transformer Model (without LM head).""" - - def __init__( - self, - vocab_size: int, - hidden_size: int, - intermediate_size: int, - num_hidden_layers: int, - num_attention_heads: int, - num_key_value_heads: int, - head_dim: int, - max_position_embeddings: int = 32768, - rope_theta: float = 10000.0, - rms_norm_eps: float = 1e-6, - attention_bias: bool = False, - mlp_bias: bool = False, - ): - super().__init__() - self.vocab_size = vocab_size - self.num_hidden_layers = num_hidden_layers - - # Token embeddings - self.embed_tokens = nn.Embedding(vocab_size, hidden_size) - - # Decoder layers - self.layers = nn.ModuleList([ - Qwen3DecoderLayer( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - max_position_embeddings=max_position_embeddings, - rope_theta=rope_theta, - rms_norm_eps=rms_norm_eps, - attention_bias=attention_bias, - mlp_bias=mlp_bias, - layer_idx=i, - ) - for i in range(num_hidden_layers) - ]) - - # Final LayerNorm - self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, - use_cache: bool = False, - output_qkv_layers: Optional[List[int]] = None, - ) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]: - """ - Args: - input_ids: [batch, seq_len] - position_ids: [batch, seq_len] - attention_mask: [batch, seq_len] or pre-computed 4D mask - past_key_values: List of (k, v) tuples for each layer - use_cache: Whether to return new cache - output_qkv_layers: List of layer indices to output QKV for - - Returns: - hidden_states: [batch, seq_len, hidden_size] - new_past_key_values: Updated cache - qkv_outputs: {layer_idx: qkv_dict} - """ - batch_size, seq_len = input_ids.shape - - # Embedding - hidden_states = self.embed_tokens(input_ids) - - # Position IDs - if position_ids is None: - past_len = past_key_values[0][0].shape[2] if past_key_values else 0 - position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) - - # Attention mask (create causal mask if not provided) - if attention_mask is None or attention_mask.dim() == 2: - kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0) - causal_mask = torch.triu( - torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device), - diagonal=kv_seq_len - seq_len + 1, - ) - attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len] - - # Initialize cache list - new_past_key_values = [] if use_cache else None - qkv_outputs = {} if output_qkv_layers else None - - # Decoder layers - for i, layer in enumerate(self.layers): - past_kv = past_key_values[i] if past_key_values else None - output_qkv = output_qkv_layers is not None and i in output_qkv_layers - - hidden_states, new_kv, qkv_dict = layer( - hidden_states=hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_value=past_kv, - use_cache=use_cache, - output_qkv=output_qkv, - ) - - if use_cache: - new_past_key_values.append(new_kv) - if qkv_dict is not None: - qkv_outputs[i] = qkv_dict - - # Final norm - hidden_states = self.norm(hidden_states) - - return hidden_states, new_past_key_values, qkv_outputs - - -class Qwen3ForCausalLM(nn.Module): - """Qwen3 Model with Language Modeling head.""" - - def __init__( - self, - vocab_size: int, - hidden_size: int, - intermediate_size: int, - num_hidden_layers: int, - num_attention_heads: int, - num_key_value_heads: int, - head_dim: int, - max_position_embeddings: int = 32768, - rope_theta: float = 10000.0, - rms_norm_eps: float = 1e-6, - attention_bias: bool = False, - mlp_bias: bool = False, - tie_word_embeddings: bool = True, - ): - super().__init__() - self.vocab_size = vocab_size - self.tie_word_embeddings = tie_word_embeddings - - # Transformer model - self.model = Qwen3Model( - vocab_size=vocab_size, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - max_position_embeddings=max_position_embeddings, - rope_theta=rope_theta, - rms_norm_eps=rms_norm_eps, - attention_bias=attention_bias, - mlp_bias=mlp_bias, - ) - - # LM head - self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, - use_cache: bool = False, - output_qkv_layers: Optional[List[int]] = None, - ) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]: - """ - Args: - input_ids: [batch, seq_len] - ... (same as Qwen3Model) - - Returns: - logits: [batch, seq_len, vocab_size] - past_key_values: Updated KV cache - qkv_outputs: QKV tensors for specified layers - """ - hidden_states, new_past_key_values, qkv_outputs = self.model( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_qkv_layers=output_qkv_layers, - ) - - logits = self.lm_head(hidden_states) - - return logits, new_past_key_values, qkv_outputs - - @classmethod - def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM": - """ - Load weights from a pretrained Qwen3 model. - - Args: - model_path: Path to model directory containing config.json and model weights - dtype: Data type for model weights - - Returns: - Initialized Qwen3ForCausalLM model - """ - import json - import os - from safetensors.torch import load_file - - # Load config - config_path = os.path.join(model_path, "config.json") - with open(config_path) as f: - config = json.load(f) - - # Create model - model = cls( - vocab_size=config["vocab_size"], - hidden_size=config["hidden_size"], - intermediate_size=config["intermediate_size"], - num_hidden_layers=config["num_hidden_layers"], - num_attention_heads=config["num_attention_heads"], - num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]), - head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]), - max_position_embeddings=config.get("max_position_embeddings", 32768), - rope_theta=config.get("rope_theta", 10000.0), - rms_norm_eps=config.get("rms_norm_eps", 1e-6), - attention_bias=config.get("attention_bias", False), - mlp_bias=config.get("mlp_bias", False), - tie_word_embeddings=config.get("tie_word_embeddings", True), - ) - - # Load weights - weight_files = sorted([ - f for f in os.listdir(model_path) - if f.endswith(".safetensors") - ]) - - state_dict = {} - for wf in weight_files: - state_dict.update(load_file(os.path.join(model_path, wf))) - - # Load into model - model.load_state_dict(state_dict, strict=False) - - # Tie lm_head weights to embed_tokens if configured - if model.tie_word_embeddings: - model.lm_head.weight = model.model.embed_tokens.weight - - model = model.to(dtype) - - return model - - @torch.no_grad() - def generate( - self, - input_ids: torch.Tensor, - max_new_tokens: int = 32, - temperature: float = 1.0, - do_sample: bool = True, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - ) -> torch.Tensor: - """Simple autoregressive generation.""" - device = input_ids.device - batch_size, seq_len = input_ids.shape - past_key_values = None - generated = input_ids.clone() - - for _ in range(max_new_tokens): - if past_key_values is None: - current_input = generated - else: - current_input = generated[:, -1:] - - logits, past_key_values, _ = self( - input_ids=current_input, - past_key_values=past_key_values, - use_cache=True, - ) - - next_token_logits = logits[:, -1, :] - if temperature > 0 and do_sample: - next_token_logits = next_token_logits / temperature - probs = torch.softmax(next_token_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - else: - next_token = next_token_logits.argmax(dim=-1, keepdim=True) - - generated = torch.cat([generated, next_token], dim=1) - - if eos_token_id is not None and (next_token == eos_token_id).all(): - break - - return generated - - -def print_computation_graph(): - """Print the computation graph for reference.""" - print(__doc__) - - -if __name__ == "__main__": - print_computation_graph() diff --git a/tests/test_chunk_attention_graph_reuse.py b/tests/test_chunk_attention_graph_reuse.py deleted file mode 100644 index a2afb29..0000000 --- a/tests/test_chunk_attention_graph_reuse.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -""" -Test: Reuse a single CUDA Graph across all layers and all chunk pairs. - -Key insight: LLM layers have identical computation structure. -We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations. - -Usage: - CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py -""" - -from dataclasses import dataclass - -import torch - -from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs - - -@dataclass -class ReusableChunkGraph: - """A single graph that can be reused with copy_() updates.""" - graph: torch.cuda.CUDAGraph - static_q: torch.Tensor - static_k: torch.Tensor - static_v: torch.Tensor - static_output: torch.Tensor - static_lse: torch.Tensor - - -def capture_reusable_graph( - chunk_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - scale: float, - device: torch.device, - dtype: torch.dtype, - causal: bool, -) -> ReusableChunkGraph: - """Capture ONE graph to be reused for all chunk pairs.""" - static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device) - static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device) - static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device) - - static_q.normal_() - static_k.normal_() - static_v.normal_() - - # Warmup - with torch.inference_mode(): - for _ in range(3): - _ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal) - torch.cuda.synchronize() - - # Capture - graph = torch.cuda.CUDAGraph() - with torch.inference_mode(): - with torch.cuda.graph(graph): - static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal) - - torch.cuda.synchronize() - - return ReusableChunkGraph( - graph=graph, - static_q=static_q, - static_k=static_k, - static_v=static_v, - static_output=static_output, - static_lse=static_lse, - ) - - -def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - """Replay graph after updating static tensors with copy_().""" - graph.static_q.copy_(q) - graph.static_k.copy_(k) - graph.static_v.copy_(v) - graph.graph.replay() - return graph.static_output.clone(), graph.static_lse.clone() - - -def main(): - device = torch.device("cuda") - dtype = torch.bfloat16 - - chunk_size = 64 - num_chunks = 4 - num_layers = 3 # Simulate multiple layers - num_heads = 8 - num_kv_heads = 8 - head_dim = 64 - scale = 1.0 / (head_dim ** 0.5) - seq_len = chunk_size * num_chunks - - print(f"Device: {torch.cuda.get_device_name()}") - print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}") - print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations") - - # Capture only 2 graphs - graph_causal = capture_reusable_graph( - chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True - ) - graph_non_causal = capture_reusable_graph( - chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False - ) - print("2 graphs captured (causal + non-causal)") - - all_pass = True - - for layer_id in range(num_layers): - # Different Q/K/V for each layer (simulating different layer outputs) - full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device) - full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device) - full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device) - - # Reference: full causal attention - with torch.inference_mode(): - full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True) - - # Chunked with graph reuse - chunked_output = torch.zeros_like(full_output) - - for q_idx in range(num_chunks): - q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size] - acc_out, acc_lse = None, None - - for k_idx in range(q_idx + 1): - k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size] - v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size] - - # Reuse graph with copy_() - graph = graph_causal if k_idx == q_idx else graph_non_causal - out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk) - - if acc_out is None: - acc_out, acc_lse = out, lse - else: - with torch.inference_mode(): - acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse) - - chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out - - torch.cuda.synchronize() - - # Compare - max_diff = (full_output - chunked_output).abs().max().item() - status = "✅" if max_diff < 1e-2 else "❌" - print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}") - if max_diff >= 1e-2: - all_pass = False - - print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED") - - -if __name__ == "__main__": - main() diff --git a/tests/test_cudagraph_memory.py b/tests/test_cudagraph_memory.py deleted file mode 100644 index fdf8d1e..0000000 --- a/tests/test_cudagraph_memory.py +++ /dev/null @@ -1,357 +0,0 @@ -#!/usr/bin/env python3 -""" -CUDA Graph Memory Analysis Test - -This script analyzes the memory overhead of CUDA Graph at each stage: -1. Model loading -2. StaticCache allocation -3. Warmup runs -4. Graph capture -5. Graph replay - -Usage: - CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py - CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B - CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048 -""" - -import argparse -import os -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.cache_utils import StaticCache - - -def get_memory_mb(): - """Get current allocated memory in MB.""" - return torch.cuda.memory_allocated() / 1024**2 - - -def get_memory_gb(): - """Get current allocated memory in GB.""" - return torch.cuda.memory_allocated() / 1024**3 - - -def get_peak_memory_gb(): - """Get peak allocated memory in GB.""" - return torch.cuda.max_memory_allocated() / 1024**3 - - -def print_separator(title=None): - """Print a separator line.""" - if title: - print(f"\n{'=' * 70}") - print(f" {title}") - print(f"{'=' * 70}") - else: - print("-" * 70) - - -def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1): - """ - Test memory usage at each stage of CUDA Graph setup. - - Args: - model_path: Path to the model - max_cache_len: Maximum cache length for StaticCache - batch_size: Batch size for inference - """ - print_separator("CUDA Graph Memory Analysis") - print(f"Model: {model_path}") - print(f"Max cache length: {max_cache_len}") - print(f"Batch size: {batch_size}") - - results = {} - - # Stage 0: Initial - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - results["initial"] = get_memory_mb() - - # Stage 1: Load model - print_separator("Stage 1: Model Loading") - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, - device_map="cuda", - trust_remote_code=True, - ) - model.eval() - - results["after_model"] = get_memory_mb() - model_size = results["after_model"] - results["initial"] - print(f" Memory: {results['after_model']:.0f} MB") - print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)") - - config = model.config - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - - # Stage 2: Allocate StaticCache - print_separator("Stage 2: StaticCache Allocation") - torch.cuda.reset_peak_memory_stats() - before = get_memory_mb() - - static_cache = StaticCache( - config=config, - max_batch_size=batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - ) - - results["after_cache"] = get_memory_mb() - cache_size = results["after_cache"] - before - print(f" Memory: {results['after_cache']:.0f} MB") - print(f" StaticCache size: {cache_size:.0f} MB") - - # Calculate theoretical cache size - num_layers = config.num_hidden_layers - num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = config.hidden_size // config.num_attention_heads - dtype_size = 2 # bfloat16 - - theoretical_cache = ( - num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size - ) / (1024**2) - print(f" Theoretical: {theoretical_cache:.0f} MB") - print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)") - - # Stage 3: Prepare static tensors - print_separator("Stage 3: Static Tensor Allocation") - before = get_memory_mb() - - static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device) - static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device) - static_cache_position = torch.tensor([0], dtype=torch.long, device=device) - - results["after_tensors"] = get_memory_mb() - tensor_size = results["after_tensors"] - before - print(f" Memory: {results['after_tensors']:.0f} MB") - print(f" Static tensors: {tensor_size:.2f} MB (negligible)") - - # Stage 4: Warmup runs - print_separator("Stage 4: Warmup Runs (3 iterations)") - torch.cuda.reset_peak_memory_stats() - before = get_memory_mb() - - with torch.inference_mode(): - for i in range(3): - _ = model( - input_ids=static_input_ids, - position_ids=static_position_ids, - past_key_values=static_cache, - cache_position=static_cache_position, - use_cache=True, - ) - torch.cuda.synchronize() - - results["after_warmup"] = get_memory_mb() - results["warmup_peak"] = get_peak_memory_gb() * 1024 - warmup_size = results["after_warmup"] - before - print(f" Memory: {results['after_warmup']:.0f} MB") - print(f" Peak: {results['warmup_peak']:.0f} MB") - print(f" Warmup overhead: {warmup_size:.0f} MB") - - # Stage 5: CUDA Graph capture - print_separator("Stage 5: CUDA Graph Capture") - torch.cuda.reset_peak_memory_stats() - before = get_memory_mb() - - graph = torch.cuda.CUDAGraph() - with torch.inference_mode(): - with torch.cuda.graph(graph): - outputs = model( - input_ids=static_input_ids, - position_ids=static_position_ids, - past_key_values=static_cache, - cache_position=static_cache_position, - use_cache=True, - ) - static_logits = outputs.logits - torch.cuda.synchronize() - - results["after_capture"] = get_memory_mb() - results["capture_peak"] = get_peak_memory_gb() * 1024 - capture_size = results["after_capture"] - before - print(f" Memory: {results['after_capture']:.0f} MB") - print(f" Peak: {results['capture_peak']:.0f} MB") - print(f" Graph capture overhead: {capture_size:.0f} MB") - - # Stage 6: Graph replay - print_separator("Stage 6: Graph Replay (10 iterations)") - torch.cuda.reset_peak_memory_stats() - before = get_memory_mb() - - with torch.inference_mode(): - for _ in range(10): - static_input_ids.fill_(1) - static_cache_position.fill_(0) - graph.replay() - torch.cuda.synchronize() - - results["after_replay"] = get_memory_mb() - results["replay_peak"] = get_peak_memory_gb() * 1024 - replay_change = results["after_replay"] - before - print(f" Memory: {results['after_replay']:.0f} MB") - print(f" Peak: {results['replay_peak']:.0f} MB") - print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)") - - # Summary - print_separator("SUMMARY") - total_overhead = results["after_capture"] - results["after_model"] - - print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}") - print("-" * 50) - print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}") - print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}") - print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}") - print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}") - print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}") - print("-" * 50) - print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}") - - print_separator("KEY FINDINGS") - print(f" 1. Model size: {model_size/1024:.2f} GB") - print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)") - print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)") - print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)") - print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB") - - return results - - -def test_cache_length_scaling(model_path: str, cache_lengths: list): - """ - Test how memory scales with different cache lengths. - - Args: - model_path: Path to the model - cache_lengths: List of cache lengths to test - """ - print_separator("Cache Length Scaling Test") - print(f"Model: {model_path}") - print(f"Cache lengths: {cache_lengths}") - - # Load model once - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, - device_map="cuda", - trust_remote_code=True, - ) - model.eval() - - config = model.config - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - - model_mem = get_memory_mb() - - results = [] - for cache_len in cache_lengths: - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - # Create cache and capture graph - static_cache = StaticCache( - config=config, - max_batch_size=1, - max_cache_len=cache_len, - device=device, - dtype=dtype, - ) - - static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device) - static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device) - static_cache_position = torch.tensor([0], dtype=torch.long, device=device) - - with torch.inference_mode(): - # Warmup - for _ in range(3): - _ = model( - input_ids=static_input_ids, - position_ids=static_position_ids, - past_key_values=static_cache, - cache_position=static_cache_position, - use_cache=True, - ) - torch.cuda.synchronize() - - # Capture - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - outputs = model( - input_ids=static_input_ids, - position_ids=static_position_ids, - past_key_values=static_cache, - cache_position=static_cache_position, - use_cache=True, - ) - torch.cuda.synchronize() - - total_mem = get_memory_mb() - overhead = total_mem - model_mem - results.append((cache_len, total_mem, overhead)) - - del static_cache, graph - torch.cuda.empty_cache() - - # Print results - print() - print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}") - print("-" * 60) - for cache_len, total, overhead in results: - per_1k = overhead / (cache_len / 1000) - print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}") - - return results - - -def main(): - parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis") - parser.add_argument( - "--model", - type=str, - default="~/models/Qwen3-4B-Instruct-2507", - help="Model path", - ) - parser.add_argument( - "--max-cache-len", - type=int, - default=1024, - help="Maximum cache length", - ) - parser.add_argument( - "--batch-size", - type=int, - default=1, - help="Batch size", - ) - parser.add_argument( - "--test-scaling", - action="store_true", - help="Test cache length scaling", - ) - args = parser.parse_args() - - model_path = os.path.expanduser(args.model) - - if not torch.cuda.is_available(): - print("CUDA is not available!") - return - - print(f"Device: cuda:{torch.cuda.current_device()}") - print(f"GPU: {torch.cuda.get_device_name()}") - - if args.test_scaling: - cache_lengths = [256, 512, 1024, 2048, 4096] - test_cache_length_scaling(model_path, cache_lengths) - else: - test_memory_stages(model_path, args.max_cache_len, args.batch_size) - - print("\ntest_cudagraph_memory: PASSED") - - -if __name__ == "__main__": - main() diff --git a/tests/test_gpuonly_density_alignment.py b/tests/test_gpuonly_density_alignment.py deleted file mode 100644 index d0201fe..0000000 --- a/tests/test_gpuonly_density_alignment.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -Test: GPU-only density alignment verification - -验证 xattn_bsa.py 中 GPU-only 路径的 density 计算是否与独立调用 xattn_estimate 一致。 - -流程: -1. 运行 GPU-only 推理,保存 Q, K, mask, attn_sums -2. 加载保存的数据,独立调用 xattn_estimate -3. 比较两者的 mask 和 density -""" -import torch -import sys -sys.path.insert(0, "/home/zijie/Code/nano-vllm") - -from nanovllm.ops.xattn import xattn_estimate - -# ============================================================ -# 参数配置 -# ============================================================ - -DATA_PATH = "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt" - -# ============================================================ -# 加载保存的数据 -# ============================================================ - -print(f"Loading data from {DATA_PATH}") -data = torch.load(DATA_PATH, weights_only=False) - -Q = data["Q"].cuda() # [1, num_heads, q_len, head_dim] -K = data["K"].cuda() # [1, num_heads, k_len, head_dim] -chunk_size = data["chunk_size"] -block_size = data["block_size"] -stride = data["stride"] -threshold = data["threshold"] -mask_saved = data["mask"] # [1, num_heads, valid_q_blocks, valid_k_blocks] -attn_sums_saved = data["attn_sums"] -q_len = data["q_len"] -k_len = data["k_len"] -valid_q_blocks = data["valid_q_blocks"] -valid_k_blocks = data["valid_k_blocks"] - -print(f"Q shape: {Q.shape}") -print(f"K shape: {K.shape}") -print(f"q_len: {q_len}, k_len: {k_len}") -print(f"chunk_size: {chunk_size}, block_size: {block_size}, stride: {stride}, threshold: {threshold}") -print(f"valid_q_blocks: {valid_q_blocks}, valid_k_blocks: {valid_k_blocks}") -print(f"mask_saved shape: {mask_saved.shape}") - -# ============================================================ -# 独立调用 xattn_estimate -# ============================================================ - -print("\nCalling xattn_estimate independently...") -attn_sums_ext, mask_ext = xattn_estimate( - Q, K, - chunk_size=chunk_size, - block_size=block_size, - stride=stride, - threshold=threshold, - use_triton=True, - causal=True, -) - -# Trim to valid blocks -mask_ext_valid = mask_ext[:, :, :valid_q_blocks, :valid_k_blocks] -attn_sums_ext_valid = attn_sums_ext[:, :, :valid_q_blocks, :valid_k_blocks] - -print(f"mask_ext shape: {mask_ext.shape}") -print(f"mask_ext_valid shape: {mask_ext_valid.shape}") - -# ============================================================ -# 比较 attn_sums -# ============================================================ - -print("\n" + "=" * 60) -print("Comparing attn_sums") -print("=" * 60) - -attn_sums_saved_gpu = attn_sums_saved.cuda() -attn_diff = (attn_sums_ext_valid - attn_sums_saved_gpu).abs() -print(f"attn_sums max diff: {attn_diff.max().item():.6e}") -print(f"attn_sums mean diff: {attn_diff.mean().item():.6e}") - -# Check if attn_sums match -attn_match = attn_diff.max().item() < 1e-4 -print(f"attn_sums match: {attn_match}") - -# ============================================================ -# 比较 mask -# ============================================================ - -print("\n" + "=" * 60) -print("Comparing mask") -print("=" * 60) - -mask_saved_gpu = mask_saved.cuda() -mask_match = (mask_ext_valid == mask_saved_gpu).all().item() -print(f"mask exact match: {mask_match}") - -if not mask_match: - diff_count = (mask_ext_valid != mask_saved_gpu).sum().item() - total_count = mask_ext_valid.numel() - print(f"mask diff count: {diff_count} / {total_count} ({diff_count/total_count*100:.2f}%)") - -# ============================================================ -# 计算 density -# ============================================================ - -print("\n" + "=" * 60) -print("Comparing density") -print("=" * 60) - -# 计算 causal mask -q_offset_blocks = valid_k_blocks - valid_q_blocks -indices = torch.arange(valid_k_blocks, device=mask_ext_valid.device).unsqueeze(0) -q_indices = torch.arange(valid_q_blocks, device=mask_ext_valid.device).unsqueeze(1) -causal_mask = indices <= (q_indices + q_offset_blocks) - -# Density from saved mask -total_saved = causal_mask.sum().item() * mask_saved_gpu.shape[0] * mask_saved_gpu.shape[1] -selected_saved = (mask_saved_gpu & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() -density_saved = selected_saved / total_saved - -# Density from external xattn_estimate -total_ext = causal_mask.sum().item() * mask_ext_valid.shape[0] * mask_ext_valid.shape[1] -selected_ext = (mask_ext_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() -density_ext = selected_ext / total_ext - -print(f"Saved density: {density_saved:.6f} (selected={selected_saved}, total={total_saved})") -print(f"External density: {density_ext:.6f} (selected={selected_ext}, total={total_ext})") -print(f"Density diff: {abs(density_saved - density_ext):.6f}") - -# ============================================================ -# 结论 -# ============================================================ - -print("\n" + "=" * 60) -print("RESULT") -print("=" * 60) - -if attn_match and mask_match: - print("✅ PASSED: GPU-only density matches external xattn_estimate") -else: - print("❌ FAILED: Mismatch detected") - if not attn_match: - print(" - attn_sums mismatch") - if not mask_match: - print(" - mask mismatch") diff --git a/tests/test_hierarchical_estimate.py b/tests/test_hierarchical_estimate.py deleted file mode 100644 index 1edd4f8..0000000 --- a/tests/test_hierarchical_estimate.py +++ /dev/null @@ -1,442 +0,0 @@ -""" -Test: Hierarchical Block Sum Estimation for XAttention - -Verify that hierarchical estimation (small estimate_block_size + aggregation) -produces equivalent results to direct estimation (large block_size), while -being significantly faster. - -Key changes validated: -1. Hierarchical block sum: estimate_block_size=1024 → aggregate to cpu_block_size=4096 -2. Selection strategy: score + threshold (NOT mask + majority voting) - -This test uses pure torch + xattn kernels, independent of nanovllm framework. -""" - -import sys -import os -import torch -import math - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum - -# ============================================================ -# Configuration -# ============================================================ - -# Model dimensions (Llama-3.1-8B-Instruct style) -NUM_HEADS = 32 -NUM_KV_HEADS = 8 -HEAD_DIM = 128 -STRIDE = 8 - -# Block sizes -CPU_BLOCK_SIZE = 4096 # External CPU block size (fixed, for overlap) -ESTIMATE_BLOCK_SIZE = 1024 # Internal estimate block size (optimized) - -# Selection parameters -THRESHOLD = 0.95 # Cumulative attention threshold - -# ============================================================ -# Hierarchical Estimation Implementation -# ============================================================ - -def compute_attention_scores(Q, K_blocks, stride): - """ - Compute attention scores for Q against multiple K blocks. - - Args: - Q: [1, num_heads, q_len, head_dim] - K_blocks: List of K tensors, each [1, num_heads, block_size, head_dim] - stride: Stride for reshape - - Returns: - attn_scores: [1, num_heads, q_reshaped, total_k_reshaped] - """ - q_len = Q.shape[2] - q_reshaped = q_len // stride - - attn_chunks = [] - for K_block in K_blocks: - # flat_group_gemm_fuse_reshape - attn_chunk = flat_group_gemm_fuse_reshape( - Q, K_block, stride, - chunk_start=0, - chunk_end=q_reshaped, - is_causal=False, - ) - attn_chunks.append(attn_chunk) - - # Concatenate along K dimension - attn_scores = torch.cat(attn_chunks, dim=-1) - return attn_scores - - -def hierarchical_block_sum( - attn_scores, - estimate_block_size, - cpu_block_size, - stride, - head_dim, -): - """ - Compute hierarchical block sums: fine-grained → aggregated to CPU block level. - - Args: - attn_scores: [batch, heads, q_reshaped, k_reshaped] - estimate_block_size: Small block size for efficient softmax (e.g., 1024) - cpu_block_size: External CPU block size (e.g., 4096) - stride: Stride used in reshape - head_dim: Head dimension for scale computation - - Returns: - cpu_block_scores: [batch, heads, num_cpu_blocks] - attention score per CPU block - """ - batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape - - # Compute reshaped block sizes - reshaped_est_bs = estimate_block_size // stride # 1024/8 = 128 - reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512 - - # Scale factor - norm = 1.0 - scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm - - # Segment size for softmax kernel - segment_size = min(4096, reshaped_est_bs) - - # Step 1: Fine-grained softmax + block sum - block_sums_fine = softmax_fuse_block_sum( - attn_scores, - reshaped_est_bs, - segment_size, - chunk_start=0, - chunk_end=q_reshaped, - real_q_len=q_reshaped, - scale=scale, - is_causal=False, - ) - # block_sums_fine: [batch, heads, q_est_blocks, k_est_blocks] - - q_est_blocks = block_sums_fine.shape[2] - k_est_blocks = block_sums_fine.shape[3] - - # Step 2: Aggregate to CPU block level - # ratio = cpu_block_size / estimate_block_size = 4 - ratio = cpu_block_size // estimate_block_size - num_cpu_blocks = k_est_blocks // ratio - - # Reshape and sum along K dimension - # [batch, heads, q_est, k_est] → [batch, heads, q_est, num_cpu, ratio] - block_sums_coarse = block_sums_fine.view( - batch_size, num_heads, q_est_blocks, num_cpu_blocks, ratio - ).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks] - - # Step 3: Sum over Q dimension (total attention from Q chunk to each K block) - cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks] - - return cpu_block_scores, block_sums_fine - - -def direct_block_sum( - attn_scores, - cpu_block_size, - stride, - head_dim, -): - """ - Compute block sums directly with CPU block size (baseline for comparison). - - Args: - attn_scores: [batch, heads, q_reshaped, k_reshaped] - cpu_block_size: Block size (e.g., 4096) - stride: Stride used in reshape - head_dim: Head dimension for scale computation - - Returns: - cpu_block_scores: [batch, heads, num_cpu_blocks] - """ - batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape - - reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512 - - norm = 1.0 - scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm - segment_size = min(4096, reshaped_cpu_bs) - - block_sums = softmax_fuse_block_sum( - attn_scores, - reshaped_cpu_bs, - segment_size, - chunk_start=0, - chunk_end=q_reshaped, - real_q_len=q_reshaped, - scale=scale, - is_causal=False, - ) - # block_sums: [batch, heads, q_cpu_blocks, k_cpu_blocks] - - # Sum over Q dimension - cpu_block_scores = block_sums.sum(dim=2) # [batch, heads, num_cpu_blocks] - - return cpu_block_scores - - -def select_blocks_by_score( - cpu_block_scores, - threshold=0.95, - always_include_first=True, - always_include_last=True, -): - """ - Select CPU blocks based on score + threshold. - - ⚠️ IMPORTANT: This replaces the original mask + majority voting strategy. - This change should be documented in the final implementation. - - Args: - cpu_block_scores: [batch, heads, num_cpu_blocks] - threshold: Cumulative attention threshold (e.g., 0.95) - always_include_first: Always include first block (sink) - always_include_last: Always include last block (safety) - - Returns: - selected_block_ids: List of selected block indices - density: Fraction of blocks selected - """ - # Average scores across heads - scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks] - num_blocks = scores_per_block.shape[0] - - # Normalize to get attention distribution - total_score = scores_per_block.sum() - score_ratio = scores_per_block / total_score - - # Sort by score (descending) - sorted_indices = torch.argsort(score_ratio, descending=True) - - # Select blocks until cumulative threshold is reached - cumsum = 0.0 - selected = set() - - for idx in sorted_indices.tolist(): - selected.add(idx) - cumsum += score_ratio[idx].item() - if cumsum >= threshold: - break - - # Always include first and last blocks - if always_include_first: - selected.add(0) - if always_include_last: - selected.add(num_blocks - 1) - - selected_block_ids = sorted(list(selected)) - density = len(selected_block_ids) / num_blocks - - return selected_block_ids, density - - -# ============================================================ -# Test Cases -# ============================================================ - -def test_equivalence(): - """ - Test that hierarchical estimation produces equivalent scores to direct estimation. - """ - print("=" * 60) - print("Test 1: Hierarchical vs Direct - Equivalence") - print("=" * 60) - - # Create random Q and multiple K blocks - q_len = CPU_BLOCK_SIZE # 4096 - num_k_blocks = 4 - - # Q: [1, num_heads, q_len, head_dim] - Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda') - - # K blocks: each [1, num_heads, cpu_block_size, head_dim] - K_blocks = [ - torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda') - for _ in range(num_k_blocks) - ] - - # Compute attention scores - attn_scores = compute_attention_scores(Q, K_blocks, STRIDE) - print(f"attn_scores shape: {attn_scores.shape}") - - # Method 1: Hierarchical (fast) - scores_hier, _ = hierarchical_block_sum( - attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM - ) - print(f"scores_hier shape: {scores_hier.shape}") - - # Method 2: Direct (slow) - scores_direct = direct_block_sum( - attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM - ) - print(f"scores_direct shape: {scores_direct.shape}") - - # Compare - diff = (scores_hier - scores_direct).abs().max().item() - print(f"\nMax difference: {diff:.6f}") - - # Per-block comparison - print("\nPer-block scores comparison:") - for i in range(num_k_blocks): - h_val = scores_hier[0, 0, i].item() - d_val = scores_direct[0, 0, i].item() - print(f" Block {i}: hierarchical={h_val:.4f}, direct={d_val:.4f}, diff={abs(h_val-d_val):.6f}") - - passed = diff < 0.01 - print(f"\nTest 1: {'PASSED' if passed else 'FAILED'}") - return passed - - -def test_selection(): - """ - Test the score + threshold selection strategy. - """ - print("\n" + "=" * 60) - print("Test 2: Score + Threshold Selection") - print("=" * 60) - - # Create Q and K blocks with varying importance - q_len = CPU_BLOCK_SIZE - num_k_blocks = 8 - - Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda') - - # Create K blocks - make some more important than others - K_blocks = [] - for i in range(num_k_blocks): - # First and middle blocks are more important (higher values) - importance = 2.0 if i in [0, 3, 4] else 1.0 - K = torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda') - K = K * importance - K_blocks.append(K) - - # Compute scores - attn_scores = compute_attention_scores(Q, K_blocks, STRIDE) - scores, _ = hierarchical_block_sum( - attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM - ) - - # Print scores per block - print("\nCPU block scores (head 0):") - for i in range(num_k_blocks): - print(f" Block {i}: {scores[0, 0, i].item():.4f}") - - # Select blocks with different thresholds - for thresh in [0.9, 0.95, 0.99]: - selected, density = select_blocks_by_score(scores, threshold=thresh) - print(f"\nThreshold {thresh}: selected {len(selected)}/{num_k_blocks} blocks ({density:.1%})") - print(f" Selected: {selected}") - - print("\nTest 2: PASSED (visual inspection)") - return True - - -def test_performance(): - """ - Benchmark hierarchical vs direct estimation performance. - """ - print("\n" + "=" * 60) - print("Test 3: Performance Benchmark") - print("=" * 60) - - import time - - NUM_WARMUP = 3 - NUM_RUNS = 10 - - # Larger test case - q_len = CPU_BLOCK_SIZE - num_k_blocks = 16 # 64K context - - Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda') - K_blocks = [ - torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda') - for _ in range(num_k_blocks) - ] - - # Compute attention scores (shared) - attn_scores = compute_attention_scores(Q, K_blocks, STRIDE) - print(f"attn_scores shape: {attn_scores.shape}") - print(f"Context: {num_k_blocks * CPU_BLOCK_SIZE // 1024}K tokens") - - # Warmup and benchmark hierarchical - for _ in range(NUM_WARMUP): - _ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM) - torch.cuda.synchronize() - - start = time.perf_counter() - for _ in range(NUM_RUNS): - _ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM) - torch.cuda.synchronize() - hier_time = (time.perf_counter() - start) / NUM_RUNS * 1000 - - # Warmup and benchmark direct - for _ in range(NUM_WARMUP): - _ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM) - torch.cuda.synchronize() - - start = time.perf_counter() - for _ in range(NUM_RUNS): - _ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM) - torch.cuda.synchronize() - direct_time = (time.perf_counter() - start) / NUM_RUNS * 1000 - - speedup = direct_time / hier_time - - print(f"\nResults:") - print(f" Hierarchical (bs=1024): {hier_time:.2f} ms") - print(f" Direct (bs=4096): {direct_time:.2f} ms") - print(f" Speedup: {speedup:.2f}x") - - passed = speedup > 5.0 # Expect at least 5x speedup - print(f"\nTest 3: {'PASSED' if passed else 'FAILED'} (speedup > 5x expected)") - return passed - - -# ============================================================ -# Main -# ============================================================ - -if __name__ == "__main__": - print("=" * 60) - print("Hierarchical Block Sum Estimation Test") - print("=" * 60) - print(f"\nConfiguration:") - print(f" NUM_HEADS: {NUM_HEADS}") - print(f" NUM_KV_HEADS: {NUM_KV_HEADS}") - print(f" HEAD_DIM: {HEAD_DIM}") - print(f" STRIDE: {STRIDE}") - print(f" CPU_BLOCK_SIZE: {CPU_BLOCK_SIZE}") - print(f" ESTIMATE_BLOCK_SIZE: {ESTIMATE_BLOCK_SIZE}") - print(f" THRESHOLD: {THRESHOLD}") - print() - - results = [] - - results.append(("Equivalence", test_equivalence())) - results.append(("Selection", test_selection())) - results.append(("Performance", test_performance())) - - print("\n" + "=" * 60) - print("SUMMARY") - print("=" * 60) - for name, passed in results: - status = "PASSED" if passed else "FAILED" - print(f" {name}: {status}") - - all_passed = all(p for _, p in results) - print("=" * 60) - if all_passed: - print("test_hierarchical_estimate: ALL PASSED") - sys.exit(0) - else: - print("test_hierarchical_estimate: SOME FAILED") - sys.exit(1) diff --git a/tests/test_quest_policy.py b/tests/test_quest_policy.py deleted file mode 100644 index 14a893f..0000000 --- a/tests/test_quest_policy.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Test for QuestPolicy block selection with GQA (Grouped Query Attention). - -Demonstrates the key limitation: scores are AVERAGED across heads, -so blocks strongly needed by one head but not others may be dropped. - -This is the expected Quest behavior - not a bug. -""" - -import torch -from nanovllm.kvcache.sparse import ( - create_sparse_policy, - SparsePolicyType, - PolicyContext, -) - -# ============================================================ -# Test: Per-Head Score Averaging in GQA -# ============================================================ - -# Determine device (GPU if available, else CPU) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print(f"Running test on device: {device}") - -# Setup: 2 KV heads, 4 query heads (GQA group_size=2) -# topk=2 to make selection competitive - -quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0) -quest.initialize( - num_layers=1, - num_kv_heads=2, - head_dim=4, - num_cpu_blocks=6, - dtype=torch.float32, - device=device, # Metadata stored on GPU -) - -metadata = quest.metadata - -def set_key(block_id, head_id, values): - """Set both key_min and key_max to same values for deterministic scoring.""" - # Values need to be on the same device as metadata - tensor = torch.tensor(values, device=device) - metadata.key_min[block_id, 0, head_id, :] = tensor - metadata.key_max[block_id, 0, head_id, :] = tensor - -# ============================================================ -# Design: Different heads want different blocks -# ============================================================ -# -# Query = [1,1,1,1] for all heads, so score = sum(key values) -# -# Block | Head 0 | Head 1 | Average | Result -# ------|--------|--------|---------|-------- -# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED -# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED -# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1) -# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2) -# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3 -# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3 - -# Block 0: Head 0 strongly wants, Head 1 strongly rejects -set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4 -set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4 - -# Block 1: Head 1 strongly wants, Head 0 strongly rejects -set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4 -set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4 - -# Block 2: Both heads want equally (highest average) -set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4 -set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4 - -# Block 3: Both heads want moderately -set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3 -set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3 - -# Block 4: Head 0 strongly wants, Head 1 neutral -set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4 -set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0 - -# Block 5: Head 1 strongly wants, Head 0 neutral -set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0 -set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4 - -# ============================================================ -# Run selection -# ============================================================ - -# Query on same device as metadata -query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads - -ctx = PolicyContext( - query_chunk_idx=0, - num_query_chunks=1, - layer_id=0, - query=query, - is_prefill=False, - block_size=1024, - total_kv_len=6144, -) - -available = list(range(6)) -selected = quest.select_blocks(available, ctx) - -# ============================================================ -# Verify: Averaging behavior -# ============================================================ - -# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected -assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}" -assert selected == [2, 3], f"Expected [2, 3], got {selected}" - -# Key insight: blocks 0 and 1 have score +4 for ONE head, -# but they cancel out due to averaging with the other head's -4 -assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)" -assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)" - -# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2 -# But +2 < +3 (block 3), so they don't make the top-2 -assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3" -assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3" - -print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4") -print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3") -print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)") -print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)") -print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)") -print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)") - -# Verify metadata is on correct device -assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}" -assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}" -print(f"✓ Metadata stored on {device.type.upper()}") - -print("\ntest_quest_policy: PASSED") diff --git a/tests/test_sequential.py b/tests/test_sequential.py deleted file mode 100644 index 67f1a1f..0000000 --- a/tests/test_sequential.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Sequential inference test for LLM. - -Tests: After completing one prompt, the system can correctly handle -a second prompt with a clean state (first prompt's KV cache deallocated). -""" - -import os -os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" - -import argparse -from nanovllm import LLM, SamplingParams -from utils import generate_needle_prompt, check_needle_answer - - -def run_sequential_test( - model_path: str, - max_model_len: int, - input_len: int, - num_gpu_blocks: int = 4, - block_size: int = 1024, - enable_cpu_offload: bool = False, - verbose: bool = True, -) -> bool: - """ - Run sequential inference test with two different prompts. - - Each prompt has a different needle value. Both must be retrieved correctly. - """ - if verbose: - print(f"\n{'='*60}") - print(f"Sequential Inference Test") - print(f"{'='*60}") - print(f"Model: {model_path}") - print(f"Max model len: {max_model_len}") - print(f"Input length: {input_len}") - print(f"Block size: {block_size}") - print(f"CPU offload: {enable_cpu_offload}") - print(f"{'='*60}\n") - - # Initialize LLM once - llm_kwargs = { - "enforce_eager": True, - "max_model_len": max_model_len, - "max_num_batched_tokens": max_model_len, - "enable_cpu_offload": enable_cpu_offload, - "kvcache_block_size": block_size, - } - if enable_cpu_offload: - llm_kwargs["num_gpu_blocks"] = num_gpu_blocks - - llm = LLM(model_path, **llm_kwargs) - - sampling_params = SamplingParams( - temperature=0.6, - max_tokens=32, - ) - - # ============================================================ - # Test 1: First prompt with needle value "1234" - # ============================================================ - needle_value_1 = "1234" - if verbose: - print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}") - - prompt_1, expected_1 = generate_needle_prompt( - tokenizer=llm.tokenizer, - target_length=input_len, - needle_position=0.5, - needle_value=needle_value_1, - ) - - outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True) - output_text_1 = outputs_1[0]["text"] - passed_1 = check_needle_answer(output_text_1, expected_1) - - if verbose: - print(f" Expected: {expected_1}") - print(f" Output: {output_text_1[:100]}...") - print(f" Status: {'PASSED' if passed_1 else 'FAILED'}") - - # ============================================================ - # Test 2: Second prompt with needle value "5678" - # ============================================================ - needle_value_2 = "5678" - if verbose: - print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}") - - prompt_2, expected_2 = generate_needle_prompt( - tokenizer=llm.tokenizer, - target_length=input_len, - needle_position=0.5, - needle_value=needle_value_2, - ) - - outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True) - output_text_2 = outputs_2[0]["text"] - passed_2 = check_needle_answer(output_text_2, expected_2) - - if verbose: - print(f" Expected: {expected_2}") - print(f" Output: {output_text_2[:100]}...") - print(f" Status: {'PASSED' if passed_2 else 'FAILED'}") - - # ============================================================ - # Test 3: Third prompt - repeat first needle to ensure no cross-contamination - # ============================================================ - needle_value_3 = "9999" - if verbose: - print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}") - - prompt_3, expected_3 = generate_needle_prompt( - tokenizer=llm.tokenizer, - target_length=input_len, - needle_position=0.5, - needle_value=needle_value_3, - ) - - outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True) - output_text_3 = outputs_3[0]["text"] - passed_3 = check_needle_answer(output_text_3, expected_3) - - if verbose: - print(f" Expected: {expected_3}") - print(f" Output: {output_text_3[:100]}...") - print(f" Status: {'PASSED' if passed_3 else 'FAILED'}") - - # ============================================================ - # Summary - # ============================================================ - all_passed = passed_1 and passed_2 and passed_3 - - if verbose: - print(f"\n{'='*60}") - print(f"Summary") - print(f"{'='*60}") - print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}") - print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}") - print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}") - print(f"Overall: {'PASSED' if all_passed else 'FAILED'}") - print(f"{'='*60}\n") - - return all_passed - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Sequential inference test") - parser.add_argument( - "--model", "-m", - type=str, - default=os.path.expanduser("~/models/Qwen3-0.6B/"), - help="Path to model" - ) - parser.add_argument( - "--max-model-len", - type=int, - default=36 * 1024, - help="Maximum model context length" - ) - parser.add_argument( - "--input-len", - type=int, - default=8 * 1024, - help="Target input sequence length" - ) - parser.add_argument( - "--num-gpu-blocks", - type=int, - default=2, - help="Number of GPU blocks for CPU offload" - ) - parser.add_argument( - "--block-size", - type=int, - default=1024, - help="KV cache block size" - ) - parser.add_argument( - "--enable-offload", - action="store_true", - help="Enable CPU offload" - ) - args = parser.parse_args() - - passed = run_sequential_test( - model_path=args.model, - max_model_len=args.max_model_len, - input_len=args.input_len, - num_gpu_blocks=args.num_gpu_blocks, - block_size=args.block_size, - enable_cpu_offload=args.enable_offload, - verbose=True, - ) - - if passed: - print("test_sequential: PASSED") - else: - print("test_sequential: FAILED") - exit(1)