[WIP] Before refactor the compute)_chunked_prefill.

This commit is contained in:
Zijie Tian
2026-01-23 03:36:12 +08:00
parent edc006463b
commit ca32ea6f93
7 changed files with 914 additions and 114 deletions

334
tests/test_xattn_bsa.py Normal file
View File

@@ -0,0 +1,334 @@
"""
Test XAttention + BSA with RULER benchmark data.
Tests XAttention sparse attention correctness using RULER NIAH task.
Attention methods:
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
- Decode: FlashAttention (always, since q_len=1)
Usage (in compass conda env with BSA available):
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
# Test with XAttention + BSA for prefill (default)
python tests/test_xattn_bsa.py --prefill-method xattn
# Test with FlashAttention for prefill (baseline)
python tests/test_xattn_bsa.py --prefill-method flash
# Test specific sample(s)
python tests/test_xattn_bsa.py --sample-id 0
python tests/test_xattn_bsa.py --sample-ids 0,1,2
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
and new `past_key_values` API).
"""
import argparse
import json
import sys
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from nanovllm.ops.xattn import xattn_estimate
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# ============================================================
# XAttention + BSA Functions
# ============================================================
def expand_kv_for_gqa(key_states, value_states, num_heads):
"""Expand KV for Grouped Query Attention."""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
"""Standard FlashAttention."""
from flash_attn import flash_attn_func
q = query_states.transpose(1, 2)
k = key_states.transpose(1, 2)
v = value_states.transpose(1, 2)
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
"""XAttention + BSA sparse attention."""
from block_sparse_attn import block_sparse_attn_func
batch_size, num_heads, q_len, head_dim = query_states.shape
k_len = key_states.shape[2]
_, mask = xattn_estimate(
query_states, key_states,
chunk_size=16384, block_size=128, threshold=threshold,
use_triton=True, causal=True,
)
q_block_num = (q_len + 127) // 128
k_block_num = (k_len + 127) // 128
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
__import__('pdb').set_trace()
output = block_sparse_attn_func(
q, k, v,
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
torch.ones(num_heads, dtype=torch.int32, device=q.device),
None,
mask[:, :, :q_block_num, :k_block_num].contiguous(),
q_len, k_len,
p_dropout=0.0, deterministic=True, is_causal=True,
)
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
DEBUG = False # Set to True to enable debugging
def create_patched_forward(prefill_method="xattn", threshold=0.9):
"""Create patched forward with configurable prefill method.
Args:
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
Note:
- Prefill (q_len > 1): Uses specified prefill_method
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
"""
call_count = [0] # Mutable to track calls across layers
def patched_forward(
self,
hidden_states,
position_embeddings=None,
attention_mask=None,
past_key_value=None, # Old API (transformers < 4.57)
past_key_values=None, # New API (transformers >= 4.57)
cache_position=None,
**kwargs
):
# Handle both old and new transformers API
kv_cache = past_key_values if past_key_values is not None else past_key_value
bsz, q_len, _ = hidden_states.size()
num_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
# Compute Q, K, V projections
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
# Apply rotary position embedding
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Handle KV cache
if kv_cache is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = kv_cache.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# Expand KV for GQA
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
# Debug output
if DEBUG and self.layer_idx == 0:
call_count[0] += 1
if call_count[0] <= 5:
phase = "prefill" if q_len > 1 else "decode"
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
print(f" kv_cache is None: {kv_cache is None}")
# Choose attention method:
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
# - Decode (q_len = 1): Always use FlashAttention
is_prefill = q_len > 1
if is_prefill and prefill_method == "xattn":
# Prefill with XAttention + BSA (sparse)
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
else:
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
# Note: For decode (q_len=1), causal=False since single query attends to all KV
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
return attn_output, None
return patched_forward
# ============================================================
# Data & Evaluation
# ============================================================
def load_samples(filepath, indices=None):
"""Load samples from JSONL file."""
samples = []
with open(filepath) as f:
for i, line in enumerate(f):
if indices is None or i in indices:
sample = json.loads(line)
sample["_idx"] = i
samples.append(sample)
return samples
def string_match_all(output_text, expected_list):
"""RULER metric: fraction of expected values found in output."""
output_lower = output_text.lower().replace('\n', ' ')
if not expected_list:
return 1.0
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
# ============================================================
# Test
# ============================================================
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
"""Test attention methods using RULER data.
Args:
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
"""
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
print("=" * 60)
print("RULER NIAH Attention Test")
print("=" * 60)
print(f"Data: {data_file}")
print(f"Samples: {sample_ids}")
print(f"Prefill method: {prefill_desc}")
print(f"Decode method: FlashAttention (always)")
if prefill_method == "xattn":
print(f"XAttention threshold: {threshold}")
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
if not samples:
print("No samples found!")
return False
print(f"Loaded {len(samples)} samples")
# Load model
print(f"\nLoading model: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="cuda",
attn_implementation="eager", # Will be patched
)
model.eval()
# Patch all layers
print(f"Patching attention layers...")
print(f" - Prefill: {prefill_desc}")
print(f" - Decode: FlashAttention")
for idx, layer in enumerate(model.model.layers):
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
layer.self_attn, type(layer.self_attn)
)
total_score = 0.0
results = []
for sample in samples:
idx = sample["_idx"]
prompt = sample["input"]
expected = sample["outputs"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
num_tokens = inputs["input_ids"].shape[1]
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
print(f"Expected: {expected}")
with torch.no_grad():
output = model.generate(
inputs["input_ids"],
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
score = string_match_all(output_text, expected)
total_score += score
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
print(f"Output: '{output_text[:100]}...'")
print(f"Result: {status} (score={score:.2f})")
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
avg_score = total_score / len(samples)
passed = sum(1 for r in results if r["passed"])
print(f"\n{'='*60}")
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
print(f"{'='*60}")
return avg_score >= 0.5
def main():
parser = argparse.ArgumentParser(
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
)
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
parser.add_argument("--max-new-tokens", type=int, default=50)
# Keep old option for backwards compatibility
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
args = parser.parse_args()
model_path = args.model.replace("~", "/home/zijie")
# Handle deprecated --no-xattn option
prefill_method = args.prefill_method
if args.no_xattn:
prefill_method = "flash"
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
if args.sample_id is not None:
sample_ids = [args.sample_id]
elif args.sample_ids:
sample_ids = [int(x) for x in args.sample_ids.split(",")]
else:
sample_ids = [0]
# Check BSA availability if using xattn
if prefill_method == "xattn":
try:
from block_sparse_attn import block_sparse_attn_func
print("✓ BSA (Block Sparse Attention) available")
except ImportError:
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
sys.exit(1)
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
print("\ntest_xattn_bsa: PASSED")
else:
print("\ntest_xattn_bsa: FAILED")
sys.exit(1)
if __name__ == "__main__":
main()

259
tests/test_xattn_chunked.py Normal file
View File

@@ -0,0 +1,259 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
Uses real QKV data captured from model inference.
"""
import sys
import os
import torch
import warnings
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096
# Default QKV data directory (relative to project root)
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
# ============================================================
# Utility Functions
# ============================================================
def load_qkv(path):
"""Load saved QKV data."""
data = torch.load(path, map_location="cpu", weights_only=False)
print(f"Loaded: {path}")
print(f" Query shape: {data['query'].shape}")
print(f" Key shape: {data['key'].shape}")
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
return data
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return xattn_estimate_chunked(
query, key,
q_start_pos=q_start_pos,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
k_end = q_start_pos + q_chunk_end
k_chunk = key[:, :, :k_end, :]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_start_pos + q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_qkv(qkv_path):
"""Test a single QKV file."""
data = load_qkv(qkv_path)
query = data["query"].cuda().to(torch.bfloat16)
key = data["key"].cuda().to(torch.bfloat16)
seq_len = query.shape[2]
print(f"\nTesting with seq_len={seq_len}")
print("=" * 60)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
)
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
q_start_pos=0,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
help="Directory containing QKV files")
args = parser.parse_args()
# QKV files to test
qkv_files = [
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
]
available_files = [p for p in qkv_files if os.path.exists(p)]
if not available_files:
print(f"No QKV file found in {args.qkv_dir}.")
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
sys.exit(1)
print(f"Found {len(available_files)} QKV files to test")
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
print(f"Using Triton kernels")
all_passed = True
results = []
for qkv_path in available_files:
passed = test_single_qkv(qkv_path)
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
results.append((seq_len, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, passed in results:
status = "PASSED" if passed else "FAILED"
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
print("=" * 60)
if all_passed:
print("test_xattn_chunked: PASSED")
sys.exit(0)
else:
print("test_xattn_chunked: FAILED")
sys.exit(1)

View File

@@ -6,9 +6,10 @@ Test: XAttention Triton kernels
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
数据流:
Q, K [batch, heads, seq_len, head_dim]
Q [batch, heads, q_len, head_dim]
K [batch, heads, kv_len, head_dim]
↓ flat_group_gemm_fuse_reshape
attn_scores [batch, heads, seq_len/stride, seq_len/stride]
attn_scores [batch, heads, q_len/stride, kv_len/stride]
↓ softmax_fuse_block_sum
block_sums [batch, heads, q_blocks, k_blocks]
"""
@@ -21,7 +22,11 @@ from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_
# 参数配置
# ============================================================
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M = 4 * 128 = 512
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
q_len = 512
kv_len = 2048
head_dim = 128
stride = 4
block_size = 128 # softmax block size (in reshaped space)
@@ -31,26 +36,56 @@ segment_size = 128 # Triton kernel 要求 segment_size >= block_size
# 构造输入: 偶数位置=1, 奇数位置=2
# ============================================================
Q = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
K = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(seq_len):
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(q_len):
if i % 2 == 0:
Q[0, 0, i, :] = 1
K[0, 0, i, :] = 1
else:
Q[0, 0, i, :] = 2
for i in range(kv_len):
if i % 2 == 0:
K[0, 0, i, :] = 1
else:
K[0, 0, i, :] = 2
# ============================================================
# Step 1: flat_group_gemm_fuse_reshape
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
# ============================================================
attn_scores = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=seq_len // stride,
is_causal=False
)
q_reshaped_len = q_len // stride # 128
kv_reshaped_len = kv_len // stride # 512
# 将 K 沿着长度维度分成多个 chunk
k_chunk_size = 512 # 每个 chunk 512 tokens
num_k_chunks = kv_len // k_chunk_size # 4 chunks
attn_scores_list = []
for k_chunk_idx in range(num_k_chunks):
k_start = k_chunk_idx * k_chunk_size
k_end = k_start + k_chunk_size
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=False
)
attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks 的结果
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
attn_scores = torch.cat(attn_scores_list, dim=-1)
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
# 验证: 反对角线求和
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
@@ -63,7 +98,6 @@ assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expect
# Step 2: softmax_fuse_block_sum
# ============================================================
reshaped_len = seq_len // stride
scale = 1.4426950408889634 # log2(e) for exp2
block_sums = softmax_fuse_block_sum(
@@ -71,15 +105,24 @@ block_sums = softmax_fuse_block_sum(
block_size,
segment_size,
chunk_start=0,
chunk_end=reshaped_len,
real_q_len=reshaped_len,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False
)
# 验证 shape: [batch, heads, q_blocks, k_blocks]
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
# 验证: 每个 block 的 softmax 结果求和
# 所有 attn_scores 相同 → softmax 均匀分布 → block_sum = block_size^2 / reshaped_len
expected_sum = block_size * block_size / reshaped_len
# 所有 attn_scores 相同 → softmax 均匀分布
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
# 每个 Q block 有 block_size 行
# block_sum = block_size * (block_size / kv_reshaped_len)
expected_sum = block_size * block_size / kv_reshaped_len
actual_sum = block_sums[0, 0, 0, 0].item()
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"