From 7b5d3b34eb7af013eeb9ed0ce4401630122b9f23 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 28 Jan 2026 00:57:11 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=88=20feat:=20add=20NVTX=20markers=20t?= =?UTF-8?q?o=20XAttention=20for=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add NVTX range markers to track XAttention performance: - GPU-only: xattn_estimate, xattn_bsa_compute - Offload: xattn_estimate_gemm, xattn_estimate_softmax, xattn_estimate_find_blocks, xattn_compute_historical, xattn_compute_current, xattn_compute_merge These markers enable detailed nsys profiling to identify performance bottlenecks in estimate vs compute phases. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude Co-Authored-By: Happy --- nanovllm/kvcache/sparse/xattn_bsa.py | 275 ++++++++++++++------------- 1 file changed, 142 insertions(+), 133 deletions(-) diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 6ec026c..50c87a5 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -13,6 +13,7 @@ Note: Decode phase is not supported - use FullAttentionPolicy for decode. import logging import torch +import torch.cuda.nvtx as nvtx from typing import List, Tuple, TYPE_CHECKING from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext @@ -304,14 +305,15 @@ class XAttentionBSAPolicy(SparsePolicy): K_exp, V_exp = K, V # Estimate block importance and get sparse mask - _, mask = xattn_estimate( - Q, K_exp, - chunk_size=self.chunk_size, - block_size=self.BSA_BLOCK_SIZE, - threshold=self.threshold, - use_triton=self.use_triton, - causal=True, - ) + with nvtx.range("xattn_estimate"): + _, mask = xattn_estimate( + Q, K_exp, + chunk_size=self.chunk_size, + block_size=self.BSA_BLOCK_SIZE, + threshold=self.threshold, + use_triton=self.use_triton, + causal=True, + ) # Compute block counts q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE @@ -339,18 +341,19 @@ class XAttentionBSAPolicy(SparsePolicy): mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous() # Compute sparse attention using BSA - output = block_sparse_attn_func( - q_bsa, k_bsa, v_bsa, - cu_seqlens_q_bsa, - cu_seqlens_k_bsa, - head_groups, - None, # key_padding_mask - mask_trimmed, - q_len, k_len, - p_dropout=0.0, - deterministic=True, - is_causal=True, - ) + with nvtx.range("xattn_bsa_compute"): + output = block_sparse_attn_func( + q_bsa, k_bsa, v_bsa, + cu_seqlens_q_bsa, + cu_seqlens_k_bsa, + head_groups, + None, # key_padding_mask + mask_trimmed, + q_len, k_len, + p_dropout=0.0, + deterministic=True, + is_causal=True, + ) # Update statistics (layer 0 only to avoid overcounting) if layer_id == 0: @@ -453,45 +456,46 @@ class XAttentionBSAPolicy(SparsePolicy): block_size = ctx.block_size # tokens per CPU block (e.g., 1024) reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128 - for cpu_block_id in available_blocks: - # Load K block from CPU to GPU (cpu_block_id is chunk index) - offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) - offload_engine.wait_slot_layer(slot) + with nvtx.range("xattn_estimate_gemm"): + for cpu_block_id in available_blocks: + # Load K block from CPU to GPU (cpu_block_id is chunk index) + offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) + offload_engine.wait_slot_layer(slot) - # Get KV: [1, block_size, num_kv_heads, head_dim] - k_block, _ = offload_engine.get_kv_for_slot(slot) + # Get KV: [1, block_size, num_kv_heads, head_dim] + k_block, _ = offload_engine.get_kv_for_slot(slot) - # Convert K to [batch, heads, k_len, head_dim] - # k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim] - K_chunk = k_block.transpose(1, 2) + # Convert K to [batch, heads, k_len, head_dim] + # k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim] + K_chunk = k_block.transpose(1, 2) - # Handle GQA: expand K heads to match Q heads - num_kv_heads = K_chunk.shape[1] - if num_heads != num_kv_heads: - num_groups = num_heads // num_kv_heads - K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) + # Handle GQA: expand K heads to match Q heads + num_kv_heads = K_chunk.shape[1] + if num_heads != num_kv_heads: + num_groups = num_heads // num_kv_heads + K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) - # Pad K if necessary (k_len must be divisible by stride * BLOCK_N) - k_len = K_chunk.shape[2] - BLOCK_N = 128 - k_alignment = self.stride * BLOCK_N - if k_len < k_alignment: - # K too short, pad it - pad_size = k_alignment - k_len - K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0) + # Pad K if necessary (k_len must be divisible by stride * BLOCK_N) + k_len = K_chunk.shape[2] + BLOCK_N = 128 + k_alignment = self.stride * BLOCK_N + if k_len < k_alignment: + # K too short, pad it + pad_size = k_alignment - k_len + K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0) - # Compute attention scores using flat_group_gemm_fuse_reshape - # Output: [batch, heads, q_len/stride, k_len/stride] - attn_chunk = flat_group_gemm_fuse_reshape( - Q, K_chunk, self.stride, - chunk_start=0, - chunk_end=q_reshaped_len, - is_causal=False - ) - attn_scores_list.append(attn_chunk) + # Compute attention scores using flat_group_gemm_fuse_reshape + # Output: [batch, heads, q_len/stride, k_len/stride] + attn_chunk = flat_group_gemm_fuse_reshape( + Q, K_chunk, self.stride, + chunk_start=0, + chunk_end=q_reshaped_len, + is_causal=False + ) + attn_scores_list.append(attn_chunk) - # Mark slot as done for reuse - offload_engine.record_slot_compute_done(slot) + # Mark slot as done for reuse + offload_engine.record_slot_compute_done(slot) # Concatenate all attention scores along K dimension # Each chunk: [1, heads, q_reshaped_len, block_reshaped_len] @@ -510,30 +514,32 @@ class XAttentionBSAPolicy(SparsePolicy): scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling segment_size = min(4096, reshaped_block_size) - block_sums = softmax_fuse_block_sum( - attn_scores, - reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128) - segment_size, - chunk_start=0, - chunk_end=q_reshaped_len, - real_q_len=q_reshaped_len, - scale=scale, - is_causal=False, # Historical blocks are all before current chunk - ) + with nvtx.range("xattn_estimate_softmax"): + block_sums = softmax_fuse_block_sum( + attn_scores, + reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128) + segment_size, + chunk_start=0, + chunk_end=q_reshaped_len, + real_q_len=q_reshaped_len, + scale=scale, + is_causal=False, # Historical blocks are all before current chunk + ) # block_sums shape: [batch, heads, q_blocks, k_blocks] # where k_blocks == len(available_blocks) (1:1 mapping with CPU blocks) # Step 3: Use find_blocks_chunked to get selection mask # current_index = 0 since we're looking at historical blocks only - mask = find_blocks_chunked( - block_sums, - current_index=0, - threshold=self.threshold, - num_to_choose=None, - decoding=False, - mode="prefill", - causal=False, # Historical blocks don't need causal mask - ) + with nvtx.range("xattn_estimate_find_blocks"): + mask = find_blocks_chunked( + block_sums, + current_index=0, + threshold=self.threshold, + num_to_choose=None, + decoding=False, + mode="prefill", + causal=False, # Historical blocks don't need causal mask + ) # mask shape: [batch, num_heads, q_blocks, k_blocks] - boolean # where k_blocks == len(available_blocks) @@ -639,78 +645,81 @@ class XAttentionBSAPolicy(SparsePolicy): cpu_block_table = selected_blocks if cpu_block_table: - load_slots = list(range(offload_engine.num_ring_slots)) - num_blocks = len(cpu_block_table) + with nvtx.range("xattn_compute_historical"): + load_slots = list(range(offload_engine.num_ring_slots)) + num_blocks = len(cpu_block_table) - if len(load_slots) == 1: - # Only 1 slot - use synchronous mode - slot = load_slots[0] - for block_idx in range(num_blocks): - cpu_block_id = cpu_block_table[block_idx] - offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) - offload_engine.wait_slot_layer(slot) + if len(load_slots) == 1: + # Only 1 slot - use synchronous mode + slot = load_slots[0] + for block_idx in range(num_blocks): + cpu_block_id = cpu_block_table[block_idx] + offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) + offload_engine.wait_slot_layer(slot) - with torch.cuda.stream(compute_stream): - prev_k, prev_v = offload_engine.get_kv_for_slot(slot) - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=softmax_scale, - causal=False, - ) - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - offload_engine.record_slot_compute_done(slot) - else: - # Multiple slots - use pipeline - num_slots = len(load_slots) - num_preload = min(num_slots, num_blocks) - for i in range(num_preload): - cpu_block_id = cpu_block_table[i] - offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id) + with torch.cuda.stream(compute_stream): + prev_k, prev_v = offload_engine.get_kv_for_slot(slot) + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=softmax_scale, + causal=False, + ) + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + offload_engine.record_slot_compute_done(slot) + else: + # Multiple slots - use pipeline + num_slots = len(load_slots) + num_preload = min(num_slots, num_blocks) + for i in range(num_preload): + cpu_block_id = cpu_block_table[i] + offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id) - for block_idx in range(num_blocks): - current_slot = load_slots[block_idx % num_slots] + for block_idx in range(num_blocks): + current_slot = load_slots[block_idx % num_slots] - offload_engine.wait_slot_layer(current_slot) + offload_engine.wait_slot_layer(current_slot) - with torch.cuda.stream(compute_stream): - prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=softmax_scale, - causal=False, - ) - offload_engine.record_slot_compute_done(current_slot) + with torch.cuda.stream(compute_stream): + prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=softmax_scale, + causal=False, + ) + offload_engine.record_slot_compute_done(current_slot) - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - # Issue next transfer - next_block_idx = block_idx + num_slots - if next_block_idx < num_blocks: - next_slot = load_slots[next_block_idx % num_slots] - next_cpu_block_id = cpu_block_table[next_block_idx] - offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) + # Issue next transfer + next_block_idx = block_idx + num_slots + if next_block_idx < num_blocks: + next_slot = load_slots[next_block_idx % num_slots] + next_cpu_block_id = cpu_block_table[next_block_idx] + offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) # Compute attention to current chunk (causal mask) - with torch.cuda.stream(compute_stream): - k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) - current_o, current_lse = flash_attn_with_lse( - q_batched, k_curr, v_curr, - softmax_scale=softmax_scale, - causal=True, - ) + with nvtx.range("xattn_compute_current"): + with torch.cuda.stream(compute_stream): + k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) + current_o, current_lse = flash_attn_with_lse( + q_batched, k_curr, v_curr, + softmax_scale=softmax_scale, + causal=True, + ) # Merge historical and current attention - with torch.cuda.stream(compute_stream): - if o_acc is None: - final_o = current_o - else: - final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + with nvtx.range("xattn_compute_merge"): + with torch.cuda.stream(compute_stream): + if o_acc is None: + final_o = current_o + else: + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) # Sync default stream with compute_stream before returning torch.cuda.default_stream().wait_stream(compute_stream)