- Move notes.md to docs/development_notes.md - Move Xattention_analysis.md to docs/xattention_analysis.md - Delete DEBUG_SUMMARY.md (no longer needed) - Update CLAUDE.md with documentation index entries Co-Authored-By: Claude <noreply@anthropic.com>
325 lines
9.2 KiB
Markdown
325 lines
9.2 KiB
Markdown
# Notes: Sparsity Integration into Layerwise Offload
|
|
|
|
## Current Architecture Analysis
|
|
|
|
### GPU-Only Path vs Offload Path
|
|
|
|
| Aspect | GPU-Only | Layerwise Offload |
|
|
|--------|----------|-------------------|
|
|
| KV Storage | GPU blocks (paged) | CPU pinned + GPU ring buffer |
|
|
| Prefill | All layers → then attention | Per-layer: attention → offload |
|
|
| Decode | FlashAttn with block table | Ring buffer H2D → FlashAttn |
|
|
| Sparse Support | MInference via `attention.py` | Not integrated |
|
|
|
|
### MInference Flow (GPU-Only)
|
|
|
|
```
|
|
attention.py:101-105:
|
|
if context.sparse_prefill_policy is not None:
|
|
o = context.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
|
|
|
|
minference.py:sparse_prefill_attention():
|
|
1. estimate_pattern(q, k, layer_id) -> vertical_indices, slash_indices
|
|
2. _triton_mixed_sparse_attention(q, k, v, indices)
|
|
3. return output
|
|
```
|
|
|
|
### Quest Flow (GPU Block Mode)
|
|
|
|
```
|
|
hybrid_manager.py (if using CPU offload with Quest):
|
|
select_blocks(available_blocks, ctx) -> selected block IDs
|
|
-> load selected blocks to GPU
|
|
-> standard FlashAttn with loaded blocks
|
|
```
|
|
|
|
### Layerwise Offload Prefill Flow
|
|
|
|
```
|
|
model_runner.py:run_layerwise_offload_prefill():
|
|
for layer_id in range(num_layers):
|
|
# QKV projection
|
|
q, k, v = qkv_proj(hidden_ln)
|
|
|
|
# RoPE
|
|
q, k = rotary_emb(positions, q, k)
|
|
|
|
# FULL attention (no sparsity!)
|
|
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
|
|
|
# MLP
|
|
hidden_states = mlp(attn_out + residual)
|
|
|
|
# Sync offload ALL k, v to CPU
|
|
for block_id in cpu_block_ids:
|
|
k_cache_cpu[layer_id, block_id].copy_(k[start:end])
|
|
v_cache_cpu[layer_id, block_id].copy_(v[start:end])
|
|
```
|
|
|
|
### Layerwise Offload Decode Flow
|
|
|
|
```
|
|
model_runner.py:run_layerwise_offload_decode():
|
|
# Preload first N layers to ring buffer
|
|
for i in range(num_buffers):
|
|
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
|
|
|
|
for layer_id in range(num_layers):
|
|
current_buffer = layer_id % num_buffers
|
|
|
|
# Wait for buffer load
|
|
offload_engine.wait_buffer_load(current_buffer)
|
|
|
|
# Get prefilled KV from ring buffer (ALL blocks loaded)
|
|
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
|
|
|
|
# QKV for new token
|
|
q, k_new, v_new = qkv_proj(hidden_ln)
|
|
|
|
# Concat and full attention
|
|
k_full = torch.cat([k_prefill, k_decode_prev, k_new])
|
|
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...)
|
|
|
|
# Start loading next layer
|
|
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
|
|
```
|
|
|
|
## Integration Points
|
|
|
|
### 1. Prefill Sparse Integration Point
|
|
|
|
**Location:** `model_runner.py:535-543`
|
|
|
|
**Current:**
|
|
```python
|
|
attn_output = flash_attn_varlen_func(
|
|
q, k, v,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=total_tokens,
|
|
max_seqlen_k=total_tokens,
|
|
softmax_scale=layer.self_attn.attn.scale,
|
|
causal=True,
|
|
)
|
|
```
|
|
|
|
**After Integration:**
|
|
```python
|
|
if self.sparse_policy and self.sparse_policy.supports_offload_prefill:
|
|
attn_output, k_sparse, v_sparse = self.sparse_policy.offload_prefill_attention(
|
|
q, k, v, layer_id
|
|
)
|
|
k_to_offload = k_sparse if k_sparse is not None else k
|
|
v_to_offload = v_sparse if v_sparse is not None else v
|
|
else:
|
|
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
|
k_to_offload, v_to_offload = k, v
|
|
```
|
|
|
|
### 2. Decode Sparse Integration Point
|
|
|
|
**Location:** `model_runner.py:636-637` and `model_runner.py:704-706`
|
|
|
|
**Current (preload):**
|
|
```python
|
|
for i in range(num_preload):
|
|
offload_engine.load_layer_kv_to_buffer(
|
|
i, i, cpu_block_table, valid_tokens_per_block
|
|
)
|
|
```
|
|
|
|
**After Integration:**
|
|
```python
|
|
for i in range(num_preload):
|
|
layer_to_load = i
|
|
if self.sparse_policy and self.sparse_policy.supports_offload_decode:
|
|
# Prepare q for this layer (need to compute ahead)
|
|
# OR: use previous layer's pattern as estimate
|
|
selected_blocks = self.sparse_policy.select_offload_blocks(
|
|
None, # q not available yet at preload
|
|
layer_to_load,
|
|
cpu_block_table,
|
|
valid_tokens_per_block
|
|
)
|
|
else:
|
|
selected_blocks = cpu_block_table
|
|
offload_engine.load_sparse_layer_kv_to_buffer(
|
|
i, layer_to_load, selected_blocks, valid_tokens_per_block
|
|
)
|
|
```
|
|
|
|
**Challenge:** Q is not available during preload phase!
|
|
|
|
**Solutions:**
|
|
1. Skip sparse preload, only sparse for non-preloaded layers
|
|
2. Use previous decode step's pattern as estimate
|
|
3. Add preload hook to sparse policy
|
|
|
|
### 3. Offload Engine Extension
|
|
|
|
**New Method in OffloadEngine:**
|
|
|
|
```python
|
|
def load_sparse_layer_kv_to_buffer(
|
|
self,
|
|
buffer_idx: int,
|
|
layer_id: int,
|
|
selected_cpu_block_ids: List[int],
|
|
original_valid_tokens: List[int],
|
|
) -> int:
|
|
"""
|
|
Load only selected blocks from CPU to buffer.
|
|
|
|
Returns:
|
|
Total tokens loaded (may be less than full sequence)
|
|
"""
|
|
stream = self.layer_load_streams[buffer_idx]
|
|
|
|
with torch.cuda.stream(stream):
|
|
stream.wait_event(self.buffer_compute_done_events[buffer_idx])
|
|
|
|
# Build mapping: original block -> selected position
|
|
offset = 0
|
|
for i, cpu_block_id in enumerate(selected_cpu_block_ids):
|
|
# Find original index to get valid tokens
|
|
valid_tokens = original_valid_tokens[i] # Need mapping
|
|
|
|
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
|
|
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
|
|
non_blocking=True
|
|
)
|
|
# ... v_cache same
|
|
|
|
offset += valid_tokens
|
|
|
|
self.buffer_load_events[buffer_idx].record(stream)
|
|
|
|
return offset # Caller needs to know actual loaded tokens
|
|
```
|
|
|
|
## Metadata Flow for Quest
|
|
|
|
### During Prefill Offload
|
|
|
|
**Current:** No metadata collection in offload path
|
|
|
|
**Required:** Call `on_prefill_offload()` for each block
|
|
|
|
```python
|
|
# In run_layerwise_offload_prefill()
|
|
for i, cpu_block_id in enumerate(cpu_block_ids):
|
|
start = i * block_size
|
|
end = min(start + block_size, total_tokens)
|
|
actual_size = end - start
|
|
|
|
# BEFORE offload: update Quest metadata
|
|
if self.sparse_policy and hasattr(self.sparse_policy, 'on_prefill_offload'):
|
|
self.sparse_policy.on_prefill_offload(
|
|
cpu_block_id, layer_id, k[start:end], actual_size
|
|
)
|
|
|
|
# Offload
|
|
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
|
|
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
|
|
```
|
|
|
|
### Quest Metadata Shape
|
|
|
|
```python
|
|
# BlockMetadataManager
|
|
key_min: [num_blocks, num_layers, num_kv_heads, head_dim] # Min key per block per layer
|
|
key_max: [num_blocks, num_layers, num_kv_heads, head_dim] # Max key per block per layer
|
|
```
|
|
|
|
**Memory:** 2 * num_blocks * num_layers * kv_heads * head_dim * 2 bytes
|
|
- Example: 1000 blocks * 28 layers * 4 heads * 128 dim * 2 * 2 = ~57 MB
|
|
|
|
## Performance Considerations
|
|
|
|
### MInference Prefill Overhead
|
|
|
|
| Operation | Time (64K seq) |
|
|
|-----------|----------------|
|
|
| Pattern estimation (last-64) | ~5ms |
|
|
| Triton sparse attention | ~80ms |
|
|
| Full FlashAttention | ~100ms |
|
|
| **Net Speedup** | ~15-20% |
|
|
|
|
### Quest Decode Overhead
|
|
|
|
| Operation | Time |
|
|
|-----------|------|
|
|
| Block scoring (GPU metadata) | ~0.1ms |
|
|
| Top-K selection | ~0.05ms |
|
|
| Sparse H2D load (8 blocks) | ~2ms |
|
|
| Full H2D load (100 blocks) | ~20ms |
|
|
| **Net Speedup** | ~10x H2D |
|
|
|
|
### Memory Trade-offs
|
|
|
|
| Mode | GPU Memory | CPU Memory | H2D Bandwidth |
|
|
|------|------------|------------|---------------|
|
|
| Full offload | Ring buffer | Full KV | High |
|
|
| Sparse offload | Ring buffer | Full KV | Low (subset) |
|
|
| Aggressive sparse | Ring buffer | Sparse KV | Very low |
|
|
|
|
## Edge Cases
|
|
|
|
### 1. Short Sequences (< sparse threshold)
|
|
|
|
```python
|
|
if total_tokens < sparse_threshold:
|
|
# Fall back to full attention
|
|
use_sparse = False
|
|
```
|
|
|
|
### 2. First Decode Step (no previous Q)
|
|
|
|
Quest can't score blocks without Q. Options:
|
|
- Use average embedding as proxy
|
|
- Load all blocks for first step
|
|
- Use prefill pattern as estimate
|
|
|
|
### 3. Variable Sequence Lengths in Batch
|
|
|
|
Layerwise offload currently only supports batch_size=1:
|
|
```python
|
|
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
|
|
```
|
|
|
|
Sparse integration should maintain this constraint.
|
|
|
|
### 4. Ring Buffer vs Sparse Load Mismatch
|
|
|
|
Ring buffer assumes fixed `total_prefill_tokens`:
|
|
```python
|
|
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, total_prefill_tokens)
|
|
```
|
|
|
|
Sparse load has variable token count. Need:
|
|
```python
|
|
# Track actual loaded tokens per buffer
|
|
loaded_tokens[buffer_idx] = sparse_load_count
|
|
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, loaded_tokens[buffer_idx])
|
|
```
|
|
|
|
## Testing Strategy
|
|
|
|
### Unit Tests
|
|
|
|
1. `test_sparse_policy_interface.py` - Verify new interface methods
|
|
2. `test_minference_offload.py` - MInference in offload mode
|
|
3. `test_quest_offload.py` - Quest block selection in offload mode
|
|
|
|
### Integration Tests
|
|
|
|
1. `test_offload_sparse_e2e.py` - Full prefill+decode with sparsity
|
|
2. `test_accuracy_comparison.py` - Compare outputs: full vs sparse
|
|
|
|
### Benchmarks
|
|
|
|
1. `bench_offload_sparse.py` - Compare:
|
|
- Full offload (baseline)
|
|
- MInference prefill + Quest decode
|
|
- Aggressive sparse offload
|