[claudesquad] update from 'int-minference-1' on 08 Jan 26 23:22 CST

This commit is contained in:
Zijie Tian
2026-01-08 23:22:38 +08:00
parent 0bfe1984ef
commit ea4e904de0
11 changed files with 853 additions and 533 deletions

View File

@@ -531,16 +531,23 @@ class ModelRunner:
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# Sparse or Full attention
if self.sparse_prefill_policy is not None:
# MInference or other sparse prefill policy
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# O projection
attn_output = attn_output.view(total_tokens, -1)
@@ -550,16 +557,8 @@ class ModelRunner:
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# 2d. Offload KV to CPU (synchronous to avoid race condition)
# NOTE: Async offload has race condition where k,v memory gets reused
# before D2H copy completes. Use sync copy for correctness.
block_size = offload_engine.block_size
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
# 2d. Offload KV to CPU (encapsulated with sparse policy hooks)
offload_engine.offload_layer_kv_sync(layer_id, k, v, cpu_block_ids, total_tokens)
# Step 3: Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)

View File

@@ -336,7 +336,8 @@ class OffloadEngine:
"""
Async offload entire decode buffer to CPU.
Called when a decode block is full.
Called when a decode block is full. Also calls sparse policy hooks
to update metadata (e.g., Quest min/max keys).
Args:
cpu_block_id: Target CPU block ID
@@ -346,6 +347,14 @@ class OffloadEngine:
self.decode_offload_stream.wait_stream(self.compute_stream)
for layer_id in range(self.num_layers):
# Hook: notify sparse policy BEFORE offload (k still on GPU)
if self.sparse_policy is not None:
self.sparse_policy.on_decode_offload(
cpu_block_id, layer_id,
self.decode_k_buffer[layer_id],
self.block_size # Full block
)
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.decode_k_buffer[layer_id], non_blocking=True
)
@@ -359,3 +368,42 @@ class OffloadEngine:
def wait_decode_offload(self) -> None:
"""Wait for decode buffer offload to complete."""
self.compute_stream.wait_event(self.decode_offload_event)
# ========== Encapsulated Prefill Offload API (with sparse hooks) ==========
def offload_layer_kv_sync(
self,
layer_id: int,
k: Tensor,
v: Tensor,
cpu_block_ids: List[int],
total_tokens: int,
) -> None:
"""
Synchronously offload layer KV to CPU with sparse policy hooks.
This method encapsulates:
1. Block-wise copy to CPU cache
2. Sparse policy hooks (on_prefill_offload for Quest metadata)
Args:
layer_id: Layer index
k: Key tensor [seq_len, kv_heads, head_dim]
v: Value tensor [seq_len, kv_heads, head_dim]
cpu_block_ids: List of CPU block IDs to offload to
total_tokens: Total number of tokens
"""
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * self.block_size
end = min(start + self.block_size, total_tokens)
actual_size = end - start
# Hook: notify sparse policy BEFORE offload (k still on GPU)
if self.sparse_policy is not None:
self.sparse_policy.on_prefill_offload(
cpu_block_id, layer_id, k[start:end], actual_size
)
# Synchronous copy to CPU
self.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
self.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])

View File

@@ -25,6 +25,7 @@ class FullAttentionPolicy(SparsePolicy):
# Full attention supports both prefill and decode
supports_prefill = True
supports_decode = True
requires_block_selection = False # Load all blocks, no selective loading
def select_blocks(
self,

View File

@@ -30,6 +30,7 @@ class MInferencePolicy(SparsePolicy):
supports_prefill = True
supports_decode = False # MInference is prefill-only sparse strategy
requires_block_selection = False # MInference only affects attention computation, not KV load
def __init__(
self,

View File

@@ -77,6 +77,12 @@ class SparsePolicy(ABC):
supports_prefill: bool = True
supports_decode: bool = True
# Whether this policy requires selective block loading during decode
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
# Example: MInference=False (only affects attention), Quest=True (affects load)
requires_block_selection: bool = False
def initialize(
self,
num_layers: int,

View File

@@ -158,6 +158,7 @@ class QuestPolicy(SparsePolicy):
# Quest is decode-only
supports_prefill = False
supports_decode = True
requires_block_selection = True # Quest affects KV load strategy (selective block loading)
def __init__(self, config: QuestConfig):
"""