♻️ refactor: remove cross-layer pipeline and rename compute_chunked_prefill
- Remove cross-layer pipeline from OffloadEngine (saves ~1GB GPU memory for long sequences) - Delete layer_k/v_buffer_a/b double buffers - Remove start_decode_pipeline, get_decode_layer_kv, end_decode_pipeline methods - Remove pipeline state tracking variables - Simplify decode to use ring buffer pipeline only (more efficient for long sequences) - Rename compute_chunked_attention → compute_chunked_prefill for clarity - Add mandatory needle test requirements: --enable-offload --input-len 32768 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,9 +0,0 @@
|
|||||||
---
|
|
||||||
active: true
|
|
||||||
iteration: 1
|
|
||||||
max_iterations: 0
|
|
||||||
completion_promise: "COMPLETE"
|
|
||||||
started_at: "2026-01-19T17:25:00Z"
|
|
||||||
---
|
|
||||||
|
|
||||||
请你按照 task_plan.md的要求,进行 nanovllm 的代码重构,确保plan 中最终目标可以圆满实现,注意你仅仅只能使用 GPU 0 来进行调试,其他 GPU 一定不能使用。最终将测试结果写一个报告。 <promise>COMPLETE</promise> -max-iterations 30
|
|
||||||
@@ -77,6 +77,45 @@ Claude: Runs `python tests/test_needle.py ...` # NO! Missing GPU specification!
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Needle Test Requirements (MANDATORY)
|
||||||
|
|
||||||
|
When running `test_needle.py`, **ALWAYS** use these settings:
|
||||||
|
|
||||||
|
1. **Enable offload**: `--enable-offload` is **REQUIRED**
|
||||||
|
2. **Use 32K context**: `--input-len 32768` is **REQUIRED**
|
||||||
|
|
||||||
|
### Standard Needle Test Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=X PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_needle.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--enable-offload \
|
||||||
|
--input-len 32768
|
||||||
|
```
|
||||||
|
|
||||||
|
### Why These Settings?
|
||||||
|
|
||||||
|
| Setting | Reason |
|
||||||
|
|---------|--------|
|
||||||
|
| `--enable-offload` | Tests the CPU offload pipeline which is the main feature being developed |
|
||||||
|
| `--input-len 32768` | 32K context properly exercises the chunked prefill/decode paths; 8K is too short to catch many issues |
|
||||||
|
|
||||||
|
### Do NOT Use
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# ❌ Wrong: Missing offload
|
||||||
|
python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct
|
||||||
|
|
||||||
|
# ❌ Wrong: Too short (default 8K)
|
||||||
|
python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct --enable-offload
|
||||||
|
|
||||||
|
# ✅ Correct: Offload + 32K
|
||||||
|
python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct --enable-offload --input-len 32768
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Combined Checklist
|
## Combined Checklist
|
||||||
|
|
||||||
Before running any GPU test:
|
Before running any GPU test:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class PrefillOnlyPolicy(SparsePolicy):
|
|||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = False
|
supports_decode = False
|
||||||
|
|
||||||
def compute_chunked_attention(self, ...):
|
def compute_chunked_prefill(self, ...):
|
||||||
# 正常实现 prefill 逻辑
|
# 正常实现 prefill 逻辑
|
||||||
...
|
...
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ class DecodeOnlyPolicy(SparsePolicy):
|
|||||||
supports_prefill = False
|
supports_prefill = False
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
def compute_chunked_attention(self, ...):
|
def compute_chunked_prefill(self, ...):
|
||||||
# 不支持 prefill,必须 assert False
|
# 不支持 prefill,必须 assert False
|
||||||
assert False, "DecodeOnlyPolicy does not support prefill phase"
|
assert False, "DecodeOnlyPolicy does not support prefill phase"
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
def compute_chunked_attention(self, ...):
|
def compute_chunked_prefill(self, ...):
|
||||||
# 完整实现
|
# 完整实现
|
||||||
|
|
||||||
def compute_chunked_decode(self, ...):
|
def compute_chunked_decode(self, ...):
|
||||||
@@ -85,14 +85,11 @@ if not sparse_policy.supports_decode:
|
|||||||
在 SparsePolicy 的 `compute_chunked_*` 方法中,所有 CPU-GPU 数据传输**必须**通过 `OffloadEngine` 进行,**禁止**直接使用 `torch.Tensor.copy_()` 或 `.to(device)`:
|
在 SparsePolicy 的 `compute_chunked_*` 方法中,所有 CPU-GPU 数据传输**必须**通过 `OffloadEngine` 进行,**禁止**直接使用 `torch.Tensor.copy_()` 或 `.to(device)`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# ✅ 正确:使用 OffloadEngine 的方法
|
# ✅ 正确:使用 OffloadEngine 的 ring buffer 方法
|
||||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
offload_engine.wait_slot_layer(slot)
|
offload_engine.wait_slot_layer(slot)
|
||||||
k, v = offload_engine.get_kv_for_slot(slot)
|
k, v = offload_engine.get_kv_for_slot(slot)
|
||||||
|
|
||||||
# ✅ 正确:使用 cross-layer pipeline
|
|
||||||
k, v = offload_engine.get_decode_layer_kv(layer_id, num_blocks)
|
|
||||||
|
|
||||||
# ❌ 错误:直接使用 torch 通信
|
# ❌ 错误:直接使用 torch 通信
|
||||||
gpu_tensor.copy_(cpu_tensor)
|
gpu_tensor.copy_(cpu_tensor)
|
||||||
gpu_tensor = cpu_tensor.to("cuda")
|
gpu_tensor = cpu_tensor.to("cuda")
|
||||||
@@ -102,6 +99,6 @@ gpu_tensor = cpu_tensor.cuda()
|
|||||||
### 原因
|
### 原因
|
||||||
|
|
||||||
1. **流同步**:OffloadEngine 内部管理 CUDA streams,确保正确的同步
|
1. **流同步**:OffloadEngine 内部管理 CUDA streams,确保正确的同步
|
||||||
2. **Pipeline 优化**:OffloadEngine 实现了 ring buffer 和 cross-layer pipeline
|
2. **Pipeline 优化**:OffloadEngine 实现了 ring buffer pipeline
|
||||||
3. **资源管理**:OffloadEngine 管理 GPU buffer slots,避免内存碎片
|
3. **资源管理**:OffloadEngine 管理 GPU buffer slots,避免内存碎片
|
||||||
4. **一致性**:统一的接口便于调试和维护
|
4. **一致性**:统一的接口便于调试和维护
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ SparsePolicy is an abstract base class that defines how attention is computed du
|
|||||||
attention.py SparsePolicy
|
attention.py SparsePolicy
|
||||||
| |
|
| |
|
||||||
| _chunked_prefill_attention |
|
| _chunked_prefill_attention |
|
||||||
| ────────────────────────────> | compute_chunked_attention()
|
| ────────────────────────────> | compute_chunked_prefill()
|
||||||
| |
|
| |
|
||||||
| _chunked_decode_attention |
|
| _chunked_decode_attention |
|
||||||
| ────────────────────────────> | compute_chunked_decode()
|
| ────────────────────────────> | compute_chunked_decode()
|
||||||
@@ -51,7 +51,7 @@ def select_blocks(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_chunked_attention(
|
def compute_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
@@ -105,7 +105,7 @@ supports_prefill = True
|
|||||||
supports_decode = True
|
supports_decode = True
|
||||||
```
|
```
|
||||||
|
|
||||||
### Prefill Flow (`compute_chunked_attention`)
|
### Prefill Flow (`compute_chunked_prefill`)
|
||||||
|
|
||||||
```
|
```
|
||||||
1. Get historical blocks from kvcache_manager
|
1. Get historical blocks from kvcache_manager
|
||||||
@@ -143,11 +143,8 @@ supports_decode = True
|
|||||||
3. Apply select_blocks for block filtering
|
3. Apply select_blocks for block filtering
|
||||||
└── cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, ctx)
|
└── cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, ctx)
|
||||||
|
|
||||||
4. Load prefilled blocks via pipeline
|
4. Load prefilled blocks via ring buffer pipeline
|
||||||
└── IF is_pipeline_active():
|
└── _decode_ring_buffer_pipeline()
|
||||||
└── _decode_with_layer_pipeline() # Cross-layer pipeline
|
|
||||||
└── ELSE:
|
|
||||||
└── _decode_ring_buffer_pipeline() # Ring buffer fallback
|
|
||||||
|
|
||||||
5. Read accumulated decode tokens from decode buffer
|
5. Read accumulated decode tokens from decode buffer
|
||||||
└── decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
└── decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||||||
@@ -160,11 +157,9 @@ supports_decode = True
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Pipeline Modes
|
## Ring Buffer Pipeline
|
||||||
|
|
||||||
### Ring Buffer Pipeline (`_decode_ring_buffer_pipeline`)
|
The ring buffer pipeline (`_decode_ring_buffer_pipeline`) loads blocks one by one using GPU ring buffer slots. This approach is memory-efficient and works well for both short and long sequences.
|
||||||
|
|
||||||
Used when cross-layer pipeline is not active. Loads blocks one by one using ring buffer slots.
|
|
||||||
|
|
||||||
```
|
```
|
||||||
Slot[0]: Block A ──> Compute ──> Block C ──> Compute
|
Slot[0]: Block A ──> Compute ──> Block C ──> Compute
|
||||||
@@ -172,8 +167,9 @@ Slot[1]: Block B ──> Compute ──> Block D ──> Compute
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Advantages**:
|
**Advantages**:
|
||||||
- Simple, proven correctness
|
- Memory efficient (only needs a few GPU slots)
|
||||||
- Works with any number of slots
|
- Fine-grained overlap between H2D transfer and compute
|
||||||
|
- Works well for long sequences
|
||||||
|
|
||||||
**Flow**:
|
**Flow**:
|
||||||
```python
|
```python
|
||||||
@@ -201,38 +197,6 @@ for block_idx in range(num_blocks):
|
|||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Cross-Layer Pipeline (`_decode_with_layer_pipeline`)
|
|
||||||
|
|
||||||
Optimized for decode when all layers need the same blocks. Uses double-buffered layer cache.
|
|
||||||
|
|
||||||
```
|
|
||||||
Layer 0: Wait Layer 0 ──> Compute ──> Trigger Layer 1 load
|
|
||||||
Layer 1: Wait Layer 1 ──> Compute ──> Trigger Layer 2 load
|
|
||||||
Layer 2: Wait Layer 2 ──> Compute ──> ...
|
|
||||||
```
|
|
||||||
|
|
||||||
**Advantages**:
|
|
||||||
- Overlaps H2D transfer with computation across layers
|
|
||||||
- Reduces effective latency: O(transfer + layers × compute) vs O(layers × transfer)
|
|
||||||
|
|
||||||
**Flow**:
|
|
||||||
```python
|
|
||||||
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
|
||||||
prev_k, prev_v = offload_engine.get_decode_layer_kv(layer_id, num_blocks)
|
|
||||||
|
|
||||||
# Reshape for FlashAttention
|
|
||||||
# prev_k, prev_v: [num_blocks, block_size, kv_heads, head_dim]
|
|
||||||
# -> [1, total_tokens, kv_heads, head_dim]
|
|
||||||
|
|
||||||
# Handle partial last block
|
|
||||||
if last_block_valid_tokens < block_size:
|
|
||||||
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
|
||||||
prev_k_flat = prev_k.reshape(-1, kv_heads, head_dim)[:actual_tokens]
|
|
||||||
|
|
||||||
# Compute attention on all prefilled blocks at once
|
|
||||||
o_acc, lse_acc = flash_attn_with_lse(q, prev_k_batched, prev_v_batched, causal=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Code Conventions
|
## Code Conventions
|
||||||
@@ -246,7 +210,7 @@ class PrefillOnlyPolicy(SparsePolicy):
|
|||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = False
|
supports_decode = False
|
||||||
|
|
||||||
def compute_chunked_attention(self, ...):
|
def compute_chunked_prefill(self, ...):
|
||||||
# Normal prefill implementation
|
# Normal prefill implementation
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@@ -644,12 +644,6 @@ class ModelRunner:
|
|||||||
# Get decode start position for accumulated token tracking
|
# Get decode start position for accumulated token tracking
|
||||||
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
||||||
|
|
||||||
# Get prefilled CPU blocks for pipeline initialization
|
|
||||||
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
|
|
||||||
|
|
||||||
# Start cross-layer pipeline (preloads Layer 0's data)
|
|
||||||
offload_engine.start_decode_pipeline(cpu_block_table)
|
|
||||||
|
|
||||||
# Set up context for chunked decode
|
# Set up context for chunked decode
|
||||||
set_context(
|
set_context(
|
||||||
is_prefill=False,
|
is_prefill=False,
|
||||||
@@ -666,9 +660,6 @@ class ModelRunner:
|
|||||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
# End cross-layer pipeline
|
|
||||||
offload_engine.end_decode_pipeline()
|
|
||||||
|
|
||||||
# Only offload when block is full (pos_in_block == block_size - 1)
|
# Only offload when block is full (pos_in_block == block_size - 1)
|
||||||
# This avoids unnecessary offloading on every decode step
|
# This avoids unnecessary offloading on every decode step
|
||||||
if pos_in_block == self.block_size - 1:
|
if pos_in_block == self.block_size - 1:
|
||||||
|
|||||||
@@ -141,40 +141,6 @@ class OffloadEngine:
|
|||||||
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||||
|
|
||||||
# ========== Cross-layer pipeline buffers for decode ==========
|
|
||||||
# Double-buffered layer cache for pipelined decode:
|
|
||||||
# - Buffer A: Current layer's prefilled KV being computed
|
|
||||||
# - Buffer B: Next layer's prefilled KV being loaded
|
|
||||||
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
|
|
||||||
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
|
|
||||||
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
|
|
||||||
self.layer_k_buffer_a = torch.zeros(
|
|
||||||
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
)
|
|
||||||
self.layer_v_buffer_a = torch.zeros(
|
|
||||||
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
)
|
|
||||||
self.layer_k_buffer_b = torch.zeros(
|
|
||||||
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
)
|
|
||||||
self.layer_v_buffer_b = torch.zeros(
|
|
||||||
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
)
|
|
||||||
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
|
||||||
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
|
|
||||||
|
|
||||||
# Pipeline state tracking
|
|
||||||
self._pipeline_active = False
|
|
||||||
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
|
|
||||||
self._pipeline_next_layer_event = torch.cuda.Event()
|
|
||||||
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
|
|
||||||
self._pipeline_num_blocks = 0
|
|
||||||
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
|
|
||||||
|
|
||||||
# ========== Per-layer prefill buffer for async offload ==========
|
# ========== Per-layer prefill buffer for async offload ==========
|
||||||
# During chunked prefill, all layers share the same GPU slot. This means
|
# During chunked prefill, all layers share the same GPU slot. This means
|
||||||
# each layer must wait for offload to complete before the next layer can
|
# each layer must wait for offload to complete before the next layer can
|
||||||
@@ -666,122 +632,6 @@ class OffloadEngine:
|
|||||||
raise
|
raise
|
||||||
logger.warning(f"Debug hook error: {e}")
|
logger.warning(f"Debug hook error: {e}")
|
||||||
|
|
||||||
# ========== Cross-layer Pipeline Methods for Decode ==========
|
|
||||||
|
|
||||||
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
|
|
||||||
"""
|
|
||||||
Start cross-layer pipeline for decode.
|
|
||||||
|
|
||||||
Called at the beginning of a decode step to initialize the pipeline.
|
|
||||||
Preloads Layer 0's data into buffer A.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cpu_block_ids: List of CPU block IDs for prefilled blocks
|
|
||||||
"""
|
|
||||||
if not cpu_block_ids:
|
|
||||||
self._pipeline_active = False
|
|
||||||
return
|
|
||||||
|
|
||||||
self._pipeline_active = True
|
|
||||||
self._pipeline_cpu_blocks = cpu_block_ids
|
|
||||||
self._pipeline_num_blocks = len(cpu_block_ids)
|
|
||||||
self._pipeline_current_buffer = 0
|
|
||||||
|
|
||||||
# Preload Layer 0 into buffer A
|
|
||||||
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
|
|
||||||
|
|
||||||
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""
|
|
||||||
Get KV cache for a layer during decode.
|
|
||||||
|
|
||||||
If pipeline is active, returns data from the current buffer.
|
|
||||||
Also triggers preloading of the next layer (if not last layer).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Current layer ID
|
|
||||||
num_blocks: Number of blocks to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
|
|
||||||
"""
|
|
||||||
if not self._pipeline_active:
|
|
||||||
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
|
|
||||||
|
|
||||||
# Wait for current layer's data to be ready
|
|
||||||
self.compute_stream.wait_event(self._pipeline_next_layer_event)
|
|
||||||
|
|
||||||
# Get current buffer
|
|
||||||
if self._pipeline_current_buffer == 0:
|
|
||||||
k = self.layer_k_buffer_a[:num_blocks]
|
|
||||||
v = self.layer_v_buffer_a[:num_blocks]
|
|
||||||
else:
|
|
||||||
k = self.layer_k_buffer_b[:num_blocks]
|
|
||||||
v = self.layer_v_buffer_b[:num_blocks]
|
|
||||||
|
|
||||||
# Trigger preloading of next layer (if not last layer)
|
|
||||||
next_layer_id = layer_id + 1
|
|
||||||
if next_layer_id < self.num_layers:
|
|
||||||
# Use the other buffer for next layer
|
|
||||||
next_buffer_idx = 1 - self._pipeline_current_buffer
|
|
||||||
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
|
|
||||||
# Switch to next buffer for next layer
|
|
||||||
self._pipeline_current_buffer = next_buffer_idx
|
|
||||||
|
|
||||||
return k, v
|
|
||||||
|
|
||||||
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
|
|
||||||
"""
|
|
||||||
Async load a layer's prefilled blocks to the specified buffer.
|
|
||||||
|
|
||||||
Uses sgDMA for efficient strided transfer from CPU cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index to load
|
|
||||||
buffer_idx: 0 for buffer A, 1 for buffer B
|
|
||||||
"""
|
|
||||||
num_blocks = self._pipeline_num_blocks
|
|
||||||
cpu_block_ids = self._pipeline_cpu_blocks
|
|
||||||
|
|
||||||
# Select target buffer
|
|
||||||
if buffer_idx == 0:
|
|
||||||
k_buffer = self.layer_k_buffer_a
|
|
||||||
v_buffer = self.layer_v_buffer_a
|
|
||||||
else:
|
|
||||||
k_buffer = self.layer_k_buffer_b
|
|
||||||
v_buffer = self.layer_v_buffer_b
|
|
||||||
|
|
||||||
# Load all blocks for this layer using dedicated stream
|
|
||||||
with torch.cuda.stream(self._pipeline_layer_stream):
|
|
||||||
for i, cpu_block_id in enumerate(cpu_block_ids):
|
|
||||||
# Copy from CPU cache (has layer dimension) to GPU buffer
|
|
||||||
k_buffer[i].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
v_buffer[i].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
# Record event when all transfers complete
|
|
||||||
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
|
|
||||||
|
|
||||||
def end_decode_pipeline(self) -> None:
|
|
||||||
"""
|
|
||||||
End the cross-layer pipeline.
|
|
||||||
|
|
||||||
Called at the end of a decode step to clean up pipeline state.
|
|
||||||
"""
|
|
||||||
if self._pipeline_active:
|
|
||||||
# Ensure all transfers complete before ending
|
|
||||||
self._pipeline_layer_stream.synchronize()
|
|
||||||
self._pipeline_active = False
|
|
||||||
self._pipeline_cpu_blocks = []
|
|
||||||
self._pipeline_num_blocks = 0
|
|
||||||
|
|
||||||
def is_pipeline_active(self) -> bool:
|
|
||||||
"""Check if decode pipeline is currently active."""
|
|
||||||
return self._pipeline_active
|
|
||||||
|
|
||||||
# ========== Per-layer Prefill Buffer Methods ==========
|
# ========== Per-layer Prefill Buffer Methods ==========
|
||||||
# These methods enable async offload during chunked prefill by using
|
# These methods enable async offload during chunked prefill by using
|
||||||
# per-layer buffers instead of shared GPU slots.
|
# per-layer buffers instead of shared GPU slots.
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
"""Return all blocks - no sparsity."""
|
"""Return all blocks - no sparsity."""
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
def compute_chunked_attention(
|
def compute_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
@@ -86,7 +86,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, "
|
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
||||||
|
|
||||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||||
@@ -256,14 +256,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
)
|
)
|
||||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||||
|
|
||||||
# Use cross-layer pipeline if active (initialized in model_runner)
|
# Use ring buffer pipeline for loading prefilled blocks
|
||||||
if offload_engine.is_pipeline_active():
|
|
||||||
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
|
||||||
q_batched, cpu_block_table, offload_engine,
|
|
||||||
block_size, last_block_valid_tokens, layer_id, softmax_scale
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fallback to original ring buffer pipeline
|
|
||||||
load_slots = offload_engine.decode_load_slots
|
load_slots = offload_engine.decode_load_slots
|
||||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
@@ -386,62 +379,5 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
|
|
||||||
return o_acc, lse_acc
|
return o_acc, lse_acc
|
||||||
|
|
||||||
def _decode_with_layer_pipeline(
|
|
||||||
self,
|
|
||||||
q_batched: torch.Tensor,
|
|
||||||
cpu_block_table: list,
|
|
||||||
offload_engine: "OffloadEngine",
|
|
||||||
block_size: int,
|
|
||||||
last_block_valid_tokens: int,
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Decode using cross-layer pipeline for optimized H2D transfer.
|
|
||||||
|
|
||||||
Uses pre-loaded layer buffers instead of loading blocks one by one.
|
|
||||||
The pipeline loads the next layer's data while the current layer
|
|
||||||
computes, achieving transfer/compute overlap.
|
|
||||||
"""
|
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
num_blocks = len(cpu_block_table)
|
|
||||||
if num_blocks == 0:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
compute_stream = offload_engine.compute_stream
|
|
||||||
|
|
||||||
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
|
||||||
prev_k, prev_v = offload_engine.get_decode_layer_kv(layer_id, num_blocks)
|
|
||||||
|
|
||||||
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
|
||||||
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
|
||||||
total_tokens = num_blocks * block_size
|
|
||||||
|
|
||||||
# Handle partial last block
|
|
||||||
if last_block_valid_tokens < block_size:
|
|
||||||
# Only use valid tokens from last block
|
|
||||||
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
|
||||||
# Flatten and truncate
|
|
||||||
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
|
||||||
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
|
||||||
else:
|
|
||||||
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
|
||||||
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
|
||||||
|
|
||||||
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
|
||||||
prev_k_batched = prev_k_flat.unsqueeze(0)
|
|
||||||
prev_v_batched = prev_v_flat.unsqueeze(0)
|
|
||||||
|
|
||||||
# Compute attention on all prefilled blocks at once
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
o_acc, lse_acc = flash_attn_with_lse(
|
|
||||||
q_batched, prev_k_batched, prev_v_batched,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o_acc, lse_acc
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "FullAttentionPolicy()"
|
return "FullAttentionPolicy()"
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class SparsePolicy(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_chunked_attention(
|
def compute_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ class Attention(nn.Module):
|
|||||||
Compute attention with per-layer prefill buffer for async offload.
|
Compute attention with per-layer prefill buffer for async offload.
|
||||||
|
|
||||||
Simplified design:
|
Simplified design:
|
||||||
- All computation logic is delegated to sparse_policy.compute_chunked_attention()
|
- All computation logic is delegated to sparse_policy.compute_chunked_prefill()
|
||||||
- This method only handles async offload after computation
|
- This method only handles async offload after computation
|
||||||
|
|
||||||
The policy handles:
|
The policy handles:
|
||||||
@@ -198,11 +198,11 @@ class Attention(nn.Module):
|
|||||||
raise RuntimeError("sparse_policy is required for chunked prefill")
|
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||||||
|
|
||||||
# [DEBUG] Verify execution path
|
# [DEBUG] Verify execution path
|
||||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_attention, "
|
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||||||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||||
|
|
||||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||||
final_o = sparse_policy.compute_chunked_attention(
|
final_o = sparse_policy.compute_chunked_prefill(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
self.layer_id,
|
self.layer_id,
|
||||||
self.scale,
|
self.scale,
|
||||||
|
|||||||
Reference in New Issue
Block a user