[fix] Fixed kvcache offload bugs.
This commit is contained in:
@@ -148,19 +148,39 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Log KV cache allocation info
|
||||
# Log KV cache allocation info with detailed per-token breakdown
|
||||
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
total_memory_mb = gpu_memory_mb + cpu_memory_mb
|
||||
|
||||
# Calculate per-token KV cache usage
|
||||
# KV per token = 2 (K+V) * num_layers * kv_heads * head_dim * dtype_size
|
||||
dtype_size = 2 if hf_config.torch_dtype in [torch.float16, torch.bfloat16] else 4
|
||||
per_token_kv_bytes = 2 * hf_config.num_hidden_layers * num_kv_heads * head_dim * dtype_size
|
||||
per_token_kv_kb = per_token_kv_bytes / 1024
|
||||
|
||||
logger.info(
|
||||
f"KV Cache per-token: {per_token_kv_kb:.2f}KB "
|
||||
f"(2 * {hf_config.num_hidden_layers}layers * {num_kv_heads}kv_heads * {head_dim}head_dim * {dtype_size}bytes)"
|
||||
)
|
||||
logger.info(
|
||||
f"KV Cache per-block: {block_bytes / (1024**2):.2f}MB "
|
||||
f"({per_token_kv_kb:.2f}KB * {self.block_size}tokens)"
|
||||
)
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
ping_size = config.num_gpu_kvcache_blocks // 2
|
||||
tokens_per_ping = ping_size * self.block_size
|
||||
logger.info(
|
||||
f"KV Cache allocated (Ping-Pong mode): "
|
||||
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
|
||||
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
|
||||
f"Total={total_memory_mb:.1f}MB, "
|
||||
f"block_size={self.block_size}, "
|
||||
f"ping_size={config.num_gpu_kvcache_blocks // 2}"
|
||||
f"Total={total_memory_mb:.1f}MB"
|
||||
)
|
||||
logger.info(
|
||||
f"Ping-Pong config: ping_size={ping_size} blocks, "
|
||||
f"tokens_per_chunk={tokens_per_ping}, "
|
||||
f"block_size={self.block_size}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -392,12 +412,12 @@ class ModelRunner:
|
||||
|
||||
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
|
||||
"""
|
||||
Check if Ping-Pong mode should be used.
|
||||
Check if 三区域 mode should be used.
|
||||
|
||||
Use Ping-Pong when:
|
||||
Use 三区域 when:
|
||||
- CPU offload is enabled
|
||||
- There are blocks on CPU (either allocated there or offloaded)
|
||||
- Sequence exceeds GPU capacity
|
||||
- Sequence exceeds GPU Compute区 capacity
|
||||
"""
|
||||
if not hasattr(self.kvcache_manager, 'offload_engine'):
|
||||
return False
|
||||
@@ -409,12 +429,12 @@ class ModelRunner:
|
||||
# Check if any blocks are on CPU
|
||||
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if cpu_blocks:
|
||||
# Has CPU blocks - use Ping-Pong
|
||||
# Has CPU blocks - use 三区域
|
||||
return True
|
||||
|
||||
# Check if sequence needs more blocks than GPU can hold
|
||||
ping_size = self.kvcache_manager.offload_engine.ping_size
|
||||
if seq.num_blocks > ping_size:
|
||||
# Check if sequence needs more blocks than GPU Compute区 can hold
|
||||
compute_size = self.kvcache_manager.offload_engine.num_compute_blocks
|
||||
if seq.num_blocks > compute_size:
|
||||
# Needs chunked processing
|
||||
return True
|
||||
|
||||
@@ -610,29 +630,28 @@ class ModelRunner:
|
||||
|
||||
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill with Ping-Pong dual buffer (CPU is primary storage).
|
||||
Run prefill with 三区域 GPU buffer (CPU is primary storage).
|
||||
|
||||
Flow:
|
||||
1. All blocks are allocated to CPU (primary storage)
|
||||
2. Process tokens in chunks using Ping/Pong GPU buffers
|
||||
3. After each chunk, offload from GPU to CPU
|
||||
4. Alternate between Ping and Pong buffers
|
||||
2. Process tokens in chunks using Compute区 GPU buffer
|
||||
3. After each chunk, offload from Compute区 to CPU
|
||||
4. Prefetch区 用于加载 previous KV(如果有的话)
|
||||
"""
|
||||
import sys
|
||||
|
||||
assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence"
|
||||
assert len(seqs) == 1, "三区域 prefill only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
ping_size = offload_engine.ping_size
|
||||
tokens_per_chunk = ping_size * self.block_size
|
||||
compute_size = offload_engine.num_compute_blocks
|
||||
tokens_per_chunk = compute_size * self.block_size
|
||||
|
||||
total_tokens = len(seq)
|
||||
print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, "
|
||||
f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||
print(f"[三区域 Prefill] Starting: {total_tokens} tokens, "
|
||||
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||
file=sys.stderr)
|
||||
|
||||
current_buffer = "ping"
|
||||
chunk_num = 0
|
||||
logits = None
|
||||
processed_tokens = 0
|
||||
@@ -651,15 +670,13 @@ class ModelRunner:
|
||||
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
|
||||
num_blocks = end_block_idx - start_block_idx
|
||||
|
||||
print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}",
|
||||
print(f"[三区域 Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"blocks {start_block_idx}-{end_block_idx-1}, "
|
||||
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
|
||||
file=sys.stderr)
|
||||
|
||||
# Get GPU slots for this chunk (Ping or Pong buffer)
|
||||
if current_buffer == "ping":
|
||||
gpu_slots = offload_engine.ping_slots[:num_blocks]
|
||||
else:
|
||||
gpu_slots = offload_engine.pong_slots[:num_blocks]
|
||||
# Get GPU slots for this chunk (使用 Compute区)
|
||||
gpu_slots = offload_engine.compute_slots[:num_blocks]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, positions = self._prepare_pingpong_chunk(
|
||||
@@ -678,24 +695,19 @@ class ModelRunner:
|
||||
logical_id = seq.block_table[i]
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# Offload this chunk from GPU to CPU (async)
|
||||
# Offload this chunk from Compute区 to CPU (async)
|
||||
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
|
||||
offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks)
|
||||
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
|
||||
|
||||
# Switch buffer for next chunk
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping_offload_done()
|
||||
current_buffer = "pong"
|
||||
else:
|
||||
offload_engine.wait_pong_offload_done()
|
||||
current_buffer = "ping"
|
||||
# Wait for offload to complete before next chunk
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
processed_tokens = chunk_end
|
||||
|
||||
# Wait for all offloads to complete
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||
print(f"[三区域 Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||
|
||||
# Sample from last logits
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
@@ -764,25 +776,26 @@ class ModelRunner:
|
||||
|
||||
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with Ping-Pong dual buffer.
|
||||
Run decode with 三区域 GPU buffer.
|
||||
|
||||
All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention.
|
||||
New token's KV is written to GPU then offloaded to CPU.
|
||||
All KV is on CPU. Uses Decode区 to write new KV, Compute/Prefetch区 to load KV chunks.
|
||||
New token's KV is written to Decode区 (slot 0) then offloaded to CPU.
|
||||
|
||||
关键:Decode区 永远不会被 Compute/Prefetch 覆盖,专门用于写入新KV。
|
||||
"""
|
||||
assert len(seqs) == 1, "Ping-Pong decode only supports single sequence"
|
||||
assert len(seqs) == 1, "三区域 decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
|
||||
# Prepare inputs
|
||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Get write slot for new KV (will use last slot of the buffer used for final chunk)
|
||||
write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq)
|
||||
|
||||
# Calculate position in block for slot mapping
|
||||
last_block_idx = seq.num_blocks - 1
|
||||
# 使用 Decode区 (slot 0) 写入新 KV
|
||||
decode_slot = offload_engine.decode_slot # = 0
|
||||
pos_in_block = (len(seq) - 1) % self.block_size
|
||||
slot = write_slot * self.block_size + pos_in_block
|
||||
slot = decode_slot * self.block_size + pos_in_block
|
||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
@@ -794,17 +807,18 @@ class ModelRunner:
|
||||
is_chunked_prefill=True, # Use chunked attention path
|
||||
offload_engine=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
decode_pos_in_block=pos_in_block,
|
||||
)
|
||||
|
||||
# Run model forward pass
|
||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||
reset_context()
|
||||
|
||||
# Offload new KV from write_slot to CPU
|
||||
# Offload new KV from Decode区 to CPU
|
||||
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
||||
if last_cpu_block >= 0:
|
||||
self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block)
|
||||
self.kvcache_manager.offload_engine.wait_all_offload_done()
|
||||
offload_engine.offload_decode_slot(last_cpu_block)
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
# Sample
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
|
||||
Reference in New Issue
Block a user