🐛 fix: skip GQA buffer allocation in XAttention offload mode
In offload mode, GQA expansion buffers (_k_expanded, _v_expanded) are not needed since compute_chunked_prefill() handles GQA inline. Previously, these buffers were always allocated based on max_model_len, causing OOM on 24GB GPUs (e.g., RTX 3090) when max_model_len=1M (16GB buffer). Changes: - Add enable_cpu_offload parameter to alloc_policy_metadata() in base class - Skip GQA buffer allocation when enable_cpu_offload=True in XAttentionBSAPolicy - Pass enable_cpu_offload from model_runner to policy Memory savings: ~16GB for 1M seq, ~1.1GB for 72K seq Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -227,9 +227,9 @@ class ModelRunner:
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# GPU-only mode: pre-allocate policy metadata buffers
|
||||
# This avoids dynamic GPU memory allocation during forward pass
|
||||
# if not config.enable_cpu_offload:
|
||||
# Pre-allocate policy metadata buffers
|
||||
# - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
|
||||
# - GPU-only mode: additionally allocate GQA expansion buffers
|
||||
num_heads = hf_config.num_attention_heads // self.world_size
|
||||
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||
num_heads=num_heads,
|
||||
@@ -238,6 +238,7 @@ class ModelRunner:
|
||||
max_seq_len=config.max_model_len,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
enable_cpu_offload=config.enable_cpu_offload,
|
||||
)
|
||||
|
||||
# Log policy info (handle both enum and None cases)
|
||||
|
||||
Reference in New Issue
Block a user