support qwen2
This commit is contained in:
@@ -105,10 +105,11 @@ class ModelRunner:
|
||||
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert config.num_kvcache_blocks > 0
|
||||
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||
|
||||
Reference in New Issue
Block a user