Merge pull request #60 from GeeeekExplorer/warmup
This commit is contained in:
@@ -14,9 +14,9 @@ A lightweight vLLM implementation built from scratch.
|
|||||||
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
||||||
```
|
```
|
||||||
|
|
||||||
## Manual download
|
## Manual Download
|
||||||
|
|
||||||
If you’d rather fetch the model weights yourself, you can use:
|
If you prefer to download the model weights manually, use the following command:
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
|
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
|
||||||
--local-dir ~/huggingface/Qwen3-0.6B/ \
|
--local-dir ~/huggingface/Qwen3-0.6B/ \
|
||||||
@@ -25,7 +25,7 @@ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
|
|||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method.
|
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
|
||||||
```python
|
```python
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from transformers import AutoConfig
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
model: str
|
model: str
|
||||||
max_num_batched_tokens: int = 32768
|
max_num_batched_tokens: int = 16384
|
||||||
max_num_seqs: int = 512
|
max_num_seqs: int = 512
|
||||||
max_model_len: int = 4096
|
max_model_len: int = 4096
|
||||||
gpu_memory_utilization: float = 0.9
|
gpu_memory_utilization: float = 0.9
|
||||||
@@ -23,3 +23,4 @@ class Config:
|
|||||||
assert 1 <= self.tensor_parallel_size <= 8
|
assert 1 <= self.tensor_parallel_size <= 8
|
||||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||||
|
assert self.max_num_batched_tokens >= self.max_model_len
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class ModelRunner:
|
|||||||
self.world_size = config.tensor_parallel_size
|
self.world_size = config.tensor_parallel_size
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.event = event
|
self.event = event
|
||||||
|
|
||||||
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
default_dtype = torch.get_default_dtype()
|
default_dtype = torch.get_default_dtype()
|
||||||
@@ -31,7 +31,8 @@ class ModelRunner:
|
|||||||
self.model = Qwen3ForCausalLM(hf_config)
|
self.model = Qwen3ForCausalLM(hf_config)
|
||||||
load_model(self.model, config.model)
|
load_model(self.model, config.model)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
self.allocate_kv_cache(config.gpu_memory_utilization)
|
self.warmup_model()
|
||||||
|
self.allocate_kv_cache()
|
||||||
if not self.enforce_eager:
|
if not self.enforce_eager:
|
||||||
self.capture_cudagraph()
|
self.capture_cudagraph()
|
||||||
torch.set_default_device("cpu")
|
torch.set_default_device("cpu")
|
||||||
@@ -76,7 +77,6 @@ class ModelRunner:
|
|||||||
assert self.world_size > 1 and not self.rank
|
assert self.world_size > 1 and not self.rank
|
||||||
data = pickle.dumps([method_name, *args])
|
data = pickle.dumps([method_name, *args])
|
||||||
n = len(data)
|
n = len(data)
|
||||||
assert n + 4 <= self.shm.size
|
|
||||||
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
||||||
self.shm.buf[4:n+4] = data
|
self.shm.buf[4:n+4] = data
|
||||||
for event in self.event:
|
for event in self.event:
|
||||||
@@ -86,17 +86,28 @@ class ModelRunner:
|
|||||||
if self.world_size > 1 and self.rank == 0:
|
if self.world_size > 1 and self.rank == 0:
|
||||||
self.write_shm(method_name, *args)
|
self.write_shm(method_name, *args)
|
||||||
method = getattr(self, method_name, None)
|
method = getattr(self, method_name, None)
|
||||||
assert callable(method)
|
|
||||||
return method(*args)
|
return method(*args)
|
||||||
|
|
||||||
def allocate_kv_cache(self, gpu_memory_utilization):
|
def warmup_model(self):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
||||||
|
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
||||||
|
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
||||||
|
self.run(seqs, True)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def allocate_kv_cache(self):
|
||||||
config = self.config
|
config = self.config
|
||||||
hf_config = config.hf_config
|
hf_config = config.hf_config
|
||||||
free, total = torch.cuda.mem_get_info()
|
free, total = torch.cuda.mem_get_info()
|
||||||
used = total - free
|
used = total - free
|
||||||
|
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||||
|
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||||
config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes
|
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||||
|
assert config.num_kvcache_blocks > 0
|
||||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||||
layer_id = 0
|
layer_id = 0
|
||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
@@ -107,10 +118,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def prepare_block_tables(self, seqs: list[Sequence]):
|
def prepare_block_tables(self, seqs: list[Sequence]):
|
||||||
max_len = max(len(seq.block_table) for seq in seqs)
|
max_len = max(len(seq.block_table) for seq in seqs)
|
||||||
block_tables = [
|
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
||||||
seq.block_table + [-1] * (max_len - len(seq.block_table))
|
|
||||||
for seq in seqs
|
|
||||||
]
|
|
||||||
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
return block_tables
|
return block_tables
|
||||||
|
|
||||||
@@ -133,6 +141,8 @@ class ModelRunner:
|
|||||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||||
|
if not seq.block_table:
|
||||||
|
continue
|
||||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||||
start = seq.block_table[i] * self.block_size
|
start = seq.block_table[i] * self.block_size
|
||||||
if i != seq.num_blocks - 1:
|
if i != seq.num_blocks - 1:
|
||||||
@@ -140,7 +150,6 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
end = start + seq.last_block_num_tokens
|
end = start + seq.last_block_num_tokens
|
||||||
slot_mapping.extend(list(range(start, end)))
|
slot_mapping.extend(list(range(start, end)))
|
||||||
assert len(input_ids) == len(slot_mapping)
|
|
||||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||||
block_tables = self.prepare_block_tables(seqs)
|
block_tables = self.prepare_block_tables(seqs)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
@@ -177,7 +186,7 @@ class ModelRunner:
|
|||||||
return temperatures
|
return temperatures
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill):
|
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
||||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
||||||
return self.model.compute_logits(self.model(input_ids, positions))
|
return self.model.compute_logits(self.model(input_ids, positions))
|
||||||
else:
|
else:
|
||||||
@@ -206,12 +215,6 @@ class ModelRunner:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_cudagraph(self):
|
def capture_cudagraph(self):
|
||||||
get_rng_state = torch.cuda.get_rng_state
|
|
||||||
set_rng_state = torch.cuda.set_rng_state
|
|
||||||
rng_state = torch.cuda.get_rng_state()
|
|
||||||
torch.cuda.get_rng_state = lambda: rng_state
|
|
||||||
torch.cuda.set_rng_state = lambda _: None
|
|
||||||
|
|
||||||
config = self.config
|
config = self.config
|
||||||
hf_config = config.hf_config
|
hf_config = config.hf_config
|
||||||
max_bs = min(self.config.max_num_seqs, 512)
|
max_bs = min(self.config.max_num_seqs, 512)
|
||||||
@@ -246,6 +249,3 @@ class ModelRunner:
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.get_rng_state = get_rng_state
|
|
||||||
torch.cuda.set_rng_state = set_rng_state
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class Sequence:
|
|||||||
block_size = 256
|
block_size = 256
|
||||||
counter = count()
|
counter = count()
|
||||||
|
|
||||||
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
|
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
||||||
self.seq_id = next(Sequence.counter)
|
self.seq_id = next(Sequence.counter)
|
||||||
self.status = SequenceStatus.WAITING
|
self.status = SequenceStatus.WAITING
|
||||||
self.token_ids = copy(token_ids)
|
self.token_ids = copy(token_ids)
|
||||||
|
|||||||
@@ -61,9 +61,9 @@ class Attention(nn.Module):
|
|||||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
context = get_context()
|
context = get_context()
|
||||||
k_cache = self.k_cache
|
k_cache, v_cache = self.k_cache, self.v_cache
|
||||||
v_cache = self.v_cache
|
if k_cache.numel() and v_cache.numel():
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
if context.is_prefill:
|
if context.is_prefill:
|
||||||
if context.block_tables is not None: # prefix cache
|
if context.block_tables is not None: # prefix cache
|
||||||
k, v = k_cache, v_cache
|
k, v = k_cache, v_cache
|
||||||
|
|||||||
Reference in New Issue
Block a user