⚡️ perf: optimize XAttention estimate phase with K-only loading
Add load_k_only_to_slot_layer() to OffloadEngine for estimate phase: - Only load K (not K+V) during block selection in select_blocks() - Reduces H2D transfer by 50% in estimate phase - 64K context: XAttn/Full ratio drops from 1.48x to 0.99x - 32K context: XAttn/Full ratio drops from 1.67x to 1.20x The estimate phase uses flat_group_gemm_fuse_reshape(Q, K) which only requires K for attention score computation. V is unused. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -34,17 +34,17 @@ GPU-CPU 通信量测试结果,对比 Full Policy 和 XAttention BSA Policy。
|
|||||||
| Decode H2D (32 tokens) | 262.13 GB | 262.13 GB | 1.00x |
|
| Decode H2D (32 tokens) | 262.13 GB | 262.13 GB | 1.00x |
|
||||||
| TTFT | 27081 ms | 33634 ms | 1.24x |
|
| TTFT | 27081 ms | 33634 ms | 1.24x |
|
||||||
|
|
||||||
## 通信量比率对比
|
## 通信量比率对比 (K-only 优化前)
|
||||||
|
|
||||||
| 上下文长度 | XAttn/Full Prefill H2D 比率 |
|
| 上下文长度 | XAttn/Full Prefill H2D 比率 |
|
||||||
|------------|----------------------------|
|
|------------|----------------------------|
|
||||||
| 32K | 1.67x |
|
| 32K | 1.67x |
|
||||||
| 64K | 1.48x |
|
| 64K | 1.48x |
|
||||||
|
|
||||||
### 分析
|
### 分析 (优化前)
|
||||||
|
|
||||||
1. **XAttention 通信量增加原因**:
|
1. **XAttention 通信量增加原因**:
|
||||||
- Estimate 阶段:加载 **100%** 历史 blocks(用于 attention score 估计)
|
- Estimate 阶段:加载 **100%** 历史 blocks 的 **K+V**(用于 attention score 估计)
|
||||||
- Compute 阶段:加载 **选中的** blocks(约 70-80%)
|
- Compute 阶段:加载 **选中的** blocks(约 70-80%)
|
||||||
- 理论比率:`1 + selection_density`
|
- 理论比率:`1 + selection_density`
|
||||||
|
|
||||||
@@ -57,6 +57,44 @@ GPU-CPU 通信量测试结果,对比 Full Policy 和 XAttention BSA Policy。
|
|||||||
- XAttention 仅支持 prefill 阶段
|
- XAttention 仅支持 prefill 阶段
|
||||||
- Decode 阶段 fallback 到 Full Policy
|
- Decode 阶段 fallback 到 Full Policy
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## K-only 优化 (2026-01-28)
|
||||||
|
|
||||||
|
### 优化原理
|
||||||
|
|
||||||
|
XAttention 的 `select_blocks` 估计阶段只需要 K 来计算 attention scores:
|
||||||
|
```python
|
||||||
|
# flat_group_gemm_fuse_reshape 只使用 Q 和 K
|
||||||
|
attn_scores = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
V 在估计阶段完全不使用,但之前代码会同时加载 K 和 V,造成 50% 通信量浪费。
|
||||||
|
|
||||||
|
### 优化实现
|
||||||
|
|
||||||
|
1. **新增方法**: `OffloadEngine.load_k_only_to_slot_layer()` - 只加载 K
|
||||||
|
2. **修改 select_blocks**: 使用只加载 K 的新方法
|
||||||
|
|
||||||
|
### 优化后测试结果
|
||||||
|
|
||||||
|
| 上下文 | Full Policy | XAttn (优化前) | XAttn (优化后) | 优化节省 |
|
||||||
|
|--------|-------------|---------------|---------------|---------|
|
||||||
|
| 32K | 66.57 GB | 111.12 GB | **79.76 GB** | **28.2%** |
|
||||||
|
| 64K | 262.13 GB | 386.62 GB | **258.78 GB** | **33.1%** |
|
||||||
|
|
||||||
|
### XAttn/Full 比率变化
|
||||||
|
|
||||||
|
| 上下文 | 优化前比率 | 优化后比率 |
|
||||||
|
|--------|-----------|-----------|
|
||||||
|
| 32K | 1.67x | **1.20x** |
|
||||||
|
| 64K | 1.48x | **0.99x** |
|
||||||
|
|
||||||
|
### 结论
|
||||||
|
|
||||||
|
优化后,64K 上下文的 XAttention 通信量与 Full Policy 基本持平 (0.99x),
|
||||||
|
而 32K 也从 1.67x 降到 1.20x。这说明估计阶段的 K-only 优化非常有效
|
||||||
|
|
||||||
## 测试命令
|
## 测试命令
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -431,6 +431,62 @@ class OffloadEngine:
|
|||||||
# Record H2D transfer: K + V = 2 * block_bytes
|
# Record H2D transfer: K + V = 2 * block_bytes
|
||||||
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=is_prefill)
|
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=is_prefill)
|
||||||
|
|
||||||
|
def load_k_only_to_slot_layer(
|
||||||
|
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Async load only K (not V) from CPU block to GPU slot.
|
||||||
|
|
||||||
|
Used by XAttention estimate phase which only needs K for attention score
|
||||||
|
computation. Saves 50% communication compared to loading K+V.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slot_idx: Target GPU slot index
|
||||||
|
layer_id: Layer index to load (for CPU cache indexing)
|
||||||
|
cpu_block_id: Source CPU block ID
|
||||||
|
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
|
||||||
|
is_prefill: True if in prefill phase, False if in decode phase
|
||||||
|
"""
|
||||||
|
logger.debug(f"Ring load K-only: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||||
|
|
||||||
|
stream = self.slot_transfer_streams[slot_idx]
|
||||||
|
|
||||||
|
if chunk_idx >= 0:
|
||||||
|
nvtx_label = f"H2D K-only: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||||
|
else:
|
||||||
|
nvtx_label = f"H2D K-only: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||||
|
|
||||||
|
nvtx.push_range(message=nvtx_label, color="cyan")
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
stream.wait_event(self.ring_slot_compute_done[slot_idx])
|
||||||
|
stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||||
|
|
||||||
|
# Only copy K, not V
|
||||||
|
self.k_cache_gpu[slot_idx].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.ring_slot_ready[slot_idx].record(stream)
|
||||||
|
nvtx.pop_range()
|
||||||
|
|
||||||
|
# Record H2D transfer: K only = 1 * block_bytes
|
||||||
|
MemoryObserver.record_h2d(self.gpu_block_bytes, is_prefill=is_prefill)
|
||||||
|
|
||||||
|
def get_k_for_slot(self, slot_idx: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Get only K for a ring buffer slot (no V).
|
||||||
|
|
||||||
|
Used by XAttention estimate phase which only needs K for attention
|
||||||
|
score computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slot_idx: GPU slot index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
k_cache, shape: [1, block_size, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
return self.k_cache_gpu[slot_idx].unsqueeze(0)
|
||||||
|
|
||||||
def wait_slot_layer(self, slot_idx: int) -> None:
|
def wait_slot_layer(self, slot_idx: int) -> None:
|
||||||
"""
|
"""
|
||||||
Wait for a slot's loading to complete.
|
Wait for a slot's loading to complete.
|
||||||
|
|||||||
@@ -458,12 +458,13 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
|
|
||||||
with nvtx.range("xattn_estimate_gemm"):
|
with nvtx.range("xattn_estimate_gemm"):
|
||||||
for cpu_block_id in available_blocks:
|
for cpu_block_id in available_blocks:
|
||||||
# Load K block from CPU to GPU (cpu_block_id is chunk index)
|
# Load only K from CPU to GPU (V not needed for estimate)
|
||||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
# This saves 50% communication in the estimate phase
|
||||||
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||||
offload_engine.wait_slot_layer(slot)
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
# Get KV: [1, block_size, num_kv_heads, head_dim]
|
# Get K only: [1, block_size, num_kv_heads, head_dim]
|
||||||
k_block, _ = offload_engine.get_kv_for_slot(slot)
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
|
||||||
# Convert K to [batch, heads, k_len, head_dim]
|
# Convert K to [batch, heads, k_len, head_dim]
|
||||||
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
|
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
|
||||||
|
|||||||
Reference in New Issue
Block a user