Merge branch 'zijie/layer-prefill-1' into tzj/vs_offload

Adds MInference sparse attention support:
- New MInference sparse policy implementation
- A-shape, vertical-slash, and block-sparse patterns
- Updated bench.py with sparse attention options
- test_minference_gpu.py validation test

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-08 03:40:53 +08:00
10 changed files with 822 additions and 32 deletions

178
bench.py
View File

@@ -2,6 +2,7 @@ import os
import time import time
from random import randint, seed from random import randint, seed
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
def bench_decode(llm, num_seqs, input_len, output_len): def bench_decode(llm, num_seqs, input_len, output_len):
@@ -23,8 +24,8 @@ def bench_decode(llm, num_seqs, input_len, output_len):
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)") print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
def bench_prefill(llm, num_seqs, input_len): def bench_prefill(llm, num_seqs, input_len, label=""):
"""Benchmark prefill performance""" """Benchmark prefill performance. Returns throughput."""
seed(0) seed(0)
# Fixed length input, minimal output to focus on prefill # Fixed length input, minimal output to focus on prefill
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)] prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
@@ -35,7 +36,28 @@ def bench_prefill(llm, num_seqs, input_len):
t = time.time() - t t = time.time() - t
total_input_tokens = num_seqs * input_len total_input_tokens = num_seqs * input_len
throughput = total_input_tokens / t throughput = total_input_tokens / t
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") label_str = f" ({label})" if label else ""
print(f"[Prefill{label_str}] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
return throughput
def create_llm(path, max_len, enable_minference=False, minference_budget=0.3,
minference_vertical=1000, minference_slash=6096,
gpu_utilization=0.8):
"""Create LLM with specified configuration."""
kwargs = {
"enforce_eager": True, # MInference uses Triton, not compatible with CUDA graphs
"max_model_len": max_len,
"max_num_batched_tokens": max_len,
"gpu_memory_utilization": gpu_utilization,
}
if enable_minference:
kwargs["sparse_policy"] = SparsePolicyType.MINFERENCE
kwargs["minference_adaptive_budget"] = minference_budget
kwargs["minference_vertical_size"] = minference_vertical
kwargs["minference_slash_size"] = minference_slash
return LLM(path, **kwargs)
def main(): def main():
@@ -46,24 +68,17 @@ def main():
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)") parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)") parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks") parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
parser.add_argument("--enable-minference", action="store_true", help="Enable MInference sparse prefill")
parser.add_argument("--minference-budget", type=float, default=0.3, help="MInference adaptive budget (default: 0.3, use 0 for fixed mode)")
parser.add_argument("--minference-vertical", type=int, default=1000, help="Fixed vertical_size (only used when budget=0)")
parser.add_argument("--minference-slash", type=int, default=6096, help="Fixed slash_size (only used when budget=0)")
parser.add_argument("--gpu-utilization", type=float, default=0.9, help="GPU memory utilization (default: 0.9)")
parser.add_argument("--compare", action="store_true", help="Compare baseline vs MInference (runs both)")
args = parser.parse_args() args = parser.parse_args()
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
max_len = args.max_len max_len = args.max_len
print(f"\n[nanovllm GPU] max_len={max_len}")
llm = LLM(
path,
enforce_eager=False,
max_model_len=max_len,
max_num_batched_tokens=max_len,
)
# Warmup
print("\nWarming up...")
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
# Default input lengths # Default input lengths
prefill_input_len = args.input_len if args.input_len else max_len - 1 prefill_input_len = args.input_len if args.input_len else max_len - 1
decode_input_len = args.input_len if args.input_len else max_len - args.output_len decode_input_len = args.input_len if args.input_len else max_len - args.output_len
@@ -72,17 +87,128 @@ def main():
run_prefill = not args.bench_decode or args.bench_all run_prefill = not args.bench_decode or args.bench_all
run_decode = args.bench_decode or args.bench_all run_decode = args.bench_decode or args.bench_all
if run_prefill: # Convert budget=0 to None for fixed mode
print("\n" + "=" * 60) minference_budget = args.minference_budget if args.minference_budget > 0 else None
print("Prefill Benchmark (nanovllm GPU)")
print("=" * 60)
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
if run_decode: if args.compare:
print("\n" + "=" * 60) # Compare baseline vs MInference using subprocesses to avoid NCCL issues
print("Decode Benchmark (nanovllm GPU)") import subprocess
print("=" * 60) import sys
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
print(f"\n{'='*60}")
print(f"Baseline vs MInference Comparison")
print(f"Input length: {prefill_input_len} tokens")
if minference_budget is not None:
print(f"MInference mode: adaptive (budget={minference_budget}, {minference_budget*100:.0f}% compute)")
else:
print(f"MInference mode: fixed (vertical={args.minference_vertical}, slash={args.minference_slash})")
print(f"{'='*60}")
# Get PYTHONPATH for subprocess
pythonpath = os.environ.get("PYTHONPATH", "")
# Run baseline in subprocess
print(f"\n[1/2] Running baseline (FULL attention)...")
cmd_baseline = [
sys.executable, __file__,
"--input-len", str(prefill_input_len),
"--max-len", str(max_len),
"--gpu-utilization", str(args.gpu_utilization),
]
env = os.environ.copy()
result = subprocess.run(cmd_baseline, capture_output=True, text=True, env=env)
print(result.stdout)
if result.returncode != 0:
print(f"Error: {result.stderr}")
return
# Parse baseline throughput
baseline_throughput = None
for line in result.stdout.split('\n'):
if "Throughput:" in line and "tok/s" in line:
# Extract throughput value
import re
match = re.search(r'Throughput:\s*([\d.]+)tok/s', line)
if match:
baseline_throughput = float(match.group(1))
# Run MInference in subprocess
if minference_budget is not None:
print(f"\n[2/2] Running MInference (budget={minference_budget})...")
else:
print(f"\n[2/2] Running MInference (vertical={args.minference_vertical}, slash={args.minference_slash})...")
cmd_minference = [
sys.executable, __file__,
"--input-len", str(prefill_input_len),
"--max-len", str(max_len),
"--gpu-utilization", str(args.gpu_utilization),
"--enable-minference",
"--minference-budget", str(args.minference_budget),
"--minference-vertical", str(args.minference_vertical),
"--minference-slash", str(args.minference_slash),
]
result = subprocess.run(cmd_minference, capture_output=True, text=True, env=env)
print(result.stdout)
if result.returncode != 0:
print(f"Error: {result.stderr}")
return
# Parse MInference throughput
minference_throughput = None
for line in result.stdout.split('\n'):
if "Throughput:" in line and "tok/s" in line:
import re
match = re.search(r'Throughput:\s*([\d.]+)tok/s', line)
if match:
minference_throughput = float(match.group(1))
# Comparison
if baseline_throughput and minference_throughput:
print(f"\n{'='*60}")
print(f"Results Summary")
print(f"{'='*60}")
print(f"Baseline: {baseline_throughput:,.0f} tok/s")
print(f"MInference: {minference_throughput:,.0f} tok/s")
speedup = minference_throughput / baseline_throughput
if speedup >= 1.0:
print(f"Speedup: {speedup:.2f}x faster")
else:
print(f"Slowdown: {1/speedup:.2f}x slower")
print(f"{'='*60}")
else:
print("Failed to parse throughput values")
else:
# Single run mode
mode = "MInference" if args.enable_minference else "GPU"
print(f"\n[nanovllm {mode}] max_len={max_len}")
if args.enable_minference:
if minference_budget is not None:
print(f"MInference mode: adaptive (budget={minference_budget})")
else:
print(f"MInference mode: fixed (vertical={args.minference_vertical}, slash={args.minference_slash})")
llm = create_llm(path, max_len, enable_minference=args.enable_minference,
minference_budget=minference_budget,
minference_vertical=args.minference_vertical,
minference_slash=args.minference_slash,
gpu_utilization=args.gpu_utilization)
# Warmup
print("\nWarming up...")
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
if run_prefill:
print("\n" + "=" * 60)
print(f"Prefill Benchmark (nanovllm {mode})")
print("=" * 60)
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
if run_decode:
print("\n" + "=" * 60)
print(f"Decode Benchmark (nanovllm {mode})")
print("=" * 60)
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -9,6 +9,7 @@ class SparsePolicyType(Enum):
"""Sparse attention policy types.""" """Sparse attention policy types."""
FULL = auto() # No sparse attention (load all blocks) FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only) QUEST = auto() # Query-aware Top-K block selection (decode only)
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
@dataclass @dataclass
@@ -39,10 +40,18 @@ class Config:
# Sparse attention configuration # Sparse attention configuration
# Quest: decode-only sparse attention with Top-K block selection # Quest: decode-only sparse attention with Top-K block selection
# FULL: no sparse attention (load all blocks) # FULL: no sparse attention (load all blocks)
# MINFERENCE: MInference vertical + slash sparse prefill (GPU-only)
sparse_policy: SparsePolicyType = SparsePolicyType.FULL sparse_policy: SparsePolicyType = SparsePolicyType.FULL
sparse_topk_blocks: int = 8 # Top-K blocks for Quest sparse_topk_blocks: int = 8 # Top-K blocks for Quest
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
# MInference configuration (used when sparse_policy == MINFERENCE)
minference_adaptive_budget: float = 0.3 # Budget as fraction of seq_len (None to use fixed sizes)
minference_vertical_size: int = 1000 # Fixed vertical size (if adaptive_budget is None)
minference_slash_size: int = 6096 # Fixed slash size (if adaptive_budget is None)
minference_num_sink_tokens: int = 30 # Sink tokens to always keep
minference_num_recent_diags: int = 100 # Recent diagonals to always keep
def __post_init__(self): def __post_init__(self):
assert os.path.isdir(self.model) assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0 assert self.kvcache_block_size % 256 == 0

