[feat] Added bench_offload.py and GreedySampler.
This commit is contained in:
@@ -5,7 +5,7 @@ from nanovllm import LLM, SamplingParams
|
|||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, max_input_len, max_output_len):
|
def bench_decode(llm, num_seqs, max_input_len, max_output_len):
|
||||||
"""Benchmark decode performance"""
|
"""Benchmark decode performance (original test)"""
|
||||||
seed(0)
|
seed(0)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
||||||
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
||||||
@@ -21,6 +21,7 @@ def bench_decode(llm, num_seqs, max_input_len, max_output_len):
|
|||||||
def bench_prefill(llm, num_seqs, input_len):
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
"""Benchmark prefill performance"""
|
"""Benchmark prefill performance"""
|
||||||
seed(0)
|
seed(0)
|
||||||
|
# Fixed length input, minimal output to focus on prefill
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
||||||
|
|
||||||
@@ -40,6 +41,8 @@ def main():
|
|||||||
max_model_len=128 * 1024,
|
max_model_len=128 * 1024,
|
||||||
max_num_batched_tokens=128 * 1024,
|
max_num_batched_tokens=128 * 1024,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
|
num_gpu_blocks=6,
|
||||||
|
num_prefetch_blocks=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
@@ -48,15 +51,16 @@ def main():
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Prefill Benchmark (CPU Offload)")
|
print("Prefill Benchmark (CPU Offload)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
bench_prefill(llm, num_seqs=1, input_len=64*1024)
|
# bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=16384)
|
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||||
# bench_prefill(llm, num_seqs=1, input_len=32000)
|
# bench_prefill(llm, num_seqs=1, input_len=4096)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=8192)
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Decode Benchmark (CPU Offload)")
|
print("Decode Benchmark (CPU Offload)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
bench_decode(llm, num_seqs=1, max_input_len=64*1024, max_output_len=256)
|
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=128)
|
||||||
# bench_decode(llm, num_seqs=1, max_input_len=16384, max_output_len=256)
|
# bench_decode(llm, num_seqs=1, max_input_len=2048, max_output_len=128)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from multiprocessing.shared_memory import SharedMemory
|
|||||||
from nanovllm.config import Config
|
from nanovllm.config import Config
|
||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||||
from nanovllm.layers.sampler import Sampler
|
from nanovllm.layers.sampler import GreedySampler
|
||||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||||
from nanovllm.utils.loader import load_model
|
from nanovllm.utils.loader import load_model
|
||||||
from nanovllm.utils.logger import get_logger
|
from nanovllm.utils.logger import get_logger
|
||||||
@@ -34,7 +34,7 @@ class ModelRunner:
|
|||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
self.model = Qwen3ForCausalLM(hf_config)
|
self.model = Qwen3ForCausalLM(hf_config)
|
||||||
load_model(self.model, config.model)
|
load_model(self.model, config.model)
|
||||||
self.sampler = Sampler()
|
self.sampler = GreedySampler()
|
||||||
self.warmup_model()
|
self.warmup_model()
|
||||||
self.allocate_kv_cache()
|
self.allocate_kv_cache()
|
||||||
if not self.enforce_eager:
|
if not self.enforce_eager:
|
||||||
|
|||||||
@@ -1039,6 +1039,8 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
assert not seq.block_table, "Sequence already has blocks"
|
assert not seq.block_table, "Sequence already has blocks"
|
||||||
|
|
||||||
|
h = -1 # Running hash for prefix cache
|
||||||
|
|
||||||
for i in range(seq.num_blocks):
|
for i in range(seq.num_blocks):
|
||||||
# Allocate CPU block
|
# Allocate CPU block
|
||||||
if not self.free_cpu_blocks:
|
if not self.free_cpu_blocks:
|
||||||
@@ -1049,10 +1051,19 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
|
|
||||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||||
|
|
||||||
|
# Get token IDs for this block and compute hash
|
||||||
|
token_ids = seq.block(i)
|
||||||
|
if len(token_ids) == self._block_size:
|
||||||
|
h = self.compute_hash(token_ids, h)
|
||||||
|
else:
|
||||||
|
h = -1 # Incomplete block
|
||||||
|
|
||||||
# Allocate logical block
|
# Allocate logical block
|
||||||
logical_id = self.free_logical_ids.popleft()
|
logical_id = self.free_logical_ids.popleft()
|
||||||
block = self.logical_blocks[logical_id]
|
block = self.logical_blocks[logical_id]
|
||||||
block.ref_count = 1
|
block.ref_count = 1
|
||||||
|
block.hash = h
|
||||||
|
block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else []
|
||||||
block.location = BlockLocation.CPU
|
block.location = BlockLocation.CPU
|
||||||
block.cpu_block_id = cpu_block_id
|
block.cpu_block_id = cpu_block_id
|
||||||
block.gpu_slot = -1
|
block.gpu_slot = -1
|
||||||
@@ -1060,6 +1071,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||||
seq.block_table.append(logical_id)
|
seq.block_table.append(logical_id)
|
||||||
|
|
||||||
|
# Update prefix cache
|
||||||
|
if h != -1:
|
||||||
|
self.hash_to_logical_id[h] = logical_id
|
||||||
|
|
||||||
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Get CPU block ID list for sequence.
|
Get CPU block ID list for sequence.
|
||||||
|
|||||||
@@ -13,3 +13,13 @@ class Sampler(nn.Module):
|
|||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
||||||
return sample_tokens
|
return sample_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class GreedySampler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor = None):
|
||||||
|
return logits.argmax(dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user