🐛 fix: skip GQA buffer allocation in XAttention offload mode

In offload mode, GQA expansion buffers (_k_expanded, _v_expanded) are not
needed since compute_chunked_prefill() handles GQA inline. Previously,
these buffers were always allocated based on max_model_len, causing OOM
on 24GB GPUs (e.g., RTX 3090) when max_model_len=1M (16GB buffer).

Changes:
- Add enable_cpu_offload parameter to alloc_policy_metadata() in base class
- Skip GQA buffer allocation when enable_cpu_offload=True in XAttentionBSAPolicy
- Pass enable_cpu_offload from model_runner to policy

Memory savings: ~16GB for 1M seq, ~1.1GB for 72K seq

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-02-05 02:57:18 +08:00
parent af4da454ba
commit 11a867f6fb
4 changed files with 59 additions and 7 deletions

View File

@@ -175,6 +175,7 @@ class XAttentionBSAPolicy(SparsePolicy):
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
enable_cpu_offload: bool = False,
) -> None:
"""
Pre-allocate GQA expansion buffers for GPU-only mode.
@@ -235,7 +236,14 @@ class XAttentionBSAPolicy(SparsePolicy):
f"m/l shape={m_partial_shape} ({m_l_memory_mb:.1f} MB), "
f"block_sums shape={block_sums_shape} ({block_sums_memory_mb:.1f} MB)")
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
# Skip GQA buffers in offload mode
# Chunked prefill uses compute_chunked_prefill() which handles GQA inline
if enable_cpu_offload:
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers (saves ~16GB for 1M seq)")
return
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
# Only allocate if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return