View File

@@ -4,7 +4,7 @@ import torch.distributed as dist
from multiprocessing.synchronize import Event from multiprocessing.synchronize import Event
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from nanovllm.config import Config from nanovllm.config import Config, SparsePolicyType
from nanovllm.engine.sequence import Sequence from nanovllm.engine.sequence import Sequence
from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import GreedySampler from nanovllm.layers.sampler import GreedySampler
@@ -35,7 +35,10 @@ class ModelRunner:
self.model = Qwen3ForCausalLM(hf_config) self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model) load_model(self.model, config.model)
self.sampler = GreedySampler() self.sampler = GreedySampler()
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
self.sparse_prefill_policy = None
#> Disable warmup for debugging #> Disable warmup for debugging
self.warmup_model() self.warmup_model()
@@ -148,6 +151,24 @@ class ModelRunner:
# Create KV cache manager using factory # Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config) self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Create sparse prefill policy for GPU-only path
# This is separate from CPU offload sparse policy (which uses select_blocks)
self.sparse_prefill_policy = None
if not config.enable_cpu_offload and config.sparse_policy != SparsePolicyType.FULL:
from nanovllm.kvcache.sparse import create_sparse_policy
policy = create_sparse_policy(
config.sparse_policy,
vertical_size=config.minference_vertical_size,
slash_size=config.minference_slash_size,
adaptive_budget=config.minference_adaptive_budget,
num_sink_tokens=config.minference_num_sink_tokens,
num_recent_diags=config.minference_num_recent_diags,
)
# Only use if policy supports sparse prefill
if policy.supports_prefill:
self.sparse_prefill_policy = policy
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
# Allocate cache through manager # Allocate cache through manager
self.kvcache_manager.allocate_cache( self.kvcache_manager.allocate_cache(
num_layers=hf_config.num_hidden_layers, num_layers=hf_config.num_hidden_layers,
@@ -329,7 +350,10 @@ class ModelRunner:
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
slot_mapping, None, block_tables,
sparse_prefill_policy=self.sparse_prefill_policy)
return input_ids, positions return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]): def prepare_decode(self, seqs: list[Sequence]):

