[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -57,8 +57,8 @@ class ModelRunner:
load_model(self.model, config.model)
self.sampler = GreedySampler()
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
self.sparse_prefill_policy = None
# Initialize attention_policy before warmup (will be configured in allocate_kv_cache)
self.attention_policy = None
#> Disable warmup for debugging
self.warmup_model()
@@ -178,38 +178,35 @@ class ModelRunner:
# Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Create sparse prefill policy
# This is used for both GPU-only and CPU offload modes when policy supports prefill
self.sparse_prefill_policy = None
if config.sparse_policy != SparsePolicyType.FULL:
from nanovllm.kvcache.sparse import create_sparse_policy
# Create attention policy (always, including FULL)
# In layerwise offload mode, all attention goes through the policy
from nanovllm.kvcache.sparse import create_attention_policy
# Get policy-specific parameters based on type
if config.sparse_policy == SparsePolicyType.XATTN:
policy_kwargs = {
"stride": config.xattn_stride,
"threshold": config.xattn_threshold,
"chunk_size": config.xattn_chunk_size,
"use_triton": config.xattn_use_triton,
"keep_sink": config.xattn_keep_sink,
"keep_recent": config.xattn_keep_recent,
"norm": config.xattn_norm,
}
else: # MINFERENCE or others
policy_kwargs = {
"vertical_size": config.minference_vertical_size,
"slash_size": config.minference_slash_size,
"adaptive_budget": config.minference_adaptive_budget,
"num_sink_tokens": config.minference_num_sink_tokens,
"num_recent_diags": config.minference_num_recent_diags,
}
# Get policy-specific parameters based on type
if config.sparse_policy == SparsePolicyType.XATTN:
policy_kwargs = {
"stride": config.xattn_stride,
"threshold": config.xattn_threshold,
"chunk_size": config.xattn_chunk_size,
"use_triton": config.xattn_use_triton,
"keep_sink": config.xattn_keep_sink,
"keep_recent": config.xattn_keep_recent,
"norm": config.xattn_norm,
"use_bsa": config.xattn_use_bsa,
}
elif config.sparse_policy == SparsePolicyType.MINFERENCE:
policy_kwargs = {
"vertical_size": config.minference_vertical_size,
"slash_size": config.minference_slash_size,
"adaptive_budget": config.minference_adaptive_budget,
"num_sink_tokens": config.minference_num_sink_tokens,
"num_recent_diags": config.minference_num_recent_diags,
}
else: # FULL or QUEST
policy_kwargs = {}
policy = create_sparse_policy(config.sparse_policy, **policy_kwargs)
# Only use if policy supports sparse prefill
if policy.supports_prefill:
self.sparse_prefill_policy = policy
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# Allocate cache through manager
self.kvcache_manager.allocate_cache(
@@ -395,7 +392,7 @@ class ModelRunner:
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
slot_mapping, None, block_tables,
sparse_prefill_policy=self.sparse_prefill_policy)
attention_policy=self.attention_policy)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
@@ -592,21 +589,11 @@ class ModelRunner:
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention (uses k, v directly - before store!)
if self.sparse_prefill_policy is not None:
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# Compute attention using policy (uses k, v directly - before store!)
attn_output = self.attention_policy.compute_prefill(
q, k, v, layer_id,
softmax_scale=layer.self_attn.attn.scale,
)
# O projection
attn_output = attn_output.view(total_tokens, -1)
@@ -872,23 +859,11 @@ class ModelRunner:
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention
if self.sparse_prefill_policy is not None:
# MInference or other sparse prefill policy
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# Compute attention using policy
attn_output = self.attention_policy.compute_prefill(
q, k, v, layer_id,
softmax_scale=layer.self_attn.attn.scale,
)
# O projection
attn_output = attn_output.view(total_tokens, -1)