[claudesquad] update from 'int-minference-1' on 08 Jan 26 23:22 CST
This commit is contained in:
@@ -336,7 +336,8 @@ class OffloadEngine:
|
||||
"""
|
||||
Async offload entire decode buffer to CPU.
|
||||
|
||||
Called when a decode block is full.
|
||||
Called when a decode block is full. Also calls sparse policy hooks
|
||||
to update metadata (e.g., Quest min/max keys).
|
||||
|
||||
Args:
|
||||
cpu_block_id: Target CPU block ID
|
||||
@@ -346,6 +347,14 @@ class OffloadEngine:
|
||||
self.decode_offload_stream.wait_stream(self.compute_stream)
|
||||
|
||||
for layer_id in range(self.num_layers):
|
||||
# Hook: notify sparse policy BEFORE offload (k still on GPU)
|
||||
if self.sparse_policy is not None:
|
||||
self.sparse_policy.on_decode_offload(
|
||||
cpu_block_id, layer_id,
|
||||
self.decode_k_buffer[layer_id],
|
||||
self.block_size # Full block
|
||||
)
|
||||
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.decode_k_buffer[layer_id], non_blocking=True
|
||||
)
|
||||
@@ -359,3 +368,42 @@ class OffloadEngine:
|
||||
def wait_decode_offload(self) -> None:
|
||||
"""Wait for decode buffer offload to complete."""
|
||||
self.compute_stream.wait_event(self.decode_offload_event)
|
||||
|
||||
# ========== Encapsulated Prefill Offload API (with sparse hooks) ==========
|
||||
|
||||
def offload_layer_kv_sync(
|
||||
self,
|
||||
layer_id: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
cpu_block_ids: List[int],
|
||||
total_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Synchronously offload layer KV to CPU with sparse policy hooks.
|
||||
|
||||
This method encapsulates:
|
||||
1. Block-wise copy to CPU cache
|
||||
2. Sparse policy hooks (on_prefill_offload for Quest metadata)
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
k: Key tensor [seq_len, kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, kv_heads, head_dim]
|
||||
cpu_block_ids: List of CPU block IDs to offload to
|
||||
total_tokens: Total number of tokens
|
||||
"""
|
||||
for i, cpu_block_id in enumerate(cpu_block_ids):
|
||||
start = i * self.block_size
|
||||
end = min(start + self.block_size, total_tokens)
|
||||
actual_size = end - start
|
||||
|
||||
# Hook: notify sparse policy BEFORE offload (k still on GPU)
|
||||
if self.sparse_policy is not None:
|
||||
self.sparse_policy.on_prefill_offload(
|
||||
cpu_block_id, layer_id, k[start:end], actual_size
|
||||
)
|
||||
|
||||
# Synchronous copy to CPU
|
||||
self.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
|
||||
self.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
|
||||
|
||||
Reference in New Issue
Block a user