View File

@@ -23,6 +23,7 @@ from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.minference import MInferencePolicy
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy: def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
@@ -55,6 +56,15 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
) )
return QuestPolicy(config) return QuestPolicy(config)
elif policy_type == SparsePolicyType.MINFERENCE:
return MInferencePolicy(
vertical_size=kwargs.get("vertical_size", 1000),
slash_size=kwargs.get("slash_size", 6096),
adaptive_budget=kwargs.get("adaptive_budget", 0.3),
num_sink_tokens=kwargs.get("num_sink_tokens", 30),
num_recent_diags=kwargs.get("num_recent_diags", 100),
)
else: else:
raise ValueError(f"Unknown policy type: {policy_type}") raise ValueError(f"Unknown policy type: {policy_type}")
@@ -67,5 +77,6 @@ __all__ = [
"QuestPolicy", "QuestPolicy",
"QuestConfig", "QuestConfig",
"BlockMetadataManager", "BlockMetadataManager",
"MInferencePolicy",
"create_sparse_policy", "create_sparse_policy",
] ]

View File

@@ -0,0 +1,353 @@
"""
MInference sparse attention policy.
Implements vertical + slash sparse pattern estimation using the last 64 query tokens.
Reference: MInference paper (https://arxiv.org/abs/2407.02490)
"""
import math
from typing import List, Tuple, Optional
import torch
import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
class MInferencePolicy(SparsePolicy):
"""
MInference sparse prefill policy using vertical + slash pattern.
This policy estimates sparse attention patterns by analyzing attention
scores from the last 64 query tokens, then selects:
- Vertical: Key positions that are important across all queries
- Slash: Diagonal bands (local context)
The estimated pattern is then used to compute sparse attention.
Note: This policy is designed for GPU-only prefill. For CPU offload,
the pattern estimation and sparse attention will be handled differently.
"""
supports_prefill = True
supports_decode = False # MInference is prefill-only sparse strategy
def __init__(
self,
vertical_size: int = 1000,
slash_size: int = 6096,
adaptive_budget: Optional[float] = 0.3,
num_sink_tokens: int = 30,
num_recent_diags: int = 100,
):
"""
Initialize MInference policy.
Args:
vertical_size: Number of vertical (column) positions to keep
slash_size: Number of diagonal bands to keep
adaptive_budget: If set, compute budget as fraction of seq_len
(overrides vertical_size and slash_size)
num_sink_tokens: Number of initial sink tokens to always keep
num_recent_diags: Number of recent diagonals to always keep
"""
self.vertical_size = vertical_size
self.slash_size = slash_size
self.adaptive_budget = adaptive_budget
self.num_sink_tokens = num_sink_tokens
self.num_recent_diags = num_recent_diags
# Cache for last-q causal mask
self._last_q_mask_cache: dict = {}
def _get_causal_mask(self, last_q: int, seq_len: int, device: torch.device) -> torch.Tensor:
"""Get causal mask for last-q attention."""
cache_key = (last_q, seq_len, device)
if cache_key not in self._last_q_mask_cache:
# Create mask where last_q queries can attend to all previous positions
# Shape: [last_q, seq_len]
mask = torch.ones(last_q, seq_len, device=device, dtype=torch.bool)
# Apply causal constraint for the last last_q positions
# Query i (from last_q) can only attend to positions <= (seq_len - last_q + i)
for i in range(last_q):
mask[i, seq_len - last_q + i + 1:] = False
self._last_q_mask_cache[cache_key] = mask
return self._last_q_mask_cache[cache_key]
def estimate_pattern(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Estimate vertical + slash sparse pattern using last 64 query tokens.
Memory-optimized for long sequences (64K+).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current layer index (for potential layer-specific patterns)
Returns:
Tuple of (vertical_indices, slash_indices):
- vertical_indices: [num_heads, vertical_size] - important K positions
- slash_indices: [num_heads, slash_size] - diagonal offsets
"""
seq_len = q.shape[0]
num_heads = q.shape[1]
head_dim = q.shape[2]
num_kv_heads = k.shape[1]
# Adaptive budget
if self.adaptive_budget is not None:
budget = int(seq_len * self.adaptive_budget)
vertical_size = max(self.num_sink_tokens + 1, int(budget * 0.2))
slash_size = max(self.num_recent_diags + 1, int(budget * 0.8))
else:
vertical_size = self.vertical_size
slash_size = self.slash_size
# Use last 64 Q tokens for estimation
last_q = min(64, seq_len)
q_last = q[-last_q:] # [last_q, heads, dim] - this is a view, not a copy
# Handle GQA: if num_kv_heads < num_heads, we need to expand K
if num_kv_heads < num_heads:
num_groups = num_heads // num_kv_heads
k_work = k.repeat_interleave(num_groups, dim=1)
else:
k_work = k
# Compute attention scores: [heads, last_q, seq_len]
scale = 1.0 / math.sqrt(head_dim)
qk = torch.einsum('qhd,khd->hqk', q_last, k_work) * scale
# Free k_work if it was a copy
if num_kv_heads < num_heads:
del k_work
# Apply causal mask for last positions (in-place)
causal_mask = self._get_causal_mask(last_q, seq_len, q.device)
qk.masked_fill_(~causal_mask.unsqueeze(0), float('-inf'))
# Softmax (in-place where possible)
qk = F.softmax(qk, dim=-1, dtype=torch.float32)
# === Vertical pattern ===
# Sum across query dimension -> importance of each K position
vertical_scores = qk.sum(dim=1) # [heads, seq_len]
# Force keep first num_sink_tokens (attention sinks) - in-place
vertical_scores[:, :self.num_sink_tokens] = float('inf')
# Select top-k
actual_vertical = min(vertical_size, seq_len)
vertical_indices = vertical_scores.topk(actual_vertical, dim=-1).indices
vertical_indices = vertical_indices.sort(dim=-1).values
del vertical_scores
# === Slash pattern ===
# Create diagonal index matrix: [last_q, seq_len] with int32 to save memory
q_indices = torch.arange(last_q, device=q.device, dtype=torch.int32).unsqueeze(1)
k_indices = torch.arange(seq_len, device=q.device, dtype=torch.int32).unsqueeze(0)
diag_indices = (seq_len - last_q + q_indices) - k_indices # [last_q, seq_len]
del q_indices
# Create causal mask for slash computation
q_pos = seq_len - last_q + torch.arange(last_q, device=q.device, dtype=torch.int32).unsqueeze(1)
slash_causal_mask = k_indices <= q_pos
del q_pos, k_indices
# Clamp diagonal indices to valid range
diag_indices = diag_indices.clamp(0, seq_len - 1)
# Apply causal mask to qk (in-place) for slash computation
qk[:, ~slash_causal_mask] = 0
del slash_causal_mask
# Accumulate scores per diagonal - process in batches to save memory
slash_scores = torch.zeros(num_heads, seq_len, device=q.device, dtype=torch.float32)
# Process heads in chunks to reduce peak memory for diag_indices_expanded
chunk_size = min(8, num_heads) # Process 8 heads at a time
for h_start in range(0, num_heads, chunk_size):
h_end = min(h_start + chunk_size, num_heads)
n_heads_chunk = h_end - h_start
# Expand diag_indices only for this chunk
diag_chunk = diag_indices.unsqueeze(0).expand(n_heads_chunk, -1, -1).long()
qk_chunk = qk[h_start:h_end]
slash_scores[h_start:h_end].scatter_add_(
1,
diag_chunk.reshape(n_heads_chunk, -1),
qk_chunk.reshape(n_heads_chunk, -1)
)
del diag_chunk, qk_chunk
del diag_indices, qk
# Force keep first num_recent_diags (in-place)
slash_scores[:, :self.num_recent_diags] = float('inf')
# Select top-k diagonal indices
actual_slash = min(slash_size, seq_len)
slash_indices = slash_scores.topk(actual_slash, dim=-1).indices
slash_indices = slash_indices.sort(dim=-1).values
del slash_scores
return vertical_indices, slash_indices
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select blocks for chunked CPU offload mode.
For MInference in GPU-only mode, this method is not used.
In CPU offload mode, it would select blocks based on the sparse pattern.
For now, return all blocks (full attention fallback).
"""
# MInference pattern is computed in attention.forward()
# For CPU offload integration (Phase B), this would use the pattern
return available_blocks
def reset(self) -> None:
"""Reset policy state."""
self._last_q_mask_cache.clear()
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute MInference sparse attention for prefill.
Uses vertical + slash pattern to compute sparse attention efficiently.
Memory-optimized to handle long sequences (64K+) by freeing intermediate tensors.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from minference.ops.pit_sparse_flash_attention_v2 import _triton_mixed_sparse_attention
from minference.cuda import convert_vertical_slash_indexes
seq_len = q.shape[0]
num_heads = q.shape[1]
head_dim = q.shape[2]
num_kv_heads = k.shape[1]
# Estimate sparse pattern (uses temporary memory for qk scores)
vertical_indices, slash_indices = self.estimate_pattern(q, k, layer_id)
# Free any cached memory from pattern estimation
torch.cuda.empty_cache()
# Triton sparse attention kernel parameters
block_size_M = 64
block_size_N = 64
# Calculate padding
pad = (block_size_M - seq_len) & (block_size_M - 1)
need_head_pad = head_dim not in [16, 32, 64, 128, 256, 512]
head_pad = (2 ** math.ceil(math.log2(head_dim)) - head_dim) if need_head_pad else 0
# Handle GQA: expand K/V to match query heads
# Do this BEFORE creating batched tensors to avoid double copies
if num_kv_heads < num_heads:
num_groups = num_heads // num_kv_heads
# Use repeat_interleave for memory-efficient expansion
k_work = k.repeat_interleave(num_groups, dim=1)
v_work = v.repeat_interleave(num_groups, dim=1)
else:
k_work = k
v_work = v
# Transform Q to [batch, heads, seq, dim] format with padding in one step
# This avoids creating intermediate copies
if pad > 0 or head_pad > 0:
q_batched = torch.nn.functional.pad(
q.unsqueeze(0).transpose(1, 2),
[0, head_pad, 0, pad, 0, 0, 0, 0]
).contiguous()
else:
q_batched = q.unsqueeze(0).transpose(1, 2).contiguous()
# Transform K to batched format
if pad > 0 or head_pad > 0:
k_batched = torch.nn.functional.pad(
k_work.unsqueeze(0).transpose(1, 2),
[0, head_pad, 0, pad, 0, 0, 0, 0]
).contiguous()
else:
k_batched = k_work.unsqueeze(0).transpose(1, 2).contiguous()
# Free k_work if it was a copy (GQA case)
if num_kv_heads < num_heads:
del k_work
# Transform V to batched format
if pad > 0 or head_pad > 0:
v_batched = torch.nn.functional.pad(
v_work.unsqueeze(0).transpose(1, 2),
[0, head_pad, 0, pad, 0, 0, 0, 0]
).contiguous()
else:
v_batched = v_work.unsqueeze(0).transpose(1, 2).contiguous()
# Free v_work if it was a copy (GQA case)
if num_kv_heads < num_heads:
del v_work
torch.cuda.empty_cache()
# Prepare indices for Triton kernel
v_idx = vertical_indices.to(torch.int32).reshape((1, num_heads, -1))
v_idx = v_idx.sort(dim=-1, descending=False)[0].contiguous()
del vertical_indices
s_idx = slash_indices.to(torch.int32).reshape((1, num_heads, -1))
s_idx = s_idx.sort(dim=-1, descending=True)[0].contiguous()
del slash_indices
seqlens = torch.tensor([seq_len], dtype=torch.int32, device=q.device)
sm_scale = head_dim ** -0.5
# Convert vertical+slash indices to block sparse format
block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes(
seqlens, v_idx, s_idx, seq_len, block_size_M, block_size_N,
)
del v_idx, s_idx
# Call Triton mixed sparse attention kernel
o = _triton_mixed_sparse_attention(
q_batched, k_batched, v_batched, seqlens,
block_count, block_offset, column_count, column_index,
sm_scale, block_size_M, block_size_N,
)
# Free input tensors immediately after kernel call
del q_batched, k_batched, v_batched
del block_count, block_offset, column_count, column_index
# Remove padding and convert back to [seq_len, num_heads, head_dim]
o = o[..., :seq_len, :head_dim]
o = o.transpose(1, 2).squeeze(0).contiguous()
return o
def __repr__(self) -> str:
return (f"MInferencePolicy("
f"adaptive_budget={self.adaptive_budget}, "
f"vertical_size={self.vertical_size}, "
f"slash_size={self.slash_size})")

View File

@@ -183,5 +183,32 @@ class SparsePolicy(ABC):
""" """
pass pass
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute sparse attention for prefill phase.
This method is called when supports_prefill=True and the policy
is used for GPU-only sparse prefill (no CPU offload).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
"Set supports_prefill=False or implement this method."
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}()" return f"{self.__class__.__name__}()"

