[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -10,10 +10,10 @@ from typing import List, Tuple, Optional
import torch
import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext
class MInferencePolicy(SparsePolicy):
class MInferencePolicy(AttentionPolicy):
"""
MInference sparse prefill policy using vertical + slash pattern.
@@ -347,6 +347,33 @@ class MInferencePolicy(SparsePolicy):
return o
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute MInference sparse prefill attention.
This is the new unified interface for attention policies.
Delegates to sparse_prefill_attention (ignores softmax_scale as MInference
computes it internally from head_dim).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (unused, computed internally)
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
return self.sparse_prefill_attention(q, k, v, layer_id)
def __repr__(self) -> str:
return (f"MInferencePolicy("
f"adaptive_budget={self.adaptive_budget}, "