📝 docs: add BSA interface documentation and cleanup temp files
- Add docs/block_sparse_attn_interface.md with BSA function signatures - Update CLAUDE.md documentation index - Remove obsolete DEBUG_SUMMARY.md and test_report_sparse_policy_refactor.md - Add notes.md to .gitignore Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -238,3 +238,4 @@ progress.md
|
|||||||
task_plan_*.md
|
task_plan_*.md
|
||||||
findings_*.md
|
findings_*.md
|
||||||
progress_*.md
|
progress_*.md
|
||||||
|
notes.md
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
||||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||||
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
||||||
|
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
||||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||||
| [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition |
|
| [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition |
|
||||||
|
|||||||
103
DEBUG_SUMMARY.md
103
DEBUG_SUMMARY.md
@@ -1,103 +0,0 @@
|
|||||||
# Chunked Prefill Bug Debug Summary
|
|
||||||
|
|
||||||
## Problem
|
|
||||||
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
|
|
||||||
|
|
||||||
The model generates completely wrong tokens instead of the expected "7492".
|
|
||||||
|
|
||||||
## Investigation Progress
|
|
||||||
|
|
||||||
### 1. Stream Synchronization Fix (Completed)
|
|
||||||
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
|
|
||||||
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
|
|
||||||
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
|
|
||||||
- Added sync: `default_stream.wait_stream(compute_stream)` before return
|
|
||||||
|
|
||||||
### 2. KV Cache Alignment Verification (Completed)
|
|
||||||
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
|
|
||||||
|
|
||||||
**RoPE Alignment:**
|
|
||||||
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
|
|
||||||
- Confirmed RoPE is NOT the cause of the bug
|
|
||||||
|
|
||||||
**K/V Cache Alignment (Chunk 0):**
|
|
||||||
- Cosine similarity: ~1.0 for all layers
|
|
||||||
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
|
|
||||||
- Mean diff: < 0.001
|
|
||||||
- **Conclusion: K/V cache offload is working correctly**
|
|
||||||
|
|
||||||
### 3. Layer Output Divergence Analysis (Completed)
|
|
||||||
Created per-chunk layer output comparison:
|
|
||||||
|
|
||||||
**Chunk 0 (tokens 0-4096):**
|
|
||||||
- All layers pass with excellent cosine similarity (0.999+)
|
|
||||||
- Max diff grows in later layers but within acceptable range
|
|
||||||
|
|
||||||
**Chunk 1 (tokens 4096-8192):**
|
|
||||||
- Layers 0-19: OK (cosine ~1.0)
|
|
||||||
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
|
|
||||||
- Divergence correlates with later transformer layers
|
|
||||||
|
|
||||||
### 4. Critical Discovery: Single-Chunk Offload Also Fails
|
|
||||||
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
|
|
||||||
|
|
||||||
```
|
|
||||||
# Without offload: PASSES
|
|
||||||
python tests/test_needle.py --input-len 2048
|
|
||||||
# Output: "7492" (correct)
|
|
||||||
|
|
||||||
# With offload: FAILS
|
|
||||||
python tests/test_needle.py --enable-offload --input-len 2048
|
|
||||||
# Output: "The Ble White Th G Lopsiswin..." (garbage)
|
|
||||||
```
|
|
||||||
|
|
||||||
**This proves the bug is NOT in:**
|
|
||||||
- Chunked attention logic (merge_attention_outputs)
|
|
||||||
- Multi-chunk KV loading
|
|
||||||
- Ring buffer pipeline
|
|
||||||
|
|
||||||
**The bug IS in:**
|
|
||||||
- The decode path when CPU offload is enabled
|
|
||||||
- How prefilled KV is loaded/used during decode
|
|
||||||
|
|
||||||
### 5. Decode Path Analysis (In Progress)
|
|
||||||
The decode path in CPU offload mode:
|
|
||||||
1. Prefill writes KV to GPU, offloads to CPU
|
|
||||||
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
|
|
||||||
3. Attend to prefilled KV + accumulated decode tokens
|
|
||||||
4. Merge results
|
|
||||||
|
|
||||||
**Observations:**
|
|
||||||
- `prefilled_blocks` set is empty after decode (should contain block IDs)
|
|
||||||
- CPU cache has valid data (reasonable mean/std values)
|
|
||||||
- Decode buffer has zeros (decode tokens not being stored correctly?)
|
|
||||||
|
|
||||||
## Current Status
|
|
||||||
|
|
||||||
### Working
|
|
||||||
- Stream synchronization fixes
|
|
||||||
- K/V cache offload to CPU (verified alignment)
|
|
||||||
- RoPE implementation
|
|
||||||
- Chunked prefill attention for first chunk
|
|
||||||
|
|
||||||
### Not Working
|
|
||||||
- Decode with CPU offload (even for single-chunk inputs)
|
|
||||||
- Multi-chunk attention (divergence in later layers for chunk 1)
|
|
||||||
|
|
||||||
## Next Steps
|
|
||||||
1. Debug why `prefilled_blocks` is empty after decode
|
|
||||||
2. Check if decode path correctly loads KV from CPU
|
|
||||||
3. Verify decode buffer is being written correctly
|
|
||||||
4. Compare decode attention outputs between offload and non-offload modes
|
|
||||||
|
|
||||||
## Key Files
|
|
||||||
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
|
|
||||||
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
|
|
||||||
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
|
|
||||||
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
|
|
||||||
|
|
||||||
## Hypothesis
|
|
||||||
The decode path fails because:
|
|
||||||
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
|
|
||||||
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
|
|
||||||
3. OR there's a stream synchronization issue specific to decode path
|
|
||||||
238
docs/block_sparse_attn_interface.md
Normal file
238
docs/block_sparse_attn_interface.md
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
# Block Sparse Attention Interface
|
||||||
|
|
||||||
|
Source: [MIT-HAN-LAB/Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention)
|
||||||
|
|
||||||
|
This document records the BSA (Block Sparse Attention) interface used by XAttention for sparse attention computation.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
BSA is installed in the `minference` conda environment:
|
||||||
|
```
|
||||||
|
/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages/block_sparse_attn/
|
||||||
|
```
|
||||||
|
|
||||||
|
To use in other environments, add to PYTHONPATH:
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages:$PYTHONPATH python script.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Interface Code
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_blocksparse_attn_interface.py
|
||||||
|
|
||||||
|
import block_sparse_attn_cuda
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def convert_blockmask(blockmask, causal):
|
||||||
|
"""Convert from the 0-1 format to the format used by the CUDA code.
|
||||||
|
0 means the block is skipped.
|
||||||
|
nonzero means the block is not skipped.
|
||||||
|
Argument:
|
||||||
|
blockmask: (row, col): a 0-1 tensor
|
||||||
|
Return:
|
||||||
|
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
|
||||||
|
indices of the nonzero blocks, padded with -1 to reach length @row.
|
||||||
|
The indices are multiplied by 4, with the smallest bit used to encode whether
|
||||||
|
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
|
||||||
|
the last nonzero in its row..
|
||||||
|
"""
|
||||||
|
assert not causal
|
||||||
|
nrow, ncol = blockmask.shape
|
||||||
|
# Sort does not support bool on CUDA
|
||||||
|
blockmask = blockmask.to(dtype=torch.uint8)
|
||||||
|
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
|
||||||
|
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
|
||||||
|
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
|
||||||
|
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
|
||||||
|
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
|
||||||
|
]
|
||||||
|
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
|
||||||
|
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
|
||||||
|
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
|
||||||
|
]
|
||||||
|
nonzero_idx = nonzero_sorted_rowidx * 4
|
||||||
|
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
|
||||||
|
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
|
||||||
|
nonzero_idx[nonzero_val == 0] = -1
|
||||||
|
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_blockmask_row_reverse(blockmask, causal=False):
|
||||||
|
blockmask = blockmask.to(dtype=torch.uint8)
|
||||||
|
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-1, stable=True, descending=False)
|
||||||
|
|
||||||
|
nonzero_idx = nonzero_sorted_rowidx
|
||||||
|
nonzero_idx[nonzero_val == 0] = -1
|
||||||
|
nonzero_idx = torch.flip(nonzero_idx, dims=[-1])
|
||||||
|
|
||||||
|
return nonzero_idx.contiguous().to(dtype=torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_blockmask_col_reverse(blockmask, causal=False):
|
||||||
|
blockmask = blockmask.to(dtype=torch.uint8)
|
||||||
|
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-2, stable=True, descending=False)
|
||||||
|
|
||||||
|
nonzero_idx = nonzero_sorted_rowidx
|
||||||
|
nonzero_idx[nonzero_val == 0] = -1
|
||||||
|
nonzero_idx = torch.flip(nonzero_idx, dims=[-2])
|
||||||
|
nonzero_idx = torch.transpose(nonzero_idx, -1, -2)
|
||||||
|
|
||||||
|
return nonzero_idx.contiguous().to(dtype=torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_ones_with_count(tensor):
|
||||||
|
ones_mask = tensor == 1
|
||||||
|
ones_num = ones_mask.sum()
|
||||||
|
count = torch.cumsum(ones_mask, dim=-1).to(tensor.dtype)
|
||||||
|
count = count * ones_mask
|
||||||
|
tensor = tensor.masked_scatter(ones_mask, count[ones_mask])
|
||||||
|
return tensor, ones_num
|
||||||
|
|
||||||
|
|
||||||
|
def _block_sparse_attn_forward(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q, cu_seqlens_k,
|
||||||
|
m_block_dim, n_block_dim,
|
||||||
|
head_mask_type,
|
||||||
|
streaming_info,
|
||||||
|
row_blockmask,
|
||||||
|
max_seqlen_q_, max_seqlen_k_,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
is_causal,
|
||||||
|
exact_streaming,
|
||||||
|
return_softmax,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right
|
||||||
|
):
|
||||||
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q, cu_seqlens_k,
|
||||||
|
m_block_dim, n_block_dim,
|
||||||
|
head_mask_type,
|
||||||
|
streaming_info,
|
||||||
|
row_blockmask,
|
||||||
|
max_seqlen_q_, max_seqlen_k_,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
is_causal,
|
||||||
|
exact_streaming,
|
||||||
|
return_softmax,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
None
|
||||||
|
)
|
||||||
|
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
||||||
|
|
||||||
|
|
||||||
|
def block_sparse_attn_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q, cu_seqlens_k,
|
||||||
|
head_mask_type,
|
||||||
|
streaming_info,
|
||||||
|
base_blockmask,
|
||||||
|
max_seqlen_q_, max_seqlen_k_,
|
||||||
|
p_dropout,
|
||||||
|
deterministic=False,
|
||||||
|
softmax_scale=None,
|
||||||
|
is_causal=False,
|
||||||
|
exact_streaming=False,
|
||||||
|
return_attn_probs=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Main entry point for block sparse attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [total_q, num_heads, head_dim]
|
||||||
|
k: Key tensor [total_k, num_heads, head_dim]
|
||||||
|
v: Value tensor [total_k, num_heads, head_dim]
|
||||||
|
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||||
|
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
|
||||||
|
head_mask_type: Per-head mask type [num_heads], 1 for block sparse
|
||||||
|
streaming_info: Optional streaming attention info
|
||||||
|
base_blockmask: Block mask [batch, num_heads, q_blocks, k_blocks]
|
||||||
|
max_seqlen_q_: Maximum Q sequence length
|
||||||
|
max_seqlen_k_: Maximum K sequence length
|
||||||
|
p_dropout: Dropout probability (0.0 for eval)
|
||||||
|
deterministic: Whether to use deterministic algorithms
|
||||||
|
softmax_scale: Softmax scale (default: 1/sqrt(head_dim))
|
||||||
|
is_causal: Whether to apply causal masking
|
||||||
|
exact_streaming: Whether to use exact streaming attention
|
||||||
|
return_attn_probs: Whether to return attention probabilities
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [total_q, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
head_mask_type, blocksparse_head_num = replace_ones_with_count(head_mask_type)
|
||||||
|
if base_blockmask is not None:
|
||||||
|
assert base_blockmask.shape[1] == blocksparse_head_num
|
||||||
|
|
||||||
|
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
|
||||||
|
return func.apply(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q, cu_seqlens_k,
|
||||||
|
128, 128, # m_block_dim, n_block_dim (fixed at 128)
|
||||||
|
head_mask_type,
|
||||||
|
streaming_info,
|
||||||
|
base_blockmask,
|
||||||
|
max_seqlen_q_, max_seqlen_k_,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
is_causal,
|
||||||
|
exact_streaming,
|
||||||
|
return_attn_probs,
|
||||||
|
-1, -1, # window_size_left, window_size_right
|
||||||
|
deterministic
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Example (from COMPASS)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from block_sparse_attn import block_sparse_attn_func
|
||||||
|
|
||||||
|
# After xattn_estimate returns sparse mask
|
||||||
|
attn_sums, approx_simple_mask = xattn_estimate(query_states, key_states, ...)
|
||||||
|
|
||||||
|
# Reshape for BSA (requires [seq_len, num_heads, head_dim] format)
|
||||||
|
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
|
||||||
|
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||||
|
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||||
|
|
||||||
|
# Cumulative sequence lengths
|
||||||
|
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
|
||||||
|
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
# Head mask type (1 for all heads using block sparse)
|
||||||
|
head_mask_type = torch.tensor([1] * num_heads, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
# Call BSA
|
||||||
|
attn_output = block_sparse_attn_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
q_cu_seq_lens,
|
||||||
|
k_cu_seq_lens,
|
||||||
|
head_mask_type,
|
||||||
|
None, # streaming_info
|
||||||
|
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
||||||
|
q_len,
|
||||||
|
k_len,
|
||||||
|
p_dropout=0.0,
|
||||||
|
deterministic=True,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back to [batch, num_heads, seq_len, head_dim]
|
||||||
|
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Constraints
|
||||||
|
|
||||||
|
- **Block size**: Fixed at 128 tokens (hardcoded in BSA)
|
||||||
|
- **Batch size**: Only batch_size=1 supported for block sparse mode
|
||||||
|
- **Mask format**: `[batch, num_heads, q_blocks, k_blocks]` boolean tensor
|
||||||
|
- **Input format**: `[total_seq_len, num_heads, head_dim]` (not batched)
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
# SparsePolicy 重构测试报告
|
|
||||||
|
|
||||||
## 任务概述
|
|
||||||
|
|
||||||
根据 task_plan.md 的要求,对 nanovllm 的 SparsePolicy 架构进行重构(v4 版本),将 chunked prefill attention 计算逻辑从 attention.py 完全迁移到 SparsePolicy。
|
|
||||||
|
|
||||||
## 修改范围
|
|
||||||
|
|
||||||
仅针对 FullPolicy,不涉及 QuestPolicy 或 XAttentionBSAPolicy,不修改 decode 阶段逻辑。
|
|
||||||
|
|
||||||
## 完成的修改
|
|
||||||
|
|
||||||
### 1. policy.py (SparsePolicy 基类)
|
|
||||||
|
|
||||||
- 添加 TYPE_CHECKING imports: `OffloadEngine`, `KVCacheManager`, `Sequence`
|
|
||||||
- 修改 `select_blocks` 签名:添加 `offload_engine` 参数
|
|
||||||
- 添加 `compute_chunked_attention` 抽象方法,参数包括:
|
|
||||||
- `q, k, v`: 张量
|
|
||||||
- `layer_id`: 层索引
|
|
||||||
- `softmax_scale`: softmax 缩放因子
|
|
||||||
- `offload_engine`: OffloadEngine 实例
|
|
||||||
- `kvcache_manager`: KVCacheManager 实例
|
|
||||||
- `current_chunk_idx`: 当前 chunk 索引
|
|
||||||
- `seq`: Sequence 对象
|
|
||||||
- `num_tokens`: 当前 chunk 的 token 数
|
|
||||||
|
|
||||||
### 2. full_policy.py (FullAttentionPolicy)
|
|
||||||
|
|
||||||
- 更新 TYPE_CHECKING imports
|
|
||||||
- `select_blocks` 方法签名添加 `offload_engine` 参数
|
|
||||||
- 重命名 `compute_prefill_attention` → `compute_chunked_attention`
|
|
||||||
- 添加 `kvcache_manager` 参数,替换所有 `seq.kvcache_manager` 引用
|
|
||||||
- 添加 debug 日志输出
|
|
||||||
|
|
||||||
### 3. attention.py
|
|
||||||
|
|
||||||
- 简化 `_chunked_prefill_attention` 方法:
|
|
||||||
- 删除所有 `flash_attn_*` 调用
|
|
||||||
- 删除所有 `merge_attention_outputs` 调用
|
|
||||||
- 仅保留委托调用 `sparse_policy.compute_chunked_attention()`
|
|
||||||
- 删除冗余方法:`_sync_load_previous_chunks`, `_ring_buffer_pipeline_load`
|
|
||||||
- decode 路径的 `select_blocks` 调用添加 `offload_engine` 参数
|
|
||||||
|
|
||||||
## 验收标准检查
|
|
||||||
|
|
||||||
| 标准 | 状态 | 说明 |
|
|
||||||
|------|------|------|
|
|
||||||
| test_needle.py --enable-offload 通过 | ✅ | 测试输出 PASSED |
|
|
||||||
| attention.py chunked prefill path 无 flash_attn_* 调用 | ✅ | `_chunked_prefill_attention` 方法(169-230行)内无直接 flash_attn 调用 |
|
|
||||||
| attention.py chunked prefill path 无 merge_attention_outputs 调用 | ✅ | 同上 |
|
|
||||||
| 所有 KV 通信通过 offload_engine 方法 | ✅ | 全部通过 `offload_engine.load_to_slot_layer`, `get_kv_for_slot`, `get_prefill_buffer_slice` |
|
|
||||||
|
|
||||||
## 测试结果
|
|
||||||
|
|
||||||
```
|
|
||||||
============================================================
|
|
||||||
Needle-in-Haystack Test
|
|
||||||
============================================================
|
|
||||||
Model: /home/zijie/models/Llama-3.1-8B-Instruct
|
|
||||||
Max model len: 131072
|
|
||||||
Input length: 8192
|
|
||||||
Block size: 1024
|
|
||||||
Needle position: 50%
|
|
||||||
Needle value: 7492
|
|
||||||
CPU offload: True
|
|
||||||
Sparse policy: FULL
|
|
||||||
============================================================
|
|
||||||
|
|
||||||
[NeedleTest] Target: 8192, Actual: 8213 tokens (diff=21)
|
|
||||||
Expected: 7492
|
|
||||||
Output: 7492<|eot_id|>...
|
|
||||||
Status: PASSED
|
|
||||||
============================================================
|
|
||||||
|
|
||||||
test_needle: PASSED
|
|
||||||
```
|
|
||||||
|
|
||||||
## 性能指标
|
|
||||||
|
|
||||||
- Prefill: 3527 tok/s
|
|
||||||
- Decode: 11 tok/s
|
|
||||||
- TTFT: 2329.29 ms
|
|
||||||
- TPOT: 655.38 ms
|
|
||||||
|
|
||||||
## 架构变更总结
|
|
||||||
|
|
||||||
**重构前**:
|
|
||||||
```
|
|
||||||
attention.py::_chunked_prefill_attention()
|
|
||||||
├── 获取 cpu_block_table
|
|
||||||
├── 调用 sparse_policy.select_blocks()
|
|
||||||
├── 直接调用 flash_attn_with_lse + merge_attention_outputs
|
|
||||||
└── 返回结果
|
|
||||||
```
|
|
||||||
|
|
||||||
**重构后**:
|
|
||||||
```
|
|
||||||
attention.py::_chunked_prefill_attention()
|
|
||||||
├── 获取 context 信息
|
|
||||||
├── 调用 sparse_policy.compute_chunked_attention() # 委托全部计算
|
|
||||||
└── 返回结果
|
|
||||||
|
|
||||||
sparse_policy.compute_chunked_attention() # 在 FullPolicy 中
|
|
||||||
├── 获取 cpu_block_table
|
|
||||||
├── 调用 self.select_blocks()
|
|
||||||
├── 加载并计算历史 KV attention
|
|
||||||
├── 计算当前 chunk attention (causal)
|
|
||||||
├── 合并所有结果
|
|
||||||
└── 返回最终输出
|
|
||||||
```
|
|
||||||
|
|
||||||
## 结论
|
|
||||||
|
|
||||||
SparsePolicy 架构 v4 重构成功完成。所有验收标准均已满足,测试通过。
|
|
||||||
Reference in New Issue
Block a user