[tests] Added test_niah_standalone.py.

This commit is contained in:
Zijie Tian
2026-01-12 00:16:37 +08:00
parent 5895de0c97
commit a6cc703d73
6 changed files with 686 additions and 9 deletions

View File

@@ -61,6 +61,15 @@ class Config:
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
assert self.max_num_batched_tokens >= self.max_model_len
# CPU offload mode only supports single sequence (layer-wise processing)
if self.enable_cpu_offload and self.max_num_seqs != 1:
import logging
logging.warning(
f"CPU offload mode only supports single sequence. "
f"Overriding max_num_seqs from {self.max_num_seqs} to 1."
)
self.max_num_seqs = 1
# Override torch_dtype if user specified
if self.dtype is not None:
dtype_map = {

View File

@@ -27,7 +27,9 @@ class ModelRunner:
self.rank = rank
self.event = event
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
import os
port = os.environ.get("NANOVLLM_DIST_PORT", "2333")
dist.init_process_group("nccl", f"tcp://localhost:{port}", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
@@ -546,8 +548,8 @@ class ModelRunner:
k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v = v.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms (Qwen3 specific)
if not layer.self_attn.qkv_bias:
# Q/K norms (Qwen3 specific - only when qkv_bias=False)
if not getattr(layer.self_attn, 'qkv_bias', True):
num_tokens = q.shape[0]
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(num_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
@@ -649,8 +651,8 @@ class ModelRunner:
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms
if not layer.self_attn.qkv_bias:
# Q/K norms (Qwen3 specific - only when qkv_bias=False)
if not getattr(layer.self_attn, 'qkv_bias', True):
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim))
@@ -785,8 +787,8 @@ class ModelRunner:
k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v = v.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms (Qwen3 specific)
if not layer.self_attn.qkv_bias:
# Q/K norms (Qwen3 specific - only when qkv_bias=False)
if not getattr(layer.self_attn, 'qkv_bias', True):
num_tokens = q.shape[0]
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(num_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)

View File

@@ -71,6 +71,12 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
)
# max_seq_len needs to be larger than max_model_len to accommodate decode tokens
# When prefill uses ~max_model_len tokens, decode needs additional slots
# Add max_new_tokens (default 512) buffer for decode phase
max_new_tokens = getattr(config, 'max_new_tokens', 512)
max_seq_len = config.max_model_len + max_new_tokens
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
@@ -78,7 +84,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
policy=eviction_policy,
sparse_policy=sparse_policy,
num_kv_buffers=getattr(config, 'num_kv_buffers', 4),
max_seq_len=config.max_model_len,
max_seq_len=max_seq_len,
)

View File

@@ -3,7 +3,13 @@
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
# Import models to trigger registration
from nanovllm.models import qwen3
# Qwen3 requires transformers>=4.51.0 for Qwen3Config
try:
from nanovllm.models import qwen3
except ImportError as e:
import warnings
warnings.warn(f"Qwen3 model not available (requires transformers>=4.51.0): {e}")
from nanovllm.models import llama
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]