⚡️ perf: optimize XAttention estimate with hierarchical block sum
Replace slow softmax_fuse_block_sum (block_size=4096) with optimized hierarchical approach (estimate_block_size=1024): - Add estimate_block_size parameter to XAttentionBSAPolicy (default 1024) - Rewrite select_blocks to use hierarchical aggregation: 1. Fine-grained softmax with small block size (15x faster kernel) 2. Aggregate to CPU block level via reshape + sum 3. Score + threshold selection (replaces mask + voting) Performance improvement (CPU Offload mode): - softmax_fuse_block_sum: 48% → 1% of total time (44x faster) - 128K: XAttention now +2.4% faster than Full (was -59%) - 64K: -3.8% (was -21%) - 32K: -6.0% (was -14%) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -95,6 +95,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
block_size: int = 128,
|
||||
samples_per_chunk: int = 128,
|
||||
use_triton: bool = True,
|
||||
estimate_block_size: int = 1024, # Optimized block size for softmax_fuse_block_sum
|
||||
):
|
||||
"""
|
||||
Initialize XAttention BSA policy.
|
||||
@@ -107,11 +108,15 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
block_size: BSA block size (must be 128)
|
||||
samples_per_chunk: Samples per chunk for estimation (unused)
|
||||
use_triton: Whether to use Triton kernels
|
||||
estimate_block_size: Block size for softmax_fuse_block_sum in select_blocks.
|
||||
Default 1024 is optimal (15x faster than 4096).
|
||||
Must be a factor of cpu_block_size (e.g., 4096/1024=4).
|
||||
"""
|
||||
self.threshold = threshold
|
||||
self.stride = stride
|
||||
self.chunk_size = chunk_size
|
||||
self.use_triton = use_triton
|
||||
self.estimate_block_size = estimate_block_size
|
||||
self._num_heads = None # Set during first forward
|
||||
|
||||
# Sparse metadata: stores attention scores per layer
|
||||
@@ -508,17 +513,28 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
# Free intermediate list immediately
|
||||
del attn_scores_list
|
||||
|
||||
# Step 2: Apply softmax_fuse_block_sum to get block-level attention
|
||||
# block_size = reshaped_block_size so each CPU block maps to exactly 1 output block
|
||||
# This ensures block_sums.shape[-1] == num_available_blocks (1:1 mapping)
|
||||
# Step 2: Apply softmax_fuse_block_sum with hierarchical aggregation
|
||||
# Use smaller estimate_block_size (1024) for 15x faster softmax kernel,
|
||||
# then aggregate to CPU block level (4096).
|
||||
#
|
||||
# Hierarchical approach:
|
||||
# 1. softmax_fuse_block_sum with estimate_block_size (1024) -> fine-grained scores
|
||||
# 2. Aggregate: reshape + sum -> CPU block level scores
|
||||
# 3. Select blocks based on score + threshold (NOT mask + voting)
|
||||
cpu_block_size = block_size # e.g., 4096
|
||||
estimate_bs = self.estimate_block_size # e.g., 1024 (15x faster)
|
||||
ratio = cpu_block_size // estimate_bs # e.g., 4
|
||||
|
||||
# Use estimate_block_size for softmax kernel (optimized)
|
||||
reshaped_est_bs = estimate_bs // self.stride # e.g., 1024/8 = 128
|
||||
norm = 1.0 # Normalization factor
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
|
||||
segment_size = min(4096, reshaped_block_size)
|
||||
segment_size = min(4096, reshaped_est_bs)
|
||||
|
||||
with nvtx.range("xattn_estimate_softmax"):
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
block_sums_fine = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
|
||||
reshaped_est_bs, # Use optimized estimate block size (128 vs 512)
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
@@ -526,54 +542,55 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
scale=scale,
|
||||
is_causal=False, # Historical blocks are all before current chunk
|
||||
)
|
||||
# block_sums shape: [batch, heads, q_blocks, k_blocks]
|
||||
# where k_blocks == len(available_blocks) (1:1 mapping with CPU blocks)
|
||||
# block_sums_fine shape: [batch, heads, q_est_blocks, k_est_blocks]
|
||||
# where k_est_blocks = len(available_blocks) * ratio
|
||||
|
||||
# Step 3: Use find_blocks_chunked to get selection mask
|
||||
# current_index = 0 since we're looking at historical blocks only
|
||||
with nvtx.range("xattn_estimate_find_blocks"):
|
||||
mask = find_blocks_chunked(
|
||||
block_sums,
|
||||
current_index=0,
|
||||
threshold=self.threshold,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=False, # Historical blocks don't need causal mask
|
||||
)
|
||||
# mask shape: [batch, num_heads, q_blocks, k_blocks] - boolean
|
||||
# where k_blocks == len(available_blocks)
|
||||
# Step 3: Aggregate to CPU block level (hierarchical sum)
|
||||
# This is mathematically equivalent to direct computation but much faster
|
||||
batch_size_bs, num_heads_bs, q_est_blocks, k_est_blocks = block_sums_fine.shape
|
||||
num_cpu_blocks = len(available_blocks)
|
||||
|
||||
# GQA-aware aggregation:
|
||||
# For GQA, multiple Q heads share one KV head. We need to select a block
|
||||
# if ANY Q head within the same KV head group selects it.
|
||||
# mask: [batch, num_heads, q_blocks, k_blocks]
|
||||
# Reshape to [batch, num_kv_heads, num_groups, q_blocks, k_blocks]
|
||||
batch_size, num_q_heads, q_blocks, k_blocks = mask.shape
|
||||
# num_kv_heads was set in the K loading loop above (line ~199)
|
||||
# num_groups = num_heads // num_kv_heads (for GQA)
|
||||
num_groups = num_heads // num_kv_heads if num_heads != num_kv_heads else 1
|
||||
with nvtx.range("xattn_estimate_aggregate"):
|
||||
# Reshape: [batch, heads, q_est, k_est] -> [batch, heads, q_est, num_cpu, ratio]
|
||||
block_sums_coarse = block_sums_fine.view(
|
||||
batch_size_bs, num_heads_bs, q_est_blocks, num_cpu_blocks, ratio
|
||||
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
|
||||
|
||||
if num_groups > 1:
|
||||
# Reshape: [batch, num_kv_heads, num_groups, q_blocks, k_blocks]
|
||||
mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks)
|
||||
# Aggregate within each KV head group: any Q head selects -> KV head selects
|
||||
mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks]
|
||||
else:
|
||||
mask_per_kv_head = mask # [batch, num_heads, q_blocks, k_blocks]
|
||||
# Sum over Q dimension to get total attention from Q chunk to each K block
|
||||
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
|
||||
|
||||
# Aggregate across KV heads and q_blocks using majority voting
|
||||
# Instead of any(), use voting: select if >50% of kv_heads select it
|
||||
# mask_per_kv_head: [batch, num_kv_heads, q_blocks, k_blocks]
|
||||
# Sum across kv_heads and q_blocks to get vote count per k_block
|
||||
vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks]
|
||||
total_votes = num_kv_heads * q_blocks
|
||||
vote_ratio = vote_count / total_votes
|
||||
# Step 4: Select blocks using score + threshold (replaces mask + majority voting)
|
||||
# This is simpler and more direct than the original mask-based approach
|
||||
with nvtx.range("xattn_estimate_select"):
|
||||
# Average scores across heads (GQA-aware: all heads contribute equally)
|
||||
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
|
||||
|
||||
# Select blocks with >50% votes (majority voting)
|
||||
vote_threshold = 0.5
|
||||
block_selected = vote_ratio > vote_threshold
|
||||
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
|
||||
# Normalize to get attention distribution
|
||||
total_score = scores_per_block.sum()
|
||||
if total_score > 0:
|
||||
score_ratio = scores_per_block / total_score
|
||||
else:
|
||||
# Edge case: all zeros, select all blocks
|
||||
selected_block_ids = list(available_blocks)
|
||||
if layer_id == 0 and available_blocks:
|
||||
self._stats_total_available_blocks += len(available_blocks)
|
||||
self._stats_total_selected_blocks += len(selected_block_ids)
|
||||
self._stats_num_chunks += 1
|
||||
return selected_block_ids
|
||||
|
||||
# Sort by score (descending) and select until threshold is reached
|
||||
sorted_indices = torch.argsort(score_ratio, descending=True)
|
||||
cumsum = 0.0
|
||||
selected_indices = set()
|
||||
|
||||
for idx in sorted_indices.tolist():
|
||||
selected_indices.add(idx)
|
||||
cumsum += score_ratio[idx].item()
|
||||
if cumsum >= self.threshold:
|
||||
break
|
||||
|
||||
# Map indices back to block IDs
|
||||
selected_block_ids = [available_blocks[i] for i in sorted(selected_indices)]
|
||||
|
||||
# Always include first block (sink) and last block for safety
|
||||
if available_blocks and available_blocks[0] not in selected_block_ids:
|
||||
@@ -593,7 +610,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
|
||||
|
||||
# Free intermediate tensors to prevent memory leak
|
||||
del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected
|
||||
del attn_scores, block_sums_fine, block_sums_coarse, cpu_block_scores, scores_per_block
|
||||
|
||||
return selected_block_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user