Compare commits
3 Commits
11a867f6fb
...
d35dd76e09
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d35dd76e09 | ||
|
|
2b61c5ab57 | ||
|
|
a709551072 |
@@ -1,314 +0,0 @@
|
||||
"""
|
||||
Benchmark: block_size impact on XAttention estimate phase performance.
|
||||
|
||||
This script tests how different block_size values affect the performance of:
|
||||
1. flat_group_gemm_fuse_reshape (estimate GEMM)
|
||||
2. softmax_fuse_block_sum (estimate softmax + block aggregation)
|
||||
|
||||
Key insight: The current select_blocks uses global kvcache_block_size for estimation,
|
||||
which may not be optimal for the Triton kernels.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import math
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
# Test configurations
|
||||
BLOCK_SIZES = [64, 128, 256, 512] # BSA optimal is 128
|
||||
STRIDE = 8
|
||||
NUM_WARMUP = 3
|
||||
NUM_RUNS = 10
|
||||
|
||||
# Model dimensions (Llama-3.1-8B-Instruct)
|
||||
NUM_HEADS = 32
|
||||
NUM_KV_HEADS = 8
|
||||
HEAD_DIM = 128
|
||||
|
||||
# Context lengths to test
|
||||
CONTEXT_LENGTHS = [16384, 32768, 65536] # 16K, 32K, 64K
|
||||
|
||||
# ============================================================
|
||||
# Benchmark Functions
|
||||
# ============================================================
|
||||
|
||||
def benchmark_flat_group_gemm(Q, K, stride, block_size, num_warmup=3, num_runs=10):
|
||||
"""
|
||||
Benchmark flat_group_gemm_fuse_reshape kernel.
|
||||
|
||||
Args:
|
||||
Q: [batch, heads, q_len, head_dim]
|
||||
K: [batch, heads, k_len, head_dim]
|
||||
stride: Stride for reshape
|
||||
block_size: Block size (affects alignment requirements)
|
||||
|
||||
Returns:
|
||||
(avg_time_ms, output_tensor)
|
||||
"""
|
||||
q_len = Q.shape[2]
|
||||
k_len = K.shape[2]
|
||||
|
||||
# Compute reshaped dimensions
|
||||
reshaped_q_len = q_len // stride
|
||||
reshaped_k_len = k_len // stride
|
||||
reshaped_block_size = block_size // stride
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmup):
|
||||
_ = flat_group_gemm_fuse_reshape(
|
||||
Q, K, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=reshaped_q_len,
|
||||
is_causal=False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start = time.perf_counter()
|
||||
for _ in range(num_runs):
|
||||
output = flat_group_gemm_fuse_reshape(
|
||||
Q, K, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=reshaped_q_len,
|
||||
is_causal=False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
avg_time_ms = (end - start) / num_runs * 1000
|
||||
return avg_time_ms, output
|
||||
|
||||
|
||||
def benchmark_softmax_fuse_block_sum(attn_weights, reshaped_block_size, num_warmup=3, num_runs=10):
|
||||
"""
|
||||
Benchmark softmax_fuse_block_sum kernel.
|
||||
|
||||
Args:
|
||||
attn_weights: [batch, heads, q_len, k_len] attention weights
|
||||
reshaped_block_size: Block size in reshaped space
|
||||
|
||||
Returns:
|
||||
avg_time_ms
|
||||
"""
|
||||
batch_size, num_heads, q_len, k_len = attn_weights.shape
|
||||
head_dim = HEAD_DIM
|
||||
stride = STRIDE
|
||||
norm = 1.0
|
||||
|
||||
# segment_size must divide k_len and be >= reshaped_block_size
|
||||
segment_size = min(4096, reshaped_block_size)
|
||||
|
||||
# Ensure k_len is divisible by segment_size
|
||||
if k_len % segment_size != 0:
|
||||
# Pad k_len
|
||||
pad_size = segment_size - (k_len % segment_size)
|
||||
attn_weights = torch.nn.functional.pad(attn_weights, (0, pad_size), value=0)
|
||||
k_len = attn_weights.shape[3]
|
||||
|
||||
# Scale factor
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmup):
|
||||
_ = softmax_fuse_block_sum(
|
||||
attn_weights,
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_len,
|
||||
real_q_len=q_len,
|
||||
scale=scale,
|
||||
is_causal=False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start = time.perf_counter()
|
||||
for _ in range(num_runs):
|
||||
output = softmax_fuse_block_sum(
|
||||
attn_weights,
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_len,
|
||||
real_q_len=q_len,
|
||||
scale=scale,
|
||||
is_causal=False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
avg_time_ms = (end - start) / num_runs * 1000
|
||||
return avg_time_ms
|
||||
|
||||
|
||||
def run_estimate_benchmark(q_len, k_len, block_size, stride=STRIDE):
|
||||
"""
|
||||
Run full estimate benchmark for given configuration.
|
||||
|
||||
Args:
|
||||
q_len: Query length
|
||||
k_len: Key length (usually same as q_len for current chunk scenario)
|
||||
block_size: Block size to test
|
||||
stride: Stride for reshape
|
||||
|
||||
Returns:
|
||||
dict with timing results
|
||||
"""
|
||||
# Create random Q and K tensors
|
||||
# Shape: [batch, heads, seq_len, head_dim]
|
||||
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
|
||||
K = torch.randn(1, NUM_HEADS, k_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
reshaped_block_size = block_size // stride
|
||||
reshaped_q_len = q_len // stride
|
||||
reshaped_k_len = k_len // stride
|
||||
|
||||
# Benchmark GEMM
|
||||
gemm_time, attn_weights = benchmark_flat_group_gemm(
|
||||
Q, K, stride, block_size,
|
||||
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
|
||||
)
|
||||
|
||||
# Benchmark softmax + block sum
|
||||
softmax_time = benchmark_softmax_fuse_block_sum(
|
||||
attn_weights, reshaped_block_size,
|
||||
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
|
||||
)
|
||||
|
||||
# Clean up
|
||||
del Q, K, attn_weights
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
"q_len": q_len,
|
||||
"k_len": k_len,
|
||||
"block_size": block_size,
|
||||
"reshaped_block_size": reshaped_block_size,
|
||||
"gemm_time_ms": gemm_time,
|
||||
"softmax_time_ms": softmax_time,
|
||||
"total_time_ms": gemm_time + softmax_time,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Benchmark
|
||||
# ============================================================
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Benchmark block_size impact on estimate phase")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="GPU to use")
|
||||
parser.add_argument("--ctx-len", type=int, default=None,
|
||||
help="Single context length to test (default: test multiple)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set GPU
|
||||
torch.cuda.set_device(args.gpu)
|
||||
device_name = torch.cuda.get_device_name(args.gpu)
|
||||
print(f"Using GPU {args.gpu}: {device_name}")
|
||||
print()
|
||||
|
||||
# Determine context lengths to test
|
||||
if args.ctx_len:
|
||||
context_lengths = [args.ctx_len]
|
||||
else:
|
||||
context_lengths = CONTEXT_LENGTHS
|
||||
|
||||
print("=" * 80)
|
||||
print("Benchmark: block_size impact on XAttention estimate phase")
|
||||
print("=" * 80)
|
||||
print(f"Configuration:")
|
||||
print(f" NUM_HEADS: {NUM_HEADS}")
|
||||
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
|
||||
print(f" HEAD_DIM: {HEAD_DIM}")
|
||||
print(f" STRIDE: {STRIDE}")
|
||||
print(f" BLOCK_SIZES: {BLOCK_SIZES}")
|
||||
print(f" NUM_WARMUP: {NUM_WARMUP}")
|
||||
print(f" NUM_RUNS: {NUM_RUNS}")
|
||||
print()
|
||||
|
||||
all_results = []
|
||||
|
||||
for ctx_len in context_lengths:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Context Length: {ctx_len // 1024}K ({ctx_len} tokens)")
|
||||
print(f"{'='*80}")
|
||||
|
||||
# Pad to alignment
|
||||
alignment = STRIDE * 128 # Triton BLOCK_M requirement
|
||||
padded_len = ((ctx_len + alignment - 1) // alignment) * alignment
|
||||
print(f"Padded to: {padded_len} tokens (alignment={alignment})")
|
||||
print()
|
||||
|
||||
results = []
|
||||
for block_size in BLOCK_SIZES:
|
||||
print(f"Testing block_size={block_size} (reshaped={block_size // STRIDE})...", end=" ")
|
||||
try:
|
||||
result = run_estimate_benchmark(padded_len, padded_len, block_size)
|
||||
results.append(result)
|
||||
print(f"GEMM={result['gemm_time_ms']:.2f}ms, "
|
||||
f"Softmax={result['softmax_time_ms']:.2f}ms, "
|
||||
f"Total={result['total_time_ms']:.2f}ms")
|
||||
except Exception as e:
|
||||
print(f"ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if results:
|
||||
all_results.extend(results)
|
||||
|
||||
# Print summary table for this context length
|
||||
print(f"\n--- Summary for {ctx_len // 1024}K context ---")
|
||||
print(f"{'block_size':>12} {'reshaped':>10} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12} {'Speedup':>10}")
|
||||
print("-" * 74)
|
||||
|
||||
baseline_total = results[0]["total_time_ms"]
|
||||
for r in results:
|
||||
speedup = baseline_total / r["total_time_ms"]
|
||||
print(f"{r['block_size']:>12} {r['reshaped_block_size']:>10} "
|
||||
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
|
||||
f"{r['total_time_ms']:>12.2f} {speedup:>9.2f}x")
|
||||
|
||||
# Final summary across all context lengths
|
||||
if len(context_lengths) > 1:
|
||||
print(f"\n{'='*80}")
|
||||
print("OVERALL SUMMARY")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'ctx_len':>10} {'block_size':>12} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12}")
|
||||
print("-" * 64)
|
||||
for r in all_results:
|
||||
print(f"{r['q_len']//1024:>9}K {r['block_size']:>12} "
|
||||
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
|
||||
f"{r['total_time_ms']:>12.2f}")
|
||||
|
||||
# Find optimal block_size for softmax
|
||||
print(f"\n{'='*80}")
|
||||
print("ANALYSIS: Optimal block_size for softmax_fuse_block_sum")
|
||||
print(f"{'='*80}")
|
||||
|
||||
for ctx_len in context_lengths:
|
||||
ctx_results = [r for r in all_results if r["q_len"] == ((ctx_len + STRIDE * 128 - 1) // (STRIDE * 128)) * (STRIDE * 128)]
|
||||
if ctx_results:
|
||||
best = min(ctx_results, key=lambda x: x["softmax_time_ms"])
|
||||
worst = max(ctx_results, key=lambda x: x["softmax_time_ms"])
|
||||
improvement = worst["softmax_time_ms"] / best["softmax_time_ms"]
|
||||
print(f"Context {ctx_len // 1024}K:")
|
||||
print(f" Best: block_size={best['block_size']} ({best['softmax_time_ms']:.2f}ms)")
|
||||
print(f" Worst: block_size={worst['block_size']} ({worst['softmax_time_ms']:.2f}ms)")
|
||||
print(f" Potential improvement: {improvement:.2f}x")
|
||||
|
||||
print("\nbench_estimate_block_size: DONE")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,757 +0,0 @@
|
||||
"""
|
||||
Custom Qwen3 implementation using only torch and transformers.
|
||||
This file provides a clean reference implementation for understanding the model computation graph.
|
||||
|
||||
Computation Graph:
|
||||
==================
|
||||
|
||||
Input: token_ids [batch, seq_len]
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Decoder Layer (x N) │
|
||||
│ ┌───────────────────────────────────────────────────┐ │
|
||||
│ │ Self Attention Block │ │
|
||||
│ │ │ │
|
||||
│ │ input_layernorm (RMSNorm) │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Qwen3Attention │ │ │
|
||||
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
|
||||
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
|
||||
│ │ │ V = v_proj(x) → reshape │ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ attn_output = attention(Q, K, V) │ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ output = o_proj(attn_output) │ │ │
|
||||
│ │ └─────────────────────────────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ hidden_states = residual + attn_output │ │
|
||||
│ └───────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────────────────────────────────────────────┐ │
|
||||
│ │ MLP Block │ │
|
||||
│ │ │ │
|
||||
│ │ post_attention_layernorm (RMSNorm) │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Qwen3MLP │ │ │
|
||||
│ │ │ gate = gate_proj(x) │ │ │
|
||||
│ │ │ up = up_proj(x) │ │ │
|
||||
│ │ │ output = down_proj(silu(gate) * up) │ │ │
|
||||
│ │ └─────────────────────────────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ hidden_states = residual + mlp_output │ │
|
||||
│ └───────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ norm │ final RMSNorm
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ lm_head │ [hidden_size, vocab_size]
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
logits [batch, seq_len, vocab_size]
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Qwen3RMSNorm(nn.Module):
|
||||
"""RMSNorm implementation."""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
|
||||
class Qwen3RotaryEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding (RoPE)."""
|
||||
|
||||
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
# Compute inverse frequencies
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
|
||||
position_ids: Position indices [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
cos, sin: [batch, seq_len, head_dim]
|
||||
"""
|
||||
# inv_freq: [dim/2]
|
||||
# position_ids: [batch, seq_len]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
|
||||
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
|
||||
|
||||
# freqs: [batch, dim/2, seq_len]
|
||||
freqs = inv_freq_expanded @ position_ids_expanded
|
||||
# freqs: [batch, seq_len, dim/2]
|
||||
freqs = freqs.transpose(1, 2)
|
||||
|
||||
# Duplicate for full head_dim: [batch, seq_len, dim]
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
cos = emb.cos().to(x.dtype)
|
||||
sin = emb.sin().to(x.dtype)
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotate half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary position embeddings to Q and K.
|
||||
|
||||
Args:
|
||||
q: [batch, num_heads, seq_len, head_dim]
|
||||
k: [batch, num_kv_heads, seq_len, head_dim]
|
||||
cos: [batch, seq_len, head_dim]
|
||||
sin: [batch, seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
q_embed, k_embed with same shapes as inputs
|
||||
"""
|
||||
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
"""
|
||||
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
|
||||
|
||||
Data Flow:
|
||||
---------
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
|
||||
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
|
||||
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
|
||||
│
|
||||
▼
|
||||
apply_rotary_pos_emb(Q, K)
|
||||
│
|
||||
▼
|
||||
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
|
||||
│
|
||||
▼
|
||||
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
attention_bias: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_kv_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_kv_groups = num_attention_heads // num_key_value_heads
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Scaling factor
|
||||
self.scaling = head_dim ** -0.5
|
||||
|
||||
# QKV projections
|
||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
|
||||
|
||||
# QK normalization (Qwen3 specific)
|
||||
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
|
||||
# Rotary embeddings
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(
|
||||
head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
|
||||
past_key_value: (k_cache, v_cache) from previous steps
|
||||
use_cache: Whether to return updated cache
|
||||
output_qkv: Whether to output Q, K, V tensors for debugging
|
||||
|
||||
Returns:
|
||||
output: [batch, seq_len, hidden_size]
|
||||
past_key_value: Updated cache (if use_cache=True)
|
||||
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
|
||||
"""
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# === QKV Projections ===
|
||||
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
|
||||
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||
|
||||
# Reshape to [batch, seq_len, num_heads, head_dim]
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# === QK Normalization (Qwen3 specific) ===
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Transpose to [batch, num_heads, seq_len, head_dim]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# === Rotary Position Embeddings ===
|
||||
cos, sin = self.rotary_emb(v, position_ids)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
# === KV Cache Update ===
|
||||
if past_key_value is not None:
|
||||
k_cache, v_cache = past_key_value
|
||||
k = torch.cat([k_cache, k], dim=2)
|
||||
v = torch.cat([v_cache, v], dim=2)
|
||||
|
||||
new_past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# === Grouped Query Attention (expand KV heads if needed) ===
|
||||
if self.num_kv_groups > 1:
|
||||
# Repeat KV for each query group
|
||||
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
||||
|
||||
# === Attention Computation (using SDPA for memory efficiency) ===
|
||||
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
|
||||
# is_causal only works when q_len == kv_len (prefill), not during decode
|
||||
q_len, kv_len = q.shape[2], k.shape[2]
|
||||
is_causal = (q_len == kv_len) and (q_len > 1)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
) # [batch, num_heads, seq_len, head_dim]
|
||||
|
||||
# === Output Projection ===
|
||||
# Transpose back and reshape
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
|
||||
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
|
||||
output = self.o_proj(attn_output)
|
||||
|
||||
# Optional QKV output for debugging
|
||||
qkv_dict = None
|
||||
if output_qkv:
|
||||
qkv_dict = {
|
||||
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
|
||||
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
|
||||
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
|
||||
}
|
||||
|
||||
return output, new_past_key_value, qkv_dict
|
||||
|
||||
|
||||
class Qwen3MLP(nn.Module):
|
||||
"""
|
||||
Qwen3 MLP with SwiGLU activation.
|
||||
|
||||
Data Flow:
|
||||
---------
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
|
||||
│
|
||||
└──► up_proj ──► up [batch, seq_len, intermediate_size]
|
||||
│
|
||||
▼
|
||||
silu(gate) * up
|
||||
│
|
||||
▼
|
||||
down_proj ──► output [batch, seq_len, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate = self.gate_proj(x)
|
||||
up = self.up_proj(x)
|
||||
return self.down_proj(F.silu(gate) * up)
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
"""Single Qwen3 Decoder Layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Pre-attention LayerNorm
|
||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# Self-attention
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
attention_bias=attention_bias,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
# Post-attention LayerNorm
|
||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# MLP
|
||||
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: Causal attention mask
|
||||
past_key_value: KV cache for this layer
|
||||
use_cache: Whether to return updated cache
|
||||
output_qkv: Whether to output Q, K, V for debugging
|
||||
|
||||
Returns:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
past_key_value: Updated cache
|
||||
qkv_dict: QKV tensors (if output_qkv=True)
|
||||
"""
|
||||
# === Self Attention Block ===
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_output, new_past_key_value, qkv_dict = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_qkv=output_qkv,
|
||||
)
|
||||
|
||||
hidden_states = residual + attn_output
|
||||
|
||||
# === MLP Block ===
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, new_past_key_value, qkv_dict
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
"""Qwen3 Transformer Model (without LM head)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_hidden_layers: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
||||
# Token embeddings
|
||||
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
|
||||
|
||||
# Decoder layers
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen3DecoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
mlp_bias=mlp_bias,
|
||||
layer_idx=i,
|
||||
)
|
||||
for i in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final LayerNorm
|
||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv_layers: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: [batch, seq_len] or pre-computed 4D mask
|
||||
past_key_values: List of (k, v) tuples for each layer
|
||||
use_cache: Whether to return new cache
|
||||
output_qkv_layers: List of layer indices to output QKV for
|
||||
|
||||
Returns:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
new_past_key_values: Updated cache
|
||||
qkv_outputs: {layer_idx: qkv_dict}
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Embedding
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Position IDs
|
||||
if position_ids is None:
|
||||
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Attention mask (create causal mask if not provided)
|
||||
if attention_mask is None or attention_mask.dim() == 2:
|
||||
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
|
||||
diagonal=kv_seq_len - seq_len + 1,
|
||||
)
|
||||
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
|
||||
|
||||
# Initialize cache list
|
||||
new_past_key_values = [] if use_cache else None
|
||||
qkv_outputs = {} if output_qkv_layers else None
|
||||
|
||||
# Decoder layers
|
||||
for i, layer in enumerate(self.layers):
|
||||
past_kv = past_key_values[i] if past_key_values else None
|
||||
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
|
||||
|
||||
hidden_states, new_kv, qkv_dict = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_kv,
|
||||
use_cache=use_cache,
|
||||
output_qkv=output_qkv,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
new_past_key_values.append(new_kv)
|
||||
if qkv_dict is not None:
|
||||
qkv_outputs[i] = qkv_dict
|
||||
|
||||
# Final norm
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states, new_past_key_values, qkv_outputs
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
"""Qwen3 Model with Language Modeling head."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_hidden_layers: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
tie_word_embeddings: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
# Transformer model
|
||||
self.model = Qwen3Model(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
mlp_bias=mlp_bias,
|
||||
)
|
||||
|
||||
# LM head
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv_layers: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
... (same as Qwen3Model)
|
||||
|
||||
Returns:
|
||||
logits: [batch, seq_len, vocab_size]
|
||||
past_key_values: Updated KV cache
|
||||
qkv_outputs: QKV tensors for specified layers
|
||||
"""
|
||||
hidden_states, new_past_key_values, qkv_outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_qkv_layers=output_qkv_layers,
|
||||
)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits, new_past_key_values, qkv_outputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
|
||||
"""
|
||||
Load weights from a pretrained Qwen3 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory containing config.json and model weights
|
||||
dtype: Data type for model weights
|
||||
|
||||
Returns:
|
||||
Initialized Qwen3ForCausalLM model
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Create model
|
||||
model = cls(
|
||||
vocab_size=config["vocab_size"],
|
||||
hidden_size=config["hidden_size"],
|
||||
intermediate_size=config["intermediate_size"],
|
||||
num_hidden_layers=config["num_hidden_layers"],
|
||||
num_attention_heads=config["num_attention_heads"],
|
||||
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
|
||||
max_position_embeddings=config.get("max_position_embeddings", 32768),
|
||||
rope_theta=config.get("rope_theta", 10000.0),
|
||||
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
|
||||
attention_bias=config.get("attention_bias", False),
|
||||
mlp_bias=config.get("mlp_bias", False),
|
||||
tie_word_embeddings=config.get("tie_word_embeddings", True),
|
||||
)
|
||||
|
||||
# Load weights
|
||||
weight_files = sorted([
|
||||
f for f in os.listdir(model_path)
|
||||
if f.endswith(".safetensors")
|
||||
])
|
||||
|
||||
state_dict = {}
|
||||
for wf in weight_files:
|
||||
state_dict.update(load_file(os.path.join(model_path, wf)))
|
||||
|
||||
# Load into model
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Tie lm_head weights to embed_tokens if configured
|
||||
if model.tie_word_embeddings:
|
||||
model.lm_head.weight = model.model.embed_tokens.weight
|
||||
|
||||
model = model.to(dtype)
|
||||
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_new_tokens: int = 32,
|
||||
temperature: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Simple autoregressive generation."""
|
||||
device = input_ids.device
|
||||
batch_size, seq_len = input_ids.shape
|
||||
past_key_values = None
|
||||
generated = input_ids.clone()
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if past_key_values is None:
|
||||
current_input = generated
|
||||
else:
|
||||
current_input = generated[:, -1:]
|
||||
|
||||
logits, past_key_values, _ = self(
|
||||
input_ids=current_input,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
next_token_logits = logits[:, -1, :]
|
||||
if temperature > 0 and do_sample:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
probs = torch.softmax(next_token_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
|
||||
|
||||
generated = torch.cat([generated, next_token], dim=1)
|
||||
|
||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
||||
break
|
||||
|
||||
return generated
|
||||
|
||||
|
||||
def print_computation_graph():
|
||||
"""Print the computation graph for reference."""
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_computation_graph()
|
||||
@@ -1,151 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test: Pre-allocated chunk pair graphs for block sparse attention.
|
||||
|
||||
Each (Q_chunk, K_chunk) pair has its own captured CUDA graph.
|
||||
Zero copy_() during replay - all data pre-filled.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph.py
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkAttentionGraph:
|
||||
"""Container for a captured chunk attention graph."""
|
||||
graph: torch.cuda.CUDAGraph
|
||||
static_q: torch.Tensor
|
||||
static_k: torch.Tensor
|
||||
static_v: torch.Tensor
|
||||
static_output: torch.Tensor
|
||||
static_lse: torch.Tensor
|
||||
causal: bool
|
||||
|
||||
|
||||
def capture_chunk_attention_graph(
|
||||
chunk_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
causal: bool = False,
|
||||
) -> ChunkAttentionGraph:
|
||||
"""Capture a CUDA graph for single chunk attention."""
|
||||
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
|
||||
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
|
||||
static_q.normal_()
|
||||
static_k.normal_()
|
||||
static_v.normal_()
|
||||
|
||||
# Warmup
|
||||
with torch.inference_mode():
|
||||
for _ in range(3):
|
||||
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.inference_mode():
|
||||
with torch.cuda.graph(graph):
|
||||
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return ChunkAttentionGraph(
|
||||
graph=graph,
|
||||
static_q=static_q,
|
||||
static_k=static_k,
|
||||
static_v=static_v,
|
||||
static_output=static_output,
|
||||
static_lse=static_lse,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
chunk_size = 64
|
||||
num_chunks = 4
|
||||
num_heads = 8
|
||||
num_kv_heads = 8
|
||||
head_dim = 64
|
||||
scale = 1.0 / (head_dim ** 0.5)
|
||||
seq_len = chunk_size * num_chunks
|
||||
|
||||
print(f"Device: {torch.cuda.get_device_name()}")
|
||||
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}")
|
||||
print(f"Total graphs: {num_chunks * (num_chunks + 1) // 2}")
|
||||
|
||||
# Test data
|
||||
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
||||
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Reference
|
||||
with torch.inference_mode():
|
||||
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
|
||||
|
||||
# Capture all graphs
|
||||
graphs: List[List[Optional[ChunkAttentionGraph]]] = [[None] * num_chunks for _ in range(num_chunks)]
|
||||
for q_idx in range(num_chunks):
|
||||
for k_idx in range(q_idx + 1):
|
||||
graphs[q_idx][k_idx] = capture_chunk_attention_graph(
|
||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype,
|
||||
causal=(k_idx == q_idx)
|
||||
)
|
||||
print("All graphs captured")
|
||||
|
||||
# Pre-fill static tensors
|
||||
for q_idx in range(num_chunks):
|
||||
for k_idx in range(q_idx + 1):
|
||||
g = graphs[q_idx][k_idx]
|
||||
g.static_q.copy_(full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size])
|
||||
g.static_k.copy_(full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
|
||||
g.static_v.copy_(full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
|
||||
print("Static tensors pre-filled")
|
||||
|
||||
# Replay and merge
|
||||
chunked_output = torch.zeros_like(full_output)
|
||||
for q_idx in range(num_chunks):
|
||||
acc_out, acc_lse = None, None
|
||||
for k_idx in range(q_idx + 1):
|
||||
g = graphs[q_idx][k_idx]
|
||||
g.graph.replay()
|
||||
out, lse = g.static_output.clone(), g.static_lse.clone()
|
||||
if acc_out is None:
|
||||
acc_out, acc_lse = out, lse
|
||||
else:
|
||||
with torch.inference_mode():
|
||||
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
|
||||
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compare
|
||||
all_pass = True
|
||||
for q_idx in range(num_chunks):
|
||||
s, e = q_idx * chunk_size, (q_idx + 1) * chunk_size
|
||||
diff = (full_output[:, s:e] - chunked_output[:, s:e]).abs().max().item()
|
||||
status = "✅" if diff < 1e-2 else "❌"
|
||||
print(f"Q[{q_idx}]: max_diff={diff:.2e} {status}")
|
||||
if diff >= 1e-2:
|
||||
all_pass = False
|
||||
|
||||
print("✅ PASSED" if all_pass else "❌ FAILED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,156 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
|
||||
|
||||
Key insight: LLM layers have identical computation structure.
|
||||
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReusableChunkGraph:
|
||||
"""A single graph that can be reused with copy_() updates."""
|
||||
graph: torch.cuda.CUDAGraph
|
||||
static_q: torch.Tensor
|
||||
static_k: torch.Tensor
|
||||
static_v: torch.Tensor
|
||||
static_output: torch.Tensor
|
||||
static_lse: torch.Tensor
|
||||
|
||||
|
||||
def capture_reusable_graph(
|
||||
chunk_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
causal: bool,
|
||||
) -> ReusableChunkGraph:
|
||||
"""Capture ONE graph to be reused for all chunk pairs."""
|
||||
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
|
||||
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
|
||||
static_q.normal_()
|
||||
static_k.normal_()
|
||||
static_v.normal_()
|
||||
|
||||
# Warmup
|
||||
with torch.inference_mode():
|
||||
for _ in range(3):
|
||||
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.inference_mode():
|
||||
with torch.cuda.graph(graph):
|
||||
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return ReusableChunkGraph(
|
||||
graph=graph,
|
||||
static_q=static_q,
|
||||
static_k=static_k,
|
||||
static_v=static_v,
|
||||
static_output=static_output,
|
||||
static_lse=static_lse,
|
||||
)
|
||||
|
||||
|
||||
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Replay graph after updating static tensors with copy_()."""
|
||||
graph.static_q.copy_(q)
|
||||
graph.static_k.copy_(k)
|
||||
graph.static_v.copy_(v)
|
||||
graph.graph.replay()
|
||||
return graph.static_output.clone(), graph.static_lse.clone()
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
chunk_size = 64
|
||||
num_chunks = 4
|
||||
num_layers = 3 # Simulate multiple layers
|
||||
num_heads = 8
|
||||
num_kv_heads = 8
|
||||
head_dim = 64
|
||||
scale = 1.0 / (head_dim ** 0.5)
|
||||
seq_len = chunk_size * num_chunks
|
||||
|
||||
print(f"Device: {torch.cuda.get_device_name()}")
|
||||
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
|
||||
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
|
||||
|
||||
# Capture only 2 graphs
|
||||
graph_causal = capture_reusable_graph(
|
||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
|
||||
)
|
||||
graph_non_causal = capture_reusable_graph(
|
||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
|
||||
)
|
||||
print("2 graphs captured (causal + non-causal)")
|
||||
|
||||
all_pass = True
|
||||
|
||||
for layer_id in range(num_layers):
|
||||
# Different Q/K/V for each layer (simulating different layer outputs)
|
||||
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
||||
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Reference: full causal attention
|
||||
with torch.inference_mode():
|
||||
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
|
||||
|
||||
# Chunked with graph reuse
|
||||
chunked_output = torch.zeros_like(full_output)
|
||||
|
||||
for q_idx in range(num_chunks):
|
||||
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
|
||||
acc_out, acc_lse = None, None
|
||||
|
||||
for k_idx in range(q_idx + 1):
|
||||
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
|
||||
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
|
||||
|
||||
# Reuse graph with copy_()
|
||||
graph = graph_causal if k_idx == q_idx else graph_non_causal
|
||||
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
|
||||
|
||||
if acc_out is None:
|
||||
acc_out, acc_lse = out, lse
|
||||
else:
|
||||
with torch.inference_mode():
|
||||
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
|
||||
|
||||
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compare
|
||||
max_diff = (full_output - chunked_output).abs().max().item()
|
||||
status = "✅" if max_diff < 1e-2 else "❌"
|
||||
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
|
||||
if max_diff >= 1e-2:
|
||||
all_pass = False
|
||||
|
||||
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,357 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CUDA Graph Memory Analysis Test
|
||||
|
||||
This script analyzes the memory overhead of CUDA Graph at each stage:
|
||||
1. Model loading
|
||||
2. StaticCache allocation
|
||||
3. Warmup runs
|
||||
4. Graph capture
|
||||
5. Graph replay
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py
|
||||
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B
|
||||
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import StaticCache
|
||||
|
||||
|
||||
def get_memory_mb():
|
||||
"""Get current allocated memory in MB."""
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_memory_gb():
|
||||
"""Get current allocated memory in GB."""
|
||||
return torch.cuda.memory_allocated() / 1024**3
|
||||
|
||||
|
||||
def get_peak_memory_gb():
|
||||
"""Get peak allocated memory in GB."""
|
||||
return torch.cuda.max_memory_allocated() / 1024**3
|
||||
|
||||
|
||||
def print_separator(title=None):
|
||||
"""Print a separator line."""
|
||||
if title:
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" {title}")
|
||||
print(f"{'=' * 70}")
|
||||
else:
|
||||
print("-" * 70)
|
||||
|
||||
|
||||
def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1):
|
||||
"""
|
||||
Test memory usage at each stage of CUDA Graph setup.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
max_cache_len: Maximum cache length for StaticCache
|
||||
batch_size: Batch size for inference
|
||||
"""
|
||||
print_separator("CUDA Graph Memory Analysis")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Max cache length: {max_cache_len}")
|
||||
print(f"Batch size: {batch_size}")
|
||||
|
||||
results = {}
|
||||
|
||||
# Stage 0: Initial
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
results["initial"] = get_memory_mb()
|
||||
|
||||
# Stage 1: Load model
|
||||
print_separator("Stage 1: Model Loading")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
results["after_model"] = get_memory_mb()
|
||||
model_size = results["after_model"] - results["initial"]
|
||||
print(f" Memory: {results['after_model']:.0f} MB")
|
||||
print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)")
|
||||
|
||||
config = model.config
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
# Stage 2: Allocate StaticCache
|
||||
print_separator("Stage 2: StaticCache Allocation")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
before = get_memory_mb()
|
||||
|
||||
static_cache = StaticCache(
|
||||
config=config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=max_cache_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
results["after_cache"] = get_memory_mb()
|
||||
cache_size = results["after_cache"] - before
|
||||
print(f" Memory: {results['after_cache']:.0f} MB")
|
||||
print(f" StaticCache size: {cache_size:.0f} MB")
|
||||
|
||||
# Calculate theoretical cache size
|
||||
num_layers = config.num_hidden_layers
|
||||
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
dtype_size = 2 # bfloat16
|
||||
|
||||
theoretical_cache = (
|
||||
num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size
|
||||
) / (1024**2)
|
||||
print(f" Theoretical: {theoretical_cache:.0f} MB")
|
||||
print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)")
|
||||
|
||||
# Stage 3: Prepare static tensors
|
||||
print_separator("Stage 3: Static Tensor Allocation")
|
||||
before = get_memory_mb()
|
||||
|
||||
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
|
||||
static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
|
||||
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
|
||||
|
||||
results["after_tensors"] = get_memory_mb()
|
||||
tensor_size = results["after_tensors"] - before
|
||||
print(f" Memory: {results['after_tensors']:.0f} MB")
|
||||
print(f" Static tensors: {tensor_size:.2f} MB (negligible)")
|
||||
|
||||
# Stage 4: Warmup runs
|
||||
print_separator("Stage 4: Warmup Runs (3 iterations)")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
before = get_memory_mb()
|
||||
|
||||
with torch.inference_mode():
|
||||
for i in range(3):
|
||||
_ = model(
|
||||
input_ids=static_input_ids,
|
||||
position_ids=static_position_ids,
|
||||
past_key_values=static_cache,
|
||||
cache_position=static_cache_position,
|
||||
use_cache=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
results["after_warmup"] = get_memory_mb()
|
||||
results["warmup_peak"] = get_peak_memory_gb() * 1024
|
||||
warmup_size = results["after_warmup"] - before
|
||||
print(f" Memory: {results['after_warmup']:.0f} MB")
|
||||
print(f" Peak: {results['warmup_peak']:.0f} MB")
|
||||
print(f" Warmup overhead: {warmup_size:.0f} MB")
|
||||
|
||||
# Stage 5: CUDA Graph capture
|
||||
print_separator("Stage 5: CUDA Graph Capture")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
before = get_memory_mb()
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.inference_mode():
|
||||
with torch.cuda.graph(graph):
|
||||
outputs = model(
|
||||
input_ids=static_input_ids,
|
||||
position_ids=static_position_ids,
|
||||
past_key_values=static_cache,
|
||||
cache_position=static_cache_position,
|
||||
use_cache=True,
|
||||
)
|
||||
static_logits = outputs.logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
results["after_capture"] = get_memory_mb()
|
||||
results["capture_peak"] = get_peak_memory_gb() * 1024
|
||||
capture_size = results["after_capture"] - before
|
||||
print(f" Memory: {results['after_capture']:.0f} MB")
|
||||
print(f" Peak: {results['capture_peak']:.0f} MB")
|
||||
print(f" Graph capture overhead: {capture_size:.0f} MB")
|
||||
|
||||
# Stage 6: Graph replay
|
||||
print_separator("Stage 6: Graph Replay (10 iterations)")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
before = get_memory_mb()
|
||||
|
||||
with torch.inference_mode():
|
||||
for _ in range(10):
|
||||
static_input_ids.fill_(1)
|
||||
static_cache_position.fill_(0)
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
results["after_replay"] = get_memory_mb()
|
||||
results["replay_peak"] = get_peak_memory_gb() * 1024
|
||||
replay_change = results["after_replay"] - before
|
||||
print(f" Memory: {results['after_replay']:.0f} MB")
|
||||
print(f" Peak: {results['replay_peak']:.0f} MB")
|
||||
print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)")
|
||||
|
||||
# Summary
|
||||
print_separator("SUMMARY")
|
||||
total_overhead = results["after_capture"] - results["after_model"]
|
||||
|
||||
print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}")
|
||||
print("-" * 50)
|
||||
print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}")
|
||||
print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}")
|
||||
print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}")
|
||||
print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}")
|
||||
print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}")
|
||||
print("-" * 50)
|
||||
print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}")
|
||||
|
||||
print_separator("KEY FINDINGS")
|
||||
print(f" 1. Model size: {model_size/1024:.2f} GB")
|
||||
print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)")
|
||||
print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)")
|
||||
print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)")
|
||||
print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_cache_length_scaling(model_path: str, cache_lengths: list):
|
||||
"""
|
||||
Test how memory scales with different cache lengths.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
cache_lengths: List of cache lengths to test
|
||||
"""
|
||||
print_separator("Cache Length Scaling Test")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Cache lengths: {cache_lengths}")
|
||||
|
||||
# Load model once
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
config = model.config
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
model_mem = get_memory_mb()
|
||||
|
||||
results = []
|
||||
for cache_len in cache_lengths:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Create cache and capture graph
|
||||
static_cache = StaticCache(
|
||||
config=config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=cache_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
|
||||
static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
|
||||
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
|
||||
|
||||
with torch.inference_mode():
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = model(
|
||||
input_ids=static_input_ids,
|
||||
position_ids=static_position_ids,
|
||||
past_key_values=static_cache,
|
||||
cache_position=static_cache_position,
|
||||
use_cache=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
outputs = model(
|
||||
input_ids=static_input_ids,
|
||||
position_ids=static_position_ids,
|
||||
past_key_values=static_cache,
|
||||
cache_position=static_cache_position,
|
||||
use_cache=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
total_mem = get_memory_mb()
|
||||
overhead = total_mem - model_mem
|
||||
results.append((cache_len, total_mem, overhead))
|
||||
|
||||
del static_cache, graph
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Print results
|
||||
print()
|
||||
print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}")
|
||||
print("-" * 60)
|
||||
for cache_len, total, overhead in results:
|
||||
per_1k = overhead / (cache_len / 1000)
|
||||
print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="~/models/Qwen3-4B-Instruct-2507",
|
||||
help="Model path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-cache-len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Maximum cache length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-scaling",
|
||||
action="store_true",
|
||||
help="Test cache length scaling",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = os.path.expanduser(args.model)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA is not available!")
|
||||
return
|
||||
|
||||
print(f"Device: cuda:{torch.cuda.current_device()}")
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
|
||||
if args.test_scaling:
|
||||
cache_lengths = [256, 512, 1024, 2048, 4096]
|
||||
test_cache_length_scaling(model_path, cache_lengths)
|
||||
else:
|
||||
test_memory_stages(model_path, args.max_cache_len, args.batch_size)
|
||||
|
||||
print("\ntest_cudagraph_memory: PASSED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,149 +0,0 @@
|
||||
"""
|
||||
Test: GPU-only density alignment verification
|
||||
|
||||
验证 xattn_bsa.py 中 GPU-only 路径的 density 计算是否与独立调用 xattn_estimate 一致。
|
||||
|
||||
流程:
|
||||
1. 运行 GPU-only 推理,保存 Q, K, mask, attn_sums
|
||||
2. 加载保存的数据,独立调用 xattn_estimate
|
||||
3. 比较两者的 mask 和 density
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
DATA_PATH = "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt"
|
||||
|
||||
# ============================================================
|
||||
# 加载保存的数据
|
||||
# ============================================================
|
||||
|
||||
print(f"Loading data from {DATA_PATH}")
|
||||
data = torch.load(DATA_PATH, weights_only=False)
|
||||
|
||||
Q = data["Q"].cuda() # [1, num_heads, q_len, head_dim]
|
||||
K = data["K"].cuda() # [1, num_heads, k_len, head_dim]
|
||||
chunk_size = data["chunk_size"]
|
||||
block_size = data["block_size"]
|
||||
stride = data["stride"]
|
||||
threshold = data["threshold"]
|
||||
mask_saved = data["mask"] # [1, num_heads, valid_q_blocks, valid_k_blocks]
|
||||
attn_sums_saved = data["attn_sums"]
|
||||
q_len = data["q_len"]
|
||||
k_len = data["k_len"]
|
||||
valid_q_blocks = data["valid_q_blocks"]
|
||||
valid_k_blocks = data["valid_k_blocks"]
|
||||
|
||||
print(f"Q shape: {Q.shape}")
|
||||
print(f"K shape: {K.shape}")
|
||||
print(f"q_len: {q_len}, k_len: {k_len}")
|
||||
print(f"chunk_size: {chunk_size}, block_size: {block_size}, stride: {stride}, threshold: {threshold}")
|
||||
print(f"valid_q_blocks: {valid_q_blocks}, valid_k_blocks: {valid_k_blocks}")
|
||||
print(f"mask_saved shape: {mask_saved.shape}")
|
||||
|
||||
# ============================================================
|
||||
# 独立调用 xattn_estimate
|
||||
# ============================================================
|
||||
|
||||
print("\nCalling xattn_estimate independently...")
|
||||
attn_sums_ext, mask_ext = xattn_estimate(
|
||||
Q, K,
|
||||
chunk_size=chunk_size,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Trim to valid blocks
|
||||
mask_ext_valid = mask_ext[:, :, :valid_q_blocks, :valid_k_blocks]
|
||||
attn_sums_ext_valid = attn_sums_ext[:, :, :valid_q_blocks, :valid_k_blocks]
|
||||
|
||||
print(f"mask_ext shape: {mask_ext.shape}")
|
||||
print(f"mask_ext_valid shape: {mask_ext_valid.shape}")
|
||||
|
||||
# ============================================================
|
||||
# 比较 attn_sums
|
||||
# ============================================================
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Comparing attn_sums")
|
||||
print("=" * 60)
|
||||
|
||||
attn_sums_saved_gpu = attn_sums_saved.cuda()
|
||||
attn_diff = (attn_sums_ext_valid - attn_sums_saved_gpu).abs()
|
||||
print(f"attn_sums max diff: {attn_diff.max().item():.6e}")
|
||||
print(f"attn_sums mean diff: {attn_diff.mean().item():.6e}")
|
||||
|
||||
# Check if attn_sums match
|
||||
attn_match = attn_diff.max().item() < 1e-4
|
||||
print(f"attn_sums match: {attn_match}")
|
||||
|
||||
# ============================================================
|
||||
# 比较 mask
|
||||
# ============================================================
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Comparing mask")
|
||||
print("=" * 60)
|
||||
|
||||
mask_saved_gpu = mask_saved.cuda()
|
||||
mask_match = (mask_ext_valid == mask_saved_gpu).all().item()
|
||||
print(f"mask exact match: {mask_match}")
|
||||
|
||||
if not mask_match:
|
||||
diff_count = (mask_ext_valid != mask_saved_gpu).sum().item()
|
||||
total_count = mask_ext_valid.numel()
|
||||
print(f"mask diff count: {diff_count} / {total_count} ({diff_count/total_count*100:.2f}%)")
|
||||
|
||||
# ============================================================
|
||||
# 计算 density
|
||||
# ============================================================
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Comparing density")
|
||||
print("=" * 60)
|
||||
|
||||
# 计算 causal mask
|
||||
q_offset_blocks = valid_k_blocks - valid_q_blocks
|
||||
indices = torch.arange(valid_k_blocks, device=mask_ext_valid.device).unsqueeze(0)
|
||||
q_indices = torch.arange(valid_q_blocks, device=mask_ext_valid.device).unsqueeze(1)
|
||||
causal_mask = indices <= (q_indices + q_offset_blocks)
|
||||
|
||||
# Density from saved mask
|
||||
total_saved = causal_mask.sum().item() * mask_saved_gpu.shape[0] * mask_saved_gpu.shape[1]
|
||||
selected_saved = (mask_saved_gpu & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_saved = selected_saved / total_saved
|
||||
|
||||
# Density from external xattn_estimate
|
||||
total_ext = causal_mask.sum().item() * mask_ext_valid.shape[0] * mask_ext_valid.shape[1]
|
||||
selected_ext = (mask_ext_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_ext = selected_ext / total_ext
|
||||
|
||||
print(f"Saved density: {density_saved:.6f} (selected={selected_saved}, total={total_saved})")
|
||||
print(f"External density: {density_ext:.6f} (selected={selected_ext}, total={total_ext})")
|
||||
print(f"Density diff: {abs(density_saved - density_ext):.6f}")
|
||||
|
||||
# ============================================================
|
||||
# 结论
|
||||
# ============================================================
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("RESULT")
|
||||
print("=" * 60)
|
||||
|
||||
if attn_match and mask_match:
|
||||
print("✅ PASSED: GPU-only density matches external xattn_estimate")
|
||||
else:
|
||||
print("❌ FAILED: Mismatch detected")
|
||||
if not attn_match:
|
||||
print(" - attn_sums mismatch")
|
||||
if not mask_match:
|
||||
print(" - mask mismatch")
|
||||
@@ -1,442 +0,0 @@
|
||||
"""
|
||||
Test: Hierarchical Block Sum Estimation for XAttention
|
||||
|
||||
Verify that hierarchical estimation (small estimate_block_size + aggregation)
|
||||
produces equivalent results to direct estimation (large block_size), while
|
||||
being significantly faster.
|
||||
|
||||
Key changes validated:
|
||||
1. Hierarchical block sum: estimate_block_size=1024 → aggregate to cpu_block_size=4096
|
||||
2. Selection strategy: score + threshold (NOT mask + majority voting)
|
||||
|
||||
This test uses pure torch + xattn kernels, independent of nanovllm framework.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import math
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
# Model dimensions (Llama-3.1-8B-Instruct style)
|
||||
NUM_HEADS = 32
|
||||
NUM_KV_HEADS = 8
|
||||
HEAD_DIM = 128
|
||||
STRIDE = 8
|
||||
|
||||
# Block sizes
|
||||
CPU_BLOCK_SIZE = 4096 # External CPU block size (fixed, for overlap)
|
||||
ESTIMATE_BLOCK_SIZE = 1024 # Internal estimate block size (optimized)
|
||||
|
||||
# Selection parameters
|
||||
THRESHOLD = 0.95 # Cumulative attention threshold
|
||||
|
||||
# ============================================================
|
||||
# Hierarchical Estimation Implementation
|
||||
# ============================================================
|
||||
|
||||
def compute_attention_scores(Q, K_blocks, stride):
|
||||
"""
|
||||
Compute attention scores for Q against multiple K blocks.
|
||||
|
||||
Args:
|
||||
Q: [1, num_heads, q_len, head_dim]
|
||||
K_blocks: List of K tensors, each [1, num_heads, block_size, head_dim]
|
||||
stride: Stride for reshape
|
||||
|
||||
Returns:
|
||||
attn_scores: [1, num_heads, q_reshaped, total_k_reshaped]
|
||||
"""
|
||||
q_len = Q.shape[2]
|
||||
q_reshaped = q_len // stride
|
||||
|
||||
attn_chunks = []
|
||||
for K_block in K_blocks:
|
||||
# flat_group_gemm_fuse_reshape
|
||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
||||
Q, K_block, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_chunks.append(attn_chunk)
|
||||
|
||||
# Concatenate along K dimension
|
||||
attn_scores = torch.cat(attn_chunks, dim=-1)
|
||||
return attn_scores
|
||||
|
||||
|
||||
def hierarchical_block_sum(
|
||||
attn_scores,
|
||||
estimate_block_size,
|
||||
cpu_block_size,
|
||||
stride,
|
||||
head_dim,
|
||||
):
|
||||
"""
|
||||
Compute hierarchical block sums: fine-grained → aggregated to CPU block level.
|
||||
|
||||
Args:
|
||||
attn_scores: [batch, heads, q_reshaped, k_reshaped]
|
||||
estimate_block_size: Small block size for efficient softmax (e.g., 1024)
|
||||
cpu_block_size: External CPU block size (e.g., 4096)
|
||||
stride: Stride used in reshape
|
||||
head_dim: Head dimension for scale computation
|
||||
|
||||
Returns:
|
||||
cpu_block_scores: [batch, heads, num_cpu_blocks] - attention score per CPU block
|
||||
"""
|
||||
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
|
||||
|
||||
# Compute reshaped block sizes
|
||||
reshaped_est_bs = estimate_block_size // stride # 1024/8 = 128
|
||||
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
|
||||
|
||||
# Scale factor
|
||||
norm = 1.0
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
# Segment size for softmax kernel
|
||||
segment_size = min(4096, reshaped_est_bs)
|
||||
|
||||
# Step 1: Fine-grained softmax + block sum
|
||||
block_sums_fine = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
reshaped_est_bs,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped,
|
||||
real_q_len=q_reshaped,
|
||||
scale=scale,
|
||||
is_causal=False,
|
||||
)
|
||||
# block_sums_fine: [batch, heads, q_est_blocks, k_est_blocks]
|
||||
|
||||
q_est_blocks = block_sums_fine.shape[2]
|
||||
k_est_blocks = block_sums_fine.shape[3]
|
||||
|
||||
# Step 2: Aggregate to CPU block level
|
||||
# ratio = cpu_block_size / estimate_block_size = 4
|
||||
ratio = cpu_block_size // estimate_block_size
|
||||
num_cpu_blocks = k_est_blocks // ratio
|
||||
|
||||
# Reshape and sum along K dimension
|
||||
# [batch, heads, q_est, k_est] → [batch, heads, q_est, num_cpu, ratio]
|
||||
block_sums_coarse = block_sums_fine.view(
|
||||
batch_size, num_heads, q_est_blocks, num_cpu_blocks, ratio
|
||||
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
|
||||
|
||||
# Step 3: Sum over Q dimension (total attention from Q chunk to each K block)
|
||||
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
|
||||
|
||||
return cpu_block_scores, block_sums_fine
|
||||
|
||||
|
||||
def direct_block_sum(
|
||||
attn_scores,
|
||||
cpu_block_size,
|
||||
stride,
|
||||
head_dim,
|
||||
):
|
||||
"""
|
||||
Compute block sums directly with CPU block size (baseline for comparison).
|
||||
|
||||
Args:
|
||||
attn_scores: [batch, heads, q_reshaped, k_reshaped]
|
||||
cpu_block_size: Block size (e.g., 4096)
|
||||
stride: Stride used in reshape
|
||||
head_dim: Head dimension for scale computation
|
||||
|
||||
Returns:
|
||||
cpu_block_scores: [batch, heads, num_cpu_blocks]
|
||||
"""
|
||||
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
|
||||
|
||||
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
|
||||
|
||||
norm = 1.0
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
segment_size = min(4096, reshaped_cpu_bs)
|
||||
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
reshaped_cpu_bs,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped,
|
||||
real_q_len=q_reshaped,
|
||||
scale=scale,
|
||||
is_causal=False,
|
||||
)
|
||||
# block_sums: [batch, heads, q_cpu_blocks, k_cpu_blocks]
|
||||
|
||||
# Sum over Q dimension
|
||||
cpu_block_scores = block_sums.sum(dim=2) # [batch, heads, num_cpu_blocks]
|
||||
|
||||
return cpu_block_scores
|
||||
|
||||
|
||||
def select_blocks_by_score(
|
||||
cpu_block_scores,
|
||||
threshold=0.95,
|
||||
always_include_first=True,
|
||||
always_include_last=True,
|
||||
):
|
||||
"""
|
||||
Select CPU blocks based on score + threshold.
|
||||
|
||||
⚠️ IMPORTANT: This replaces the original mask + majority voting strategy.
|
||||
This change should be documented in the final implementation.
|
||||
|
||||
Args:
|
||||
cpu_block_scores: [batch, heads, num_cpu_blocks]
|
||||
threshold: Cumulative attention threshold (e.g., 0.95)
|
||||
always_include_first: Always include first block (sink)
|
||||
always_include_last: Always include last block (safety)
|
||||
|
||||
Returns:
|
||||
selected_block_ids: List of selected block indices
|
||||
density: Fraction of blocks selected
|
||||
"""
|
||||
# Average scores across heads
|
||||
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
|
||||
num_blocks = scores_per_block.shape[0]
|
||||
|
||||
# Normalize to get attention distribution
|
||||
total_score = scores_per_block.sum()
|
||||
score_ratio = scores_per_block / total_score
|
||||
|
||||
# Sort by score (descending)
|
||||
sorted_indices = torch.argsort(score_ratio, descending=True)
|
||||
|
||||
# Select blocks until cumulative threshold is reached
|
||||
cumsum = 0.0
|
||||
selected = set()
|
||||
|
||||
for idx in sorted_indices.tolist():
|
||||
selected.add(idx)
|
||||
cumsum += score_ratio[idx].item()
|
||||
if cumsum >= threshold:
|
||||
break
|
||||
|
||||
# Always include first and last blocks
|
||||
if always_include_first:
|
||||
selected.add(0)
|
||||
if always_include_last:
|
||||
selected.add(num_blocks - 1)
|
||||
|
||||
selected_block_ids = sorted(list(selected))
|
||||
density = len(selected_block_ids) / num_blocks
|
||||
|
||||
return selected_block_ids, density
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test Cases
|
||||
# ============================================================
|
||||
|
||||
def test_equivalence():
|
||||
"""
|
||||
Test that hierarchical estimation produces equivalent scores to direct estimation.
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("Test 1: Hierarchical vs Direct - Equivalence")
|
||||
print("=" * 60)
|
||||
|
||||
# Create random Q and multiple K blocks
|
||||
q_len = CPU_BLOCK_SIZE # 4096
|
||||
num_k_blocks = 4
|
||||
|
||||
# Q: [1, num_heads, q_len, head_dim]
|
||||
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# K blocks: each [1, num_heads, cpu_block_size, head_dim]
|
||||
K_blocks = [
|
||||
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
||||
for _ in range(num_k_blocks)
|
||||
]
|
||||
|
||||
# Compute attention scores
|
||||
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
|
||||
print(f"attn_scores shape: {attn_scores.shape}")
|
||||
|
||||
# Method 1: Hierarchical (fast)
|
||||
scores_hier, _ = hierarchical_block_sum(
|
||||
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
|
||||
)
|
||||
print(f"scores_hier shape: {scores_hier.shape}")
|
||||
|
||||
# Method 2: Direct (slow)
|
||||
scores_direct = direct_block_sum(
|
||||
attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
|
||||
)
|
||||
print(f"scores_direct shape: {scores_direct.shape}")
|
||||
|
||||
# Compare
|
||||
diff = (scores_hier - scores_direct).abs().max().item()
|
||||
print(f"\nMax difference: {diff:.6f}")
|
||||
|
||||
# Per-block comparison
|
||||
print("\nPer-block scores comparison:")
|
||||
for i in range(num_k_blocks):
|
||||
h_val = scores_hier[0, 0, i].item()
|
||||
d_val = scores_direct[0, 0, i].item()
|
||||
print(f" Block {i}: hierarchical={h_val:.4f}, direct={d_val:.4f}, diff={abs(h_val-d_val):.6f}")
|
||||
|
||||
passed = diff < 0.01
|
||||
print(f"\nTest 1: {'PASSED' if passed else 'FAILED'}")
|
||||
return passed
|
||||
|
||||
|
||||
def test_selection():
|
||||
"""
|
||||
Test the score + threshold selection strategy.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 2: Score + Threshold Selection")
|
||||
print("=" * 60)
|
||||
|
||||
# Create Q and K blocks with varying importance
|
||||
q_len = CPU_BLOCK_SIZE
|
||||
num_k_blocks = 8
|
||||
|
||||
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Create K blocks - make some more important than others
|
||||
K_blocks = []
|
||||
for i in range(num_k_blocks):
|
||||
# First and middle blocks are more important (higher values)
|
||||
importance = 2.0 if i in [0, 3, 4] else 1.0
|
||||
K = torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
||||
K = K * importance
|
||||
K_blocks.append(K)
|
||||
|
||||
# Compute scores
|
||||
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
|
||||
scores, _ = hierarchical_block_sum(
|
||||
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
|
||||
)
|
||||
|
||||
# Print scores per block
|
||||
print("\nCPU block scores (head 0):")
|
||||
for i in range(num_k_blocks):
|
||||
print(f" Block {i}: {scores[0, 0, i].item():.4f}")
|
||||
|
||||
# Select blocks with different thresholds
|
||||
for thresh in [0.9, 0.95, 0.99]:
|
||||
selected, density = select_blocks_by_score(scores, threshold=thresh)
|
||||
print(f"\nThreshold {thresh}: selected {len(selected)}/{num_k_blocks} blocks ({density:.1%})")
|
||||
print(f" Selected: {selected}")
|
||||
|
||||
print("\nTest 2: PASSED (visual inspection)")
|
||||
return True
|
||||
|
||||
|
||||
def test_performance():
|
||||
"""
|
||||
Benchmark hierarchical vs direct estimation performance.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 3: Performance Benchmark")
|
||||
print("=" * 60)
|
||||
|
||||
import time
|
||||
|
||||
NUM_WARMUP = 3
|
||||
NUM_RUNS = 10
|
||||
|
||||
# Larger test case
|
||||
q_len = CPU_BLOCK_SIZE
|
||||
num_k_blocks = 16 # 64K context
|
||||
|
||||
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
||||
K_blocks = [
|
||||
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
||||
for _ in range(num_k_blocks)
|
||||
]
|
||||
|
||||
# Compute attention scores (shared)
|
||||
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
|
||||
print(f"attn_scores shape: {attn_scores.shape}")
|
||||
print(f"Context: {num_k_blocks * CPU_BLOCK_SIZE // 1024}K tokens")
|
||||
|
||||
# Warmup and benchmark hierarchical
|
||||
for _ in range(NUM_WARMUP):
|
||||
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(NUM_RUNS):
|
||||
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
|
||||
torch.cuda.synchronize()
|
||||
hier_time = (time.perf_counter() - start) / NUM_RUNS * 1000
|
||||
|
||||
# Warmup and benchmark direct
|
||||
for _ in range(NUM_WARMUP):
|
||||
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(NUM_RUNS):
|
||||
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
|
||||
torch.cuda.synchronize()
|
||||
direct_time = (time.perf_counter() - start) / NUM_RUNS * 1000
|
||||
|
||||
speedup = direct_time / hier_time
|
||||
|
||||
print(f"\nResults:")
|
||||
print(f" Hierarchical (bs=1024): {hier_time:.2f} ms")
|
||||
print(f" Direct (bs=4096): {direct_time:.2f} ms")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
passed = speedup > 5.0 # Expect at least 5x speedup
|
||||
print(f"\nTest 3: {'PASSED' if passed else 'FAILED'} (speedup > 5x expected)")
|
||||
return passed
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Hierarchical Block Sum Estimation Test")
|
||||
print("=" * 60)
|
||||
print(f"\nConfiguration:")
|
||||
print(f" NUM_HEADS: {NUM_HEADS}")
|
||||
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
|
||||
print(f" HEAD_DIM: {HEAD_DIM}")
|
||||
print(f" STRIDE: {STRIDE}")
|
||||
print(f" CPU_BLOCK_SIZE: {CPU_BLOCK_SIZE}")
|
||||
print(f" ESTIMATE_BLOCK_SIZE: {ESTIMATE_BLOCK_SIZE}")
|
||||
print(f" THRESHOLD: {THRESHOLD}")
|
||||
print()
|
||||
|
||||
results = []
|
||||
|
||||
results.append(("Equivalence", test_equivalence()))
|
||||
results.append(("Selection", test_selection()))
|
||||
results.append(("Performance", test_performance()))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for name, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
print(f" {name}: {status}")
|
||||
|
||||
all_passed = all(p for _, p in results)
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("test_hierarchical_estimate: ALL PASSED")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("test_hierarchical_estimate: SOME FAILED")
|
||||
sys.exit(1)
|
||||
@@ -1,254 +0,0 @@
|
||||
"""
|
||||
Needle-in-a-haystack test for LLM.
|
||||
|
||||
Tests: Long context retrieval capability with configurable sequence length.
|
||||
|
||||
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
|
||||
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.config import SparsePolicyType
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
def run_needle_test(
|
||||
model_path: str,
|
||||
max_model_len: int,
|
||||
input_len: int,
|
||||
num_gpu_blocks: int = 4,
|
||||
block_size: int = 1024,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
max_new_tokens: int = 32,
|
||||
enable_cpu_offload: bool = False,
|
||||
enable_quest: bool = False,
|
||||
enable_xattn_bsa: bool = False,
|
||||
sparse_topk: int = 8,
|
||||
sparse_threshold: int = 4,
|
||||
sparse_samples: int = 128,
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Run a needle-in-haystack test.
|
||||
|
||||
Args:
|
||||
model_path: Path to model
|
||||
max_model_len: Maximum model context length
|
||||
input_len: Target input sequence length
|
||||
num_gpu_blocks: Number of GPU blocks for offload
|
||||
block_size: KV cache block size
|
||||
needle_position: Where to place needle (0.0-1.0)
|
||||
needle_value: The secret value to find
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
enable_cpu_offload: Enable CPU offload mode
|
||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||
enable_xattn_bsa: Enable XAttention BSA sparse attention (prefill-only)
|
||||
sparse_topk: Top-K blocks for Quest
|
||||
sparse_threshold: Threshold for sparse selection (Quest/XAttention BSA)
|
||||
sparse_samples: Samples per chunk for XAttention BSA estimation
|
||||
verbose: Print detailed output
|
||||
|
||||
Returns:
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
# Determine sparse policy
|
||||
if enable_xattn_bsa:
|
||||
sparse_policy = SparsePolicyType.XATTN_BSA
|
||||
elif enable_quest:
|
||||
sparse_policy = SparsePolicyType.QUEST
|
||||
else:
|
||||
sparse_policy = SparsePolicyType.FULL
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Needle-in-Haystack Test")
|
||||
print(f"{'='*60}")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Max model len: {max_model_len}")
|
||||
print(f"Input length: {input_len}")
|
||||
print(f"Block size: {block_size}")
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
if enable_cpu_offload:
|
||||
print(f"Sparse policy: {sparse_policy.name}")
|
||||
if sparse_policy == SparsePolicyType.QUEST:
|
||||
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
|
||||
elif sparse_policy == SparsePolicyType.XATTN_BSA:
|
||||
print(f" XAttention BSA: threshold={sparse_threshold}, samples={sparse_samples}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 1. Initialize LLM
|
||||
llm_kwargs = {
|
||||
"enforce_eager": True,
|
||||
"max_model_len": max_model_len,
|
||||
"max_num_batched_tokens": max_model_len,
|
||||
"enable_cpu_offload": enable_cpu_offload,
|
||||
"kvcache_block_size": block_size,
|
||||
}
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
llm_kwargs["sparse_policy"] = sparse_policy
|
||||
if sparse_policy == SparsePolicyType.QUEST:
|
||||
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||
elif sparse_policy == SparsePolicyType.XATTN_BSA:
|
||||
llm_kwargs["sparse_threshold"] = float(sparse_threshold) / 10.0 # Convert to 0.0-1.0 range
|
||||
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
|
||||
# 2. Generate needle prompt
|
||||
prompt, expected = generate_needle_prompt(
|
||||
tokenizer=llm.tokenizer,
|
||||
target_length=input_len,
|
||||
needle_position=needle_position,
|
||||
needle_value=needle_value,
|
||||
)
|
||||
|
||||
# 3. Generate output
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.6, # Moderate temperature
|
||||
max_tokens=max_new_tokens,
|
||||
)
|
||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
|
||||
|
||||
# 4. Check result
|
||||
output_text = outputs[0]["text"]
|
||||
output_token_ids = outputs[0]["token_ids"]
|
||||
passed = check_needle_answer(output_text, expected)
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Result")
|
||||
print(f"{'='*60}")
|
||||
print(f"Expected: {expected}")
|
||||
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
|
||||
print(f"Output: {output_text[:200]}...")
|
||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI Entry Point
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
|
||||
parser.add_argument(
|
||||
"--model", "-m",
|
||||
type=str,
|
||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
||||
help="Path to model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=128 * 1024,
|
||||
help="Maximum model context length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=8 * 1024,
|
||||
help="Target input sequence length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-gpu-blocks",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of GPU blocks for CPU offload"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="KV cache block size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-position",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-value",
|
||||
type=str,
|
||||
default="7492",
|
||||
help="The secret value to hide"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Maximum tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-offload",
|
||||
action="store_true",
|
||||
help="Enable CPU offload (has known bug for long sequences)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-quest",
|
||||
action="store_true",
|
||||
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-xattn-bsa",
|
||||
action="store_true",
|
||||
help="Enable XAttention BSA sparse attention (prefill-only)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-topk",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Top-K blocks for Quest sparse attention"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-threshold",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Apply sparse only when blocks > threshold (Quest) or attention threshold 0-9 (XAttention BSA)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-samples",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Samples per chunk for XAttention BSA estimation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
passed = run_needle_test(
|
||||
model_path=args.model,
|
||||
max_model_len=args.max_model_len,
|
||||
input_len=args.input_len,
|
||||
num_gpu_blocks=args.num_gpu_blocks,
|
||||
block_size=args.block_size,
|
||||
needle_position=args.needle_position,
|
||||
needle_value=args.needle_value,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
enable_cpu_offload=args.enable_offload,
|
||||
enable_quest=args.enable_quest,
|
||||
enable_xattn_bsa=args.enable_xattn_bsa,
|
||||
sparse_topk=args.sparse_topk,
|
||||
sparse_threshold=args.sparse_threshold,
|
||||
sparse_samples=args.sparse_samples,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if passed:
|
||||
print("test_needle: PASSED")
|
||||
else:
|
||||
print("test_needle: FAILED")
|
||||
exit(1)
|
||||
@@ -1,176 +0,0 @@
|
||||
"""
|
||||
Needle-in-a-haystack reference test using pure torch + transformers.
|
||||
|
||||
This is a reference implementation for comparison with nanovllm.
|
||||
Uses standard HuggingFace inference (no custom KV cache, no offload).
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
def run_needle_test(
|
||||
model_path: str,
|
||||
input_len: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
max_new_tokens: int = 32,
|
||||
dtype: str = "auto",
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Run a needle-in-haystack test using standard transformers inference.
|
||||
|
||||
Args:
|
||||
model_path: Path to model
|
||||
input_len: Target input sequence length
|
||||
needle_position: Where to place needle (0.0-1.0)
|
||||
needle_value: The secret value to find
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
dtype: Model dtype ("auto", "float16", "bfloat16")
|
||||
verbose: Print detailed output
|
||||
|
||||
Returns:
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Needle-in-Haystack Reference Test (torch + transformers)")
|
||||
print(f"{'='*60}")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Input length: {input_len}")
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"Dtype: {dtype}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 1. Load tokenizer
|
||||
print("[1/4] Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# 2. Generate needle prompt
|
||||
print("[2/4] Generating needle prompt...")
|
||||
prompt, expected = generate_needle_prompt(
|
||||
tokenizer=tokenizer,
|
||||
target_length=input_len,
|
||||
needle_position=needle_position,
|
||||
needle_value=needle_value,
|
||||
)
|
||||
|
||||
# 3. Load model
|
||||
print("[3/4] Loading model...")
|
||||
torch_dtype = {
|
||||
"auto": torch.float16, # default to float16 for custom model
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}.get(dtype, torch.float16)
|
||||
|
||||
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
|
||||
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.eval()
|
||||
|
||||
# 4. Generate output
|
||||
print("[4/4] Running inference...")
|
||||
device = next(model.parameters()).device
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
print(f" Input shape: {input_ids.shape}")
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.6,
|
||||
do_sample=True,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
# Decode only the new tokens
|
||||
new_token_ids = output_ids[0, input_ids.shape[1]:]
|
||||
output_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
|
||||
|
||||
# 5. Check result
|
||||
passed = check_needle_answer(output_text, expected)
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Result")
|
||||
print(f"{'='*60}")
|
||||
print(f"Expected: {expected}")
|
||||
print(f"Output tokens ({len(new_token_ids)}): {new_token_ids[:20].tolist()}")
|
||||
print(f"Output: {output_text[:200]}...")
|
||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI Entry Point
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Needle-in-haystack reference test (torch + transformers)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", "-m",
|
||||
type=str,
|
||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
||||
help="Path to model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=8 * 1024,
|
||||
help="Target input sequence length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-position",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-value",
|
||||
type=str,
|
||||
default="7492",
|
||||
help="The secret value to hide"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Maximum tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "float16", "bfloat16"],
|
||||
help="Model dtype"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
passed = run_needle_test(
|
||||
model_path=args.model,
|
||||
input_len=args.input_len,
|
||||
needle_position=args.needle_position,
|
||||
needle_value=args.needle_value,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
dtype=args.dtype,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if passed:
|
||||
print("test_needle_ref: PASSED")
|
||||
else:
|
||||
print("test_needle_ref: FAILED")
|
||||
exit(1)
|
||||
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
|
||||
|
||||
Demonstrates the key limitation: scores are AVERAGED across heads,
|
||||
so blocks strongly needed by one head but not others may be dropped.
|
||||
|
||||
This is the expected Quest behavior - not a bug.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from nanovllm.kvcache.sparse import (
|
||||
create_sparse_policy,
|
||||
SparsePolicyType,
|
||||
PolicyContext,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Test: Per-Head Score Averaging in GQA
|
||||
# ============================================================
|
||||
|
||||
# Determine device (GPU if available, else CPU)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Running test on device: {device}")
|
||||
|
||||
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
|
||||
# topk=2 to make selection competitive
|
||||
|
||||
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
|
||||
quest.initialize(
|
||||
num_layers=1,
|
||||
num_kv_heads=2,
|
||||
head_dim=4,
|
||||
num_cpu_blocks=6,
|
||||
dtype=torch.float32,
|
||||
device=device, # Metadata stored on GPU
|
||||
)
|
||||
|
||||
metadata = quest.metadata
|
||||
|
||||
def set_key(block_id, head_id, values):
|
||||
"""Set both key_min and key_max to same values for deterministic scoring."""
|
||||
# Values need to be on the same device as metadata
|
||||
tensor = torch.tensor(values, device=device)
|
||||
metadata.key_min[block_id, 0, head_id, :] = tensor
|
||||
metadata.key_max[block_id, 0, head_id, :] = tensor
|
||||
|
||||
# ============================================================
|
||||
# Design: Different heads want different blocks
|
||||
# ============================================================
|
||||
#
|
||||
# Query = [1,1,1,1] for all heads, so score = sum(key values)
|
||||
#
|
||||
# Block | Head 0 | Head 1 | Average | Result
|
||||
# ------|--------|--------|---------|--------
|
||||
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
|
||||
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
|
||||
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
|
||||
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
|
||||
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
|
||||
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
|
||||
|
||||
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
|
||||
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
|
||||
|
||||
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
|
||||
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
|
||||
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||
|
||||
# Block 2: Both heads want equally (highest average)
|
||||
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||
|
||||
# Block 3: Both heads want moderately
|
||||
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
|
||||
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
|
||||
|
||||
# Block 4: Head 0 strongly wants, Head 1 neutral
|
||||
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
|
||||
|
||||
# Block 5: Head 1 strongly wants, Head 0 neutral
|
||||
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
|
||||
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||
|
||||
# ============================================================
|
||||
# Run selection
|
||||
# ============================================================
|
||||
|
||||
# Query on same device as metadata
|
||||
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
|
||||
|
||||
ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=0,
|
||||
query=query,
|
||||
is_prefill=False,
|
||||
block_size=1024,
|
||||
total_kv_len=6144,
|
||||
)
|
||||
|
||||
available = list(range(6))
|
||||
selected = quest.select_blocks(available, ctx)
|
||||
|
||||
# ============================================================
|
||||
# Verify: Averaging behavior
|
||||
# ============================================================
|
||||
|
||||
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
|
||||
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
|
||||
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
|
||||
|
||||
# Key insight: blocks 0 and 1 have score +4 for ONE head,
|
||||
# but they cancel out due to averaging with the other head's -4
|
||||
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
|
||||
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
|
||||
|
||||
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
|
||||
# But +2 < +3 (block 3), so they don't make the top-2
|
||||
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
|
||||
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
|
||||
|
||||
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
|
||||
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
|
||||
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
|
||||
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
|
||||
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
|
||||
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
|
||||
|
||||
# Verify metadata is on correct device
|
||||
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
|
||||
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
|
||||
print(f"✓ Metadata stored on {device.type.upper()}")
|
||||
|
||||
print("\ntest_quest_policy: PASSED")
|
||||
@@ -1,199 +0,0 @@
|
||||
"""
|
||||
Sequential inference test for LLM.
|
||||
|
||||
Tests: After completing one prompt, the system can correctly handle
|
||||
a second prompt with a clean state (first prompt's KV cache deallocated).
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
def run_sequential_test(
|
||||
model_path: str,
|
||||
max_model_len: int,
|
||||
input_len: int,
|
||||
num_gpu_blocks: int = 4,
|
||||
block_size: int = 1024,
|
||||
enable_cpu_offload: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Run sequential inference test with two different prompts.
|
||||
|
||||
Each prompt has a different needle value. Both must be retrieved correctly.
|
||||
"""
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Sequential Inference Test")
|
||||
print(f"{'='*60}")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Max model len: {max_model_len}")
|
||||
print(f"Input length: {input_len}")
|
||||
print(f"Block size: {block_size}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Initialize LLM once
|
||||
llm_kwargs = {
|
||||
"enforce_eager": True,
|
||||
"max_model_len": max_model_len,
|
||||
"max_num_batched_tokens": max_model_len,
|
||||
"enable_cpu_offload": enable_cpu_offload,
|
||||
"kvcache_block_size": block_size,
|
||||
}
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.6,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Test 1: First prompt with needle value "1234"
|
||||
# ============================================================
|
||||
needle_value_1 = "1234"
|
||||
if verbose:
|
||||
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
|
||||
|
||||
prompt_1, expected_1 = generate_needle_prompt(
|
||||
tokenizer=llm.tokenizer,
|
||||
target_length=input_len,
|
||||
needle_position=0.5,
|
||||
needle_value=needle_value_1,
|
||||
)
|
||||
|
||||
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
|
||||
output_text_1 = outputs_1[0]["text"]
|
||||
passed_1 = check_needle_answer(output_text_1, expected_1)
|
||||
|
||||
if verbose:
|
||||
print(f" Expected: {expected_1}")
|
||||
print(f" Output: {output_text_1[:100]}...")
|
||||
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
|
||||
|
||||
# ============================================================
|
||||
# Test 2: Second prompt with needle value "5678"
|
||||
# ============================================================
|
||||
needle_value_2 = "5678"
|
||||
if verbose:
|
||||
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
|
||||
|
||||
prompt_2, expected_2 = generate_needle_prompt(
|
||||
tokenizer=llm.tokenizer,
|
||||
target_length=input_len,
|
||||
needle_position=0.5,
|
||||
needle_value=needle_value_2,
|
||||
)
|
||||
|
||||
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
|
||||
output_text_2 = outputs_2[0]["text"]
|
||||
passed_2 = check_needle_answer(output_text_2, expected_2)
|
||||
|
||||
if verbose:
|
||||
print(f" Expected: {expected_2}")
|
||||
print(f" Output: {output_text_2[:100]}...")
|
||||
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
|
||||
|
||||
# ============================================================
|
||||
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
|
||||
# ============================================================
|
||||
needle_value_3 = "9999"
|
||||
if verbose:
|
||||
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
|
||||
|
||||
prompt_3, expected_3 = generate_needle_prompt(
|
||||
tokenizer=llm.tokenizer,
|
||||
target_length=input_len,
|
||||
needle_position=0.5,
|
||||
needle_value=needle_value_3,
|
||||
)
|
||||
|
||||
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
|
||||
output_text_3 = outputs_3[0]["text"]
|
||||
passed_3 = check_needle_answer(output_text_3, expected_3)
|
||||
|
||||
if verbose:
|
||||
print(f" Expected: {expected_3}")
|
||||
print(f" Output: {output_text_3[:100]}...")
|
||||
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
|
||||
|
||||
# ============================================================
|
||||
# Summary
|
||||
# ============================================================
|
||||
all_passed = passed_1 and passed_2 and passed_3
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Summary")
|
||||
print(f"{'='*60}")
|
||||
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
|
||||
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
|
||||
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
|
||||
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Sequential inference test")
|
||||
parser.add_argument(
|
||||
"--model", "-m",
|
||||
type=str,
|
||||
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
|
||||
help="Path to model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=36 * 1024,
|
||||
help="Maximum model context length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=8 * 1024,
|
||||
help="Target input sequence length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-gpu-blocks",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of GPU blocks for CPU offload"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="KV cache block size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-offload",
|
||||
action="store_true",
|
||||
help="Enable CPU offload"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
passed = run_sequential_test(
|
||||
model_path=args.model,
|
||||
max_model_len=args.max_model_len,
|
||||
input_len=args.input_len,
|
||||
num_gpu_blocks=args.num_gpu_blocks,
|
||||
block_size=args.block_size,
|
||||
enable_cpu_offload=args.enable_offload,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if passed:
|
||||
print("test_sequential: PASSED")
|
||||
else:
|
||||
print("test_sequential: FAILED")
|
||||
exit(1)
|
||||
@@ -1,334 +0,0 @@
|
||||
"""
|
||||
Test XAttention + BSA with RULER benchmark data.
|
||||
|
||||
Tests XAttention sparse attention correctness using RULER NIAH task.
|
||||
|
||||
Attention methods:
|
||||
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
|
||||
- Decode: FlashAttention (always, since q_len=1)
|
||||
|
||||
Usage (in compass conda env with BSA available):
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
|
||||
|
||||
# Test with XAttention + BSA for prefill (default)
|
||||
python tests/test_xattn_bsa.py --prefill-method xattn
|
||||
|
||||
# Test with FlashAttention for prefill (baseline)
|
||||
python tests/test_xattn_bsa.py --prefill-method flash
|
||||
|
||||
# Test specific sample(s)
|
||||
python tests/test_xattn_bsa.py --sample-id 0
|
||||
python tests/test_xattn_bsa.py --sample-ids 0,1,2
|
||||
|
||||
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
|
||||
and new `past_key_values` API).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
|
||||
# ============================================================
|
||||
# XAttention + BSA Functions
|
||||
# ============================================================
|
||||
|
||||
def expand_kv_for_gqa(key_states, value_states, num_heads):
|
||||
"""Expand KV for Grouped Query Attention."""
|
||||
num_kv_heads = key_states.shape[1]
|
||||
if num_heads == num_kv_heads:
|
||||
return key_states, value_states
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
|
||||
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
|
||||
"""Standard FlashAttention."""
|
||||
from flash_attn import flash_attn_func
|
||||
q = query_states.transpose(1, 2)
|
||||
k = key_states.transpose(1, 2)
|
||||
v = value_states.transpose(1, 2)
|
||||
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
|
||||
|
||||
|
||||
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
|
||||
"""XAttention + BSA sparse attention."""
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
k_len = key_states.shape[2]
|
||||
|
||||
_, mask = xattn_estimate(
|
||||
query_states, key_states,
|
||||
chunk_size=16384, block_size=128, threshold=threshold,
|
||||
use_triton=True, causal=True,
|
||||
)
|
||||
|
||||
q_block_num = (q_len + 127) // 128
|
||||
k_block_num = (k_len + 127) // 128
|
||||
|
||||
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
|
||||
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
output = block_sparse_attn_func(
|
||||
q, k, v,
|
||||
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
|
||||
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
|
||||
torch.ones(num_heads, dtype=torch.int32, device=q.device),
|
||||
None,
|
||||
mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
||||
q_len, k_len,
|
||||
p_dropout=0.0, deterministic=True, is_causal=True,
|
||||
)
|
||||
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
DEBUG = False # Set to True to enable debugging
|
||||
|
||||
def create_patched_forward(prefill_method="xattn", threshold=0.9):
|
||||
"""Create patched forward with configurable prefill method.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
|
||||
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
|
||||
|
||||
Note:
|
||||
- Prefill (q_len > 1): Uses specified prefill_method
|
||||
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
|
||||
"""
|
||||
call_count = [0] # Mutable to track calls across layers
|
||||
|
||||
def patched_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
position_embeddings=None,
|
||||
attention_mask=None,
|
||||
past_key_value=None, # Old API (transformers < 4.57)
|
||||
past_key_values=None, # New API (transformers >= 4.57)
|
||||
cache_position=None,
|
||||
**kwargs
|
||||
):
|
||||
# Handle both old and new transformers API
|
||||
kv_cache = past_key_values if past_key_values is not None else past_key_value
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
num_heads = self.config.num_attention_heads
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
|
||||
# Compute Q, K, V projections
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary position embedding
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Handle KV cache
|
||||
if kv_cache is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = kv_cache.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# Expand KV for GQA
|
||||
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
|
||||
|
||||
# Debug output
|
||||
if DEBUG and self.layer_idx == 0:
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 5:
|
||||
phase = "prefill" if q_len > 1 else "decode"
|
||||
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
|
||||
print(f" kv_cache is None: {kv_cache is None}")
|
||||
|
||||
# Choose attention method:
|
||||
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
|
||||
# - Decode (q_len = 1): Always use FlashAttention
|
||||
is_prefill = q_len > 1
|
||||
|
||||
if is_prefill and prefill_method == "xattn":
|
||||
# Prefill with XAttention + BSA (sparse)
|
||||
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
|
||||
else:
|
||||
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
|
||||
# Note: For decode (q_len=1), causal=False since single query attends to all KV
|
||||
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
|
||||
|
||||
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
|
||||
return attn_output, None
|
||||
|
||||
return patched_forward
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data & Evaluation
|
||||
# ============================================================
|
||||
|
||||
def load_samples(filepath, indices=None):
|
||||
"""Load samples from JSONL file."""
|
||||
samples = []
|
||||
with open(filepath) as f:
|
||||
for i, line in enumerate(f):
|
||||
if indices is None or i in indices:
|
||||
sample = json.loads(line)
|
||||
sample["_idx"] = i
|
||||
samples.append(sample)
|
||||
return samples
|
||||
|
||||
|
||||
def string_match_all(output_text, expected_list):
|
||||
"""RULER metric: fraction of expected values found in output."""
|
||||
output_lower = output_text.lower().replace('\n', ' ')
|
||||
if not expected_list:
|
||||
return 1.0
|
||||
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test
|
||||
# ============================================================
|
||||
|
||||
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
|
||||
"""Test attention methods using RULER data.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
|
||||
"""
|
||||
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
|
||||
|
||||
print("=" * 60)
|
||||
print("RULER NIAH Attention Test")
|
||||
print("=" * 60)
|
||||
print(f"Data: {data_file}")
|
||||
print(f"Samples: {sample_ids}")
|
||||
print(f"Prefill method: {prefill_desc}")
|
||||
print(f"Decode method: FlashAttention (always)")
|
||||
if prefill_method == "xattn":
|
||||
print(f"XAttention threshold: {threshold}")
|
||||
|
||||
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
|
||||
if not samples:
|
||||
print("No samples found!")
|
||||
return False
|
||||
print(f"Loaded {len(samples)} samples")
|
||||
|
||||
# Load model
|
||||
print(f"\nLoading model: {model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="cuda",
|
||||
attn_implementation="eager", # Will be patched
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Patch all layers
|
||||
print(f"Patching attention layers...")
|
||||
print(f" - Prefill: {prefill_desc}")
|
||||
print(f" - Decode: FlashAttention")
|
||||
for idx, layer in enumerate(model.model.layers):
|
||||
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
|
||||
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
|
||||
layer.self_attn, type(layer.self_attn)
|
||||
)
|
||||
|
||||
total_score = 0.0
|
||||
results = []
|
||||
|
||||
for sample in samples:
|
||||
idx = sample["_idx"]
|
||||
prompt = sample["input"]
|
||||
expected = sample["outputs"]
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
num_tokens = inputs["input_ids"].shape[1]
|
||||
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
|
||||
print(f"Expected: {expected}")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
inputs["input_ids"],
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
|
||||
score = string_match_all(output_text, expected)
|
||||
total_score += score
|
||||
|
||||
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
|
||||
print(f"Output: '{output_text[:100]}...'")
|
||||
print(f"Result: {status} (score={score:.2f})")
|
||||
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
|
||||
|
||||
avg_score = total_score / len(samples)
|
||||
passed = sum(1 for r in results if r["passed"])
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return avg_score >= 0.5
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
|
||||
)
|
||||
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
|
||||
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
|
||||
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
|
||||
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
|
||||
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
|
||||
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
|
||||
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=50)
|
||||
# Keep old option for backwards compatibility
|
||||
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model.replace("~", "/home/zijie")
|
||||
|
||||
# Handle deprecated --no-xattn option
|
||||
prefill_method = args.prefill_method
|
||||
if args.no_xattn:
|
||||
prefill_method = "flash"
|
||||
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
|
||||
|
||||
if args.sample_id is not None:
|
||||
sample_ids = [args.sample_id]
|
||||
elif args.sample_ids:
|
||||
sample_ids = [int(x) for x in args.sample_ids.split(",")]
|
||||
else:
|
||||
sample_ids = [0]
|
||||
|
||||
# Check BSA availability if using xattn
|
||||
if prefill_method == "xattn":
|
||||
try:
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
print("✓ BSA (Block Sparse Attention) available")
|
||||
except ImportError:
|
||||
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
|
||||
sys.exit(1)
|
||||
|
||||
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
|
||||
print("\ntest_xattn_bsa: PASSED")
|
||||
else:
|
||||
print("\ntest_xattn_bsa: FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,259 +0,0 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
|
||||
|
||||
Uses real QKV data captured from model inference.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096
|
||||
|
||||
# Default QKV data directory (relative to project root)
|
||||
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def load_qkv(path):
|
||||
"""Load saved QKV data."""
|
||||
data = torch.load(path, map_location="cpu", weights_only=False)
|
||||
print(f"Loaded: {path}")
|
||||
print(f" Query shape: {data['query'].shape}")
|
||||
print(f" Key shape: {data['key'].shape}")
|
||||
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
|
||||
return data
|
||||
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=q_start_pos,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
k_end = q_start_pos + q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_start_pos + q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_qkv(qkv_path):
|
||||
"""Test a single QKV file."""
|
||||
data = load_qkv(qkv_path)
|
||||
query = data["query"].cuda().to(torch.bfloat16)
|
||||
key = data["key"].cuda().to(torch.bfloat16)
|
||||
|
||||
seq_len = query.shape[2]
|
||||
print(f"\nTesting with seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
)
|
||||
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
|
||||
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
|
||||
help="Directory containing QKV files")
|
||||
args = parser.parse_args()
|
||||
|
||||
# QKV files to test
|
||||
qkv_files = [
|
||||
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
|
||||
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
|
||||
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
|
||||
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
|
||||
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
|
||||
]
|
||||
|
||||
available_files = [p for p in qkv_files if os.path.exists(p)]
|
||||
|
||||
if not available_files:
|
||||
print(f"No QKV file found in {args.qkv_dir}.")
|
||||
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(available_files)} QKV files to test")
|
||||
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
|
||||
print(f"Using Triton kernels")
|
||||
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for qkv_path in available_files:
|
||||
passed = test_single_qkv(qkv_path)
|
||||
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
|
||||
results.append((seq_len, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("test_xattn_chunked: PASSED")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("test_xattn_chunked: FAILED")
|
||||
sys.exit(1)
|
||||
@@ -1,244 +0,0 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
||||
as standard estimation. This ensures the chunked version can be used in
|
||||
chunked prefill scenarios without accuracy loss.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_estimate_chunked.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
# Configuration for xattn_estimate_chunked consistency test.
|
||||
# Key requirements for 100% match:
|
||||
# 1. Use matching chunk_size for both standard and chunked versions
|
||||
# 2. Use same random seed for reproducibility
|
||||
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
||||
# floating point precision in cumulative sum calculations.
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096 # External chunking size
|
||||
|
||||
# Test sequence lengths
|
||||
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
||||
# K is [0, q_chunk_end) for causal attention
|
||||
k_end = q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
||||
"""Test a single sequence length."""
|
||||
print(f"\nTesting seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Generate random Q/K
|
||||
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
causal=True,
|
||||
)
|
||||
density_std = mask_std.float().mean().item()
|
||||
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
density_chunked = mask_chunked.float().mean().item()
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("XAttention Chunked vs Standard Test")
|
||||
print("=" * 60)
|
||||
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
||||
print(f"External chunk_size={CHUNK_SIZE}")
|
||||
print()
|
||||
|
||||
# Check CUDA availability
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available!")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
print("✓ xattn_estimate imported")
|
||||
print("✓ xattn_estimate_chunked imported")
|
||||
|
||||
# Run tests
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for seq_len in TEST_SEQ_LENS:
|
||||
passed = test_single_seq_len(seq_len)
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
results.append((seq_len, chunks, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, chunks, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("ALL TESTS PASSED!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("SOME TESTS FAILED!")
|
||||
sys.exit(1)
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
Test: XAttention Triton kernels
|
||||
|
||||
演示 XAttention 的两个核心 Triton kernel:
|
||||
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
|
||||
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
|
||||
|
||||
数据流:
|
||||
Q [batch, heads, q_len, head_dim]
|
||||
K [batch, heads, kv_len, head_dim]
|
||||
↓ flat_group_gemm_fuse_reshape
|
||||
attn_scores [batch, heads, q_len/stride, kv_len/stride]
|
||||
↓ softmax_fuse_block_sum
|
||||
block_sums [batch, heads, q_blocks, k_blocks]
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
|
||||
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
|
||||
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
|
||||
q_len = 512
|
||||
kv_len = 2048
|
||||
head_dim = 128
|
||||
stride = 4
|
||||
block_size = 128 # softmax block size (in reshaped space)
|
||||
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
|
||||
|
||||
# ============================================================
|
||||
# 构造输入: 偶数位置=1, 奇数位置=2
|
||||
# ============================================================
|
||||
|
||||
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
|
||||
for i in range(q_len):
|
||||
if i % 2 == 0:
|
||||
Q[0, 0, i, :] = 1 * (i // stride + 1)
|
||||
else:
|
||||
Q[0, 0, i, :] = 2 * (i // stride + 1)
|
||||
|
||||
for i in range(kv_len):
|
||||
if i % 2 == 0:
|
||||
K[0, 0, i, :] = 1
|
||||
else:
|
||||
K[0, 0, i, :] = 2
|
||||
|
||||
# ============================================================
|
||||
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
|
||||
# ============================================================
|
||||
|
||||
q_reshaped_len = q_len // stride # 128
|
||||
kv_reshaped_len = kv_len // stride # 512
|
||||
|
||||
# 将 K 沿着长度维度分成多个 chunk
|
||||
k_chunk_size = 512 # 每个 chunk 512 tokens
|
||||
num_k_chunks = kv_len // k_chunk_size # 4 chunks
|
||||
|
||||
attn_scores_list = []
|
||||
for k_chunk_idx in range(num_k_chunks):
|
||||
k_start = k_chunk_idx * k_chunk_size
|
||||
k_end = k_start + k_chunk_size
|
||||
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
|
||||
|
||||
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
|
||||
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
|
||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
is_causal=True
|
||||
)
|
||||
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
attn_scores_list.append(attn_chunk)
|
||||
|
||||
# 拼接所有 K chunks 的结果
|
||||
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
|
||||
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
|
||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
||||
|
||||
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
|
||||
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
|
||||
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
|
||||
|
||||
# 验证: 反对角线求和
|
||||
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
|
||||
# 反对角线有 stride/2 对,再乘以 head_dim
|
||||
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
|
||||
actual_gemm = attn_scores[0, 0, 0, 0].item()
|
||||
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
|
||||
|
||||
# ============================================================
|
||||
# Step 2: softmax_fuse_block_sum
|
||||
# ============================================================
|
||||
|
||||
scale = 1.4426950408889634 # log2(e) for exp2
|
||||
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
block_size,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
real_q_len=q_reshaped_len,
|
||||
scale=scale,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# 验证 shape: [batch, heads, q_blocks, k_blocks]
|
||||
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
|
||||
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
|
||||
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
|
||||
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
|
||||
|
||||
# 验证: 每个 block 的 softmax 结果求和
|
||||
# 所有 attn_scores 相同 → softmax 均匀分布
|
||||
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
|
||||
# 每个 Q block 有 block_size 行
|
||||
# block_sum = block_size * (block_size / kv_reshaped_len)
|
||||
expected_sum = block_size * block_size / kv_reshaped_len
|
||||
actual_sum = block_sums[0, 0, 0, 0].item()
|
||||
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
|
||||
|
||||
print("test_xattn_kernels: PASSED")
|
||||
@@ -1,246 +0,0 @@
|
||||
"""
|
||||
Test: 批量验证 xattn_estimate 与 KV chunking kernels 的一致性
|
||||
|
||||
测试 results/kvcache 下所有保存的 QKV 数据
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_kv_chunking_batch.py
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import math
|
||||
from nanovllm.ops.xattn import (
|
||||
xattn_estimate,
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_compute_partial_stats,
|
||||
softmax_normalize_and_block_sum,
|
||||
merge_softmax_stats,
|
||||
find_blocks_chunked,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
DATA_DIR = "/home/zijie/Code/nano-vllm/results/kvcache"
|
||||
BSA_BLOCK_SIZE = 128
|
||||
CHUNK_SIZE = 16384
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def test_single_file(data_file: str) -> dict:
|
||||
"""测试单个 kvcache 文件"""
|
||||
data = torch.load(data_file, map_location="cpu")
|
||||
Q = data["query"].to(device)
|
||||
K = data["key"].to(device)
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = Q.shape
|
||||
STRIDE = data["stride"]
|
||||
THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"]
|
||||
|
||||
# ========== xattn_estimate API ==========
|
||||
attn_sums_api, mask_api = xattn_estimate(
|
||||
Q, K,
|
||||
block_size=BSA_BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
|
||||
|
||||
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
|
||||
total_api = causal_mask.sum().item() * batch_size * num_heads
|
||||
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_api = selected_api / total_api
|
||||
|
||||
# ========== 三阶段 KV Chunking ==========
|
||||
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
|
||||
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||
|
||||
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
|
||||
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
|
||||
|
||||
reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
||||
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
|
||||
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
|
||||
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
|
||||
if k_num_to_pad > 0:
|
||||
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
|
||||
else:
|
||||
K_padded = K
|
||||
|
||||
if q_num_to_pad > 0:
|
||||
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
|
||||
else:
|
||||
Q_padded = Q
|
||||
|
||||
norm = 1.0
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
||||
|
||||
simple_mask_list = []
|
||||
|
||||
for q_chunk_idx in range(q_chunk_num):
|
||||
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
|
||||
q_end = q_start + reshaped_chunk_size * STRIDE
|
||||
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||
|
||||
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
|
||||
chunk_end = chunk_start + reshaped_chunk_size
|
||||
|
||||
m_chunks = []
|
||||
l_chunks = []
|
||||
attn_weights_chunks = []
|
||||
|
||||
for kv_chunk_idx in range(kv_chunk_num):
|
||||
kv_start = kv_chunk_idx * CHUNK_SIZE
|
||||
kv_end = kv_start + CHUNK_SIZE
|
||||
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q_chunk, K_chunk, STRIDE,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights_chunks.append(attn_weights_kv)
|
||||
|
||||
m_partial, l_partial = softmax_compute_partial_stats(
|
||||
attn_weights_kv,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m_partial)
|
||||
l_chunks.append(l_partial)
|
||||
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
|
||||
attn_sum_per_kv = []
|
||||
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||
attn_weights_kv,
|
||||
m_global,
|
||||
l_global,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
attn_sum_per_kv.append(attn_sum_kv)
|
||||
|
||||
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum_concat,
|
||||
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
|
||||
threshold=THRESHOLD,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True,
|
||||
)
|
||||
simple_mask_list.append(simple_mask)
|
||||
|
||||
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
|
||||
|
||||
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
|
||||
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_kv = selected_kv / total_api
|
||||
|
||||
mask_total = mask_api_valid.numel()
|
||||
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
|
||||
mask_diff_pct = 100 * mask_diff / mask_total
|
||||
|
||||
return {
|
||||
"seq_len": seq_len,
|
||||
"stride": STRIDE,
|
||||
"threshold": THRESHOLD,
|
||||
"kv_chunks": kv_chunk_num,
|
||||
"density_api": density_api,
|
||||
"density_kv": density_kv,
|
||||
"density_diff": abs(density_api - density_kv),
|
||||
"mask_diff_pct": mask_diff_pct,
|
||||
"passed": abs(density_api - density_kv) < 1e-6 and mask_diff_pct < 0.01,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
files = sorted(glob.glob(os.path.join(DATA_DIR, "qkv_*.pt")))
|
||||
|
||||
print("=" * 80)
|
||||
print("XAttention KV Chunking Alignment Test")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
results = []
|
||||
for f in files:
|
||||
fname = os.path.basename(f)
|
||||
print(f"Testing {fname}...", end=" ", flush=True)
|
||||
try:
|
||||
r = test_single_file(f)
|
||||
results.append(r)
|
||||
status = "✓ PASS" if r["passed"] else "✗ FAIL"
|
||||
print(f"{status} (seq_len={r['seq_len']}, kv_chunks={r['kv_chunks']})")
|
||||
except Exception as e:
|
||||
print(f"✗ ERROR: {e}")
|
||||
results.append({"file": fname, "error": str(e)})
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("Results Summary")
|
||||
print("=" * 80)
|
||||
print()
|
||||
print("| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |")
|
||||
print("|---------|--------|-----------|-----------|-------------|------------|------|-----------|--------|")
|
||||
|
||||
all_passed = True
|
||||
for r in results:
|
||||
if "error" in r:
|
||||
print(f"| ERROR | - | - | - | - | - | - | - | {r['error'][:20]} |")
|
||||
all_passed = False
|
||||
else:
|
||||
status = "PASS" if r["passed"] else "FAIL"
|
||||
if not r["passed"]:
|
||||
all_passed = False
|
||||
print(f"| {r['seq_len']:>7} | {r['stride']:>6} | {r['threshold']:.2f} | {r['kv_chunks']:>9} | "
|
||||
f"{r['density_api']:.6f} | {r['density_kv']:.6f} | {r['density_diff']:.6f} | "
|
||||
f"{r['mask_diff_pct']:.4f}% | {status} |")
|
||||
|
||||
print()
|
||||
if all_passed:
|
||||
print("test_xattn_kv_chunking_batch: ALL PASSED")
|
||||
else:
|
||||
print("test_xattn_kv_chunking_batch: SOME FAILED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user