♻️ refactor: remove cross-layer pipeline and rename compute_chunked_prefill

- Remove cross-layer pipeline from OffloadEngine (saves ~1GB GPU memory for long sequences)
  - Delete layer_k/v_buffer_a/b double buffers
  - Remove start_decode_pipeline, get_decode_layer_kv, end_decode_pipeline methods
  - Remove pipeline state tracking variables
- Simplify decode to use ring buffer pipeline only (more efficient for long sequences)
- Rename compute_chunked_attention → compute_chunked_prefill for clarity
- Add mandatory needle test requirements: --enable-offload --input-len 32768

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-20 02:10:40 +08:00
parent 6080bf7554
commit fa7601f4b8
9 changed files with 67 additions and 299 deletions

View File

@@ -10,7 +10,7 @@ SparsePolicy is an abstract base class that defines how attention is computed du
attention.py SparsePolicy
| |
| _chunked_prefill_attention |
| ────────────────────────────> | compute_chunked_attention()
| ────────────────────────────> | compute_chunked_prefill()
| |
| _chunked_decode_attention |
| ────────────────────────────> | compute_chunked_decode()
@@ -51,7 +51,7 @@ def select_blocks(
pass
@abstractmethod
def compute_chunked_attention(
def compute_chunked_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
@@ -105,7 +105,7 @@ supports_prefill = True
supports_decode = True
```
### Prefill Flow (`compute_chunked_attention`)
### Prefill Flow (`compute_chunked_prefill`)
```
1. Get historical blocks from kvcache_manager
@@ -143,11 +143,8 @@ supports_decode = True
3. Apply select_blocks for block filtering
└── cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, ctx)
4. Load prefilled blocks via pipeline
└── IF is_pipeline_active():
└── _decode_with_layer_pipeline() # Cross-layer pipeline
└── ELSE:
└── _decode_ring_buffer_pipeline() # Ring buffer fallback
4. Load prefilled blocks via ring buffer pipeline
└── _decode_ring_buffer_pipeline()
5. Read accumulated decode tokens from decode buffer
└── decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
@@ -160,11 +157,9 @@ supports_decode = True
---
## Pipeline Modes
## Ring Buffer Pipeline
### Ring Buffer Pipeline (`_decode_ring_buffer_pipeline`)
Used when cross-layer pipeline is not active. Loads blocks one by one using ring buffer slots.
The ring buffer pipeline (`_decode_ring_buffer_pipeline`) loads blocks one by one using GPU ring buffer slots. This approach is memory-efficient and works well for both short and long sequences.
```
Slot[0]: Block A ──> Compute ──> Block C ──> Compute
@@ -172,8 +167,9 @@ Slot[1]: Block B ──> Compute ──> Block D ──> Compute
```
**Advantages**:
- Simple, proven correctness
- Works with any number of slots
- Memory efficient (only needs a few GPU slots)
- Fine-grained overlap between H2D transfer and compute
- Works well for long sequences
**Flow**:
```python
@@ -201,38 +197,6 @@ for block_idx in range(num_blocks):
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
```
### Cross-Layer Pipeline (`_decode_with_layer_pipeline`)
Optimized for decode when all layers need the same blocks. Uses double-buffered layer cache.
```
Layer 0: Wait Layer 0 ──> Compute ──> Trigger Layer 1 load
Layer 1: Wait Layer 1 ──> Compute ──> Trigger Layer 2 load
Layer 2: Wait Layer 2 ──> Compute ──> ...
```
**Advantages**:
- Overlaps H2D transfer with computation across layers
- Reduces effective latency: O(transfer + layers × compute) vs O(layers × transfer)
**Flow**:
```python
# Get KV from pre-loaded layer buffer (triggers next layer loading)
prev_k, prev_v = offload_engine.get_decode_layer_kv(layer_id, num_blocks)
# Reshape for FlashAttention
# prev_k, prev_v: [num_blocks, block_size, kv_heads, head_dim]
# -> [1, total_tokens, kv_heads, head_dim]
# Handle partial last block
if last_block_valid_tokens < block_size:
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
prev_k_flat = prev_k.reshape(-1, kv_heads, head_dim)[:actual_tokens]
# Compute attention on all prefilled blocks at once
o_acc, lse_acc = flash_attn_with_lse(q, prev_k_batched, prev_v_batched, causal=False)
```
---
## Code Conventions
@@ -246,7 +210,7 @@ class PrefillOnlyPolicy(SparsePolicy):
supports_prefill = True
supports_decode = False
def compute_chunked_attention(self, ...):
def compute_chunked_prefill(self, ...):
# Normal prefill implementation
...