⚡️ perf: optimize XAttention estimate phase with K-only loading
Add load_k_only_to_slot_layer() to OffloadEngine for estimate phase: - Only load K (not K+V) during block selection in select_blocks() - Reduces H2D transfer by 50% in estimate phase - 64K context: XAttn/Full ratio drops from 1.48x to 0.99x - 32K context: XAttn/Full ratio drops from 1.67x to 1.20x The estimate phase uses flat_group_gemm_fuse_reshape(Q, K) which only requires K for attention score computation. V is unused. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -458,12 +458,13 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
|
||||
with nvtx.range("xattn_estimate_gemm"):
|
||||
for cpu_block_id in available_blocks:
|
||||
# Load K block from CPU to GPU (cpu_block_id is chunk index)
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
# Load only K from CPU to GPU (V not needed for estimate)
|
||||
# This saves 50% communication in the estimate phase
|
||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
# Get KV: [1, block_size, num_kv_heads, head_dim]
|
||||
k_block, _ = offload_engine.get_kv_for_slot(slot)
|
||||
# Get K only: [1, block_size, num_kv_heads, head_dim]
|
||||
k_block = offload_engine.get_k_for_slot(slot)
|
||||
|
||||
# Convert K to [batch, heads, k_len, head_dim]
|
||||
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
|
||||
|
||||
Reference in New Issue
Block a user