[WIP] Before fix needle.

This commit is contained in:
Zijie Tian
2025-12-31 23:35:25 +08:00
parent ccd1b3d4ab
commit 30462fe89a
5 changed files with 212 additions and 290 deletions

View File

@@ -489,24 +489,15 @@ class ModelRunner:
logical_id = seq.block_table[block_idx]
self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk's ring buffer slot to CPU (async)
# NOTE: Per-layer offloading is now done in attention.forward
# Each layer offloads its KV to CPU immediately after computing attention.
# We just need to wait for the last offload to complete before reusing the slot.
if block_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[block_idx]
# Call sparse policy hook before offload (to capture metadata)
sparse_policy = self.kvcache_manager.sparse_policy
if sparse_policy is not None:
num_tokens = chunk_end - chunk_start
for layer_id in range(offload_engine.num_layers):
k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens]
sparse_policy.on_block_offloaded(
cpu_block_id=cpu_block_id,
layer_id=layer_id,
k_cache=k_cache,
num_valid_tokens=num_tokens,
)
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
# TODO: Sparse policy hook needs update for new GPU cache architecture
# The GPU cache no longer has layer dimension, so we can't access
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
# in attention.forward after per-layer offload.
pass
# Wait for offload to complete before next chunk
# (slot will be reused after N chunks)
@@ -628,7 +619,11 @@ class ModelRunner:
if pos_in_block == self.block_size - 1:
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0:
offload_engine.offload_decode_slot(last_cpu_block)
# TODO: In new GPU cache architecture (no layer dimension),
# decode offload should be done per-layer in attention.forward.
# For now, offload all layers sequentially.
for layer_id in range(offload_engine.num_layers):
offload_engine.offload_decode_slot_layer(layer_id, last_cpu_block)
offload_engine.wait_all_offload_done()
# Reset decode start position for next block
self.kvcache_manager.reset_decode_start_pos(seq)