diff --git a/bench.py b/bench.py index 05b8b47..899cc0a 100644 --- a/bench.py +++ b/bench.py @@ -2,6 +2,7 @@ import os import time from random import randint, seed from nanovllm import LLM, SamplingParams +from nanovllm.config import SparsePolicyType def bench_decode(llm, num_seqs, input_len, output_len): @@ -23,8 +24,8 @@ def bench_decode(llm, num_seqs, input_len, output_len): print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)") -def bench_prefill(llm, num_seqs, input_len): - """Benchmark prefill performance""" +def bench_prefill(llm, num_seqs, input_len, label=""): + """Benchmark prefill performance. Returns throughput.""" seed(0) # Fixed length input, minimal output to focus on prefill prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)] @@ -35,7 +36,28 @@ def bench_prefill(llm, num_seqs, input_len): t = time.time() - t total_input_tokens = num_seqs * input_len throughput = total_input_tokens / t - print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + label_str = f" ({label})" if label else "" + print(f"[Prefill{label_str}] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + return throughput + + +def create_llm(path, max_len, enable_minference=False, minference_budget=0.3, + minference_vertical=1000, minference_slash=6096, + gpu_utilization=0.8): + """Create LLM with specified configuration.""" + kwargs = { + "enforce_eager": True, # MInference uses Triton, not compatible with CUDA graphs + "max_model_len": max_len, + "max_num_batched_tokens": max_len, + "gpu_memory_utilization": gpu_utilization, + } + if enable_minference: + kwargs["sparse_policy"] = SparsePolicyType.MINFERENCE + kwargs["minference_adaptive_budget"] = minference_budget + kwargs["minference_vertical_size"] = minference_vertical + kwargs["minference_slash_size"] = minference_slash + + return LLM(path, **kwargs) def main(): @@ -46,24 +68,17 @@ def main(): parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)") parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)") parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks") + parser.add_argument("--enable-minference", action="store_true", help="Enable MInference sparse prefill") + parser.add_argument("--minference-budget", type=float, default=0.3, help="MInference adaptive budget (default: 0.3, use 0 for fixed mode)") + parser.add_argument("--minference-vertical", type=int, default=1000, help="Fixed vertical_size (only used when budget=0)") + parser.add_argument("--minference-slash", type=int, default=6096, help="Fixed slash_size (only used when budget=0)") + parser.add_argument("--gpu-utilization", type=float, default=0.9, help="GPU memory utilization (default: 0.9)") + parser.add_argument("--compare", action="store_true", help="Compare baseline vs MInference (runs both)") args = parser.parse_args() path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") max_len = args.max_len - print(f"\n[nanovllm GPU] max_len={max_len}") - - llm = LLM( - path, - enforce_eager=False, - max_model_len=max_len, - max_num_batched_tokens=max_len, - ) - - # Warmup - print("\nWarming up...") - llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10)) - # Default input lengths prefill_input_len = args.input_len if args.input_len else max_len - 1 decode_input_len = args.input_len if args.input_len else max_len - args.output_len @@ -72,17 +87,128 @@ def main(): run_prefill = not args.bench_decode or args.bench_all run_decode = args.bench_decode or args.bench_all - if run_prefill: - print("\n" + "=" * 60) - print("Prefill Benchmark (nanovllm GPU)") - print("=" * 60) - bench_prefill(llm, num_seqs=1, input_len=prefill_input_len) + # Convert budget=0 to None for fixed mode + minference_budget = args.minference_budget if args.minference_budget > 0 else None - if run_decode: - print("\n" + "=" * 60) - print("Decode Benchmark (nanovllm GPU)") - print("=" * 60) - bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len) + if args.compare: + # Compare baseline vs MInference using subprocesses to avoid NCCL issues + import subprocess + import sys + + print(f"\n{'='*60}") + print(f"Baseline vs MInference Comparison") + print(f"Input length: {prefill_input_len} tokens") + if minference_budget is not None: + print(f"MInference mode: adaptive (budget={minference_budget}, {minference_budget*100:.0f}% compute)") + else: + print(f"MInference mode: fixed (vertical={args.minference_vertical}, slash={args.minference_slash})") + print(f"{'='*60}") + + # Get PYTHONPATH for subprocess + pythonpath = os.environ.get("PYTHONPATH", "") + + # Run baseline in subprocess + print(f"\n[1/2] Running baseline (FULL attention)...") + cmd_baseline = [ + sys.executable, __file__, + "--input-len", str(prefill_input_len), + "--max-len", str(max_len), + "--gpu-utilization", str(args.gpu_utilization), + ] + env = os.environ.copy() + result = subprocess.run(cmd_baseline, capture_output=True, text=True, env=env) + print(result.stdout) + if result.returncode != 0: + print(f"Error: {result.stderr}") + return + + # Parse baseline throughput + baseline_throughput = None + for line in result.stdout.split('\n'): + if "Throughput:" in line and "tok/s" in line: + # Extract throughput value + import re + match = re.search(r'Throughput:\s*([\d.]+)tok/s', line) + if match: + baseline_throughput = float(match.group(1)) + + # Run MInference in subprocess + if minference_budget is not None: + print(f"\n[2/2] Running MInference (budget={minference_budget})...") + else: + print(f"\n[2/2] Running MInference (vertical={args.minference_vertical}, slash={args.minference_slash})...") + cmd_minference = [ + sys.executable, __file__, + "--input-len", str(prefill_input_len), + "--max-len", str(max_len), + "--gpu-utilization", str(args.gpu_utilization), + "--enable-minference", + "--minference-budget", str(args.minference_budget), + "--minference-vertical", str(args.minference_vertical), + "--minference-slash", str(args.minference_slash), + ] + result = subprocess.run(cmd_minference, capture_output=True, text=True, env=env) + print(result.stdout) + if result.returncode != 0: + print(f"Error: {result.stderr}") + return + + # Parse MInference throughput + minference_throughput = None + for line in result.stdout.split('\n'): + if "Throughput:" in line and "tok/s" in line: + import re + match = re.search(r'Throughput:\s*([\d.]+)tok/s', line) + if match: + minference_throughput = float(match.group(1)) + + # Comparison + if baseline_throughput and minference_throughput: + print(f"\n{'='*60}") + print(f"Results Summary") + print(f"{'='*60}") + print(f"Baseline: {baseline_throughput:,.0f} tok/s") + print(f"MInference: {minference_throughput:,.0f} tok/s") + speedup = minference_throughput / baseline_throughput + if speedup >= 1.0: + print(f"Speedup: {speedup:.2f}x faster") + else: + print(f"Slowdown: {1/speedup:.2f}x slower") + print(f"{'='*60}") + else: + print("Failed to parse throughput values") + + else: + # Single run mode + mode = "MInference" if args.enable_minference else "GPU" + print(f"\n[nanovllm {mode}] max_len={max_len}") + if args.enable_minference: + if minference_budget is not None: + print(f"MInference mode: adaptive (budget={minference_budget})") + else: + print(f"MInference mode: fixed (vertical={args.minference_vertical}, slash={args.minference_slash})") + + llm = create_llm(path, max_len, enable_minference=args.enable_minference, + minference_budget=minference_budget, + minference_vertical=args.minference_vertical, + minference_slash=args.minference_slash, + gpu_utilization=args.gpu_utilization) + + # Warmup + print("\nWarming up...") + llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10)) + + if run_prefill: + print("\n" + "=" * 60) + print(f"Prefill Benchmark (nanovllm {mode})") + print("=" * 60) + bench_prefill(llm, num_seqs=1, input_len=prefill_input_len) + + if run_decode: + print("\n" + "=" * 60) + print(f"Decode Benchmark (nanovllm {mode})") + print("=" * 60) + bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len) if __name__ == "__main__": diff --git a/nanovllm/config.py b/nanovllm/config.py index 2be7b8d..51298db 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -9,6 +9,7 @@ class SparsePolicyType(Enum): """Sparse attention policy types.""" FULL = auto() # No sparse attention (load all blocks) QUEST = auto() # Query-aware Top-K block selection (decode only) + MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only) @dataclass @@ -39,10 +40,18 @@ class Config: # Sparse attention configuration # Quest: decode-only sparse attention with Top-K block selection # FULL: no sparse attention (load all blocks) + # MINFERENCE: MInference vertical + slash sparse prefill (GPU-only) sparse_policy: SparsePolicyType = SparsePolicyType.FULL sparse_topk_blocks: int = 8 # Top-K blocks for Quest sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold + # MInference configuration (used when sparse_policy == MINFERENCE) + minference_adaptive_budget: float = 0.3 # Budget as fraction of seq_len (None to use fixed sizes) + minference_vertical_size: int = 1000 # Fixed vertical size (if adaptive_budget is None) + minference_slash_size: int = 6096 # Fixed slash size (if adaptive_budget is None) + minference_num_sink_tokens: int = 30 # Sink tokens to always keep + minference_num_recent_diags: int = 100 # Recent diagonals to always keep + def __post_init__(self): assert os.path.isdir(self.model) assert self.kvcache_block_size % 256 == 0 diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 05d77de..1e9eccd 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -4,7 +4,7 @@ import torch.distributed as dist from multiprocessing.synchronize import Event from multiprocessing.shared_memory import SharedMemory -from nanovllm.config import Config +from nanovllm.config import Config, SparsePolicyType from nanovllm.engine.sequence import Sequence from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.layers.sampler import GreedySampler @@ -35,7 +35,10 @@ class ModelRunner: self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = GreedySampler() - + + # Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache) + self.sparse_prefill_policy = None + #> Disable warmup for debugging self.warmup_model() @@ -148,6 +151,24 @@ class ModelRunner: # Create KV cache manager using factory self.kvcache_manager: KVCacheManager = create_kvcache_manager(config) + # Create sparse prefill policy for GPU-only path + # This is separate from CPU offload sparse policy (which uses select_blocks) + self.sparse_prefill_policy = None + if not config.enable_cpu_offload and config.sparse_policy != SparsePolicyType.FULL: + from nanovllm.kvcache.sparse import create_sparse_policy + policy = create_sparse_policy( + config.sparse_policy, + vertical_size=config.minference_vertical_size, + slash_size=config.minference_slash_size, + adaptive_budget=config.minference_adaptive_budget, + num_sink_tokens=config.minference_num_sink_tokens, + num_recent_diags=config.minference_num_recent_diags, + ) + # Only use if policy supports sparse prefill + if policy.supports_prefill: + self.sparse_prefill_policy = policy + logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}") + # Allocate cache through manager self.kvcache_manager.allocate_cache( num_layers=hf_config.num_hidden_layers, @@ -329,7 +350,10 @@ class ModelRunner: cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables) + + set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + slot_mapping, None, block_tables, + sparse_prefill_policy=self.sparse_prefill_policy) return input_ids, positions def prepare_decode(self, seqs: list[Sequence]): diff --git a/nanovllm/kvcache/sparse/__init__.py b/nanovllm/kvcache/sparse/__init__.py index ae8e922..756a1ef 100644 --- a/nanovllm/kvcache/sparse/__init__.py +++ b/nanovllm/kvcache/sparse/__init__.py @@ -23,6 +23,7 @@ from nanovllm.config import SparsePolicyType from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager +from nanovllm.kvcache.sparse.minference import MInferencePolicy def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy: @@ -55,6 +56,15 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic ) return QuestPolicy(config) + elif policy_type == SparsePolicyType.MINFERENCE: + return MInferencePolicy( + vertical_size=kwargs.get("vertical_size", 1000), + slash_size=kwargs.get("slash_size", 6096), + adaptive_budget=kwargs.get("adaptive_budget", 0.3), + num_sink_tokens=kwargs.get("num_sink_tokens", 30), + num_recent_diags=kwargs.get("num_recent_diags", 100), + ) + else: raise ValueError(f"Unknown policy type: {policy_type}") @@ -67,5 +77,6 @@ __all__ = [ "QuestPolicy", "QuestConfig", "BlockMetadataManager", + "MInferencePolicy", "create_sparse_policy", ] diff --git a/nanovllm/kvcache/sparse/minference.py b/nanovllm/kvcache/sparse/minference.py new file mode 100644 index 0000000..6b861aa --- /dev/null +++ b/nanovllm/kvcache/sparse/minference.py @@ -0,0 +1,353 @@ +""" +MInference sparse attention policy. + +Implements vertical + slash sparse pattern estimation using the last 64 query tokens. +Reference: MInference paper (https://arxiv.org/abs/2407.02490) +""" + +import math +from typing import List, Tuple, Optional +import torch +import torch.nn.functional as F + +from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext + + +class MInferencePolicy(SparsePolicy): + """ + MInference sparse prefill policy using vertical + slash pattern. + + This policy estimates sparse attention patterns by analyzing attention + scores from the last 64 query tokens, then selects: + - Vertical: Key positions that are important across all queries + - Slash: Diagonal bands (local context) + + The estimated pattern is then used to compute sparse attention. + + Note: This policy is designed for GPU-only prefill. For CPU offload, + the pattern estimation and sparse attention will be handled differently. + """ + + supports_prefill = True + supports_decode = False # MInference is prefill-only sparse strategy + + def __init__( + self, + vertical_size: int = 1000, + slash_size: int = 6096, + adaptive_budget: Optional[float] = 0.3, + num_sink_tokens: int = 30, + num_recent_diags: int = 100, + ): + """ + Initialize MInference policy. + + Args: + vertical_size: Number of vertical (column) positions to keep + slash_size: Number of diagonal bands to keep + adaptive_budget: If set, compute budget as fraction of seq_len + (overrides vertical_size and slash_size) + num_sink_tokens: Number of initial sink tokens to always keep + num_recent_diags: Number of recent diagonals to always keep + """ + self.vertical_size = vertical_size + self.slash_size = slash_size + self.adaptive_budget = adaptive_budget + self.num_sink_tokens = num_sink_tokens + self.num_recent_diags = num_recent_diags + + # Cache for last-q causal mask + self._last_q_mask_cache: dict = {} + + def _get_causal_mask(self, last_q: int, seq_len: int, device: torch.device) -> torch.Tensor: + """Get causal mask for last-q attention.""" + cache_key = (last_q, seq_len, device) + if cache_key not in self._last_q_mask_cache: + # Create mask where last_q queries can attend to all previous positions + # Shape: [last_q, seq_len] + mask = torch.ones(last_q, seq_len, device=device, dtype=torch.bool) + # Apply causal constraint for the last last_q positions + # Query i (from last_q) can only attend to positions <= (seq_len - last_q + i) + for i in range(last_q): + mask[i, seq_len - last_q + i + 1:] = False + self._last_q_mask_cache[cache_key] = mask + return self._last_q_mask_cache[cache_key] + + def estimate_pattern( + self, + q: torch.Tensor, + k: torch.Tensor, + layer_id: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Estimate vertical + slash sparse pattern using last 64 query tokens. + Memory-optimized for long sequences (64K+). + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + layer_id: Current layer index (for potential layer-specific patterns) + + Returns: + Tuple of (vertical_indices, slash_indices): + - vertical_indices: [num_heads, vertical_size] - important K positions + - slash_indices: [num_heads, slash_size] - diagonal offsets + """ + seq_len = q.shape[0] + num_heads = q.shape[1] + head_dim = q.shape[2] + num_kv_heads = k.shape[1] + + # Adaptive budget + if self.adaptive_budget is not None: + budget = int(seq_len * self.adaptive_budget) + vertical_size = max(self.num_sink_tokens + 1, int(budget * 0.2)) + slash_size = max(self.num_recent_diags + 1, int(budget * 0.8)) + else: + vertical_size = self.vertical_size + slash_size = self.slash_size + + # Use last 64 Q tokens for estimation + last_q = min(64, seq_len) + q_last = q[-last_q:] # [last_q, heads, dim] - this is a view, not a copy + + # Handle GQA: if num_kv_heads < num_heads, we need to expand K + if num_kv_heads < num_heads: + num_groups = num_heads // num_kv_heads + k_work = k.repeat_interleave(num_groups, dim=1) + else: + k_work = k + + # Compute attention scores: [heads, last_q, seq_len] + scale = 1.0 / math.sqrt(head_dim) + qk = torch.einsum('qhd,khd->hqk', q_last, k_work) * scale + + # Free k_work if it was a copy + if num_kv_heads < num_heads: + del k_work + + # Apply causal mask for last positions (in-place) + causal_mask = self._get_causal_mask(last_q, seq_len, q.device) + qk.masked_fill_(~causal_mask.unsqueeze(0), float('-inf')) + + # Softmax (in-place where possible) + qk = F.softmax(qk, dim=-1, dtype=torch.float32) + + # === Vertical pattern === + # Sum across query dimension -> importance of each K position + vertical_scores = qk.sum(dim=1) # [heads, seq_len] + + # Force keep first num_sink_tokens (attention sinks) - in-place + vertical_scores[:, :self.num_sink_tokens] = float('inf') + + # Select top-k + actual_vertical = min(vertical_size, seq_len) + vertical_indices = vertical_scores.topk(actual_vertical, dim=-1).indices + vertical_indices = vertical_indices.sort(dim=-1).values + del vertical_scores + + # === Slash pattern === + # Create diagonal index matrix: [last_q, seq_len] with int32 to save memory + q_indices = torch.arange(last_q, device=q.device, dtype=torch.int32).unsqueeze(1) + k_indices = torch.arange(seq_len, device=q.device, dtype=torch.int32).unsqueeze(0) + diag_indices = (seq_len - last_q + q_indices) - k_indices # [last_q, seq_len] + del q_indices + + # Create causal mask for slash computation + q_pos = seq_len - last_q + torch.arange(last_q, device=q.device, dtype=torch.int32).unsqueeze(1) + slash_causal_mask = k_indices <= q_pos + del q_pos, k_indices + + # Clamp diagonal indices to valid range + diag_indices = diag_indices.clamp(0, seq_len - 1) + + # Apply causal mask to qk (in-place) for slash computation + qk[:, ~slash_causal_mask] = 0 + del slash_causal_mask + + # Accumulate scores per diagonal - process in batches to save memory + slash_scores = torch.zeros(num_heads, seq_len, device=q.device, dtype=torch.float32) + + # Process heads in chunks to reduce peak memory for diag_indices_expanded + chunk_size = min(8, num_heads) # Process 8 heads at a time + for h_start in range(0, num_heads, chunk_size): + h_end = min(h_start + chunk_size, num_heads) + n_heads_chunk = h_end - h_start + + # Expand diag_indices only for this chunk + diag_chunk = diag_indices.unsqueeze(0).expand(n_heads_chunk, -1, -1).long() + qk_chunk = qk[h_start:h_end] + + slash_scores[h_start:h_end].scatter_add_( + 1, + diag_chunk.reshape(n_heads_chunk, -1), + qk_chunk.reshape(n_heads_chunk, -1) + ) + del diag_chunk, qk_chunk + + del diag_indices, qk + + # Force keep first num_recent_diags (in-place) + slash_scores[:, :self.num_recent_diags] = float('inf') + + # Select top-k diagonal indices + actual_slash = min(slash_size, seq_len) + slash_indices = slash_scores.topk(actual_slash, dim=-1).indices + slash_indices = slash_indices.sort(dim=-1).values + del slash_scores + + return vertical_indices, slash_indices + + def select_blocks( + self, + available_blocks: List[int], + ctx: PolicyContext, + ) -> List[int]: + """ + Select blocks for chunked CPU offload mode. + + For MInference in GPU-only mode, this method is not used. + In CPU offload mode, it would select blocks based on the sparse pattern. + + For now, return all blocks (full attention fallback). + """ + # MInference pattern is computed in attention.forward() + # For CPU offload integration (Phase B), this would use the pattern + return available_blocks + + def reset(self) -> None: + """Reset policy state.""" + self._last_q_mask_cache.clear() + + def sparse_prefill_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + ) -> torch.Tensor: + """ + Compute MInference sparse attention for prefill. + + Uses vertical + slash pattern to compute sparse attention efficiently. + Memory-optimized to handle long sequences (64K+) by freeing intermediate tensors. + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + layer_id: Current transformer layer index + + Returns: + Attention output [seq_len, num_heads, head_dim] + """ + from minference.ops.pit_sparse_flash_attention_v2 import _triton_mixed_sparse_attention + from minference.cuda import convert_vertical_slash_indexes + + seq_len = q.shape[0] + num_heads = q.shape[1] + head_dim = q.shape[2] + num_kv_heads = k.shape[1] + + # Estimate sparse pattern (uses temporary memory for qk scores) + vertical_indices, slash_indices = self.estimate_pattern(q, k, layer_id) + # Free any cached memory from pattern estimation + torch.cuda.empty_cache() + + # Triton sparse attention kernel parameters + block_size_M = 64 + block_size_N = 64 + + # Calculate padding + pad = (block_size_M - seq_len) & (block_size_M - 1) + need_head_pad = head_dim not in [16, 32, 64, 128, 256, 512] + head_pad = (2 ** math.ceil(math.log2(head_dim)) - head_dim) if need_head_pad else 0 + + # Handle GQA: expand K/V to match query heads + # Do this BEFORE creating batched tensors to avoid double copies + if num_kv_heads < num_heads: + num_groups = num_heads // num_kv_heads + # Use repeat_interleave for memory-efficient expansion + k_work = k.repeat_interleave(num_groups, dim=1) + v_work = v.repeat_interleave(num_groups, dim=1) + else: + k_work = k + v_work = v + + # Transform Q to [batch, heads, seq, dim] format with padding in one step + # This avoids creating intermediate copies + if pad > 0 or head_pad > 0: + q_batched = torch.nn.functional.pad( + q.unsqueeze(0).transpose(1, 2), + [0, head_pad, 0, pad, 0, 0, 0, 0] + ).contiguous() + else: + q_batched = q.unsqueeze(0).transpose(1, 2).contiguous() + + # Transform K to batched format + if pad > 0 or head_pad > 0: + k_batched = torch.nn.functional.pad( + k_work.unsqueeze(0).transpose(1, 2), + [0, head_pad, 0, pad, 0, 0, 0, 0] + ).contiguous() + else: + k_batched = k_work.unsqueeze(0).transpose(1, 2).contiguous() + + # Free k_work if it was a copy (GQA case) + if num_kv_heads < num_heads: + del k_work + + # Transform V to batched format + if pad > 0 or head_pad > 0: + v_batched = torch.nn.functional.pad( + v_work.unsqueeze(0).transpose(1, 2), + [0, head_pad, 0, pad, 0, 0, 0, 0] + ).contiguous() + else: + v_batched = v_work.unsqueeze(0).transpose(1, 2).contiguous() + + # Free v_work if it was a copy (GQA case) + if num_kv_heads < num_heads: + del v_work + torch.cuda.empty_cache() + + # Prepare indices for Triton kernel + v_idx = vertical_indices.to(torch.int32).reshape((1, num_heads, -1)) + v_idx = v_idx.sort(dim=-1, descending=False)[0].contiguous() + del vertical_indices + + s_idx = slash_indices.to(torch.int32).reshape((1, num_heads, -1)) + s_idx = s_idx.sort(dim=-1, descending=True)[0].contiguous() + del slash_indices + + seqlens = torch.tensor([seq_len], dtype=torch.int32, device=q.device) + sm_scale = head_dim ** -0.5 + + # Convert vertical+slash indices to block sparse format + block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes( + seqlens, v_idx, s_idx, seq_len, block_size_M, block_size_N, + ) + del v_idx, s_idx + + # Call Triton mixed sparse attention kernel + o = _triton_mixed_sparse_attention( + q_batched, k_batched, v_batched, seqlens, + block_count, block_offset, column_count, column_index, + sm_scale, block_size_M, block_size_N, + ) + + # Free input tensors immediately after kernel call + del q_batched, k_batched, v_batched + del block_count, block_offset, column_count, column_index + + # Remove padding and convert back to [seq_len, num_heads, head_dim] + o = o[..., :seq_len, :head_dim] + o = o.transpose(1, 2).squeeze(0).contiguous() + + return o + + def __repr__(self) -> str: + return (f"MInferencePolicy(" + f"adaptive_budget={self.adaptive_budget}, " + f"vertical_size={self.vertical_size}, " + f"slash_size={self.slash_size})") diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index 2813745..b1b234f 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -183,5 +183,32 @@ class SparsePolicy(ABC): """ pass + def sparse_prefill_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + ) -> torch.Tensor: + """ + Compute sparse attention for prefill phase. + + This method is called when supports_prefill=True and the policy + is used for GPU-only sparse prefill (no CPU offload). + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + layer_id: Current transformer layer index + + Returns: + Attention output [seq_len, num_heads, head_dim] + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement sparse_prefill_attention. " + "Set supports_prefill=False or implement this method." + ) + def __repr__(self) -> str: return f"{self.__class__.__name__}()" diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 028626c..eef2a58 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -140,6 +140,11 @@ class Attention(nn.Module): max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) + elif context.sparse_prefill_policy is not None: + # Sparse prefill (GPU-only) - delegate to policy + o = context.sparse_prefill_policy.sparse_prefill_attention( + q, k, v, self.layer_id + ) else: o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index 23a1483..7828120 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -35,6 +35,10 @@ class Context: # Current chunk index for ring buffer pipeline (prefill only) current_chunk_idx: int = 0 + # Sparse prefill attention support (GPU-only path) + # When set, uses policy.sparse_prefill_attention() instead of FlashAttention + sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True + _CONTEXT = Context() @@ -60,6 +64,7 @@ def set_context( decode_pos_in_block=0, decode_start_pos_in_block=0, current_chunk_idx=0, + sparse_prefill_policy=None, ): global _CONTEXT _CONTEXT = Context( @@ -79,6 +84,7 @@ def set_context( decode_pos_in_block=decode_pos_in_block, decode_start_pos_in_block=decode_start_pos_in_block, current_chunk_idx=current_chunk_idx, + sparse_prefill_policy=sparse_prefill_policy, ) diff --git a/tests/test_minference_gpu.py b/tests/test_minference_gpu.py new file mode 100644 index 0000000..72d87db --- /dev/null +++ b/tests/test_minference_gpu.py @@ -0,0 +1,163 @@ +""" +Needle-in-haystack test with MInference sparse attention. + +Tests: MInference sparse prefill on GPU-only path (no CPU offload). +This validates that MInference's vertical + slash sparse pattern can +correctly retrieve information from long context. +""" + +import os +os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" + +import argparse +from nanovllm import LLM, SamplingParams +from nanovllm.config import SparsePolicyType +from utils import generate_needle_prompt, check_needle_answer + + +def run_minference_test( + model_path: str, + max_model_len: int = 16384, + input_len: int = 8192, + needle_position: float = 0.5, + needle_value: str = "7492", + adaptive_budget: float = 0.3, + max_new_tokens: int = 32, + verbose: bool = True, +) -> bool: + """ + Run needle test with MInference sparse prefill attention. + + Args: + model_path: Path to model + max_model_len: Maximum model context length + input_len: Target input sequence length + needle_position: Where to place needle (0.0-1.0) + needle_value: The secret value to find + adaptive_budget: MInference budget as fraction of seq_len + max_new_tokens: Maximum tokens to generate + verbose: Print detailed output + + Returns: + True if test passed, False otherwise + """ + if verbose: + print(f"\n{'='*60}") + print(f"MInference Sparse Prefill Test (GPU-only)") + print(f"{'='*60}") + print(f"Model: {model_path}") + print(f"Max model len: {max_model_len}") + print(f"Input length: {input_len}") + print(f"Needle position: {needle_position:.0%}") + print(f"Needle value: {needle_value}") + print(f"Adaptive budget: {adaptive_budget}") + print(f"{'='*60}\n") + + # Initialize LLM with MInference sparse attention + llm = LLM( + model_path, + enforce_eager=True, + max_model_len=max_model_len, + max_num_batched_tokens=max_model_len, + enable_cpu_offload=False, # GPU-only + sparse_policy=SparsePolicyType.MINFERENCE, + minference_adaptive_budget=adaptive_budget, + ) + + # Generate needle prompt + prompt, expected = generate_needle_prompt( + tokenizer=llm.tokenizer, + target_length=input_len, + needle_position=needle_position, + needle_value=needle_value, + ) + + # Generate output + sampling_params = SamplingParams( + temperature=0.6, + max_tokens=max_new_tokens, + ) + outputs = llm.generate([prompt], sampling_params, use_tqdm=True) + + # Check result + output_text = outputs[0]["text"] + output_token_ids = outputs[0]["token_ids"] + passed = check_needle_answer(output_text, expected) + + if verbose: + print(f"\n{'='*60}") + print(f"Result") + print(f"{'='*60}") + print(f"Expected: {expected}") + print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}") + print(f"Output: {output_text[:200]}...") + print(f"Status: {'PASSED' if passed else 'FAILED'}") + print(f"{'='*60}\n") + + return passed + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Needle-in-haystack test with MInference sparse prefill" + ) + parser.add_argument( + "--model", "-m", + type=str, + default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"), + help="Path to model" + ) + parser.add_argument( + "--max-model-len", + type=int, + default=16 * 1024, + help="Maximum model context length" + ) + parser.add_argument( + "--input-len", + type=int, + default=8 * 1024, + help="Target input sequence length" + ) + parser.add_argument( + "--needle-position", + type=float, + default=0.5, + help="Needle position (0.0=start, 0.5=middle, 1.0=end)" + ) + parser.add_argument( + "--needle-value", + type=str, + default="7492", + help="The secret value to hide" + ) + parser.add_argument( + "--adaptive-budget", + type=float, + default=0.3, + help="MInference adaptive budget (fraction of seq_len)" + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=32, + help="Maximum tokens to generate" + ) + args = parser.parse_args() + + passed = run_minference_test( + model_path=args.model, + max_model_len=args.max_model_len, + input_len=args.input_len, + needle_position=args.needle_position, + needle_value=args.needle_value, + adaptive_budget=args.adaptive_budget, + max_new_tokens=args.max_new_tokens, + verbose=True, + ) + + if passed: + print("test_minference_gpu: PASSED") + else: + print("test_minference_gpu: FAILED") + exit(1) diff --git a/tests/test_needle.py b/tests/test_needle.py index 7792ddc..fb228b6 100644 --- a/tests/test_needle.py +++ b/tests/test_needle.py @@ -31,8 +31,13 @@ def run_needle_test( max_new_tokens: int = 32, enable_cpu_offload: bool = False, enable_quest: bool = False, + enable_minference: bool = False, sparse_topk: int = 8, sparse_threshold: int = 4, + minference_budget: float = 0.3, + minference_vertical: int = 1000, + minference_slash: int = 6096, + gpu_utilization: float = 0.9, verbose: bool = True, ) -> bool: """ @@ -49,14 +54,25 @@ def run_needle_test( max_new_tokens: Maximum tokens to generate enable_cpu_offload: Enable CPU offload mode enable_quest: Enable Quest sparse attention (decode-only Top-K) + enable_minference: Enable MInference sparse prefill (GPU-only) sparse_topk: Top-K blocks for Quest sparse_threshold: Apply sparse only when blocks > threshold + minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode) + minference_vertical: Fixed vertical_size (only used when budget=None) + minference_slash: Fixed slash_size (only used when budget=None) + gpu_utilization: GPU memory utilization fraction verbose: Print detailed output Returns: True if test passed, False otherwise """ - sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL + # Determine sparse policy + if enable_minference: + sparse_policy = SparsePolicyType.MINFERENCE + elif enable_quest: + sparse_policy = SparsePolicyType.QUEST + else: + sparse_policy = SparsePolicyType.FULL if verbose: print(f"\n{'='*60}") @@ -69,8 +85,14 @@ def run_needle_test( print(f"Needle position: {needle_position:.0%}") print(f"Needle value: {needle_value}") print(f"CPU offload: {enable_cpu_offload}") - if enable_cpu_offload: - print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})") + print(f"Sparse policy: {sparse_policy.name}") + if enable_cpu_offload and enable_quest: + print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}") + if enable_minference: + if minference_budget is not None: + print(f" MInference: adaptive (budget={minference_budget})") + else: + print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})") print(f"{'='*60}\n") # 1. Initialize LLM @@ -80,12 +102,19 @@ def run_needle_test( "max_num_batched_tokens": max_model_len, "enable_cpu_offload": enable_cpu_offload, "kvcache_block_size": block_size, + "gpu_memory_utilization": gpu_utilization, } if enable_cpu_offload: llm_kwargs["num_gpu_blocks"] = num_gpu_blocks llm_kwargs["sparse_policy"] = sparse_policy llm_kwargs["sparse_topk_blocks"] = sparse_topk llm_kwargs["sparse_threshold_blocks"] = sparse_threshold + elif enable_minference: + # MInference is GPU-only sparse prefill + llm_kwargs["sparse_policy"] = sparse_policy + llm_kwargs["minference_adaptive_budget"] = minference_budget + llm_kwargs["minference_vertical_size"] = minference_vertical + llm_kwargs["minference_slash_size"] = minference_slash llm = LLM(model_path, **llm_kwargs) @@ -186,6 +215,11 @@ if __name__ == "__main__": action="store_true", help="Enable Quest sparse attention (decode-only Top-K selection)" ) + parser.add_argument( + "--enable-minference", + action="store_true", + help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)" + ) parser.add_argument( "--sparse-topk", type=int, @@ -198,8 +232,35 @@ if __name__ == "__main__": default=4, help="Apply sparse only when blocks > threshold" ) + parser.add_argument( + "--minference-budget", + type=float, + default=0.3, + help="MInference adaptive budget (fraction of seq_len, 0.3=30%% compute, 0=fixed mode)" + ) + parser.add_argument( + "--minference-vertical", + type=int, + default=1000, + help="Fixed vertical_size (only used when budget=0)" + ) + parser.add_argument( + "--minference-slash", + type=int, + default=6096, + help="Fixed slash_size (only used when budget=0)" + ) + parser.add_argument( + "--gpu-utilization", + type=float, + default=0.9, + help="GPU memory utilization (default: 0.9)" + ) args = parser.parse_args() + # Convert budget=0 to None for fixed mode + minference_budget = args.minference_budget if args.minference_budget > 0 else None + passed = run_needle_test( model_path=args.model, max_model_len=args.max_model_len, @@ -211,8 +272,13 @@ if __name__ == "__main__": max_new_tokens=args.max_new_tokens, enable_cpu_offload=args.enable_offload, enable_quest=args.enable_quest, + enable_minference=args.enable_minference, sparse_topk=args.sparse_topk, sparse_threshold=args.sparse_threshold, + minference_budget=minference_budget, + minference_vertical=args.minference_vertical, + minference_slash=args.minference_slash, + gpu_utilization=args.gpu_utilization, verbose=True, )