📈 feat: add NVTX markers to XAttention for profiling

Add NVTX range markers to track XAttention performance:
- GPU-only: xattn_estimate, xattn_bsa_compute
- Offload: xattn_estimate_gemm, xattn_estimate_softmax,
  xattn_estimate_find_blocks, xattn_compute_historical,
  xattn_compute_current, xattn_compute_merge

These markers enable detailed nsys profiling to identify
performance bottlenecks in estimate vs compute phases.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-28 00:57:11 +08:00
parent b760de84c5
commit 7b5d3b34eb

View File

@@ -13,6 +13,7 @@ Note: Decode phase is not supported - use FullAttentionPolicy for decode.
import logging import logging
import torch import torch
import torch.cuda.nvtx as nvtx
from typing import List, Tuple, TYPE_CHECKING from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
@@ -304,6 +305,7 @@ class XAttentionBSAPolicy(SparsePolicy):
K_exp, V_exp = K, V K_exp, V_exp = K, V
# Estimate block importance and get sparse mask # Estimate block importance and get sparse mask
with nvtx.range("xattn_estimate"):
_, mask = xattn_estimate( _, mask = xattn_estimate(
Q, K_exp, Q, K_exp,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
@@ -339,6 +341,7 @@ class XAttentionBSAPolicy(SparsePolicy):
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous() mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
# Compute sparse attention using BSA # Compute sparse attention using BSA
with nvtx.range("xattn_bsa_compute"):
output = block_sparse_attn_func( output = block_sparse_attn_func(
q_bsa, k_bsa, v_bsa, q_bsa, k_bsa, v_bsa,
cu_seqlens_q_bsa, cu_seqlens_q_bsa,
@@ -453,6 +456,7 @@ class XAttentionBSAPolicy(SparsePolicy):
block_size = ctx.block_size # tokens per CPU block (e.g., 1024) block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128 reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
with nvtx.range("xattn_estimate_gemm"):
for cpu_block_id in available_blocks: for cpu_block_id in available_blocks:
# Load K block from CPU to GPU (cpu_block_id is chunk index) # Load K block from CPU to GPU (cpu_block_id is chunk index)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
@@ -510,6 +514,7 @@ class XAttentionBSAPolicy(SparsePolicy):
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
segment_size = min(4096, reshaped_block_size) segment_size = min(4096, reshaped_block_size)
with nvtx.range("xattn_estimate_softmax"):
block_sums = softmax_fuse_block_sum( block_sums = softmax_fuse_block_sum(
attn_scores, attn_scores,
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128) reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
@@ -525,6 +530,7 @@ class XAttentionBSAPolicy(SparsePolicy):
# Step 3: Use find_blocks_chunked to get selection mask # Step 3: Use find_blocks_chunked to get selection mask
# current_index = 0 since we're looking at historical blocks only # current_index = 0 since we're looking at historical blocks only
with nvtx.range("xattn_estimate_find_blocks"):
mask = find_blocks_chunked( mask = find_blocks_chunked(
block_sums, block_sums,
current_index=0, current_index=0,
@@ -639,6 +645,7 @@ class XAttentionBSAPolicy(SparsePolicy):
cpu_block_table = selected_blocks cpu_block_table = selected_blocks
if cpu_block_table: if cpu_block_table:
with nvtx.range("xattn_compute_historical"):
load_slots = list(range(offload_engine.num_ring_slots)) load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table) num_blocks = len(cpu_block_table)
@@ -697,6 +704,7 @@ class XAttentionBSAPolicy(SparsePolicy):
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Compute attention to current chunk (causal mask) # Compute attention to current chunk (causal mask)
with nvtx.range("xattn_compute_current"):
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse( current_o, current_lse = flash_attn_with_lse(
@@ -706,6 +714,7 @@ class XAttentionBSAPolicy(SparsePolicy):
) )
# Merge historical and current attention # Merge historical and current attention
with nvtx.range("xattn_compute_merge"):
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
if o_acc is None: if o_acc is None:
final_o = current_o final_o = current_o