📈 feat: add NVTX markers to XAttention for profiling

Add NVTX range markers to track XAttention performance:
- GPU-only: xattn_estimate, xattn_bsa_compute
- Offload: xattn_estimate_gemm, xattn_estimate_softmax,
  xattn_estimate_find_blocks, xattn_compute_historical,
  xattn_compute_current, xattn_compute_merge

These markers enable detailed nsys profiling to identify
performance bottlenecks in estimate vs compute phases.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-28 00:57:11 +08:00
parent b760de84c5
commit 7b5d3b34eb

View File

@@ -13,6 +13,7 @@ Note: Decode phase is not supported - use FullAttentionPolicy for decode.
import logging
import torch
import torch.cuda.nvtx as nvtx
from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
@@ -304,14 +305,15 @@ class XAttentionBSAPolicy(SparsePolicy):
K_exp, V_exp = K, V
# Estimate block importance and get sparse mask
_, mask = xattn_estimate(
Q, K_exp,
chunk_size=self.chunk_size,
block_size=self.BSA_BLOCK_SIZE,
threshold=self.threshold,
use_triton=self.use_triton,
causal=True,
)
with nvtx.range("xattn_estimate"):
_, mask = xattn_estimate(
Q, K_exp,
chunk_size=self.chunk_size,
block_size=self.BSA_BLOCK_SIZE,
threshold=self.threshold,
use_triton=self.use_triton,
causal=True,
)
# Compute block counts
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
@@ -339,18 +341,19 @@ class XAttentionBSAPolicy(SparsePolicy):
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
# Compute sparse attention using BSA
output = block_sparse_attn_func(
q_bsa, k_bsa, v_bsa,
cu_seqlens_q_bsa,
cu_seqlens_k_bsa,
head_groups,
None, # key_padding_mask
mask_trimmed,
q_len, k_len,
p_dropout=0.0,
deterministic=True,
is_causal=True,
)
with nvtx.range("xattn_bsa_compute"):
output = block_sparse_attn_func(
q_bsa, k_bsa, v_bsa,
cu_seqlens_q_bsa,
cu_seqlens_k_bsa,
head_groups,
None, # key_padding_mask
mask_trimmed,
q_len, k_len,
p_dropout=0.0,
deterministic=True,
is_causal=True,
)
# Update statistics (layer 0 only to avoid overcounting)
if layer_id == 0:
@@ -453,45 +456,46 @@ class XAttentionBSAPolicy(SparsePolicy):
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
for cpu_block_id in available_blocks:
# Load K block from CPU to GPU (cpu_block_id is chunk index)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
with nvtx.range("xattn_estimate_gemm"):
for cpu_block_id in available_blocks:
# Load K block from CPU to GPU (cpu_block_id is chunk index)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
# Get KV: [1, block_size, num_kv_heads, head_dim]
k_block, _ = offload_engine.get_kv_for_slot(slot)
# Get KV: [1, block_size, num_kv_heads, head_dim]
k_block, _ = offload_engine.get_kv_for_slot(slot)
# Convert K to [batch, heads, k_len, head_dim]
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
K_chunk = k_block.transpose(1, 2)
# Convert K to [batch, heads, k_len, head_dim]
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
K_chunk = k_block.transpose(1, 2)
# Handle GQA: expand K heads to match Q heads
num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# Handle GQA: expand K heads to match Q heads
num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N)
k_len = K_chunk.shape[2]
BLOCK_N = 128
k_alignment = self.stride * BLOCK_N
if k_len < k_alignment:
# K too short, pad it
pad_size = k_alignment - k_len
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N)
k_len = K_chunk.shape[2]
BLOCK_N = 128
k_alignment = self.stride * BLOCK_N
if k_len < k_alignment:
# K too short, pad it
pad_size = k_alignment - k_len
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# Compute attention scores using flat_group_gemm_fuse_reshape
# Output: [batch, heads, q_len/stride, k_len/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=False
)
attn_scores_list.append(attn_chunk)
# Compute attention scores using flat_group_gemm_fuse_reshape
# Output: [batch, heads, q_len/stride, k_len/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=False
)
attn_scores_list.append(attn_chunk)
# Mark slot as done for reuse
offload_engine.record_slot_compute_done(slot)
# Mark slot as done for reuse
offload_engine.record_slot_compute_done(slot)
# Concatenate all attention scores along K dimension
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len]
@@ -510,30 +514,32 @@ class XAttentionBSAPolicy(SparsePolicy):
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
segment_size = min(4096, reshaped_block_size)
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False, # Historical blocks are all before current chunk
)
with nvtx.range("xattn_estimate_softmax"):
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128)
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False, # Historical blocks are all before current chunk
)
# block_sums shape: [batch, heads, q_blocks, k_blocks]
# where k_blocks == len(available_blocks) (1:1 mapping with CPU blocks)
# Step 3: Use find_blocks_chunked to get selection mask
# current_index = 0 since we're looking at historical blocks only
mask = find_blocks_chunked(
block_sums,
current_index=0,
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=False, # Historical blocks don't need causal mask
)
with nvtx.range("xattn_estimate_find_blocks"):
mask = find_blocks_chunked(
block_sums,
current_index=0,
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=False, # Historical blocks don't need causal mask
)
# mask shape: [batch, num_heads, q_blocks, k_blocks] - boolean
# where k_blocks == len(available_blocks)
@@ -639,78 +645,81 @@ class XAttentionBSAPolicy(SparsePolicy):
cpu_block_table = selected_blocks
if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table)
with nvtx.range("xattn_compute_historical"):
load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table)
if len(load_slots) == 1:
# Only 1 slot - use synchronous mode
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
if len(load_slots) == 1:
# Only 1 slot - use synchronous mode
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
offload_engine.record_slot_compute_done(slot)
else:
# Multiple slots - use pipeline
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
offload_engine.record_slot_compute_done(slot)
else:
# Multiple slots - use pipeline
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
offload_engine.wait_slot_layer(current_slot)
offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
offload_engine.record_slot_compute_done(current_slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
offload_engine.record_slot_compute_done(current_slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Issue next transfer
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Issue next transfer
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched, k_curr, v_curr,
softmax_scale=softmax_scale,
causal=True,
)
with nvtx.range("xattn_compute_current"):
with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched, k_curr, v_curr,
softmax_scale=softmax_scale,
causal=True,
)
# Merge historical and current attention
with torch.cuda.stream(compute_stream):
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
with nvtx.range("xattn_compute_merge"):
with torch.cuda.stream(compute_stream):
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Sync default stream with compute_stream before returning
torch.cuda.default_stream().wait_stream(compute_stream)