[claudesquad] update from 'layer-prefill-1' on 08 Jan 26 03:36 CST
This commit is contained in:
178
bench.py
178
bench.py
@@ -2,6 +2,7 @@ import os
|
||||
import time
|
||||
from random import randint, seed
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.config import SparsePolicyType
|
||||
|
||||
|
||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||
@@ -23,8 +24,8 @@ def bench_decode(llm, num_seqs, input_len, output_len):
|
||||
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
||||
|
||||
|
||||
def bench_prefill(llm, num_seqs, input_len):
|
||||
"""Benchmark prefill performance"""
|
||||
def bench_prefill(llm, num_seqs, input_len, label=""):
|
||||
"""Benchmark prefill performance. Returns throughput."""
|
||||
seed(0)
|
||||
# Fixed length input, minimal output to focus on prefill
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||
@@ -35,7 +36,28 @@ def bench_prefill(llm, num_seqs, input_len):
|
||||
t = time.time() - t
|
||||
total_input_tokens = num_seqs * input_len
|
||||
throughput = total_input_tokens / t
|
||||
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||
label_str = f" ({label})" if label else ""
|
||||
print(f"[Prefill{label_str}] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||
return throughput
|
||||
|
||||
|
||||
def create_llm(path, max_len, enable_minference=False, minference_budget=0.3,
|
||||
minference_vertical=1000, minference_slash=6096,
|
||||
gpu_utilization=0.8):
|
||||
"""Create LLM with specified configuration."""
|
||||
kwargs = {
|
||||
"enforce_eager": True, # MInference uses Triton, not compatible with CUDA graphs
|
||||
"max_model_len": max_len,
|
||||
"max_num_batched_tokens": max_len,
|
||||
"gpu_memory_utilization": gpu_utilization,
|
||||
}
|
||||
if enable_minference:
|
||||
kwargs["sparse_policy"] = SparsePolicyType.MINFERENCE
|
||||
kwargs["minference_adaptive_budget"] = minference_budget
|
||||
kwargs["minference_vertical_size"] = minference_vertical
|
||||
kwargs["minference_slash_size"] = minference_slash
|
||||
|
||||
return LLM(path, **kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -46,24 +68,17 @@ def main():
|
||||
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||
parser.add_argument("--enable-minference", action="store_true", help="Enable MInference sparse prefill")
|
||||
parser.add_argument("--minference-budget", type=float, default=0.3, help="MInference adaptive budget (default: 0.3, use 0 for fixed mode)")
|
||||
parser.add_argument("--minference-vertical", type=int, default=1000, help="Fixed vertical_size (only used when budget=0)")
|
||||
parser.add_argument("--minference-slash", type=int, default=6096, help="Fixed slash_size (only used when budget=0)")
|
||||
parser.add_argument("--gpu-utilization", type=float, default=0.9, help="GPU memory utilization (default: 0.9)")
|
||||
parser.add_argument("--compare", action="store_true", help="Compare baseline vs MInference (runs both)")
|
||||
args = parser.parse_args()
|
||||
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
max_len = args.max_len
|
||||
|
||||
print(f"\n[nanovllm GPU] max_len={max_len}")
|
||||
|
||||
llm = LLM(
|
||||
path,
|
||||
enforce_eager=False,
|
||||
max_model_len=max_len,
|
||||
max_num_batched_tokens=max_len,
|
||||
)
|
||||
|
||||
# Warmup
|
||||
print("\nWarming up...")
|
||||
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
|
||||
|
||||
# Default input lengths
|
||||
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
||||
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
||||
@@ -72,17 +87,128 @@ def main():
|
||||
run_prefill = not args.bench_decode or args.bench_all
|
||||
run_decode = args.bench_decode or args.bench_all
|
||||
|
||||
if run_prefill:
|
||||
print("\n" + "=" * 60)
|
||||
print("Prefill Benchmark (nanovllm GPU)")
|
||||
print("=" * 60)
|
||||
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
||||
# Convert budget=0 to None for fixed mode
|
||||
minference_budget = args.minference_budget if args.minference_budget > 0 else None
|
||||
|
||||
if run_decode:
|
||||
print("\n" + "=" * 60)
|
||||
print("Decode Benchmark (nanovllm GPU)")
|
||||
print("=" * 60)
|
||||
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
||||
if args.compare:
|
||||
# Compare baseline vs MInference using subprocesses to avoid NCCL issues
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Baseline vs MInference Comparison")
|
||||
print(f"Input length: {prefill_input_len} tokens")
|
||||
if minference_budget is not None:
|
||||
print(f"MInference mode: adaptive (budget={minference_budget}, {minference_budget*100:.0f}% compute)")
|
||||
else:
|
||||
print(f"MInference mode: fixed (vertical={args.minference_vertical}, slash={args.minference_slash})")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Get PYTHONPATH for subprocess
|
||||
pythonpath = os.environ.get("PYTHONPATH", "")
|
||||
|
||||
# Run baseline in subprocess
|
||||
print(f"\n[1/2] Running baseline (FULL attention)...")
|
||||
cmd_baseline = [
|
||||
sys.executable, __file__,
|
||||
"--input-len", str(prefill_input_len),
|
||||
"--max-len", str(max_len),
|
||||
"--gpu-utilization", str(args.gpu_utilization),
|
||||
]
|
||||
env = os.environ.copy()
|
||||
result = subprocess.run(cmd_baseline, capture_output=True, text=True, env=env)
|
||||
print(result.stdout)
|
||||
if result.returncode != 0:
|
||||
print(f"Error: {result.stderr}")
|
||||
return
|
||||
|
||||
# Parse baseline throughput
|
||||
baseline_throughput = None
|
||||
for line in result.stdout.split('\n'):
|
||||
if "Throughput:" in line and "tok/s" in line:
|
||||
# Extract throughput value
|
||||
import re
|
||||
match = re.search(r'Throughput:\s*([\d.]+)tok/s', line)
|
||||
if match:
|
||||
baseline_throughput = float(match.group(1))
|
||||
|
||||
# Run MInference in subprocess
|
||||
if minference_budget is not None:
|
||||
print(f"\n[2/2] Running MInference (budget={minference_budget})...")
|
||||
else:
|
||||
print(f"\n[2/2] Running MInference (vertical={args.minference_vertical}, slash={args.minference_slash})...")
|
||||
cmd_minference = [
|
||||
sys.executable, __file__,
|
||||
"--input-len", str(prefill_input_len),
|
||||
"--max-len", str(max_len),
|
||||
"--gpu-utilization", str(args.gpu_utilization),
|
||||
"--enable-minference",
|
||||
"--minference-budget", str(args.minference_budget),
|
||||
"--minference-vertical", str(args.minference_vertical),
|
||||
"--minference-slash", str(args.minference_slash),
|
||||
]
|
||||
result = subprocess.run(cmd_minference, capture_output=True, text=True, env=env)
|
||||
print(result.stdout)
|
||||
if result.returncode != 0:
|
||||
print(f"Error: {result.stderr}")
|
||||
return
|
||||
|
||||
# Parse MInference throughput
|
||||
minference_throughput = None
|
||||
for line in result.stdout.split('\n'):
|
||||
if "Throughput:" in line and "tok/s" in line:
|
||||
import re
|
||||
match = re.search(r'Throughput:\s*([\d.]+)tok/s', line)
|
||||
if match:
|
||||
minference_throughput = float(match.group(1))
|
||||
|
||||
# Comparison
|
||||
if baseline_throughput and minference_throughput:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results Summary")
|
||||
print(f"{'='*60}")
|
||||
print(f"Baseline: {baseline_throughput:,.0f} tok/s")
|
||||
print(f"MInference: {minference_throughput:,.0f} tok/s")
|
||||
speedup = minference_throughput / baseline_throughput
|
||||
if speedup >= 1.0:
|
||||
print(f"Speedup: {speedup:.2f}x faster")
|
||||
else:
|
||||
print(f"Slowdown: {1/speedup:.2f}x slower")
|
||||
print(f"{'='*60}")
|
||||
else:
|
||||
print("Failed to parse throughput values")
|
||||
|
||||
else:
|
||||
# Single run mode
|
||||
mode = "MInference" if args.enable_minference else "GPU"
|
||||
print(f"\n[nanovllm {mode}] max_len={max_len}")
|
||||
if args.enable_minference:
|
||||
if minference_budget is not None:
|
||||
print(f"MInference mode: adaptive (budget={minference_budget})")
|
||||
else:
|
||||
print(f"MInference mode: fixed (vertical={args.minference_vertical}, slash={args.minference_slash})")
|
||||
|
||||
llm = create_llm(path, max_len, enable_minference=args.enable_minference,
|
||||
minference_budget=minference_budget,
|
||||
minference_vertical=args.minference_vertical,
|
||||
minference_slash=args.minference_slash,
|
||||
gpu_utilization=args.gpu_utilization)
|
||||
|
||||
# Warmup
|
||||
print("\nWarming up...")
|
||||
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
|
||||
|
||||
if run_prefill:
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Prefill Benchmark (nanovllm {mode})")
|
||||
print("=" * 60)
|
||||
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
||||
|
||||
if run_decode:
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Decode Benchmark (nanovllm {mode})")
|
||||
print("=" * 60)
|
||||
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -9,6 +9,7 @@ class SparsePolicyType(Enum):
|
||||
"""Sparse attention policy types."""
|
||||
FULL = auto() # No sparse attention (load all blocks)
|
||||
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
||||
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,10 +40,18 @@ class Config:
|
||||
# Sparse attention configuration
|
||||
# Quest: decode-only sparse attention with Top-K block selection
|
||||
# FULL: no sparse attention (load all blocks)
|
||||
# MINFERENCE: MInference vertical + slash sparse prefill (GPU-only)
|
||||
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
||||
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
||||
|
||||
# MInference configuration (used when sparse_policy == MINFERENCE)
|
||||
minference_adaptive_budget: float = 0.3 # Budget as fraction of seq_len (None to use fixed sizes)
|
||||
minference_vertical_size: int = 1000 # Fixed vertical size (if adaptive_budget is None)
|
||||
minference_slash_size: int = 6096 # Fixed slash size (if adaptive_budget is None)
|
||||
minference_num_sink_tokens: int = 30 # Sink tokens to always keep
|
||||
minference_num_recent_diags: int = 100 # Recent diagonals to always keep
|
||||
|
||||
def __post_init__(self):
|
||||
assert os.path.isdir(self.model)
|
||||
assert self.kvcache_block_size % 256 == 0
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch.distributed as dist
|
||||
from multiprocessing.synchronize import Event
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.config import Config, SparsePolicyType
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.layers.sampler import GreedySampler
|
||||
@@ -35,7 +35,10 @@ class ModelRunner:
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
load_model(self.model, config.model)
|
||||
self.sampler = GreedySampler()
|
||||
|
||||
|
||||
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
|
||||
self.sparse_prefill_policy = None
|
||||
|
||||
#> Disable warmup for debugging
|
||||
self.warmup_model()
|
||||
|
||||
@@ -148,6 +151,24 @@ class ModelRunner:
|
||||
# Create KV cache manager using factory
|
||||
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
||||
|
||||
# Create sparse prefill policy for GPU-only path
|
||||
# This is separate from CPU offload sparse policy (which uses select_blocks)
|
||||
self.sparse_prefill_policy = None
|
||||
if not config.enable_cpu_offload and config.sparse_policy != SparsePolicyType.FULL:
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||
policy = create_sparse_policy(
|
||||
config.sparse_policy,
|
||||
vertical_size=config.minference_vertical_size,
|
||||
slash_size=config.minference_slash_size,
|
||||
adaptive_budget=config.minference_adaptive_budget,
|
||||
num_sink_tokens=config.minference_num_sink_tokens,
|
||||
num_recent_diags=config.minference_num_recent_diags,
|
||||
)
|
||||
# Only use if policy supports sparse prefill
|
||||
if policy.supports_prefill:
|
||||
self.sparse_prefill_policy = policy
|
||||
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
|
||||
|
||||
# Allocate cache through manager
|
||||
self.kvcache_manager.allocate_cache(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
@@ -329,7 +350,10 @@ class ModelRunner:
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
||||
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
slot_mapping, None, block_tables,
|
||||
sparse_prefill_policy=self.sparse_prefill_policy)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
|
||||
@@ -23,6 +23,7 @@ from nanovllm.config import SparsePolicyType
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
||||
|
||||
|
||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||
@@ -55,6 +56,15 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
|
||||
)
|
||||
return QuestPolicy(config)
|
||||
|
||||
elif policy_type == SparsePolicyType.MINFERENCE:
|
||||
return MInferencePolicy(
|
||||
vertical_size=kwargs.get("vertical_size", 1000),
|
||||
slash_size=kwargs.get("slash_size", 6096),
|
||||
adaptive_budget=kwargs.get("adaptive_budget", 0.3),
|
||||
num_sink_tokens=kwargs.get("num_sink_tokens", 30),
|
||||
num_recent_diags=kwargs.get("num_recent_diags", 100),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||
|
||||
@@ -67,5 +77,6 @@ __all__ = [
|
||||
"QuestPolicy",
|
||||
"QuestConfig",
|
||||
"BlockMetadataManager",
|
||||
"MInferencePolicy",
|
||||
"create_sparse_policy",
|
||||
]
|
||||
|
||||
353
nanovllm/kvcache/sparse/minference.py
Normal file
353
nanovllm/kvcache/sparse/minference.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
MInference sparse attention policy.
|
||||
|
||||
Implements vertical + slash sparse pattern estimation using the last 64 query tokens.
|
||||
Reference: MInference paper (https://arxiv.org/abs/2407.02490)
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Tuple, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
|
||||
|
||||
class MInferencePolicy(SparsePolicy):
|
||||
"""
|
||||
MInference sparse prefill policy using vertical + slash pattern.
|
||||
|
||||
This policy estimates sparse attention patterns by analyzing attention
|
||||
scores from the last 64 query tokens, then selects:
|
||||
- Vertical: Key positions that are important across all queries
|
||||
- Slash: Diagonal bands (local context)
|
||||
|
||||
The estimated pattern is then used to compute sparse attention.
|
||||
|
||||
Note: This policy is designed for GPU-only prefill. For CPU offload,
|
||||
the pattern estimation and sparse attention will be handled differently.
|
||||
"""
|
||||
|
||||
supports_prefill = True
|
||||
supports_decode = False # MInference is prefill-only sparse strategy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vertical_size: int = 1000,
|
||||
slash_size: int = 6096,
|
||||
adaptive_budget: Optional[float] = 0.3,
|
||||
num_sink_tokens: int = 30,
|
||||
num_recent_diags: int = 100,
|
||||
):
|
||||
"""
|
||||
Initialize MInference policy.
|
||||
|
||||
Args:
|
||||
vertical_size: Number of vertical (column) positions to keep
|
||||
slash_size: Number of diagonal bands to keep
|
||||
adaptive_budget: If set, compute budget as fraction of seq_len
|
||||
(overrides vertical_size and slash_size)
|
||||
num_sink_tokens: Number of initial sink tokens to always keep
|
||||
num_recent_diags: Number of recent diagonals to always keep
|
||||
"""
|
||||
self.vertical_size = vertical_size
|
||||
self.slash_size = slash_size
|
||||
self.adaptive_budget = adaptive_budget
|
||||
self.num_sink_tokens = num_sink_tokens
|
||||
self.num_recent_diags = num_recent_diags
|
||||
|
||||
# Cache for last-q causal mask
|
||||
self._last_q_mask_cache: dict = {}
|
||||
|
||||
def _get_causal_mask(self, last_q: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
"""Get causal mask for last-q attention."""
|
||||
cache_key = (last_q, seq_len, device)
|
||||
if cache_key not in self._last_q_mask_cache:
|
||||
# Create mask where last_q queries can attend to all previous positions
|
||||
# Shape: [last_q, seq_len]
|
||||
mask = torch.ones(last_q, seq_len, device=device, dtype=torch.bool)
|
||||
# Apply causal constraint for the last last_q positions
|
||||
# Query i (from last_q) can only attend to positions <= (seq_len - last_q + i)
|
||||
for i in range(last_q):
|
||||
mask[i, seq_len - last_q + i + 1:] = False
|
||||
self._last_q_mask_cache[cache_key] = mask
|
||||
return self._last_q_mask_cache[cache_key]
|
||||
|
||||
def estimate_pattern(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Estimate vertical + slash sparse pattern using last 64 query tokens.
|
||||
Memory-optimized for long sequences (64K+).
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Current layer index (for potential layer-specific patterns)
|
||||
|
||||
Returns:
|
||||
Tuple of (vertical_indices, slash_indices):
|
||||
- vertical_indices: [num_heads, vertical_size] - important K positions
|
||||
- slash_indices: [num_heads, slash_size] - diagonal offsets
|
||||
"""
|
||||
seq_len = q.shape[0]
|
||||
num_heads = q.shape[1]
|
||||
head_dim = q.shape[2]
|
||||
num_kv_heads = k.shape[1]
|
||||
|
||||
# Adaptive budget
|
||||
if self.adaptive_budget is not None:
|
||||
budget = int(seq_len * self.adaptive_budget)
|
||||
vertical_size = max(self.num_sink_tokens + 1, int(budget * 0.2))
|
||||
slash_size = max(self.num_recent_diags + 1, int(budget * 0.8))
|
||||
else:
|
||||
vertical_size = self.vertical_size
|
||||
slash_size = self.slash_size
|
||||
|
||||
# Use last 64 Q tokens for estimation
|
||||
last_q = min(64, seq_len)
|
||||
q_last = q[-last_q:] # [last_q, heads, dim] - this is a view, not a copy
|
||||
|
||||
# Handle GQA: if num_kv_heads < num_heads, we need to expand K
|
||||
if num_kv_heads < num_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
k_work = k.repeat_interleave(num_groups, dim=1)
|
||||
else:
|
||||
k_work = k
|
||||
|
||||
# Compute attention scores: [heads, last_q, seq_len]
|
||||
scale = 1.0 / math.sqrt(head_dim)
|
||||
qk = torch.einsum('qhd,khd->hqk', q_last, k_work) * scale
|
||||
|
||||
# Free k_work if it was a copy
|
||||
if num_kv_heads < num_heads:
|
||||
del k_work
|
||||
|
||||
# Apply causal mask for last positions (in-place)
|
||||
causal_mask = self._get_causal_mask(last_q, seq_len, q.device)
|
||||
qk.masked_fill_(~causal_mask.unsqueeze(0), float('-inf'))
|
||||
|
||||
# Softmax (in-place where possible)
|
||||
qk = F.softmax(qk, dim=-1, dtype=torch.float32)
|
||||
|
||||
# === Vertical pattern ===
|
||||
# Sum across query dimension -> importance of each K position
|
||||
vertical_scores = qk.sum(dim=1) # [heads, seq_len]
|
||||
|
||||
# Force keep first num_sink_tokens (attention sinks) - in-place
|
||||
vertical_scores[:, :self.num_sink_tokens] = float('inf')
|
||||
|
||||
# Select top-k
|
||||
actual_vertical = min(vertical_size, seq_len)
|
||||
vertical_indices = vertical_scores.topk(actual_vertical, dim=-1).indices
|
||||
vertical_indices = vertical_indices.sort(dim=-1).values
|
||||
del vertical_scores
|
||||
|
||||
# === Slash pattern ===
|
||||
# Create diagonal index matrix: [last_q, seq_len] with int32 to save memory
|
||||
q_indices = torch.arange(last_q, device=q.device, dtype=torch.int32).unsqueeze(1)
|
||||
k_indices = torch.arange(seq_len, device=q.device, dtype=torch.int32).unsqueeze(0)
|
||||
diag_indices = (seq_len - last_q + q_indices) - k_indices # [last_q, seq_len]
|
||||
del q_indices
|
||||
|
||||
# Create causal mask for slash computation
|
||||
q_pos = seq_len - last_q + torch.arange(last_q, device=q.device, dtype=torch.int32).unsqueeze(1)
|
||||
slash_causal_mask = k_indices <= q_pos
|
||||
del q_pos, k_indices
|
||||
|
||||
# Clamp diagonal indices to valid range
|
||||
diag_indices = diag_indices.clamp(0, seq_len - 1)
|
||||
|
||||
# Apply causal mask to qk (in-place) for slash computation
|
||||
qk[:, ~slash_causal_mask] = 0
|
||||
del slash_causal_mask
|
||||
|
||||
# Accumulate scores per diagonal - process in batches to save memory
|
||||
slash_scores = torch.zeros(num_heads, seq_len, device=q.device, dtype=torch.float32)
|
||||
|
||||
# Process heads in chunks to reduce peak memory for diag_indices_expanded
|
||||
chunk_size = min(8, num_heads) # Process 8 heads at a time
|
||||
for h_start in range(0, num_heads, chunk_size):
|
||||
h_end = min(h_start + chunk_size, num_heads)
|
||||
n_heads_chunk = h_end - h_start
|
||||
|
||||
# Expand diag_indices only for this chunk
|
||||
diag_chunk = diag_indices.unsqueeze(0).expand(n_heads_chunk, -1, -1).long()
|
||||
qk_chunk = qk[h_start:h_end]
|
||||
|
||||
slash_scores[h_start:h_end].scatter_add_(
|
||||
1,
|
||||
diag_chunk.reshape(n_heads_chunk, -1),
|
||||
qk_chunk.reshape(n_heads_chunk, -1)
|
||||
)
|
||||
del diag_chunk, qk_chunk
|
||||
|
||||
del diag_indices, qk
|
||||
|
||||
# Force keep first num_recent_diags (in-place)
|
||||
slash_scores[:, :self.num_recent_diags] = float('inf')
|
||||
|
||||
# Select top-k diagonal indices
|
||||
actual_slash = min(slash_size, seq_len)
|
||||
slash_indices = slash_scores.topk(actual_slash, dim=-1).indices
|
||||
slash_indices = slash_indices.sort(dim=-1).values
|
||||
del slash_scores
|
||||
|
||||
return vertical_indices, slash_indices
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select blocks for chunked CPU offload mode.
|
||||
|
||||
For MInference in GPU-only mode, this method is not used.
|
||||
In CPU offload mode, it would select blocks based on the sparse pattern.
|
||||
|
||||
For now, return all blocks (full attention fallback).
|
||||
"""
|
||||
# MInference pattern is computed in attention.forward()
|
||||
# For CPU offload integration (Phase B), this would use the pattern
|
||||
return available_blocks
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset policy state."""
|
||||
self._last_q_mask_cache.clear()
|
||||
|
||||
def sparse_prefill_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute MInference sparse attention for prefill.
|
||||
|
||||
Uses vertical + slash pattern to compute sparse attention efficiently.
|
||||
Memory-optimized to handle long sequences (64K+) by freeing intermediate tensors.
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Current transformer layer index
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
from minference.ops.pit_sparse_flash_attention_v2 import _triton_mixed_sparse_attention
|
||||
from minference.cuda import convert_vertical_slash_indexes
|
||||
|
||||
seq_len = q.shape[0]
|
||||
num_heads = q.shape[1]
|
||||
head_dim = q.shape[2]
|
||||
num_kv_heads = k.shape[1]
|
||||
|
||||
# Estimate sparse pattern (uses temporary memory for qk scores)
|
||||
vertical_indices, slash_indices = self.estimate_pattern(q, k, layer_id)
|
||||
# Free any cached memory from pattern estimation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Triton sparse attention kernel parameters
|
||||
block_size_M = 64
|
||||
block_size_N = 64
|
||||
|
||||
# Calculate padding
|
||||
pad = (block_size_M - seq_len) & (block_size_M - 1)
|
||||
need_head_pad = head_dim not in [16, 32, 64, 128, 256, 512]
|
||||
head_pad = (2 ** math.ceil(math.log2(head_dim)) - head_dim) if need_head_pad else 0
|
||||
|
||||
# Handle GQA: expand K/V to match query heads
|
||||
# Do this BEFORE creating batched tensors to avoid double copies
|
||||
if num_kv_heads < num_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
# Use repeat_interleave for memory-efficient expansion
|
||||
k_work = k.repeat_interleave(num_groups, dim=1)
|
||||
v_work = v.repeat_interleave(num_groups, dim=1)
|
||||
else:
|
||||
k_work = k
|
||||
v_work = v
|
||||
|
||||
# Transform Q to [batch, heads, seq, dim] format with padding in one step
|
||||
# This avoids creating intermediate copies
|
||||
if pad > 0 or head_pad > 0:
|
||||
q_batched = torch.nn.functional.pad(
|
||||
q.unsqueeze(0).transpose(1, 2),
|
||||
[0, head_pad, 0, pad, 0, 0, 0, 0]
|
||||
).contiguous()
|
||||
else:
|
||||
q_batched = q.unsqueeze(0).transpose(1, 2).contiguous()
|
||||
|
||||
# Transform K to batched format
|
||||
if pad > 0 or head_pad > 0:
|
||||
k_batched = torch.nn.functional.pad(
|
||||
k_work.unsqueeze(0).transpose(1, 2),
|
||||
[0, head_pad, 0, pad, 0, 0, 0, 0]
|
||||
).contiguous()
|
||||
else:
|
||||
k_batched = k_work.unsqueeze(0).transpose(1, 2).contiguous()
|
||||
|
||||
# Free k_work if it was a copy (GQA case)
|
||||
if num_kv_heads < num_heads:
|
||||
del k_work
|
||||
|
||||
# Transform V to batched format
|
||||
if pad > 0 or head_pad > 0:
|
||||
v_batched = torch.nn.functional.pad(
|
||||
v_work.unsqueeze(0).transpose(1, 2),
|
||||
[0, head_pad, 0, pad, 0, 0, 0, 0]
|
||||
).contiguous()
|
||||
else:
|
||||
v_batched = v_work.unsqueeze(0).transpose(1, 2).contiguous()
|
||||
|
||||
# Free v_work if it was a copy (GQA case)
|
||||
if num_kv_heads < num_heads:
|
||||
del v_work
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Prepare indices for Triton kernel
|
||||
v_idx = vertical_indices.to(torch.int32).reshape((1, num_heads, -1))
|
||||
v_idx = v_idx.sort(dim=-1, descending=False)[0].contiguous()
|
||||
del vertical_indices
|
||||
|
||||
s_idx = slash_indices.to(torch.int32).reshape((1, num_heads, -1))
|
||||
s_idx = s_idx.sort(dim=-1, descending=True)[0].contiguous()
|
||||
del slash_indices
|
||||
|
||||
seqlens = torch.tensor([seq_len], dtype=torch.int32, device=q.device)
|
||||
sm_scale = head_dim ** -0.5
|
||||
|
||||
# Convert vertical+slash indices to block sparse format
|
||||
block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes(
|
||||
seqlens, v_idx, s_idx, seq_len, block_size_M, block_size_N,
|
||||
)
|
||||
del v_idx, s_idx
|
||||
|
||||
# Call Triton mixed sparse attention kernel
|
||||
o = _triton_mixed_sparse_attention(
|
||||
q_batched, k_batched, v_batched, seqlens,
|
||||
block_count, block_offset, column_count, column_index,
|
||||
sm_scale, block_size_M, block_size_N,
|
||||
)
|
||||
|
||||
# Free input tensors immediately after kernel call
|
||||
del q_batched, k_batched, v_batched
|
||||
del block_count, block_offset, column_count, column_index
|
||||
|
||||
# Remove padding and convert back to [seq_len, num_heads, head_dim]
|
||||
o = o[..., :seq_len, :head_dim]
|
||||
o = o.transpose(1, 2).squeeze(0).contiguous()
|
||||
|
||||
return o
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"MInferencePolicy("
|
||||
f"adaptive_budget={self.adaptive_budget}, "
|
||||
f"vertical_size={self.vertical_size}, "
|
||||
f"slash_size={self.slash_size})")
|
||||
@@ -183,5 +183,32 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def sparse_prefill_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute sparse attention for prefill phase.
|
||||
|
||||
This method is called when supports_prefill=True and the policy
|
||||
is used for GPU-only sparse prefill (no CPU offload).
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Current transformer layer index
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
|
||||
"Set supports_prefill=False or implement this method."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
@@ -140,6 +140,11 @@ class Attention(nn.Module):
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
elif context.sparse_prefill_policy is not None:
|
||||
# Sparse prefill (GPU-only) - delegate to policy
|
||||
o = context.sparse_prefill_policy.sparse_prefill_attention(
|
||||
q, k, v, self.layer_id
|
||||
)
|
||||
else:
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
|
||||
@@ -35,6 +35,10 @@ class Context:
|
||||
# Current chunk index for ring buffer pipeline (prefill only)
|
||||
current_chunk_idx: int = 0
|
||||
|
||||
# Sparse prefill attention support (GPU-only path)
|
||||
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
|
||||
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
|
||||
@@ -60,6 +64,7 @@ def set_context(
|
||||
decode_pos_in_block=0,
|
||||
decode_start_pos_in_block=0,
|
||||
current_chunk_idx=0,
|
||||
sparse_prefill_policy=None,
|
||||
):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(
|
||||
@@ -79,6 +84,7 @@ def set_context(
|
||||
decode_pos_in_block=decode_pos_in_block,
|
||||
decode_start_pos_in_block=decode_start_pos_in_block,
|
||||
current_chunk_idx=current_chunk_idx,
|
||||
sparse_prefill_policy=sparse_prefill_policy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
163
tests/test_minference_gpu.py
Normal file
163
tests/test_minference_gpu.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Needle-in-haystack test with MInference sparse attention.
|
||||
|
||||
Tests: MInference sparse prefill on GPU-only path (no CPU offload).
|
||||
This validates that MInference's vertical + slash sparse pattern can
|
||||
correctly retrieve information from long context.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.config import SparsePolicyType
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
def run_minference_test(
|
||||
model_path: str,
|
||||
max_model_len: int = 16384,
|
||||
input_len: int = 8192,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
adaptive_budget: float = 0.3,
|
||||
max_new_tokens: int = 32,
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Run needle test with MInference sparse prefill attention.
|
||||
|
||||
Args:
|
||||
model_path: Path to model
|
||||
max_model_len: Maximum model context length
|
||||
input_len: Target input sequence length
|
||||
needle_position: Where to place needle (0.0-1.0)
|
||||
needle_value: The secret value to find
|
||||
adaptive_budget: MInference budget as fraction of seq_len
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
verbose: Print detailed output
|
||||
|
||||
Returns:
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"MInference Sparse Prefill Test (GPU-only)")
|
||||
print(f"{'='*60}")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Max model len: {max_model_len}")
|
||||
print(f"Input length: {input_len}")
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"Adaptive budget: {adaptive_budget}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Initialize LLM with MInference sparse attention
|
||||
llm = LLM(
|
||||
model_path,
|
||||
enforce_eager=True,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_model_len,
|
||||
enable_cpu_offload=False, # GPU-only
|
||||
sparse_policy=SparsePolicyType.MINFERENCE,
|
||||
minference_adaptive_budget=adaptive_budget,
|
||||
)
|
||||
|
||||
# Generate needle prompt
|
||||
prompt, expected = generate_needle_prompt(
|
||||
tokenizer=llm.tokenizer,
|
||||
target_length=input_len,
|
||||
needle_position=needle_position,
|
||||
needle_value=needle_value,
|
||||
)
|
||||
|
||||
# Generate output
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.6,
|
||||
max_tokens=max_new_tokens,
|
||||
)
|
||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
|
||||
|
||||
# Check result
|
||||
output_text = outputs[0]["text"]
|
||||
output_token_ids = outputs[0]["token_ids"]
|
||||
passed = check_needle_answer(output_text, expected)
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Result")
|
||||
print(f"{'='*60}")
|
||||
print(f"Expected: {expected}")
|
||||
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
|
||||
print(f"Output: {output_text[:200]}...")
|
||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Needle-in-haystack test with MInference sparse prefill"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", "-m",
|
||||
type=str,
|
||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
||||
help="Path to model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=16 * 1024,
|
||||
help="Maximum model context length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=8 * 1024,
|
||||
help="Target input sequence length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-position",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-value",
|
||||
type=str,
|
||||
default="7492",
|
||||
help="The secret value to hide"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adaptive-budget",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="MInference adaptive budget (fraction of seq_len)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Maximum tokens to generate"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
passed = run_minference_test(
|
||||
model_path=args.model,
|
||||
max_model_len=args.max_model_len,
|
||||
input_len=args.input_len,
|
||||
needle_position=args.needle_position,
|
||||
needle_value=args.needle_value,
|
||||
adaptive_budget=args.adaptive_budget,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if passed:
|
||||
print("test_minference_gpu: PASSED")
|
||||
else:
|
||||
print("test_minference_gpu: FAILED")
|
||||
exit(1)
|
||||
@@ -31,8 +31,13 @@ def run_needle_test(
|
||||
max_new_tokens: int = 32,
|
||||
enable_cpu_offload: bool = False,
|
||||
enable_quest: bool = False,
|
||||
enable_minference: bool = False,
|
||||
sparse_topk: int = 8,
|
||||
sparse_threshold: int = 4,
|
||||
minference_budget: float = 0.3,
|
||||
minference_vertical: int = 1000,
|
||||
minference_slash: int = 6096,
|
||||
gpu_utilization: float = 0.9,
|
||||
verbose: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -49,14 +54,25 @@ def run_needle_test(
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
enable_cpu_offload: Enable CPU offload mode
|
||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||
enable_minference: Enable MInference sparse prefill (GPU-only)
|
||||
sparse_topk: Top-K blocks for Quest
|
||||
sparse_threshold: Apply sparse only when blocks > threshold
|
||||
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
|
||||
minference_vertical: Fixed vertical_size (only used when budget=None)
|
||||
minference_slash: Fixed slash_size (only used when budget=None)
|
||||
gpu_utilization: GPU memory utilization fraction
|
||||
verbose: Print detailed output
|
||||
|
||||
Returns:
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
|
||||
# Determine sparse policy
|
||||
if enable_minference:
|
||||
sparse_policy = SparsePolicyType.MINFERENCE
|
||||
elif enable_quest:
|
||||
sparse_policy = SparsePolicyType.QUEST
|
||||
else:
|
||||
sparse_policy = SparsePolicyType.FULL
|
||||
|
||||
if verbose:
|
||||
print(f"\n{'='*60}")
|
||||
@@ -69,8 +85,14 @@ def run_needle_test(
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
if enable_cpu_offload:
|
||||
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
|
||||
print(f"Sparse policy: {sparse_policy.name}")
|
||||
if enable_cpu_offload and enable_quest:
|
||||
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
|
||||
if enable_minference:
|
||||
if minference_budget is not None:
|
||||
print(f" MInference: adaptive (budget={minference_budget})")
|
||||
else:
|
||||
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 1. Initialize LLM
|
||||
@@ -80,12 +102,19 @@ def run_needle_test(
|
||||
"max_num_batched_tokens": max_model_len,
|
||||
"enable_cpu_offload": enable_cpu_offload,
|
||||
"kvcache_block_size": block_size,
|
||||
"gpu_memory_utilization": gpu_utilization,
|
||||
}
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
llm_kwargs["sparse_policy"] = sparse_policy
|
||||
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||
elif enable_minference:
|
||||
# MInference is GPU-only sparse prefill
|
||||
llm_kwargs["sparse_policy"] = sparse_policy
|
||||
llm_kwargs["minference_adaptive_budget"] = minference_budget
|
||||
llm_kwargs["minference_vertical_size"] = minference_vertical
|
||||
llm_kwargs["minference_slash_size"] = minference_slash
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
|
||||
@@ -186,6 +215,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-minference",
|
||||
action="store_true",
|
||||
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-topk",
|
||||
type=int,
|
||||
@@ -198,8 +232,35 @@ if __name__ == "__main__":
|
||||
default=4,
|
||||
help="Apply sparse only when blocks > threshold"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minference-budget",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="MInference adaptive budget (fraction of seq_len, 0.3=30%% compute, 0=fixed mode)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minference-vertical",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Fixed vertical_size (only used when budget=0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minference-slash",
|
||||
type=int,
|
||||
default=6096,
|
||||
help="Fixed slash_size (only used when budget=0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-utilization",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="GPU memory utilization (default: 0.9)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert budget=0 to None for fixed mode
|
||||
minference_budget = args.minference_budget if args.minference_budget > 0 else None
|
||||
|
||||
passed = run_needle_test(
|
||||
model_path=args.model,
|
||||
max_model_len=args.max_model_len,
|
||||
@@ -211,8 +272,13 @@ if __name__ == "__main__":
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
enable_cpu_offload=args.enable_offload,
|
||||
enable_quest=args.enable_quest,
|
||||
enable_minference=args.enable_minference,
|
||||
sparse_topk=args.sparse_topk,
|
||||
sparse_threshold=args.sparse_threshold,
|
||||
minference_budget=minference_budget,
|
||||
minference_vertical=args.minference_vertical,
|
||||
minference_slash=args.minference_slash,
|
||||
gpu_utilization=args.gpu_utilization,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user