diff --git a/bench.py b/bench.py index 7fa0d69..8717ef1 100644 --- a/bench.py +++ b/bench.py @@ -56,6 +56,8 @@ def main(): help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)") parser.add_argument("--enable-policy", action="store_true", help="Enable sparse policy routing (FullAttentionPolicy by default)") + parser.add_argument("--gpu-util", type=float, default=0.9, + help="GPU memory utilization (default: 0.9)") args = parser.parse_args() path = os.path.expanduser(args.model) @@ -78,6 +80,7 @@ def main(): max_model_len=max_len, max_num_batched_tokens=max_len, sparse_policy=sparse_policy, + gpu_memory_utilization=args.gpu_util, ) # Warmup diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index aff540d..e4ad65c 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -208,6 +208,19 @@ class ModelRunner: device=torch.device("cuda"), ) + # GPU-only mode: pre-allocate policy metadata buffers + # This avoids dynamic GPU memory allocation during forward pass + if not config.enable_cpu_offload: + num_heads = hf_config.num_attention_heads // self.world_size + self.kvcache_manager.sparse_policy.alloc_policy_metadata( + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + max_seq_len=config.max_model_len, + dtype=hf_config.torch_dtype, + device=torch.device("cuda"), + ) + # Log policy info (handle both enum and None cases) policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL" logger.info( diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index f6b71a0..d1c3e33 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -108,6 +108,34 @@ class SparsePolicy(ABC): """ pass + def alloc_policy_metadata( + self, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_seq_len: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + """ + Pre-allocate GPU buffers for policy computation. + + Called by the framework after KV cache allocation, but ONLY for GPU-only + mode (not CPU offload mode). Override this to pre-allocate buffers that + would otherwise be dynamically allocated during forward pass. + + This is separate from initialize() which is used for CPU offload metadata. + + Args: + num_heads: Number of query heads + num_kv_heads: Number of KV heads (for GQA) + head_dim: Dimension per head + max_seq_len: Maximum sequence length (for buffer sizing) + dtype: Data type (typically float16/bfloat16) + device: Target device (cuda) + """ + pass + @abstractmethod def select_blocks( self, diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 277f6d6..6ec026c 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -122,6 +122,54 @@ class XAttentionBSAPolicy(SparsePolicy): self._stats_total_selected_blocks = 0 self._stats_num_chunks = 0 + # Pre-allocated GQA expansion buffers (GPU-only mode) + # Set by alloc_policy_metadata(), None if not pre-allocated + self._k_expanded: torch.Tensor | None = None + self._v_expanded: torch.Tensor | None = None + self._max_seq_len: int = 0 + + def alloc_policy_metadata( + self, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_seq_len: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + """ + Pre-allocate GQA expansion buffers for GPU-only mode. + + These buffers are used by compute_prefill() to avoid dynamic allocation + during forward pass. The buffers are sized for max_seq_len and sliced + to actual seq_len during use. + + Memory usage: 2 * num_heads * max_seq_len * head_dim * dtype_size + For 64K seq, 32 heads, 128 dim, fp16: 2 * 32 * 65536 * 128 * 2 = 1 GB + + Args: + num_heads: Number of query heads + num_kv_heads: Number of KV heads (for GQA) + head_dim: Dimension per head + max_seq_len: Maximum sequence length + dtype: Data type + device: Target device + """ + # Only allocate if GQA (num_heads != num_kv_heads) + if num_heads == num_kv_heads: + logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})") + return + + # Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format + # Also used for BSA which expects [seq_len, num_heads, head_dim] + shape = (1, num_heads, max_seq_len, head_dim) + self._k_expanded = torch.empty(shape, dtype=dtype, device=device) + self._v_expanded = torch.empty(shape, dtype=dtype, device=device) + self._max_seq_len = max_seq_len + + memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024) + logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB") + # ========================================================================= # GPU-only methods (non-chunked) # ========================================================================= @@ -234,8 +282,26 @@ class XAttentionBSAPolicy(SparsePolicy): K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim] V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim] - # Expand KV for GQA (xattn_estimate requires matching heads) - K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads) + # Expand KV for GQA - use pre-allocated buffers if available + if num_heads != num_kv_heads: + num_groups = num_heads // num_kv_heads + if self._k_expanded is not None and k_len <= self._max_seq_len: + # Use pre-allocated buffers with in-place expansion + K_exp = self._k_expanded[:, :, :k_len, :] + V_exp = self._v_expanded[:, :, :k_len, :] + # In-place GQA expansion: [1, num_kv_heads, k_len, head_dim] -> [1, num_heads, k_len, head_dim] + # Reshape K to [1, num_kv_heads, 1, k_len, head_dim] and broadcast to [1, num_kv_heads, num_groups, k_len, head_dim] + K_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_( + K.unsqueeze(2).expand(-1, -1, num_groups, -1, -1) + ) + V_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_( + V.unsqueeze(2).expand(-1, -1, num_groups, -1, -1) + ) + else: + # Fallback: dynamic allocation (when buffers not pre-allocated or seq too long) + K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads) + else: + K_exp, V_exp = K, V # Estimate block importance and get sparse mask _, mask = xattn_estimate( @@ -255,12 +321,11 @@ class XAttentionBSAPolicy(SparsePolicy): # q, k, v need to be [seq_len, num_heads, head_dim] q_bsa = q # Already [q_len, num_heads, head_dim] - # For GQA with BSA, we need to expand k, v to match num_heads - # k, v: [k_len, num_kv_heads, head_dim] -> [k_len, num_heads, head_dim] + # For GQA with BSA, reuse the expanded K_exp, V_exp (convert to BSA format) + # K_exp: [1, num_heads, k_len, head_dim] -> [k_len, num_heads, head_dim] if num_heads != num_kv_heads: - num_groups = num_heads // num_kv_heads - k_bsa = k.repeat_interleave(num_groups, dim=1) - v_bsa = v.repeat_interleave(num_groups, dim=1) + k_bsa = K_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim] + v_bsa = V_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim] else: k_bsa = k v_bsa = v