View File

@@ -140,6 +140,11 @@ class Attention(nn.Module):
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables) softmax_scale=self.scale, causal=True, block_table=context.block_tables)
elif context.sparse_prefill_policy is not None:
# Sparse prefill (GPU-only) - delegate to policy
o = context.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, self.layer_id
)
else: else:
o = flash_attn_varlen_func(q, k, v, o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,

View File

@@ -35,6 +35,10 @@ class Context:
# Current chunk index for ring buffer pipeline (prefill only) # Current chunk index for ring buffer pipeline (prefill only)
current_chunk_idx: int = 0 current_chunk_idx: int = 0
# Sparse prefill attention support (GPU-only path)
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
_CONTEXT = Context() _CONTEXT = Context()
@@ -60,6 +64,7 @@ def set_context(
decode_pos_in_block=0, decode_pos_in_block=0,
decode_start_pos_in_block=0, decode_start_pos_in_block=0,
current_chunk_idx=0, current_chunk_idx=0,
sparse_prefill_policy=None,
): ):
global _CONTEXT global _CONTEXT
_CONTEXT = Context( _CONTEXT = Context(
@@ -79,6 +84,7 @@ def set_context(
decode_pos_in_block=decode_pos_in_block, decode_pos_in_block=decode_pos_in_block,
decode_start_pos_in_block=decode_start_pos_in_block, decode_start_pos_in_block=decode_start_pos_in_block,
current_chunk_idx=current_chunk_idx, current_chunk_idx=current_chunk_idx,
sparse_prefill_policy=sparse_prefill_policy,
) )

View File

@@ -0,0 +1,163 @@
"""
Needle-in-haystack test with MInference sparse attention.
Tests: MInference sparse prefill on GPU-only path (no CPU offload).
This validates that MInference's vertical + slash sparse pattern can
correctly retrieve information from long context.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
from utils import generate_needle_prompt, check_needle_answer
def run_minference_test(
model_path: str,
max_model_len: int = 16384,
input_len: int = 8192,
needle_position: float = 0.5,
needle_value: str = "7492",
adaptive_budget: float = 0.3,
max_new_tokens: int = 32,
verbose: bool = True,
) -> bool:
"""
Run needle test with MInference sparse prefill attention.
Args:
model_path: Path to model
max_model_len: Maximum model context length
input_len: Target input sequence length
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
adaptive_budget: MInference budget as fraction of seq_len
max_new_tokens: Maximum tokens to generate
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
if verbose:
print(f"\n{'='*60}")
print(f"MInference Sparse Prefill Test (GPU-only)")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"Adaptive budget: {adaptive_budget}")
print(f"{'='*60}\n")
# Initialize LLM with MInference sparse attention
llm = LLM(
model_path,
enforce_eager=True,
max_model_len=max_model_len,
max_num_batched_tokens=max_model_len,
enable_cpu_offload=False, # GPU-only
sparse_policy=SparsePolicyType.MINFERENCE,
minference_adaptive_budget=adaptive_budget,
)
# Generate needle prompt
prompt, expected = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# Generate output
sampling_params = SamplingParams(
temperature=0.6,
max_tokens=max_new_tokens,
)
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
# Check result
output_text = outputs[0]["text"]
output_token_ids = outputs[0]["token_ids"]
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Needle-in-haystack test with MInference sparse prefill"
)
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=16 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--adaptive-budget",
type=float,
default=0.3,
help="MInference adaptive budget (fraction of seq_len)"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
args = parser.parse_args()
passed = run_minference_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
needle_position=args.needle_position,
needle_value=args.needle_value,
adaptive_budget=args.adaptive_budget,
max_new_tokens=args.max_new_tokens,
verbose=True,
)
if passed:
print("test_minference_gpu: PASSED")
else:
print("test_minference_gpu: FAILED")
exit(1)

