️ perf: optimize XAttention estimate phase with K-only loading

Add load_k_only_to_slot_layer() to OffloadEngine for estimate phase:
- Only load K (not K+V) during block selection in select_blocks()
- Reduces H2D transfer by 50% in estimate phase
- 64K context: XAttn/Full ratio drops from 1.48x to 0.99x
- 32K context: XAttn/Full ratio drops from 1.67x to 1.20x

The estimate phase uses flat_group_gemm_fuse_reshape(Q, K) which
only requires K for attention score computation. V is unused.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-28 06:24:20 +08:00
parent a832d127b6
commit 3da9b8aef2
3 changed files with 102 additions and 7 deletions

View File

@@ -431,6 +431,62 @@ class OffloadEngine:
# Record H2D transfer: K + V = 2 * block_bytes
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=is_prefill)
def load_k_only_to_slot_layer(
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1,
is_prefill: bool = True,
) -> None:
"""
Async load only K (not V) from CPU block to GPU slot.
Used by XAttention estimate phase which only needs K for attention score
computation. Saves 50% communication compared to loading K+V.
Args:
slot_idx: Target GPU slot index
layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
is_prefill: True if in prefill phase, False if in decode phase
"""
logger.debug(f"Ring load K-only: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
stream = self.slot_transfer_streams[slot_idx]
if chunk_idx >= 0:
nvtx_label = f"H2D K-only: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
else:
nvtx_label = f"H2D K-only: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
nvtx.push_range(message=nvtx_label, color="cyan")
with torch.cuda.stream(stream):
stream.wait_event(self.ring_slot_compute_done[slot_idx])
stream.wait_event(self.ring_slot_offload_done[slot_idx])
# Only copy K, not V
self.k_cache_gpu[slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx].record(stream)
nvtx.pop_range()
# Record H2D transfer: K only = 1 * block_bytes
MemoryObserver.record_h2d(self.gpu_block_bytes, is_prefill=is_prefill)
def get_k_for_slot(self, slot_idx: int) -> Tensor:
"""
Get only K for a ring buffer slot (no V).
Used by XAttention estimate phase which only needs K for attention
score computation.
Args:
slot_idx: GPU slot index
Returns:
k_cache, shape: [1, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu[slot_idx].unsqueeze(0)
def wait_slot_layer(self, slot_idx: int) -> None:
"""
Wait for a slot's loading to complete.