🐛 fix: skip GQA buffer allocation in XAttention offload mode
In offload mode, GQA expansion buffers (_k_expanded, _v_expanded) are not needed since compute_chunked_prefill() handles GQA inline. Previously, these buffers were always allocated based on max_model_len, causing OOM on 24GB GPUs (e.g., RTX 3090) when max_model_len=1M (16GB buffer). Changes: - Add enable_cpu_offload parameter to alloc_policy_metadata() in base class - Skip GQA buffer allocation when enable_cpu_offload=True in XAttentionBSAPolicy - Pass enable_cpu_offload from model_runner to policy Memory savings: ~16GB for 1M seq, ~1.1GB for 72K seq Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -116,13 +116,15 @@ class SparsePolicy(ABC):
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
enable_cpu_offload: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-allocate GPU buffers for policy computation.
|
||||
|
||||
Called by the framework after KV cache allocation, but ONLY for GPU-only
|
||||
mode (not CPU offload mode). Override this to pre-allocate buffers that
|
||||
would otherwise be dynamically allocated during forward pass.
|
||||
Called by the framework after KV cache allocation. Implementations should
|
||||
use enable_cpu_offload to decide which buffers to allocate:
|
||||
- Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
|
||||
- GPU-only mode: additionally allocate GQA expansion buffers
|
||||
|
||||
This is separate from initialize() which is used for CPU offload metadata.
|
||||
|
||||
@@ -133,6 +135,7 @@ class SparsePolicy(ABC):
|
||||
max_seq_len: Maximum sequence length (for buffer sizing)
|
||||
dtype: Data type (typically float16/bfloat16)
|
||||
device: Target device (cuda)
|
||||
enable_cpu_offload: Whether CPU offload is enabled
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -175,6 +175,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
enable_cpu_offload: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-allocate GQA expansion buffers for GPU-only mode.
|
||||
@@ -235,7 +236,14 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
f"m/l shape={m_partial_shape} ({m_l_memory_mb:.1f} MB), "
|
||||
f"block_sums shape={block_sums_shape} ({block_sums_memory_mb:.1f} MB)")
|
||||
|
||||
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
|
||||
# Skip GQA buffers in offload mode
|
||||
# Chunked prefill uses compute_chunked_prefill() which handles GQA inline
|
||||
if enable_cpu_offload:
|
||||
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers (saves ~16GB for 1M seq)")
|
||||
return
|
||||
|
||||
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
|
||||
# Only allocate if GQA (num_heads != num_kv_heads)
|
||||
if num_heads == num_kv_heads:
|
||||
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user