[WIP] need refactor.
This commit is contained in:
@@ -57,8 +57,8 @@ class ModelRunner:
|
||||
load_model(self.model, config.model)
|
||||
self.sampler = GreedySampler()
|
||||
|
||||
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
|
||||
self.sparse_prefill_policy = None
|
||||
# Initialize attention_policy before warmup (will be configured in allocate_kv_cache)
|
||||
self.attention_policy = None
|
||||
|
||||
#> Disable warmup for debugging
|
||||
self.warmup_model()
|
||||
@@ -178,38 +178,35 @@ class ModelRunner:
|
||||
# Create KV cache manager using factory
|
||||
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
||||
|
||||
# Create sparse prefill policy
|
||||
# This is used for both GPU-only and CPU offload modes when policy supports prefill
|
||||
self.sparse_prefill_policy = None
|
||||
if config.sparse_policy != SparsePolicyType.FULL:
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||
# Create attention policy (always, including FULL)
|
||||
# In layerwise offload mode, all attention goes through the policy
|
||||
from nanovllm.kvcache.sparse import create_attention_policy
|
||||
|
||||
# Get policy-specific parameters based on type
|
||||
if config.sparse_policy == SparsePolicyType.XATTN:
|
||||
policy_kwargs = {
|
||||
"stride": config.xattn_stride,
|
||||
"threshold": config.xattn_threshold,
|
||||
"chunk_size": config.xattn_chunk_size,
|
||||
"use_triton": config.xattn_use_triton,
|
||||
"keep_sink": config.xattn_keep_sink,
|
||||
"keep_recent": config.xattn_keep_recent,
|
||||
"norm": config.xattn_norm,
|
||||
}
|
||||
else: # MINFERENCE or others
|
||||
policy_kwargs = {
|
||||
"vertical_size": config.minference_vertical_size,
|
||||
"slash_size": config.minference_slash_size,
|
||||
"adaptive_budget": config.minference_adaptive_budget,
|
||||
"num_sink_tokens": config.minference_num_sink_tokens,
|
||||
"num_recent_diags": config.minference_num_recent_diags,
|
||||
}
|
||||
# Get policy-specific parameters based on type
|
||||
if config.sparse_policy == SparsePolicyType.XATTN:
|
||||
policy_kwargs = {
|
||||
"stride": config.xattn_stride,
|
||||
"threshold": config.xattn_threshold,
|
||||
"chunk_size": config.xattn_chunk_size,
|
||||
"use_triton": config.xattn_use_triton,
|
||||
"keep_sink": config.xattn_keep_sink,
|
||||
"keep_recent": config.xattn_keep_recent,
|
||||
"norm": config.xattn_norm,
|
||||
"use_bsa": config.xattn_use_bsa,
|
||||
}
|
||||
elif config.sparse_policy == SparsePolicyType.MINFERENCE:
|
||||
policy_kwargs = {
|
||||
"vertical_size": config.minference_vertical_size,
|
||||
"slash_size": config.minference_slash_size,
|
||||
"adaptive_budget": config.minference_adaptive_budget,
|
||||
"num_sink_tokens": config.minference_num_sink_tokens,
|
||||
"num_recent_diags": config.minference_num_recent_diags,
|
||||
}
|
||||
else: # FULL or QUEST
|
||||
policy_kwargs = {}
|
||||
|
||||
policy = create_sparse_policy(config.sparse_policy, **policy_kwargs)
|
||||
|
||||
# Only use if policy supports sparse prefill
|
||||
if policy.supports_prefill:
|
||||
self.sparse_prefill_policy = policy
|
||||
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
|
||||
self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs)
|
||||
logger.info(f"Attention policy: {self.attention_policy}")
|
||||
|
||||
# Allocate cache through manager
|
||||
self.kvcache_manager.allocate_cache(
|
||||
@@ -395,7 +392,7 @@ class ModelRunner:
|
||||
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
slot_mapping, None, block_tables,
|
||||
sparse_prefill_policy=self.sparse_prefill_policy)
|
||||
attention_policy=self.attention_policy)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
@@ -592,21 +589,11 @@ class ModelRunner:
|
||||
# RoPE
|
||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||
|
||||
# Sparse or Full attention (uses k, v directly - before store!)
|
||||
if self.sparse_prefill_policy is not None:
|
||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||
q, k, v, layer_id
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=total_tokens,
|
||||
max_seqlen_k=total_tokens,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
causal=True,
|
||||
)
|
||||
# Compute attention using policy (uses k, v directly - before store!)
|
||||
attn_output = self.attention_policy.compute_prefill(
|
||||
q, k, v, layer_id,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
)
|
||||
|
||||
# O projection
|
||||
attn_output = attn_output.view(total_tokens, -1)
|
||||
@@ -872,23 +859,11 @@ class ModelRunner:
|
||||
# RoPE
|
||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||
|
||||
# Sparse or Full attention
|
||||
if self.sparse_prefill_policy is not None:
|
||||
# MInference or other sparse prefill policy
|
||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||
q, k, v, layer_id
|
||||
)
|
||||
else:
|
||||
# Full attention using FlashAttention
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=total_tokens,
|
||||
max_seqlen_k=total_tokens,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
causal=True,
|
||||
)
|
||||
# Compute attention using policy
|
||||
attn_output = self.attention_policy.compute_prefill(
|
||||
q, k, v, layer_id,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
)
|
||||
|
||||
# O projection
|
||||
attn_output = attn_output.view(total_tokens, -1)
|
||||
|
||||
Reference in New Issue
Block a user