️ perf: optimize XAttention estimate phase with K-only loading

Add load_k_only_to_slot_layer() to OffloadEngine for estimate phase:
- Only load K (not K+V) during block selection in select_blocks()
- Reduces H2D transfer by 50% in estimate phase
- 64K context: XAttn/Full ratio drops from 1.48x to 0.99x
- 32K context: XAttn/Full ratio drops from 1.67x to 1.20x

The estimate phase uses flat_group_gemm_fuse_reshape(Q, K) which
only requires K for attention score computation. V is unused.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-28 06:24:20 +08:00
parent a832d127b6
commit 3da9b8aef2
3 changed files with 102 additions and 7 deletions

View File

@@ -458,12 +458,13 @@ class XAttentionBSAPolicy(SparsePolicy):
with nvtx.range("xattn_estimate_gemm"):
for cpu_block_id in available_blocks:
# Load K block from CPU to GPU (cpu_block_id is chunk index)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
# Load only K from CPU to GPU (V not needed for estimate)
# This saves 50% communication in the estimate phase
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
# Get KV: [1, block_size, num_kv_heads, head_dim]
k_block, _ = offload_engine.get_kv_for_slot(slot)
# Get K only: [1, block_size, num_kv_heads, head_dim]
k_block = offload_engine.get_k_for_slot(slot)
# Convert K to [batch, heads, k_len, head_dim]
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]