📈 feat: add NVTX markers to XAttention for profiling
Add NVTX range markers to track XAttention performance: - GPU-only: xattn_estimate, xattn_bsa_compute - Offload: xattn_estimate_gemm, xattn_estimate_softmax, xattn_estimate_find_blocks, xattn_compute_historical, xattn_compute_current, xattn_compute_merge These markers enable detailed nsys profiling to identify performance bottlenecks in estimate vs compute phases. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -13,6 +13,7 @@ Note: Decode phase is not supported - use FullAttentionPolicy for decode.
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.cuda.nvtx as nvtx
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
@@ -304,6 +305,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
K_exp, V_exp = K, V
|
||||
|
||||
# Estimate block importance and get sparse mask
|
||||
with nvtx.range("xattn_estimate"):
|
||||
_, mask = xattn_estimate(
|
||||
Q, K_exp,
|
||||
chunk_size=self.chunk_size,
|
||||
@@ -339,6 +341,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
|
||||
|
||||
# Compute sparse attention using BSA
|
||||
with nvtx.range("xattn_bsa_compute"):
|
||||
output = block_sparse_attn_func(
|
||||
q_bsa, k_bsa, v_bsa,
|
||||
cu_seqlens_q_bsa,
|
||||
@@ -453,6 +456,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
|
||||
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
|
||||
|
||||
with nvtx.range("xattn_estimate_gemm"):
|
||||
for cpu_block_id in available_blocks:
|
||||
# Load K block from CPU to GPU (cpu_block_id is chunk index)
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
@@ -510,6 +514,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
|
||||
segment_size = min(4096, reshaped_block_size)
|
||||
|
||||
with nvtx.range("xattn_estimate_softmax"):
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
|
||||
@@ -525,6 +530,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
|
||||
# Step 3: Use find_blocks_chunked to get selection mask
|
||||
# current_index = 0 since we're looking at historical blocks only
|
||||
with nvtx.range("xattn_estimate_find_blocks"):
|
||||
mask = find_blocks_chunked(
|
||||
block_sums,
|
||||
current_index=0,
|
||||
@@ -639,6 +645,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
cpu_block_table = selected_blocks
|
||||
|
||||
if cpu_block_table:
|
||||
with nvtx.range("xattn_compute_historical"):
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
num_blocks = len(cpu_block_table)
|
||||
|
||||
@@ -697,6 +704,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
|
||||
|
||||
# Compute attention to current chunk (causal mask)
|
||||
with nvtx.range("xattn_compute_current"):
|
||||
with torch.cuda.stream(compute_stream):
|
||||
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
@@ -706,6 +714,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
)
|
||||
|
||||
# Merge historical and current attention
|
||||
with nvtx.range("xattn_compute_merge"):
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
|
||||
Reference in New Issue
Block a user