[WIP] need to fix model to normally decode.

This commit is contained in:
Zijie Tian
2026-01-01 05:18:27 +08:00
parent 62b8a63314
commit 74ee6d0895
3 changed files with 317 additions and 123 deletions

View File

@@ -118,6 +118,24 @@ class OffloadEngine:
dtype=dtype, device="cuda"
)
# ========== Per-layer decode buffer ==========
# During decode, all layers share decode_slot (no layer dimension in GPU cache).
# This causes accumulated tokens to be overwritten by each layer.
# Solution: Maintain separate per-layer buffers for decode tokens.
# Shape: [num_layers, block_size, kv_heads, head_dim]
# Memory: num_layers * block_size * kv_heads * head_dim * dtype_size
# e.g., 28 * 1024 * 8 * 128 * 2 = 58.7 MB (acceptable)
self.decode_k_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.decode_v_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,