[claudesquad] update from 'layer-prefill-1' on 08 Jan 26 03:36 CST

This commit is contained in:
Zijie Tian
2026-01-08 03:36:39 +08:00
parent 6575099a06
commit d8a87da1c3
10 changed files with 822 additions and 32 deletions

View File

@@ -4,7 +4,7 @@ import torch.distributed as dist
from multiprocessing.synchronize import Event
from multiprocessing.shared_memory import SharedMemory
from nanovllm.config import Config
from nanovllm.config import Config, SparsePolicyType
from nanovllm.engine.sequence import Sequence
from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import GreedySampler
@@ -35,7 +35,10 @@ class ModelRunner:
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = GreedySampler()
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
self.sparse_prefill_policy = None
#> Disable warmup for debugging
self.warmup_model()
@@ -148,6 +151,24 @@ class ModelRunner:
# Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Create sparse prefill policy for GPU-only path
# This is separate from CPU offload sparse policy (which uses select_blocks)
self.sparse_prefill_policy = None
if not config.enable_cpu_offload and config.sparse_policy != SparsePolicyType.FULL:
from nanovllm.kvcache.sparse import create_sparse_policy
policy = create_sparse_policy(
config.sparse_policy,
vertical_size=config.minference_vertical_size,
slash_size=config.minference_slash_size,
adaptive_budget=config.minference_adaptive_budget,
num_sink_tokens=config.minference_num_sink_tokens,
num_recent_diags=config.minference_num_recent_diags,
)
# Only use if policy supports sparse prefill
if policy.supports_prefill:
self.sparse_prefill_policy = policy
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
# Allocate cache through manager
self.kvcache_manager.allocate_cache(
num_layers=hf_config.num_hidden_layers,
@@ -329,7 +350,10 @@ class ModelRunner:
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
slot_mapping, None, block_tables,
sparse_prefill_policy=self.sparse_prefill_policy)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):