9.2 KiB
9.2 KiB
Notes: Sparsity Integration into Layerwise Offload
Current Architecture Analysis
GPU-Only Path vs Offload Path
| Aspect | GPU-Only | Layerwise Offload |
|---|---|---|
| KV Storage | GPU blocks (paged) | CPU pinned + GPU ring buffer |
| Prefill | All layers → then attention | Per-layer: attention → offload |
| Decode | FlashAttn with block table | Ring buffer H2D → FlashAttn |
| Sparse Support | MInference via attention.py |
Not integrated |
MInference Flow (GPU-Only)
attention.py:101-105:
if context.sparse_prefill_policy is not None:
o = context.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
minference.py:sparse_prefill_attention():
1. estimate_pattern(q, k, layer_id) -> vertical_indices, slash_indices
2. _triton_mixed_sparse_attention(q, k, v, indices)
3. return output
Quest Flow (GPU Block Mode)
hybrid_manager.py (if using CPU offload with Quest):
select_blocks(available_blocks, ctx) -> selected block IDs
-> load selected blocks to GPU
-> standard FlashAttn with loaded blocks
Layerwise Offload Prefill Flow
model_runner.py:run_layerwise_offload_prefill():
for layer_id in range(num_layers):
# QKV projection
q, k, v = qkv_proj(hidden_ln)
# RoPE
q, k = rotary_emb(positions, q, k)
# FULL attention (no sparsity!)
attn_output = flash_attn_varlen_func(q, k, v, ...)
# MLP
hidden_states = mlp(attn_out + residual)
# Sync offload ALL k, v to CPU
for block_id in cpu_block_ids:
k_cache_cpu[layer_id, block_id].copy_(k[start:end])
v_cache_cpu[layer_id, block_id].copy_(v[start:end])
Layerwise Offload Decode Flow
model_runner.py:run_layerwise_offload_decode():
# Preload first N layers to ring buffer
for i in range(num_buffers):
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
for layer_id in range(num_layers):
current_buffer = layer_id % num_buffers
# Wait for buffer load
offload_engine.wait_buffer_load(current_buffer)
# Get prefilled KV from ring buffer (ALL blocks loaded)
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
# QKV for new token
q, k_new, v_new = qkv_proj(hidden_ln)
# Concat and full attention
k_full = torch.cat([k_prefill, k_decode_prev, k_new])
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...)
# Start loading next layer
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
Integration Points
1. Prefill Sparse Integration Point
Location: model_runner.py:535-543
Current:
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
After Integration:
if self.sparse_policy and self.sparse_policy.supports_offload_prefill:
attn_output, k_sparse, v_sparse = self.sparse_policy.offload_prefill_attention(
q, k, v, layer_id
)
k_to_offload = k_sparse if k_sparse is not None else k
v_to_offload = v_sparse if v_sparse is not None else v
else:
attn_output = flash_attn_varlen_func(q, k, v, ...)
k_to_offload, v_to_offload = k, v
2. Decode Sparse Integration Point
Location: model_runner.py:636-637 and model_runner.py:704-706
Current (preload):
for i in range(num_preload):
offload_engine.load_layer_kv_to_buffer(
i, i, cpu_block_table, valid_tokens_per_block
)
After Integration:
for i in range(num_preload):
layer_to_load = i
if self.sparse_policy and self.sparse_policy.supports_offload_decode:
# Prepare q for this layer (need to compute ahead)
# OR: use previous layer's pattern as estimate
selected_blocks = self.sparse_policy.select_offload_blocks(
None, # q not available yet at preload
layer_to_load,
cpu_block_table,
valid_tokens_per_block
)
else:
selected_blocks = cpu_block_table
offload_engine.load_sparse_layer_kv_to_buffer(
i, layer_to_load, selected_blocks, valid_tokens_per_block
)
Challenge: Q is not available during preload phase!
Solutions:
- Skip sparse preload, only sparse for non-preloaded layers
- Use previous decode step's pattern as estimate
- Add preload hook to sparse policy
3. Offload Engine Extension
New Method in OffloadEngine:
def load_sparse_layer_kv_to_buffer(
self,
buffer_idx: int,
layer_id: int,
selected_cpu_block_ids: List[int],
original_valid_tokens: List[int],
) -> int:
"""
Load only selected blocks from CPU to buffer.
Returns:
Total tokens loaded (may be less than full sequence)
"""
stream = self.layer_load_streams[buffer_idx]
with torch.cuda.stream(stream):
stream.wait_event(self.buffer_compute_done_events[buffer_idx])
# Build mapping: original block -> selected position
offset = 0
for i, cpu_block_id in enumerate(selected_cpu_block_ids):
# Find original index to get valid tokens
valid_tokens = original_valid_tokens[i] # Need mapping
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
non_blocking=True
)
# ... v_cache same
offset += valid_tokens
self.buffer_load_events[buffer_idx].record(stream)
return offset # Caller needs to know actual loaded tokens
Metadata Flow for Quest
During Prefill Offload
Current: No metadata collection in offload path
Required: Call on_prefill_offload() for each block
# In run_layerwise_offload_prefill()
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * block_size
end = min(start + block_size, total_tokens)
actual_size = end - start
# BEFORE offload: update Quest metadata
if self.sparse_policy and hasattr(self.sparse_policy, 'on_prefill_offload'):
self.sparse_policy.on_prefill_offload(
cpu_block_id, layer_id, k[start:end], actual_size
)
# Offload
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
Quest Metadata Shape
# BlockMetadataManager
key_min: [num_blocks, num_layers, num_kv_heads, head_dim] # Min key per block per layer
key_max: [num_blocks, num_layers, num_kv_heads, head_dim] # Max key per block per layer
Memory: 2 * num_blocks * num_layers * kv_heads * head_dim * 2 bytes
- Example: 1000 blocks * 28 layers * 4 heads * 128 dim * 2 * 2 = ~57 MB
Performance Considerations
MInference Prefill Overhead
| Operation | Time (64K seq) |
|---|---|
| Pattern estimation (last-64) | ~5ms |
| Triton sparse attention | ~80ms |
| Full FlashAttention | ~100ms |
| Net Speedup | ~15-20% |
Quest Decode Overhead
| Operation | Time |
|---|---|
| Block scoring (GPU metadata) | ~0.1ms |
| Top-K selection | ~0.05ms |
| Sparse H2D load (8 blocks) | ~2ms |
| Full H2D load (100 blocks) | ~20ms |
| Net Speedup | ~10x H2D |
Memory Trade-offs
| Mode | GPU Memory | CPU Memory | H2D Bandwidth |
|---|---|---|---|
| Full offload | Ring buffer | Full KV | High |
| Sparse offload | Ring buffer | Full KV | Low (subset) |
| Aggressive sparse | Ring buffer | Sparse KV | Very low |
Edge Cases
1. Short Sequences (< sparse threshold)
if total_tokens < sparse_threshold:
# Fall back to full attention
use_sparse = False
2. First Decode Step (no previous Q)
Quest can't score blocks without Q. Options:
- Use average embedding as proxy
- Load all blocks for first step
- Use prefill pattern as estimate
3. Variable Sequence Lengths in Batch
Layerwise offload currently only supports batch_size=1:
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
Sparse integration should maintain this constraint.
4. Ring Buffer vs Sparse Load Mismatch
Ring buffer assumes fixed total_prefill_tokens:
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, total_prefill_tokens)
Sparse load has variable token count. Need:
# Track actual loaded tokens per buffer
loaded_tokens[buffer_idx] = sparse_load_count
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, loaded_tokens[buffer_idx])
Testing Strategy
Unit Tests
test_sparse_policy_interface.py- Verify new interface methodstest_minference_offload.py- MInference in offload modetest_quest_offload.py- Quest block selection in offload mode
Integration Tests
test_offload_sparse_e2e.py- Full prefill+decode with sparsitytest_accuracy_comparison.py- Compare outputs: full vs sparse
Benchmarks
bench_offload_sparse.py- Compare:- Full offload (baseline)
- MInference prefill + Quest decode
- Aggressive sparse offload