📈 feat: add NVTX markers to XAttention for profiling

Add NVTX range markers to track XAttention performance:
- GPU-only: xattn_estimate, xattn_bsa_compute
- Offload: xattn_estimate_gemm, xattn_estimate_softmax,
  xattn_estimate_find_blocks, xattn_compute_historical,
  xattn_compute_current, xattn_compute_merge

These markers enable detailed nsys profiling to identify
performance bottlenecks in estimate vs compute phases.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-28 00:57:11 +08:00
parent b760de84c5
commit 7b5d3b34eb

View File

@@ -13,6 +13,7 @@ Note: Decode phase is not supported - use FullAttentionPolicy for decode.
import logging import logging
import torch import torch
import torch.cuda.nvtx as nvtx
from typing import List, Tuple, TYPE_CHECKING from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
@@ -304,14 +305,15 @@ class XAttentionBSAPolicy(SparsePolicy):
K_exp, V_exp = K, V K_exp, V_exp = K, V
# Estimate block importance and get sparse mask # Estimate block importance and get sparse mask
_, mask = xattn_estimate( with nvtx.range("xattn_estimate"):
Q, K_exp, _, mask = xattn_estimate(
chunk_size=self.chunk_size, Q, K_exp,
block_size=self.BSA_BLOCK_SIZE, chunk_size=self.chunk_size,
threshold=self.threshold, block_size=self.BSA_BLOCK_SIZE,
use_triton=self.use_triton, threshold=self.threshold,
causal=True, use_triton=self.use_triton,
) causal=True,
)
# Compute block counts # Compute block counts
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
@@ -339,18 +341,19 @@ class XAttentionBSAPolicy(SparsePolicy):
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous() mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
# Compute sparse attention using BSA # Compute sparse attention using BSA
output = block_sparse_attn_func( with nvtx.range("xattn_bsa_compute"):
q_bsa, k_bsa, v_bsa, output = block_sparse_attn_func(
cu_seqlens_q_bsa, q_bsa, k_bsa, v_bsa,
cu_seqlens_k_bsa, cu_seqlens_q_bsa,
head_groups, cu_seqlens_k_bsa,
None, # key_padding_mask head_groups,
mask_trimmed, None, # key_padding_mask
q_len, k_len, mask_trimmed,
p_dropout=0.0, q_len, k_len,
deterministic=True, p_dropout=0.0,
is_causal=True, deterministic=True,
) is_causal=True,
)
# Update statistics (layer 0 only to avoid overcounting) # Update statistics (layer 0 only to avoid overcounting)
if layer_id == 0: if layer_id == 0:
@@ -453,45 +456,46 @@ class XAttentionBSAPolicy(SparsePolicy):
block_size = ctx.block_size # tokens per CPU block (e.g., 1024) block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128 reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
for cpu_block_id in available_blocks: with nvtx.range("xattn_estimate_gemm"):
# Load K block from CPU to GPU (cpu_block_id is chunk index) for cpu_block_id in available_blocks:
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) # Load K block from CPU to GPU (cpu_block_id is chunk index)
offload_engine.wait_slot_layer(slot) offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
# Get KV: [1, block_size, num_kv_heads, head_dim] # Get KV: [1, block_size, num_kv_heads, head_dim]
k_block, _ = offload_engine.get_kv_for_slot(slot) k_block, _ = offload_engine.get_kv_for_slot(slot)
# Convert K to [batch, heads, k_len, head_dim] # Convert K to [batch, heads, k_len, head_dim]
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim] # k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
K_chunk = k_block.transpose(1, 2) K_chunk = k_block.transpose(1, 2)
# Handle GQA: expand K heads to match Q heads # Handle GQA: expand K heads to match Q heads
num_kv_heads = K_chunk.shape[1] num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads: if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N) # Pad K if necessary (k_len must be divisible by stride * BLOCK_N)
k_len = K_chunk.shape[2] k_len = K_chunk.shape[2]
BLOCK_N = 128 BLOCK_N = 128
k_alignment = self.stride * BLOCK_N k_alignment = self.stride * BLOCK_N
if k_len < k_alignment: if k_len < k_alignment:
# K too short, pad it # K too short, pad it
pad_size = k_alignment - k_len pad_size = k_alignment - k_len
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0) K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# Compute attention scores using flat_group_gemm_fuse_reshape # Compute attention scores using flat_group_gemm_fuse_reshape
# Output: [batch, heads, q_len/stride, k_len/stride] # Output: [batch, heads, q_len/stride, k_len/stride]
attn_chunk = flat_group_gemm_fuse_reshape( attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride, Q, K_chunk, self.stride,
chunk_start=0, chunk_start=0,
chunk_end=q_reshaped_len, chunk_end=q_reshaped_len,
is_causal=False is_causal=False
) )
attn_scores_list.append(attn_chunk) attn_scores_list.append(attn_chunk)
# Mark slot as done for reuse # Mark slot as done for reuse
offload_engine.record_slot_compute_done(slot) offload_engine.record_slot_compute_done(slot)
# Concatenate all attention scores along K dimension # Concatenate all attention scores along K dimension
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len] # Each chunk: [1, heads, q_reshaped_len, block_reshaped_len]
@@ -510,30 +514,32 @@ class XAttentionBSAPolicy(SparsePolicy):
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
segment_size = min(4096, reshaped_block_size) segment_size = min(4096, reshaped_block_size)
block_sums = softmax_fuse_block_sum( with nvtx.range("xattn_estimate_softmax"):
attn_scores, block_sums = softmax_fuse_block_sum(
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128) attn_scores,
segment_size, reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
chunk_start=0, segment_size,
chunk_end=q_reshaped_len, chunk_start=0,
real_q_len=q_reshaped_len, chunk_end=q_reshaped_len,
scale=scale, real_q_len=q_reshaped_len,
is_causal=False, # Historical blocks are all before current chunk scale=scale,
) is_causal=False, # Historical blocks are all before current chunk
)
# block_sums shape: [batch, heads, q_blocks, k_blocks] # block_sums shape: [batch, heads, q_blocks, k_blocks]
# where k_blocks == len(available_blocks) (1:1 mapping with CPU blocks) # where k_blocks == len(available_blocks) (1:1 mapping with CPU blocks)
# Step 3: Use find_blocks_chunked to get selection mask # Step 3: Use find_blocks_chunked to get selection mask
# current_index = 0 since we're looking at historical blocks only # current_index = 0 since we're looking at historical blocks only
mask = find_blocks_chunked( with nvtx.range("xattn_estimate_find_blocks"):
block_sums, mask = find_blocks_chunked(
current_index=0, block_sums,
threshold=self.threshold, current_index=0,
num_to_choose=None, threshold=self.threshold,
decoding=False, num_to_choose=None,
mode="prefill", decoding=False,
causal=False, # Historical blocks don't need causal mask mode="prefill",
) causal=False, # Historical blocks don't need causal mask
)
# mask shape: [batch, num_heads, q_blocks, k_blocks] - boolean # mask shape: [batch, num_heads, q_blocks, k_blocks] - boolean
# where k_blocks == len(available_blocks) # where k_blocks == len(available_blocks)
@@ -639,78 +645,81 @@ class XAttentionBSAPolicy(SparsePolicy):
cpu_block_table = selected_blocks cpu_block_table = selected_blocks
if cpu_block_table: if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots)) with nvtx.range("xattn_compute_historical"):
num_blocks = len(cpu_block_table) load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table)
if len(load_slots) == 1: if len(load_slots) == 1:
# Only 1 slot - use synchronous mode # Only 1 slot - use synchronous mode
slot = load_slots[0] slot = load_slots[0]
for block_idx in range(num_blocks): for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx] cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot) offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot) prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse( prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v, q_batched, prev_k, prev_v,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=False, causal=False,
) )
if o_acc is None: if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse o_acc, lse_acc = prev_o, prev_lse
else: else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
offload_engine.record_slot_compute_done(slot) offload_engine.record_slot_compute_done(slot)
else: else:
# Multiple slots - use pipeline # Multiple slots - use pipeline
num_slots = len(load_slots) num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks) num_preload = min(num_slots, num_blocks)
for i in range(num_preload): for i in range(num_preload):
cpu_block_id = cpu_block_table[i] cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
for block_idx in range(num_blocks): for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots] current_slot = load_slots[block_idx % num_slots]
offload_engine.wait_slot_layer(current_slot) offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse( prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v, q_batched, prev_k, prev_v,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=False, causal=False,
) )
offload_engine.record_slot_compute_done(current_slot) offload_engine.record_slot_compute_done(current_slot)
if o_acc is None: if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse o_acc, lse_acc = prev_o, prev_lse
else: else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Issue next transfer # Issue next transfer
next_block_idx = block_idx + num_slots next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks: if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots] next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx] next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Compute attention to current chunk (causal mask) # Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream): with nvtx.range("xattn_compute_current"):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) with torch.cuda.stream(compute_stream):
current_o, current_lse = flash_attn_with_lse( k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
q_batched, k_curr, v_curr, current_o, current_lse = flash_attn_with_lse(
softmax_scale=softmax_scale, q_batched, k_curr, v_curr,
causal=True, softmax_scale=softmax_scale,
) causal=True,
)
# Merge historical and current attention # Merge historical and current attention
with torch.cuda.stream(compute_stream): with nvtx.range("xattn_compute_merge"):
if o_acc is None: with torch.cuda.stream(compute_stream):
final_o = current_o if o_acc is None:
else: final_o = current_o
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Sync default stream with compute_stream before returning # Sync default stream with compute_stream before returning
torch.cuda.default_stream().wait_stream(compute_stream) torch.cuda.default_stream().wait_stream(compute_stream)