perf: pre-allocate GQA buffers in XAttention policy

Add alloc_policy_metadata() method to SparsePolicy base class for
pre-allocating GPU buffers during initialization. This avoids
dynamic memory allocation during forward pass.

Changes:
- Add alloc_policy_metadata() to SparsePolicy base class
- Implement GQA buffer pre-allocation in XAttentionBSAPolicy
- Call alloc_policy_metadata() in model_runner for GPU-only mode
- Modify compute_prefill() to reuse pre-allocated buffers
- Add --gpu-util parameter to bench.py

Memory savings:
- Previously: 2x GQA expansion (~2GB for 64K)
- Now: 1x pre-allocated buffer (~1GB for 64K, reused)

Tested:
- GPU-only 32K: 5602 tok/s (512MB pre-allocated)
- GPU-only 64K: 4821 tok/s (1GB pre-allocated, gpu_util=0.7)
- Offload Full: PASSED (no changes to offload path)
- Offload XAttention: PASSED (uses compute_chunked_prefill)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 05:49:23 +08:00
parent 076656c9c2
commit a504bd873d
4 changed files with 116 additions and 7 deletions

View File

@@ -56,6 +56,8 @@ def main():
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
parser.add_argument("--enable-policy", action="store_true",
help="Enable sparse policy routing (FullAttentionPolicy by default)")
parser.add_argument("--gpu-util", type=float, default=0.9,
help="GPU memory utilization (default: 0.9)")
args = parser.parse_args()
path = os.path.expanduser(args.model)
@@ -78,6 +80,7 @@ def main():
max_model_len=max_len,
max_num_batched_tokens=max_len,
sparse_policy=sparse_policy,
gpu_memory_utilization=args.gpu_util,
)
# Warmup

View File

@@ -208,6 +208,19 @@ class ModelRunner:
device=torch.device("cuda"),
)
# GPU-only mode: pre-allocate policy metadata buffers
# This avoids dynamic GPU memory allocation during forward pass
if not config.enable_cpu_offload:
num_heads = hf_config.num_attention_heads // self.world_size
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_seq_len=config.max_model_len,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
# Log policy info (handle both enum and None cases)
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
logger.info(

View File

@@ -108,6 +108,34 @@ class SparsePolicy(ABC):
"""
pass
def alloc_policy_metadata(
self,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
"""
Pre-allocate GPU buffers for policy computation.
Called by the framework after KV cache allocation, but ONLY for GPU-only
mode (not CPU offload mode). Override this to pre-allocate buffers that
would otherwise be dynamically allocated during forward pass.
This is separate from initialize() which is used for CPU offload metadata.
Args:
num_heads: Number of query heads
num_kv_heads: Number of KV heads (for GQA)
head_dim: Dimension per head
max_seq_len: Maximum sequence length (for buffer sizing)
dtype: Data type (typically float16/bfloat16)
device: Target device (cuda)
"""
pass
@abstractmethod
def select_blocks(
self,

View File

@@ -122,6 +122,54 @@ class XAttentionBSAPolicy(SparsePolicy):
self._stats_total_selected_blocks = 0
self._stats_num_chunks = 0
# Pre-allocated GQA expansion buffers (GPU-only mode)
# Set by alloc_policy_metadata(), None if not pre-allocated
self._k_expanded: torch.Tensor | None = None
self._v_expanded: torch.Tensor | None = None
self._max_seq_len: int = 0
def alloc_policy_metadata(
self,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
"""
Pre-allocate GQA expansion buffers for GPU-only mode.
These buffers are used by compute_prefill() to avoid dynamic allocation
during forward pass. The buffers are sized for max_seq_len and sliced
to actual seq_len during use.
Memory usage: 2 * num_heads * max_seq_len * head_dim * dtype_size
For 64K seq, 32 heads, 128 dim, fp16: 2 * 32 * 65536 * 128 * 2 = 1 GB
Args:
num_heads: Number of query heads
num_kv_heads: Number of KV heads (for GQA)
head_dim: Dimension per head
max_seq_len: Maximum sequence length
dtype: Data type
device: Target device
"""
# Only allocate if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
# Also used for BSA which expects [seq_len, num_heads, head_dim]
shape = (1, num_heads, max_seq_len, head_dim)
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
self._max_seq_len = max_seq_len
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
# =========================================================================
# GPU-only methods (non-chunked)
# =========================================================================
@@ -234,8 +282,26 @@ class XAttentionBSAPolicy(SparsePolicy):
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
# Expand KV for GQA (xattn_estimate requires matching heads)
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
# Expand KV for GQA - use pre-allocated buffers if available
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
if self._k_expanded is not None and k_len <= self._max_seq_len:
# Use pre-allocated buffers with in-place expansion
K_exp = self._k_expanded[:, :, :k_len, :]
V_exp = self._v_expanded[:, :, :k_len, :]
# In-place GQA expansion: [1, num_kv_heads, k_len, head_dim] -> [1, num_heads, k_len, head_dim]
# Reshape K to [1, num_kv_heads, 1, k_len, head_dim] and broadcast to [1, num_kv_heads, num_groups, k_len, head_dim]
K_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
K.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
)
V_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
V.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
)
else:
# Fallback: dynamic allocation (when buffers not pre-allocated or seq too long)
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
else:
K_exp, V_exp = K, V
# Estimate block importance and get sparse mask
_, mask = xattn_estimate(
@@ -255,12 +321,11 @@ class XAttentionBSAPolicy(SparsePolicy):
# q, k, v need to be [seq_len, num_heads, head_dim]
q_bsa = q # Already [q_len, num_heads, head_dim]
# For GQA with BSA, we need to expand k, v to match num_heads
# k, v: [k_len, num_kv_heads, head_dim] -> [k_len, num_heads, head_dim]
# For GQA with BSA, reuse the expanded K_exp, V_exp (convert to BSA format)
# K_exp: [1, num_heads, k_len, head_dim] -> [k_len, num_heads, head_dim]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
k_bsa = k.repeat_interleave(num_groups, dim=1)
v_bsa = v.repeat_interleave(num_groups, dim=1)
k_bsa = K_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v_bsa = V_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
else:
k_bsa = k
v_bsa = v