⚡ perf: pre-allocate GQA buffers in XAttention policy
Add alloc_policy_metadata() method to SparsePolicy base class for pre-allocating GPU buffers during initialization. This avoids dynamic memory allocation during forward pass. Changes: - Add alloc_policy_metadata() to SparsePolicy base class - Implement GQA buffer pre-allocation in XAttentionBSAPolicy - Call alloc_policy_metadata() in model_runner for GPU-only mode - Modify compute_prefill() to reuse pre-allocated buffers - Add --gpu-util parameter to bench.py Memory savings: - Previously: 2x GQA expansion (~2GB for 64K) - Now: 1x pre-allocated buffer (~1GB for 64K, reused) Tested: - GPU-only 32K: 5602 tok/s (512MB pre-allocated) - GPU-only 64K: 4821 tok/s (1GB pre-allocated, gpu_util=0.7) - Offload Full: PASSED (no changes to offload path) - Offload XAttention: PASSED (uses compute_chunked_prefill) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
3
bench.py
3
bench.py
@@ -56,6 +56,8 @@ def main():
|
||||
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
|
||||
parser.add_argument("--enable-policy", action="store_true",
|
||||
help="Enable sparse policy routing (FullAttentionPolicy by default)")
|
||||
parser.add_argument("--gpu-util", type=float, default=0.9,
|
||||
help="GPU memory utilization (default: 0.9)")
|
||||
args = parser.parse_args()
|
||||
|
||||
path = os.path.expanduser(args.model)
|
||||
@@ -78,6 +80,7 @@ def main():
|
||||
max_model_len=max_len,
|
||||
max_num_batched_tokens=max_len,
|
||||
sparse_policy=sparse_policy,
|
||||
gpu_memory_utilization=args.gpu_util,
|
||||
)
|
||||
|
||||
# Warmup
|
||||
|
||||
@@ -208,6 +208,19 @@ class ModelRunner:
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# GPU-only mode: pre-allocate policy metadata buffers
|
||||
# This avoids dynamic GPU memory allocation during forward pass
|
||||
if not config.enable_cpu_offload:
|
||||
num_heads = hf_config.num_attention_heads // self.world_size
|
||||
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
max_seq_len=config.max_model_len,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# Log policy info (handle both enum and None cases)
|
||||
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
||||
logger.info(
|
||||
|
||||
@@ -108,6 +108,34 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def alloc_policy_metadata(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-allocate GPU buffers for policy computation.
|
||||
|
||||
Called by the framework after KV cache allocation, but ONLY for GPU-only
|
||||
mode (not CPU offload mode). Override this to pre-allocate buffers that
|
||||
would otherwise be dynamically allocated during forward pass.
|
||||
|
||||
This is separate from initialize() which is used for CPU offload metadata.
|
||||
|
||||
Args:
|
||||
num_heads: Number of query heads
|
||||
num_kv_heads: Number of KV heads (for GQA)
|
||||
head_dim: Dimension per head
|
||||
max_seq_len: Maximum sequence length (for buffer sizing)
|
||||
dtype: Data type (typically float16/bfloat16)
|
||||
device: Target device (cuda)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
|
||||
@@ -122,6 +122,54 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
self._stats_total_selected_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
# Pre-allocated GQA expansion buffers (GPU-only mode)
|
||||
# Set by alloc_policy_metadata(), None if not pre-allocated
|
||||
self._k_expanded: torch.Tensor | None = None
|
||||
self._v_expanded: torch.Tensor | None = None
|
||||
self._max_seq_len: int = 0
|
||||
|
||||
def alloc_policy_metadata(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-allocate GQA expansion buffers for GPU-only mode.
|
||||
|
||||
These buffers are used by compute_prefill() to avoid dynamic allocation
|
||||
during forward pass. The buffers are sized for max_seq_len and sliced
|
||||
to actual seq_len during use.
|
||||
|
||||
Memory usage: 2 * num_heads * max_seq_len * head_dim * dtype_size
|
||||
For 64K seq, 32 heads, 128 dim, fp16: 2 * 32 * 65536 * 128 * 2 = 1 GB
|
||||
|
||||
Args:
|
||||
num_heads: Number of query heads
|
||||
num_kv_heads: Number of KV heads (for GQA)
|
||||
head_dim: Dimension per head
|
||||
max_seq_len: Maximum sequence length
|
||||
dtype: Data type
|
||||
device: Target device
|
||||
"""
|
||||
# Only allocate if GQA (num_heads != num_kv_heads)
|
||||
if num_heads == num_kv_heads:
|
||||
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
||||
return
|
||||
|
||||
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
|
||||
# Also used for BSA which expects [seq_len, num_heads, head_dim]
|
||||
shape = (1, num_heads, max_seq_len, head_dim)
|
||||
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||
self._max_seq_len = max_seq_len
|
||||
|
||||
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
# =========================================================================
|
||||
@@ -234,8 +282,26 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||
|
||||
# Expand KV for GQA (xattn_estimate requires matching heads)
|
||||
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
|
||||
# Expand KV for GQA - use pre-allocated buffers if available
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
if self._k_expanded is not None and k_len <= self._max_seq_len:
|
||||
# Use pre-allocated buffers with in-place expansion
|
||||
K_exp = self._k_expanded[:, :, :k_len, :]
|
||||
V_exp = self._v_expanded[:, :, :k_len, :]
|
||||
# In-place GQA expansion: [1, num_kv_heads, k_len, head_dim] -> [1, num_heads, k_len, head_dim]
|
||||
# Reshape K to [1, num_kv_heads, 1, k_len, head_dim] and broadcast to [1, num_kv_heads, num_groups, k_len, head_dim]
|
||||
K_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
|
||||
K.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
|
||||
)
|
||||
V_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
|
||||
V.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
|
||||
)
|
||||
else:
|
||||
# Fallback: dynamic allocation (when buffers not pre-allocated or seq too long)
|
||||
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
|
||||
else:
|
||||
K_exp, V_exp = K, V
|
||||
|
||||
# Estimate block importance and get sparse mask
|
||||
_, mask = xattn_estimate(
|
||||
@@ -255,12 +321,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
# q, k, v need to be [seq_len, num_heads, head_dim]
|
||||
q_bsa = q # Already [q_len, num_heads, head_dim]
|
||||
|
||||
# For GQA with BSA, we need to expand k, v to match num_heads
|
||||
# k, v: [k_len, num_kv_heads, head_dim] -> [k_len, num_heads, head_dim]
|
||||
# For GQA with BSA, reuse the expanded K_exp, V_exp (convert to BSA format)
|
||||
# K_exp: [1, num_heads, k_len, head_dim] -> [k_len, num_heads, head_dim]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
k_bsa = k.repeat_interleave(num_groups, dim=1)
|
||||
v_bsa = v.repeat_interleave(num_groups, dim=1)
|
||||
k_bsa = K_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
v_bsa = V_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
else:
|
||||
k_bsa = k
|
||||
v_bsa = v
|
||||
|
||||
Reference in New Issue
Block a user