🐛 fix: skip GQA buffer allocation in XAttention offload mode

In offload mode, GQA expansion buffers (_k_expanded, _v_expanded) are not
needed since compute_chunked_prefill() handles GQA inline. Previously,
these buffers were always allocated based on max_model_len, causing OOM
on 24GB GPUs (e.g., RTX 3090) when max_model_len=1M (16GB buffer).

Changes:
- Add enable_cpu_offload parameter to alloc_policy_metadata() in base class
- Skip GQA buffer allocation when enable_cpu_offload=True in XAttentionBSAPolicy
- Pass enable_cpu_offload from model_runner to policy

Memory savings: ~16GB for 1M seq, ~1.1GB for 72K seq

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-02-05 02:57:18 +08:00
parent af4da454ba
commit 11a867f6fb
4 changed files with 59 additions and 7 deletions

View File

@@ -227,9 +227,9 @@ class ModelRunner:
device=torch.device("cuda"),
)
# GPU-only mode: pre-allocate policy metadata buffers
# This avoids dynamic GPU memory allocation during forward pass
# if not config.enable_cpu_offload:
# Pre-allocate policy metadata buffers
# - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
# - GPU-only mode: additionally allocate GQA expansion buffers
num_heads = hf_config.num_attention_heads // self.world_size
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads,
@@ -238,6 +238,7 @@ class ModelRunner:
max_seq_len=config.max_model_len,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
enable_cpu_offload=config.enable_cpu_offload,
)
# Log policy info (handle both enum and None cases)