View File

@@ -31,8 +31,13 @@ def run_needle_test(
max_new_tokens: int = 32, max_new_tokens: int = 32,
enable_cpu_offload: bool = False, enable_cpu_offload: bool = False,
enable_quest: bool = False, enable_quest: bool = False,
enable_minference: bool = False,
sparse_topk: int = 8, sparse_topk: int = 8,
sparse_threshold: int = 4, sparse_threshold: int = 4,
minference_budget: float = 0.3,
minference_vertical: int = 1000,
minference_slash: int = 6096,
gpu_utilization: float = 0.9,
verbose: bool = True, verbose: bool = True,
) -> bool: ) -> bool:
""" """
@@ -49,14 +54,25 @@ def run_needle_test(
max_new_tokens: Maximum tokens to generate max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K) enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_minference: Enable MInference sparse prefill (GPU-only)
sparse_topk: Top-K blocks for Quest sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold sparse_threshold: Apply sparse only when blocks > threshold
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
minference_vertical: Fixed vertical_size (only used when budget=None)
minference_slash: Fixed slash_size (only used when budget=None)
gpu_utilization: GPU memory utilization fraction
verbose: Print detailed output verbose: Print detailed output
Returns: Returns:
True if test passed, False otherwise True if test passed, False otherwise
""" """
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL # Determine sparse policy
if enable_minference:
sparse_policy = SparsePolicyType.MINFERENCE
elif enable_quest:
sparse_policy = SparsePolicyType.QUEST
else:
sparse_policy = SparsePolicyType.FULL
if verbose: if verbose:
print(f"\n{'='*60}") print(f"\n{'='*60}")
@@ -69,8 +85,14 @@ def run_needle_test(
print(f"Needle position: {needle_position:.0%}") print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}") print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}") print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload: print(f"Sparse policy: {sparse_policy.name}")
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})") if enable_cpu_offload and enable_quest:
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
if enable_minference:
if minference_budget is not None:
print(f" MInference: adaptive (budget={minference_budget})")
else:
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
print(f"{'='*60}\n") print(f"{'='*60}\n")
# 1. Initialize LLM # 1. Initialize LLM
@@ -80,12 +102,19 @@ def run_needle_test(
"max_num_batched_tokens": max_model_len, "max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload, "enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size, "kvcache_block_size": block_size,
"gpu_memory_utilization": gpu_utilization,
} }
if enable_cpu_offload: if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy llm_kwargs["sparse_policy"] = sparse_policy
llm_kwargs["sparse_topk_blocks"] = sparse_topk llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
elif enable_minference:
# MInference is GPU-only sparse prefill
llm_kwargs["sparse_policy"] = sparse_policy
llm_kwargs["minference_adaptive_budget"] = minference_budget
llm_kwargs["minference_vertical_size"] = minference_vertical
llm_kwargs["minference_slash_size"] = minference_slash
llm = LLM(model_path, **llm_kwargs) llm = LLM(model_path, **llm_kwargs)
@@ -186,6 +215,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)" help="Enable Quest sparse attention (decode-only Top-K selection)"
) )
parser.add_argument(
"--enable-minference",
action="store_true",
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
)
parser.add_argument( parser.add_argument(
"--sparse-topk", "--sparse-topk",
type=int, type=int,
@@ -198,8 +232,35 @@ if __name__ == "__main__":
default=4, default=4,
help="Apply sparse only when blocks > threshold" help="Apply sparse only when blocks > threshold"
) )
parser.add_argument(
"--minference-budget",
type=float,
default=0.3,
help="MInference adaptive budget (fraction of seq_len, 0.3=30%% compute, 0=fixed mode)"
)
parser.add_argument(
"--minference-vertical",
type=int,
default=1000,
help="Fixed vertical_size (only used when budget=0)"
)
parser.add_argument(
"--minference-slash",
type=int,
default=6096,
help="Fixed slash_size (only used when budget=0)"
)
parser.add_argument(
"--gpu-utilization",
type=float,
default=0.9,
help="GPU memory utilization (default: 0.9)"
)
args = parser.parse_args() args = parser.parse_args()
# Convert budget=0 to None for fixed mode
minference_budget = args.minference_budget if args.minference_budget > 0 else None
passed = run_needle_test( passed = run_needle_test(
model_path=args.model, model_path=args.model,
max_model_len=args.max_model_len, max_model_len=args.max_model_len,
@@ -211,8 +272,13 @@ if __name__ == "__main__":
max_new_tokens=args.max_new_tokens, max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload, enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest, enable_quest=args.enable_quest,
enable_minference=args.enable_minference,
sparse_topk=args.sparse_topk, sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold, sparse_threshold=args.sparse_threshold,
minference_budget=minference_budget,
minference_vertical=args.minference_vertical,
minference_slash=args.minference_slash,
gpu_utilization=args.gpu_utilization,
verbose=True, verbose=True,
) )