[WIP] need to fix model to normally decode.
This commit is contained in:
@@ -118,6 +118,24 @@ class OffloadEngine:
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
# ========== Per-layer decode buffer ==========
|
||||
# During decode, all layers share decode_slot (no layer dimension in GPU cache).
|
||||
# This causes accumulated tokens to be overwritten by each layer.
|
||||
# Solution: Maintain separate per-layer buffers for decode tokens.
|
||||
# Shape: [num_layers, block_size, kv_heads, head_dim]
|
||||
# Memory: num_layers * block_size * kv_heads * head_dim * dtype_size
|
||||
# e.g., 28 * 1024 * 8 * 128 * 2 = 58.7 MB (acceptable)
|
||||
self.decode_k_buffer = torch.zeros(
|
||||
num_layers, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.decode_v_buffer = torch.zeros(
|
||||
num_layers, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||
|
||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||
self.k_cache_cpu = torch.zeros(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
|
||||
Reference in New Issue
Block a user