♻️ refactor: remove cross-layer pipeline and rename compute_chunked_prefill
- Remove cross-layer pipeline from OffloadEngine (saves ~1GB GPU memory for long sequences) - Delete layer_k/v_buffer_a/b double buffers - Remove start_decode_pipeline, get_decode_layer_kv, end_decode_pipeline methods - Remove pipeline state tracking variables - Simplify decode to use ring buffer pipeline only (more efficient for long sequences) - Rename compute_chunked_attention → compute_chunked_prefill for clarity - Add mandatory needle test requirements: --enable-offload --input-len 32768 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -10,7 +10,7 @@ SparsePolicy is an abstract base class that defines how attention is computed du
|
||||
attention.py SparsePolicy
|
||||
| |
|
||||
| _chunked_prefill_attention |
|
||||
| ────────────────────────────> | compute_chunked_attention()
|
||||
| ────────────────────────────> | compute_chunked_prefill()
|
||||
| |
|
||||
| _chunked_decode_attention |
|
||||
| ────────────────────────────> | compute_chunked_decode()
|
||||
@@ -51,7 +51,7 @@ def select_blocks(
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_attention(
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@@ -105,7 +105,7 @@ supports_prefill = True
|
||||
supports_decode = True
|
||||
```
|
||||
|
||||
### Prefill Flow (`compute_chunked_attention`)
|
||||
### Prefill Flow (`compute_chunked_prefill`)
|
||||
|
||||
```
|
||||
1. Get historical blocks from kvcache_manager
|
||||
@@ -143,11 +143,8 @@ supports_decode = True
|
||||
3. Apply select_blocks for block filtering
|
||||
└── cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, ctx)
|
||||
|
||||
4. Load prefilled blocks via pipeline
|
||||
└── IF is_pipeline_active():
|
||||
└── _decode_with_layer_pipeline() # Cross-layer pipeline
|
||||
└── ELSE:
|
||||
└── _decode_ring_buffer_pipeline() # Ring buffer fallback
|
||||
4. Load prefilled blocks via ring buffer pipeline
|
||||
└── _decode_ring_buffer_pipeline()
|
||||
|
||||
5. Read accumulated decode tokens from decode buffer
|
||||
└── decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||||
@@ -160,11 +157,9 @@ supports_decode = True
|
||||
|
||||
---
|
||||
|
||||
## Pipeline Modes
|
||||
## Ring Buffer Pipeline
|
||||
|
||||
### Ring Buffer Pipeline (`_decode_ring_buffer_pipeline`)
|
||||
|
||||
Used when cross-layer pipeline is not active. Loads blocks one by one using ring buffer slots.
|
||||
The ring buffer pipeline (`_decode_ring_buffer_pipeline`) loads blocks one by one using GPU ring buffer slots. This approach is memory-efficient and works well for both short and long sequences.
|
||||
|
||||
```
|
||||
Slot[0]: Block A ──> Compute ──> Block C ──> Compute
|
||||
@@ -172,8 +167,9 @@ Slot[1]: Block B ──> Compute ──> Block D ──> Compute
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Simple, proven correctness
|
||||
- Works with any number of slots
|
||||
- Memory efficient (only needs a few GPU slots)
|
||||
- Fine-grained overlap between H2D transfer and compute
|
||||
- Works well for long sequences
|
||||
|
||||
**Flow**:
|
||||
```python
|
||||
@@ -201,38 +197,6 @@ for block_idx in range(num_blocks):
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
```
|
||||
|
||||
### Cross-Layer Pipeline (`_decode_with_layer_pipeline`)
|
||||
|
||||
Optimized for decode when all layers need the same blocks. Uses double-buffered layer cache.
|
||||
|
||||
```
|
||||
Layer 0: Wait Layer 0 ──> Compute ──> Trigger Layer 1 load
|
||||
Layer 1: Wait Layer 1 ──> Compute ──> Trigger Layer 2 load
|
||||
Layer 2: Wait Layer 2 ──> Compute ──> ...
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Overlaps H2D transfer with computation across layers
|
||||
- Reduces effective latency: O(transfer + layers × compute) vs O(layers × transfer)
|
||||
|
||||
**Flow**:
|
||||
```python
|
||||
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
||||
prev_k, prev_v = offload_engine.get_decode_layer_kv(layer_id, num_blocks)
|
||||
|
||||
# Reshape for FlashAttention
|
||||
# prev_k, prev_v: [num_blocks, block_size, kv_heads, head_dim]
|
||||
# -> [1, total_tokens, kv_heads, head_dim]
|
||||
|
||||
# Handle partial last block
|
||||
if last_block_valid_tokens < block_size:
|
||||
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
||||
prev_k_flat = prev_k.reshape(-1, kv_heads, head_dim)[:actual_tokens]
|
||||
|
||||
# Compute attention on all prefilled blocks at once
|
||||
o_acc, lse_acc = flash_attn_with_lse(q, prev_k_batched, prev_v_batched, causal=False)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Code Conventions
|
||||
@@ -246,7 +210,7 @@ class PrefillOnlyPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = False
|
||||
|
||||
def compute_chunked_attention(self, ...):
|
||||
def compute_chunked_prefill(self, ...):
|
||||
# Normal prefill implementation
|
||||
...
|
||||
|
||||
|
||||
Reference in New Issue
Block a user