perf: pre-allocate GQA buffers in XAttention policy

Add alloc_policy_metadata() method to SparsePolicy base class for
pre-allocating GPU buffers during initialization. This avoids
dynamic memory allocation during forward pass.

Changes:
- Add alloc_policy_metadata() to SparsePolicy base class
- Implement GQA buffer pre-allocation in XAttentionBSAPolicy
- Call alloc_policy_metadata() in model_runner for GPU-only mode
- Modify compute_prefill() to reuse pre-allocated buffers
- Add --gpu-util parameter to bench.py

Memory savings:
- Previously: 2x GQA expansion (~2GB for 64K)
- Now: 1x pre-allocated buffer (~1GB for 64K, reused)

Tested:
- GPU-only 32K: 5602 tok/s (512MB pre-allocated)
- GPU-only 64K: 4821 tok/s (1GB pre-allocated, gpu_util=0.7)
- Offload Full: PASSED (no changes to offload path)
- Offload XAttention: PASSED (uses compute_chunked_prefill)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 05:49:23 +08:00
parent 076656c9c2
commit a504bd873d
4 changed files with 116 additions and 7 deletions

View File

@@ -108,6 +108,34 @@ class SparsePolicy(ABC):
"""
pass
def alloc_policy_metadata(
self,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
"""
Pre-allocate GPU buffers for policy computation.
Called by the framework after KV cache allocation, but ONLY for GPU-only
mode (not CPU offload mode). Override this to pre-allocate buffers that
would otherwise be dynamically allocated during forward pass.
This is separate from initialize() which is used for CPU offload metadata.
Args:
num_heads: Number of query heads
num_kv_heads: Number of KV heads (for GQA)
head_dim: Dimension per head
max_seq_len: Maximum sequence length (for buffer sizing)
dtype: Data type (typically float16/bfloat16)
device: Target device (cuda)
"""
pass
@abstractmethod
def select_blocks(
self,