[claudesquad] update from 'int-minference-1' on 08 Jan 26 23:22 CST
This commit is contained in:
@@ -531,16 +531,23 @@ class ModelRunner:
|
||||
# RoPE
|
||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||
|
||||
# Full attention using FlashAttention
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=total_tokens,
|
||||
max_seqlen_k=total_tokens,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
causal=True,
|
||||
)
|
||||
# Sparse or Full attention
|
||||
if self.sparse_prefill_policy is not None:
|
||||
# MInference or other sparse prefill policy
|
||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||
q, k, v, layer_id
|
||||
)
|
||||
else:
|
||||
# Full attention using FlashAttention
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=total_tokens,
|
||||
max_seqlen_k=total_tokens,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# O projection
|
||||
attn_output = attn_output.view(total_tokens, -1)
|
||||
@@ -550,16 +557,8 @@ class ModelRunner:
|
||||
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = layer.mlp(hidden_states)
|
||||
|
||||
# 2d. Offload KV to CPU (synchronous to avoid race condition)
|
||||
# NOTE: Async offload has race condition where k,v memory gets reused
|
||||
# before D2H copy completes. Use sync copy for correctness.
|
||||
block_size = offload_engine.block_size
|
||||
for i, cpu_block_id in enumerate(cpu_block_ids):
|
||||
start = i * block_size
|
||||
end = min(start + block_size, total_tokens)
|
||||
actual_size = end - start
|
||||
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
|
||||
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
|
||||
# 2d. Offload KV to CPU (encapsulated with sparse policy hooks)
|
||||
offload_engine.offload_layer_kv_sync(layer_id, k, v, cpu_block_ids, total_tokens)
|
||||
|
||||
# Step 3: Final norm
|
||||
hidden_states, _ = self.model.model.norm(hidden_states, residual)
|
||||
|
||||
@@ -336,7 +336,8 @@ class OffloadEngine:
|
||||
"""
|
||||
Async offload entire decode buffer to CPU.
|
||||
|
||||
Called when a decode block is full.
|
||||
Called when a decode block is full. Also calls sparse policy hooks
|
||||
to update metadata (e.g., Quest min/max keys).
|
||||
|
||||
Args:
|
||||
cpu_block_id: Target CPU block ID
|
||||
@@ -346,6 +347,14 @@ class OffloadEngine:
|
||||
self.decode_offload_stream.wait_stream(self.compute_stream)
|
||||
|
||||
for layer_id in range(self.num_layers):
|
||||
# Hook: notify sparse policy BEFORE offload (k still on GPU)
|
||||
if self.sparse_policy is not None:
|
||||
self.sparse_policy.on_decode_offload(
|
||||
cpu_block_id, layer_id,
|
||||
self.decode_k_buffer[layer_id],
|
||||
self.block_size # Full block
|
||||
)
|
||||
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.decode_k_buffer[layer_id], non_blocking=True
|
||||
)
|
||||
@@ -359,3 +368,42 @@ class OffloadEngine:
|
||||
def wait_decode_offload(self) -> None:
|
||||
"""Wait for decode buffer offload to complete."""
|
||||
self.compute_stream.wait_event(self.decode_offload_event)
|
||||
|
||||
# ========== Encapsulated Prefill Offload API (with sparse hooks) ==========
|
||||
|
||||
def offload_layer_kv_sync(
|
||||
self,
|
||||
layer_id: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
cpu_block_ids: List[int],
|
||||
total_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Synchronously offload layer KV to CPU with sparse policy hooks.
|
||||
|
||||
This method encapsulates:
|
||||
1. Block-wise copy to CPU cache
|
||||
2. Sparse policy hooks (on_prefill_offload for Quest metadata)
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
k: Key tensor [seq_len, kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, kv_heads, head_dim]
|
||||
cpu_block_ids: List of CPU block IDs to offload to
|
||||
total_tokens: Total number of tokens
|
||||
"""
|
||||
for i, cpu_block_id in enumerate(cpu_block_ids):
|
||||
start = i * self.block_size
|
||||
end = min(start + self.block_size, total_tokens)
|
||||
actual_size = end - start
|
||||
|
||||
# Hook: notify sparse policy BEFORE offload (k still on GPU)
|
||||
if self.sparse_policy is not None:
|
||||
self.sparse_policy.on_prefill_offload(
|
||||
cpu_block_id, layer_id, k[start:end], actual_size
|
||||
)
|
||||
|
||||
# Synchronous copy to CPU
|
||||
self.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
|
||||
self.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
|
||||
|
||||
@@ -25,6 +25,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# Full attention supports both prefill and decode
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
requires_block_selection = False # Load all blocks, no selective loading
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
|
||||
@@ -30,6 +30,7 @@ class MInferencePolicy(SparsePolicy):
|
||||
|
||||
supports_prefill = True
|
||||
supports_decode = False # MInference is prefill-only sparse strategy
|
||||
requires_block_selection = False # MInference only affects attention computation, not KV load
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -77,6 +77,12 @@ class SparsePolicy(ABC):
|
||||
supports_prefill: bool = True
|
||||
supports_decode: bool = True
|
||||
|
||||
# Whether this policy requires selective block loading during decode
|
||||
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
|
||||
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
|
||||
# Example: MInference=False (only affects attention), Quest=True (affects load)
|
||||
requires_block_selection: bool = False
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
num_layers: int,
|
||||
|
||||
@@ -158,6 +158,7 @@ class QuestPolicy(SparsePolicy):
|
||||
# Quest is decode-only
|
||||
supports_prefill = False
|
||||
supports_decode = True
|
||||
requires_block_selection = True # Quest affects KV load strategy (selective block loading)
|
||||
|
||||
def __init__(self, config: QuestConfig):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user