⚡ perf: pre-allocate GQA buffers in XAttention policy
Add alloc_policy_metadata() method to SparsePolicy base class for pre-allocating GPU buffers during initialization. This avoids dynamic memory allocation during forward pass. Changes: - Add alloc_policy_metadata() to SparsePolicy base class - Implement GQA buffer pre-allocation in XAttentionBSAPolicy - Call alloc_policy_metadata() in model_runner for GPU-only mode - Modify compute_prefill() to reuse pre-allocated buffers - Add --gpu-util parameter to bench.py Memory savings: - Previously: 2x GQA expansion (~2GB for 64K) - Now: 1x pre-allocated buffer (~1GB for 64K, reused) Tested: - GPU-only 32K: 5602 tok/s (512MB pre-allocated) - GPU-only 64K: 4821 tok/s (1GB pre-allocated, gpu_util=0.7) - Offload Full: PASSED (no changes to offload path) - Offload XAttention: PASSED (uses compute_chunked_prefill) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -108,6 +108,34 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def alloc_policy_metadata(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-allocate GPU buffers for policy computation.
|
||||
|
||||
Called by the framework after KV cache allocation, but ONLY for GPU-only
|
||||
mode (not CPU offload mode). Override this to pre-allocate buffers that
|
||||
would otherwise be dynamically allocated during forward pass.
|
||||
|
||||
This is separate from initialize() which is used for CPU offload metadata.
|
||||
|
||||
Args:
|
||||
num_heads: Number of query heads
|
||||
num_kv_heads: Number of KV heads (for GQA)
|
||||
head_dim: Dimension per head
|
||||
max_seq_len: Maximum sequence length (for buffer sizing)
|
||||
dtype: Data type (typically float16/bfloat16)
|
||||
device: Target device (cuda)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user