Add comprehensive documentation analyzing the 32K chunked offload accuracy issues with proposed solutions covering LSE precision, ring buffer state management, and position encoding validation. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1079 lines
36 KiB
Markdown
1079 lines
36 KiB
Markdown
# Chunked Attention 准确性问题解决方案
|
||
|
||
**Status**: 🟡 IN PROGRESS
|
||
**Related Issue**: [`ruler_32k_chunked_offload_issue.md`](ruler_32k_chunked_offload_issue.md)
|
||
**Created**: 2026-01-20
|
||
**Author**: Zijie Tian
|
||
|
||
---
|
||
|
||
## 概述
|
||
|
||
本文档基于对 chunked attention 代码的深入分析,详细描述了导致 RULER 32K 测试 20% 错误率的潜在代码问题,以及对应的解决方案。
|
||
|
||
**核心问题**: Chunked prefill 机制在处理 32K 长序列时,由于多个因素的累积效应,导致输出准确性下降。
|
||
|
||
---
|
||
|
||
## 问题 1: `flash_attn_with_lse` 的 LSE 获取方式
|
||
|
||
### 问题描述
|
||
|
||
**文件**: `nanovllm/ops/chunked_attention.py`
|
||
**位置**: 第 264-269 行
|
||
|
||
```python
|
||
def flash_attn_with_lse(
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
softmax_scale: Optional[float] = None,
|
||
causal: bool = False,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
from flash_attn.flash_attn_interface import flash_attn_func
|
||
|
||
# ...
|
||
|
||
# 使用 return_attn_probs=True 来获取 LSE
|
||
out, lse, _ = flash_attn_func(
|
||
q, k, v,
|
||
softmax_scale=softmax_scale,
|
||
causal=causal,
|
||
return_attn_probs=True, # ⚠️ 这个 API 不是专门用于 LSE 输出的
|
||
)
|
||
|
||
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
|
||
lse = lse[:, :, :seqlen_q]
|
||
return out, lse
|
||
```
|
||
|
||
### 问题分析
|
||
|
||
1. **API 设计目的不匹配**: `return_attn_probs=True` 的主要目的是返回 attention probabilities 用于调试,不是为了高精度 LSE 输出
|
||
2. **潜在精度损失**: Flash Attention 内部可能对 LSE 进行了某些优化,这些优化可能牺牲精度
|
||
3. **S_dmask 返回值**: 第三个返回值 `S_dmask` 在某些版本中可能包含随机 dropout mask,影响结果
|
||
|
||
### 验证方法
|
||
|
||
```python
|
||
def verify_lse_accuracy():
|
||
"""验证 flash_attn 返回的 LSE 是否准确"""
|
||
import torch
|
||
from flash_attn.flash_attn_interface import flash_attn_func
|
||
|
||
# 小规模数据用于精确计算
|
||
batch, seqlen, nheads, headdim = 1, 128, 8, 64
|
||
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)
|
||
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)
|
||
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)
|
||
|
||
# Flash Attention LSE
|
||
_, lse_flash, _ = flash_attn_func(q, k, v, return_attn_probs=True)
|
||
|
||
# 手动计算 LSE (使用 PyTorch)
|
||
scale = 1.0 / (headdim ** 0.5)
|
||
qk = torch.einsum('bqhd,bkhd->bhqk', q, k) * scale
|
||
lse_manual = torch.logsumexp(qk, dim=-1) # [batch, heads, seqlen_q]
|
||
|
||
# 比较
|
||
diff = (lse_flash[:, :, :seqlen] - lse_manual).abs()
|
||
print(f"LSE max diff: {diff.max().item():.6f}")
|
||
print(f"LSE mean diff: {diff.mean().item():.6f}")
|
||
|
||
return diff.max().item() < 1e-3
|
||
```
|
||
|
||
### 解决方案
|
||
|
||
#### 方案 A: 使用 FlashInfer 的精确 LSE 接口 (推荐)
|
||
|
||
FlashInfer 提供了专门用于 chunked attention 的接口:
|
||
|
||
```python
|
||
def flash_attn_with_lse_flashinfer(
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
softmax_scale: Optional[float] = None,
|
||
causal: bool = False,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""使用 FlashInfer 的接口获取精确 LSE"""
|
||
import flashinfer
|
||
|
||
batch, seqlen_q, nheads_q, headdim = q.shape
|
||
_, seqlen_k, nheads_kv, _ = k.shape
|
||
|
||
if softmax_scale is None:
|
||
softmax_scale = 1.0 / math.sqrt(headdim)
|
||
|
||
# FlashInfer 的 single_prefill_with_kv_cache_return_lse
|
||
# 专门设计用于返回精确的 LSE
|
||
out, lse = flashinfer.single_prefill_with_kv_cache_return_lse(
|
||
q.view(batch * seqlen_q, nheads_q, headdim),
|
||
k.view(batch * seqlen_k, nheads_kv, headdim),
|
||
v.view(batch * seqlen_k, nheads_kv, headdim),
|
||
causal=causal,
|
||
sm_scale=softmax_scale,
|
||
)
|
||
|
||
# Reshape outputs
|
||
out = out.view(batch, seqlen_q, nheads_q, headdim)
|
||
lse = lse.view(batch, seqlen_q, nheads_q).transpose(1, 2) # [batch, nheads, seqlen_q]
|
||
|
||
return out, lse
|
||
```
|
||
|
||
#### 方案 B: 自定义 Triton Kernel
|
||
|
||
如果 FlashInfer 不可用,使用自定义的 Triton kernel:
|
||
|
||
```python
|
||
# 已经在 chunked_attention.py 中实现: _fwd_kernel_with_lse
|
||
# 这个 kernel 直接在 Triton 中计算并输出 LSE
|
||
# 优点: 完全控制精度
|
||
# 缺点: 可能比 Flash Attention 慢
|
||
```
|
||
|
||
#### 方案 C: 混合方法
|
||
|
||
```python
|
||
def flash_attn_with_lse_hybrid(
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
softmax_scale: Optional[float] = None,
|
||
causal: bool = False,
|
||
use_triton_kernel: bool = False,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
根据情况选择最佳实现:
|
||
- 短序列 (< 4K): 使用 Flash Attention (快)
|
||
- 长序列 (>= 4K): 使用 Triton kernel (精确)
|
||
"""
|
||
seqlen = q.shape[1]
|
||
|
||
if seqlen < 4096 and not use_triton_kernel:
|
||
return flash_attn_with_lse(q, k, v, softmax_scale, causal)
|
||
else:
|
||
return triton_flash_attn_with_lse(q, k, v, softmax_scale, causal)
|
||
```
|
||
|
||
### 预期改进
|
||
|
||
- 修复 LSE 精度问题后,预计错误率降低 5-10%
|
||
|
||
---
|
||
|
||
## 问题 2: 累积 Merge 次数过多导致误差累积
|
||
|
||
### 问题描述
|
||
|
||
**文件**: `nanovllm/kvcache/sparse/full_policy.py`
|
||
**位置**: 第 136-137 行, 163-164 行
|
||
|
||
```python
|
||
# compute_chunked_prefill 中的 merge 循环
|
||
for block_idx in range(num_blocks):
|
||
# ... 加载 block ...
|
||
prev_o, prev_lse = flash_attn_with_lse(
|
||
q_batched, prev_k, prev_v,
|
||
softmax_scale=softmax_scale,
|
||
causal=False,
|
||
)
|
||
if o_acc is None:
|
||
o_acc, lse_acc = prev_o, prev_lse
|
||
else:
|
||
# ⚠️ 每个 block 都执行一次 merge
|
||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||
```
|
||
|
||
### 问题分析
|
||
|
||
对于 32K context (1024 block size),需要执行 ~32 次 merge。每次 merge 的数学公式:
|
||
|
||
```
|
||
max_lse = max(lse1, lse2)
|
||
exp1 = exp(lse1 - max_lse)
|
||
exp2 = exp(lse2 - max_lse)
|
||
o_merged = (o1 * exp1 + o2 * exp2) / (exp1 + exp2)
|
||
lse_merged = max_lse + log(exp1 + exp2)
|
||
```
|
||
|
||
**误差来源**:
|
||
1. `exp()` 和 `log()` 操作在边界情况下精度损失
|
||
2. 除法操作 `/ (exp1 + exp2)` 可能有精度问题
|
||
3. 32 次累积后,误差可达到显著水平
|
||
|
||
### 数学分析
|
||
|
||
假设每次 merge 引入相对误差 ε ≈ 10^-6 (fp32):
|
||
- 32 次累积后: 总误差 ≈ 32 × ε = 3.2 × 10^-5
|
||
- 对于 BFloat16 (7-bit mantissa): ε ≈ 10^-2, 总误差 ≈ 0.32 (32%)
|
||
|
||
### 验证方法
|
||
|
||
```python
|
||
def measure_merge_error_accumulation():
|
||
"""测量 merge 误差随 chunk 数量的变化"""
|
||
import torch
|
||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||
from flash_attn.flash_attn_interface import flash_attn_func
|
||
|
||
torch.manual_seed(42)
|
||
|
||
batch, seqlen, nheads, headdim = 1, 32768, 32, 128
|
||
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||
|
||
# Reference: full attention
|
||
out_ref = flash_attn_func(q, k, v, causal=False)
|
||
|
||
results = []
|
||
for num_chunks in [4, 8, 16, 32, 64]:
|
||
chunk_size = seqlen // num_chunks
|
||
o_acc, lse_acc = None, None
|
||
|
||
for i in range(num_chunks):
|
||
start = i * chunk_size
|
||
end = (i + 1) * chunk_size
|
||
k_chunk = k[:, start:end, :, :]
|
||
v_chunk = v[:, start:end, :, :]
|
||
|
||
chunk_o, chunk_lse = flash_attn_with_lse(q, k_chunk, v_chunk, causal=False)
|
||
|
||
if o_acc is None:
|
||
o_acc, lse_acc = chunk_o, chunk_lse
|
||
else:
|
||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)
|
||
|
||
diff = (out_ref - o_acc).abs()
|
||
results.append({
|
||
'num_chunks': num_chunks,
|
||
'max_diff': diff.max().item(),
|
||
'mean_diff': diff.mean().item(),
|
||
})
|
||
print(f"chunks={num_chunks:3d}, max_diff={diff.max().item():.6f}, mean_diff={diff.mean().item():.8f}")
|
||
|
||
return results
|
||
```
|
||
|
||
### 解决方案
|
||
|
||
#### 方案 A: N-way Merge (推荐)
|
||
|
||
代替二叉树式的 pairwise merge,实现一次性 merge 多个 chunk:
|
||
|
||
```python
|
||
def merge_attention_outputs_nway(
|
||
outputs: List[torch.Tensor], # List of [batch, seqlen_q, nheads, headdim]
|
||
lses: List[torch.Tensor], # List of [batch, nheads, seqlen_q]
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
N-way merge: 一次性 merge 所有 chunk,减少累积误差
|
||
|
||
数学原理:
|
||
- max_lse = max(lse_1, lse_2, ..., lse_n)
|
||
- sum_exp = sum(exp(lse_i - max_lse) for i in 1..n)
|
||
- o_merged = sum(o_i * exp(lse_i - max_lse) for i in 1..n) / sum_exp
|
||
- lse_merged = max_lse + log(sum_exp)
|
||
"""
|
||
n = len(outputs)
|
||
if n == 0:
|
||
raise ValueError("Need at least one output to merge")
|
||
if n == 1:
|
||
return outputs[0], lses[0]
|
||
|
||
# Stack for batch processing
|
||
o_stack = torch.stack(outputs, dim=0) # [n, batch, seqlen, heads, dim]
|
||
lse_stack = torch.stack(lses, dim=0) # [n, batch, heads, seqlen]
|
||
|
||
# Compute max_lse across all chunks (in fp32 for precision)
|
||
lse_fp32 = lse_stack.float()
|
||
max_lse = lse_fp32.max(dim=0).values # [batch, heads, seqlen]
|
||
|
||
# Compute exp(lse_i - max_lse) for each chunk
|
||
exp_weights = torch.exp(lse_fp32 - max_lse.unsqueeze(0)) # [n, batch, heads, seqlen]
|
||
|
||
# Normalize weights
|
||
sum_exp = exp_weights.sum(dim=0) # [batch, heads, seqlen]
|
||
normalized_weights = exp_weights / sum_exp.unsqueeze(0) # [n, batch, heads, seqlen]
|
||
|
||
# Weighted sum of outputs
|
||
# outputs: [n, batch, seqlen, heads, dim]
|
||
# weights: [n, batch, heads, seqlen] -> [n, batch, seqlen, heads, 1]
|
||
weights_expanded = normalized_weights.permute(0, 1, 3, 2).unsqueeze(-1)
|
||
o_merged = (o_stack.float() * weights_expanded).sum(dim=0).to(outputs[0].dtype)
|
||
|
||
# Compute merged LSE
|
||
lse_merged = (max_lse + torch.log(sum_exp)).to(lses[0].dtype)
|
||
|
||
return o_merged, lse_merged
|
||
```
|
||
|
||
**使用方式**:
|
||
|
||
```python
|
||
# 原来的逐个 merge
|
||
for block_idx in range(num_blocks):
|
||
chunk_o, chunk_lse = flash_attn_with_lse(...)
|
||
if o_acc is None:
|
||
o_acc, lse_acc = chunk_o, chunk_lse
|
||
else:
|
||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)
|
||
|
||
# 改为: 收集所有 chunk,然后一次性 merge
|
||
chunk_outputs = []
|
||
chunk_lses = []
|
||
for block_idx in range(num_blocks):
|
||
chunk_o, chunk_lse = flash_attn_with_lse(...)
|
||
chunk_outputs.append(chunk_o)
|
||
chunk_lses.append(chunk_lse)
|
||
|
||
o_acc, lse_acc = merge_attention_outputs_nway(chunk_outputs, chunk_lses)
|
||
```
|
||
|
||
#### 方案 B: Kahan Summation
|
||
|
||
使用 Kahan 求和算法减少浮点累积误差:
|
||
|
||
```python
|
||
def merge_attention_outputs_kahan(
|
||
o1: torch.Tensor, lse1: torch.Tensor,
|
||
o2: torch.Tensor, lse2: torch.Tensor,
|
||
compensation: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||
"""
|
||
使用 Kahan 补偿算法的 merge,保持累积误差补偿项
|
||
|
||
Returns:
|
||
o_merged, lse_merged, (output_comp, lse_comp)
|
||
"""
|
||
# ... 标准 merge 计算 ...
|
||
max_lse = torch.maximum(lse1.float(), lse2.float())
|
||
exp1 = torch.exp(lse1.float() - max_lse)
|
||
exp2 = torch.exp(lse2.float() - max_lse)
|
||
sum_exp = exp1 + exp2
|
||
|
||
o_merged_raw = (o1.float() * exp1.unsqueeze(-1) + o2.float() * exp2.unsqueeze(-1)) / sum_exp.unsqueeze(-1)
|
||
lse_merged_raw = max_lse + torch.log(sum_exp)
|
||
|
||
# Kahan compensation
|
||
if compensation is not None:
|
||
output_comp, lse_comp = compensation
|
||
# Apply compensation
|
||
y = o_merged_raw - output_comp
|
||
t = o_merged_raw # We need to track the actual result
|
||
output_comp_new = (t - o_merged_raw) + y
|
||
# ... similar for lse ...
|
||
else:
|
||
output_comp_new = torch.zeros_like(o_merged_raw)
|
||
lse_comp_new = torch.zeros_like(lse_merged_raw)
|
||
|
||
return o_merged_raw.to(o1.dtype), lse_merged_raw.to(lse1.dtype), (output_comp_new, lse_comp_new)
|
||
```
|
||
|
||
#### 方案 C: 分层 Merge (折中方案)
|
||
|
||
将 32 个 chunk 分组,先组内 merge,再组间 merge:
|
||
|
||
```
|
||
原来: c1 -> c2 -> c3 -> ... -> c32 (31 次 merge)
|
||
改为:
|
||
Group1: c1-c4 -> g1 (3 次)
|
||
Group2: c5-c8 -> g2 (3 次)
|
||
...
|
||
Group8: c29-c32 -> g8 (3 次)
|
||
Final: g1-g8 -> result (7 次)
|
||
总计: 3*8 + 7 = 31 次 merge,但最大累积深度从 31 降到 7
|
||
```
|
||
|
||
```python
|
||
def merge_attention_outputs_hierarchical(
|
||
outputs: List[torch.Tensor],
|
||
lses: List[torch.Tensor],
|
||
group_size: int = 4,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
分层 merge: 先分组 merge,再合并组
|
||
减少累积深度从 O(n) 到 O(log(n))
|
||
"""
|
||
from nanovllm.ops.chunked_attention import merge_attention_outputs
|
||
|
||
n = len(outputs)
|
||
if n <= 1:
|
||
return outputs[0] if n == 1 else (None, None)
|
||
|
||
# Level 1: Merge within groups
|
||
group_outputs = []
|
||
group_lses = []
|
||
|
||
for i in range(0, n, group_size):
|
||
group = outputs[i:i+group_size]
|
||
group_l = lses[i:i+group_size]
|
||
|
||
o_g, lse_g = group[0], group_l[0]
|
||
for j in range(1, len(group)):
|
||
o_g, lse_g = merge_attention_outputs(o_g, lse_g, group[j], group_l[j])
|
||
|
||
group_outputs.append(o_g)
|
||
group_lses.append(lse_g)
|
||
|
||
# Recursively merge groups
|
||
if len(group_outputs) > 1:
|
||
return merge_attention_outputs_hierarchical(group_outputs, group_lses, group_size)
|
||
else:
|
||
return group_outputs[0], group_lses[0]
|
||
```
|
||
|
||
### 预期改进
|
||
|
||
- N-way merge: 预计错误率降低 10-15%
|
||
- 分层 merge: 预计错误率降低 5-8%
|
||
|
||
---
|
||
|
||
## 问题 3: Ring Buffer Slot 数量不足
|
||
|
||
### 问题描述
|
||
|
||
**文件**: `nanovllm/kvcache/offload_engine.py`
|
||
**位置**: Ring buffer 初始化部分
|
||
|
||
```python
|
||
# 2-slot 配置 (原始)
|
||
num_gpu_blocks = 2 # 仅 2 个 GPU block 用于 ring buffer
|
||
|
||
# 问题:
|
||
# - Slot 0: 用于 decode
|
||
# - Slot 1: 用于 prefill 加载
|
||
# - 32 个 chunk 共享 1 个 prefill slot
|
||
# - 高频率的 CPU<->GPU 传输可能导致竞态条件
|
||
```
|
||
|
||
### 问题分析
|
||
|
||
1. **竞态条件**: 当 chunk N 的 compute 还在进行时,chunk N+1 的 transfer 可能覆盖同一个 slot
|
||
2. **同步开销**: 频繁的 `wait_slot_layer()` 调用增加延迟
|
||
3. **已验证**: 4-slot 配置显著改善但未完全解决问题
|
||
|
||
### 验证结果 (已完成)
|
||
|
||
| 配置 | niah_single_1 准确率 | niah_multikey_3 准确率 |
|
||
|------|---------------------|----------------------|
|
||
| 2-slot | 94% | 48% |
|
||
| 4-slot | 98% | 56% |
|
||
| 改进 | +4% | +8% |
|
||
|
||
**关键发现**: Sample 40 在 4-slot 配置下仍然产生相同错误 (`6171717161711716`),说明 slot 数量不是唯一原因。
|
||
|
||
### 解决方案
|
||
|
||
#### 方案 A: 增加 Slot 数量到 8
|
||
|
||
```python
|
||
# offload_engine.py 配置
|
||
class OffloadEngineConfig:
|
||
# 原来
|
||
# num_ring_slots = 2
|
||
|
||
# 建议
|
||
num_ring_slots = 8 # 增加到 8 个 slots
|
||
|
||
# 或者动态计算
|
||
@property
|
||
def num_ring_slots(self):
|
||
# 根据可用 GPU 内存动态调整
|
||
available_memory_gb = get_available_gpu_memory()
|
||
slot_memory_gb = self.block_size * self.kv_cache_dtype_size * 2 / 1e9
|
||
max_slots = int(available_memory_gb * 0.3 / slot_memory_gb) # 使用 30% 内存
|
||
return min(max(max_slots, 4), 16) # 限制在 4-16 之间
|
||
```
|
||
|
||
#### 方案 B: 改进同步机制
|
||
|
||
```python
|
||
def load_to_slot_layer_with_barrier(self, slot: int, layer_id: int, cpu_block_id: int):
|
||
"""
|
||
带有显式 barrier 的加载,确保前一个操作完成
|
||
"""
|
||
# 等待该 slot 上所有挂起的操作完成
|
||
if self.slot_events[slot] is not None:
|
||
self.slot_events[slot].synchronize()
|
||
|
||
# 执行传输
|
||
self._do_transfer(slot, layer_id, cpu_block_id)
|
||
|
||
# 记录新的事件
|
||
self.slot_events[slot] = torch.cuda.Event()
|
||
self.slot_events[slot].record(self.transfer_stream)
|
||
```
|
||
|
||
#### 方案 C: Double Buffering
|
||
|
||
```python
|
||
class DoubleBufferedRingBuffer:
|
||
"""
|
||
双缓冲设计: 每个逻辑 slot 有两个物理 buffer
|
||
- 一个用于当前 compute
|
||
- 一个用于下一个 transfer
|
||
"""
|
||
def __init__(self, num_logical_slots: int, ...):
|
||
self.num_logical_slots = num_logical_slots
|
||
self.num_physical_slots = num_logical_slots * 2
|
||
self.current_buffer = [0] * num_logical_slots # 每个 slot 当前使用的 buffer (0 or 1)
|
||
|
||
def get_compute_slot(self, logical_slot: int) -> int:
|
||
"""获取用于 compute 的物理 slot"""
|
||
return logical_slot * 2 + self.current_buffer[logical_slot]
|
||
|
||
def get_transfer_slot(self, logical_slot: int) -> int:
|
||
"""获取用于 transfer 的物理 slot (另一个 buffer)"""
|
||
return logical_slot * 2 + (1 - self.current_buffer[logical_slot])
|
||
|
||
def swap_buffer(self, logical_slot: int):
|
||
"""交换 buffer"""
|
||
self.current_buffer[logical_slot] = 1 - self.current_buffer[logical_slot]
|
||
```
|
||
|
||
### 预期改进
|
||
|
||
- 8-slot 配置: 预计错误率再降低 3-5%
|
||
- Double buffering: 预计错误率再降低 2-3%
|
||
|
||
---
|
||
|
||
## 问题 4: Sparse Policy Mapping Bug
|
||
|
||
### 问题描述
|
||
|
||
**文件**: `COMPASS/eval/RULER/scripts/pred/call_api.py` (COMPASS 项目)
|
||
**位置**: metric 到 policy 的映射
|
||
|
||
```python
|
||
metric_to_policy = {
|
||
'full': 'FULL',
|
||
'xattn': 'XATTN', # ❌ XATTN 不存在于 SparsePolicyType
|
||
'compass': 'XATTN',
|
||
'minfer': 'MINFERENCE', # ❌ MINFERENCE 不存在于 SparsePolicyType
|
||
'avgpool': 'FULL',
|
||
'flex': 'FLEXPREFILL',
|
||
}
|
||
|
||
sparse_policy_name = metric_to_policy.get(self.metric, 'FULL')
|
||
sparse_policy = getattr(SparsePolicyType, sparse_policy_name, SparsePolicyType.FULL)
|
||
# ⚠️ 静默回退到 FULL,没有警告!
|
||
```
|
||
|
||
### 问题分析
|
||
|
||
1. **静默回退**: `getattr(..., SparsePolicyType.FULL)` 使得无效名称静默回退
|
||
2. **测试配置错误**: 用户以为在测试 `xattn`,实际上在测试 `FULL`
|
||
3. **结果不可靠**: 所有 sparse 方法的测试结果可能都是 FULL 的结果
|
||
|
||
### 验证方法
|
||
|
||
```python
|
||
# 检查当前 SparsePolicyType 的有效值
|
||
from nanovllm.kvcache.sparse.policy import SparsePolicyType
|
||
print([e.name for e in SparsePolicyType])
|
||
# 应该输出类似: ['FULL', 'QUEST', 'XATTN_BSA', ...]
|
||
|
||
# 验证映射是否正确
|
||
for metric, policy_name in metric_to_policy.items():
|
||
has_policy = hasattr(SparsePolicyType, policy_name)
|
||
print(f"{metric} -> {policy_name}: {'✓' if has_policy else '✗'}")
|
||
```
|
||
|
||
### 解决方案
|
||
|
||
#### 方案: 修复映射并添加验证
|
||
|
||
```python
|
||
# 正确的映射 (需要根据 SparsePolicyType 定义更新)
|
||
metric_to_policy = {
|
||
'full': 'FULL',
|
||
'xattn': 'XATTN_BSA', # ✓ 使用正确的枚举名
|
||
'compass': 'XATTN_BSA',
|
||
'minfer': 'MINFERENCE', # 需要确认是否存在
|
||
'avgpool': 'AVGPOOL', # 需要确认是否存在
|
||
'flex': 'FLEXPREFILL', # 需要确认是否存在
|
||
'quest': 'QUEST', # 添加 quest
|
||
}
|
||
|
||
def get_sparse_policy(metric: str) -> SparsePolicyType:
|
||
"""获取 sparse policy,带有验证"""
|
||
policy_name = metric_to_policy.get(metric)
|
||
if policy_name is None:
|
||
raise ValueError(f"Unknown metric: {metric}. Valid options: {list(metric_to_policy.keys())}")
|
||
|
||
if not hasattr(SparsePolicyType, policy_name):
|
||
raise ValueError(f"SparsePolicyType.{policy_name} does not exist. "
|
||
f"Available: {[e.name for e in SparsePolicyType]}")
|
||
|
||
return getattr(SparsePolicyType, policy_name)
|
||
```
|
||
|
||
### 预期改进
|
||
|
||
- 这个 bug 不影响 chunked offload 准确性
|
||
- 但修复后可以正确测试各种 sparse 方法的实际效果
|
||
|
||
---
|
||
|
||
## 问题 5: Decode Buffer 读取边界条件
|
||
|
||
### 问题描述
|
||
|
||
**文件**: `nanovllm/kvcache/sparse/full_policy.py`
|
||
**位置**: 第 269-272 行
|
||
|
||
```python
|
||
def compute_chunked_decode(self, ...):
|
||
# ...
|
||
seq_len = len(seq)
|
||
decode_pos_in_block = (seq_len - 1) % block_size
|
||
decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||
decode_start_pos_in_block = decode_start_pos % block_size
|
||
num_accumulated = decode_pos_in_block - decode_start_pos_in_block + 1
|
||
|
||
# ⚠️ 潜在问题: 当 decode_pos_in_block < decode_start_pos_in_block 时
|
||
# num_accumulated 会是负数
|
||
```
|
||
|
||
### 问题分析
|
||
|
||
场景:当 decode 跨越 block 边界时:
|
||
- 假设 block_size = 1024
|
||
- decode_start_pos = 1020 (在 block 边界附近)
|
||
- 当前 seq_len = 1026 (跨越到下一个 block)
|
||
- decode_pos_in_block = (1026 - 1) % 1024 = 1
|
||
- decode_start_pos_in_block = 1020 % 1024 = 1020
|
||
- num_accumulated = 1 - 1020 + 1 = -1018 ❌
|
||
|
||
### 解决方案
|
||
|
||
```python
|
||
def compute_chunked_decode(self, ...):
|
||
# ...
|
||
seq_len = len(seq)
|
||
decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||
|
||
# 计算 decode buffer 中的实际 token 数量
|
||
# 使用绝对位置而非块内相对位置
|
||
num_accumulated = seq_len - decode_start_pos
|
||
|
||
# 验证
|
||
assert num_accumulated > 0, f"Invalid decode range: seq_len={seq_len}, decode_start={decode_start_pos}"
|
||
|
||
# 计算块内偏移用于读取
|
||
decode_start_pos_in_block = decode_start_pos % block_size
|
||
decode_end_pos_in_block = (seq_len - 1) % block_size
|
||
|
||
# 处理跨块情况
|
||
if decode_end_pos_in_block < decode_start_pos_in_block:
|
||
# 跨块: 需要从两个位置读取
|
||
# 这种情况应该在 kvcache_manager 层面处理
|
||
raise NotImplementedError("Cross-block decode not yet supported")
|
||
|
||
# 同块: 正常读取
|
||
decode_k = offload_engine.decode_k_buffer[
|
||
layer_id,
|
||
decode_start_pos_in_block:decode_end_pos_in_block+1
|
||
]
|
||
decode_v = offload_engine.decode_v_buffer[
|
||
layer_id,
|
||
decode_start_pos_in_block:decode_end_pos_in_block+1
|
||
]
|
||
```
|
||
|
||
### 预期改进
|
||
|
||
- 修复边界条件后,预计消除特定情况下的崩溃或错误输出
|
||
|
||
---
|
||
|
||
## 问题 6: Merge Kernel Block Size 优化
|
||
|
||
### 问题描述
|
||
|
||
**文件**: `nanovllm/ops/chunked_attention.py`
|
||
**位置**: 第 406 行
|
||
|
||
```python
|
||
def merge_attention_outputs(...):
|
||
# ...
|
||
# Launch output merge kernel
|
||
BLOCK_SIZE = 128 # 固定值
|
||
grid_output = (batch, seqlen_q, nheads)
|
||
_merge_output_kernel[grid_output](
|
||
o1, o2, lse1, lse2, o_merged,
|
||
batch, seqlen_q, nheads, headdim,
|
||
BLOCK_SIZE=BLOCK_SIZE,
|
||
)
|
||
```
|
||
|
||
### 问题分析
|
||
|
||
1. **固定 BLOCK_SIZE**: 128 可能不是所有 headdim 的最优值
|
||
2. **性能影响**: 不当的 BLOCK_SIZE 可能导致 GPU 占用率低
|
||
3. **精度影响**: 某些 BLOCK_SIZE 可能导致边界处理问题
|
||
|
||
### 解决方案
|
||
|
||
```python
|
||
def merge_attention_outputs(...):
|
||
batch, seqlen_q, nheads, headdim = o1.shape
|
||
|
||
# 动态选择 BLOCK_SIZE
|
||
# 原则: BLOCK_SIZE 应该是 headdim 的因子,且适合 GPU warp (32)
|
||
if headdim <= 64:
|
||
BLOCK_SIZE = 64
|
||
elif headdim <= 128:
|
||
BLOCK_SIZE = 128
|
||
else:
|
||
BLOCK_SIZE = 256
|
||
|
||
# 确保 BLOCK_SIZE 是 headdim 的因子
|
||
while headdim % BLOCK_SIZE != 0 and BLOCK_SIZE > 32:
|
||
BLOCK_SIZE //= 2
|
||
|
||
# ...
|
||
```
|
||
|
||
### 预期改进
|
||
|
||
- 这是性能优化,对准确性影响较小
|
||
|
||
---
|
||
|
||
## 综合解决方案优先级
|
||
|
||
| 优先级 | 问题 | 解决方案 | 预期改进 | 实现难度 |
|
||
|--------|------|----------|----------|----------|
|
||
| **P0** | N-way Merge | 实现 `merge_attention_outputs_nway` | 10-15% | 中 |
|
||
| **P0** | LSE 精度 | 使用 FlashInfer 或验证 flash_attn LSE | 5-10% | 低 |
|
||
| **P1** | Ring Buffer Slots | 增加到 8 slots + double buffering | 3-5% | 中 |
|
||
| **P1** | Policy Mapping | 修复 COMPASS 中的映射 bug | N/A | 低 |
|
||
| **P2** | Decode 边界 | 添加边界检查和处理 | 1-2% | 低 |
|
||
| **P2** | Merge Block Size | 动态选择 BLOCK_SIZE | <1% | 低 |
|
||
|
||
---
|
||
|
||
## 验证计划
|
||
|
||
### 阶段 1: 误差追踪 (先执行)
|
||
|
||
在实施任何修改之前,添加详细的误差追踪:
|
||
|
||
```python
|
||
# 在 full_policy.py 中添加
|
||
def compute_chunked_prefill_with_debug(self, ...):
|
||
debug_info = {
|
||
'chunk_max_diffs': [],
|
||
'lse_values': [],
|
||
'merge_count': 0,
|
||
}
|
||
|
||
# ... 原有逻辑 ...
|
||
|
||
for block_idx in range(num_blocks):
|
||
# ... 计算 ...
|
||
|
||
if o_acc is not None:
|
||
# 记录 merge 前后的差异
|
||
pre_merge_o = o_acc.clone()
|
||
o_acc, lse_acc = merge_attention_outputs(...)
|
||
debug_info['chunk_max_diffs'].append((o_acc - pre_merge_o).abs().max().item())
|
||
debug_info['lse_values'].append(lse_acc.mean().item())
|
||
debug_info['merge_count'] += 1
|
||
|
||
# 保存 debug 信息
|
||
torch.save(debug_info, f'debug_prefill_layer{layer_id}.pt')
|
||
```
|
||
|
||
### 阶段 2: 单一改进测试
|
||
|
||
每次只应用一个改进,测试效果:
|
||
|
||
1. 仅应用 N-way merge -> 测试准确率
|
||
2. 仅增加 ring buffer slots -> 测试准确率
|
||
3. 仅使用 FlashInfer LSE -> 测试准确率
|
||
|
||
### 阶段 3: 组合测试
|
||
|
||
组合效果最好的改进,验证是否有叠加效果:
|
||
|
||
```
|
||
最终配置 = N-way merge + 8 slots + FlashInfer LSE
|
||
目标: 错误率 < 10% (接近 xattn_stride8 的 8% baseline)
|
||
```
|
||
|
||
---
|
||
|
||
## 附录: 相关文件索引
|
||
|
||
| 文件 | 关键函数/类 | 问题编号 |
|
||
|------|-------------|----------|
|
||
| `nanovllm/ops/chunked_attention.py` | `flash_attn_with_lse`, `merge_attention_outputs` | 1, 2, 6 |
|
||
| `nanovllm/kvcache/sparse/full_policy.py` | `compute_chunked_prefill`, `compute_chunked_decode` | 2, 5 |
|
||
| `nanovllm/kvcache/offload_engine.py` | Ring buffer 管理 | 3 |
|
||
| `COMPASS/eval/RULER/scripts/pred/call_api.py` | `metric_to_policy` | 4 |
|
||
|
||
---
|
||
|
||
## 附录: 错误样本详情
|
||
|
||
以下是 RULER 32K 测试中出错的样本,用于后续验证修复效果。
|
||
|
||
### 错误统计
|
||
|
||
| Task | Total Samples | Errors | Error Rate |
|
||
|------|--------------|--------|------------|
|
||
| niah_single_1 | 100 | 19 | 19% |
|
||
| niah_single_2 | 100 | 23 | 23% |
|
||
| niah_single_3 | 100 | 8 | 8% |
|
||
| niah_multikey_1 | 100 | 16 | 16% |
|
||
| niah_multikey_2 | 100 | 30 | 30% |
|
||
| niah_multikey_3 | 100 | 24 | 24% |
|
||
| **TOTAL** | **600** | **120** | **20%** |
|
||
|
||
---
|
||
|
||
### niah_single_1 (19 errors)
|
||
|
||
**错误 Sample IDs**: `28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83`
|
||
|
||
| Index | Expected | Actual Output |
|
||
|-------|----------|---------------|
|
||
| 28 | `9874152` | `:151:52<\|eot_id\|>` |
|
||
| 33 | `9196204` | `:<\|eot_id\|>` |
|
||
| 39 | `3484601` | `:<\|eot_id\|>` |
|
||
| 40 | `6171716` | `: 17: 16<\|eot_id\|>` |
|
||
| 41 | `4524499` | `:<\|eot_id\|>` |
|
||
| 43 | `3726327` | `: 16: 7<\|eot_id\|>` |
|
||
| 44 | `4009172` | `: 2<\|eot_id\|>` |
|
||
| 49 | `4240180` | `:354:180<\|eot_id\|>` |
|
||
| 51 | `9546409` | `:<\|eot_id\|>` |
|
||
| 52 | `2935113` | `: 29351113.<\|eot_id\|>` |
|
||
| 53 | `5453786` | `:354:678:90<\|eot_id\|>` |
|
||
| 57 | `8315831` | `: 5831<\|eot_id\|>` |
|
||
| 61 | `5960271` | `: 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,...<\|eot_id\|>` |
|
||
| 63 | `6049101` | `: 5 0 4 9 1 0 1<\|eot_id\|>` |
|
||
| 65 | `6406444` | `:361361361361361361361361361361361361361361361361361361361361361361361361361361...<\|eot_id\|>` |
|
||
| 67 | `2422633` | `:31<\|eot_id\|>` |
|
||
| 72 | `7442089` | ` 7953166<\|eot_id\|>` |
|
||
| 77 | `8795419` | `:<\|eot_id\|>` |
|
||
| 83 | `6363836` | `: 2<\|eot_id\|>` |
|
||
|
||
---
|
||
|
||
### niah_single_2 (23 errors)
|
||
|
||
**错误 Sample IDs**: `16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93`
|
||
|
||
| Index | Expected | Actual Output |
|
||
|-------|----------|---------------|
|
||
| 16 | `2344047` | `: 23440447.<\|eot_id\|>` |
|
||
| 24 | `5449324` | `:<\|eot_id\|>` |
|
||
| 30 | `5727085` | `:<\|eot_id\|>` |
|
||
| 32 | `9196204` | `:<\|eot_id\|>` |
|
||
| 40 | `4524499` | `:460<\|eot_id\|>` |
|
||
| 41 | `7817881` | `:171.<\|eot_id\|>` |
|
||
| 42 | `3726327` | `:<\|eot_id\|>` |
|
||
| 50 | `9546409` | `:<\|eot_id\|>` |
|
||
| 51 | `2935113` | `: 3: 5113<\|eot_id\|>` |
|
||
| 52 | `5453786` | `:354<\|eot_id\|>` |
|
||
| 55 | `4188992` | `: 418899189418899, but it is not explicitly stated...` |
|
||
| 58 | `6266630` | `:5963<\|eot_id\|>` |
|
||
| 60 | `5960271` | ` 0271<\|eot_id\|>` |
|
||
| 62 | `6049101` | `:<\|eot_id\|>` |
|
||
| 64 | `6406444` | `:<\|eot_id\|>` |
|
||
| 66 | `2422633` | `:5313<\|eot_id\|>` |
|
||
| 67 | `4940441` | `:5311<\|eot_id\|>` |
|
||
| 68 | `3472189` | `:361.<\|eot_id\|>` |
|
||
| 69 | `8971465` | `:361.<\|eot_id\|>` |
|
||
| 77 | `8963715` | `: 0 8 9 7 1 5<\|eot_id\|>` |
|
||
| 85 | `2044645` | `: 20446445.<\|eot_id\|>` |
|
||
| 91 | `7783308` | `:<\|eot_id\|>` |
|
||
| 93 | `1454696` | `:<\|eot_id\|>` |
|
||
|
||
---
|
||
|
||
### niah_single_3 (8 errors)
|
||
|
||
**错误 Sample IDs**: `7, 9, 14, 24, 25, 29, 31, 43`
|
||
|
||
| Index | Expected | Actual Output |
|
||
|-------|----------|---------------|
|
||
| 7 | `ee87905e-4ca4-45ea-8dfa-6a56d12dbc9a` | `: 2010-07-01T00:00:00Z<\|eot_id\|>` |
|
||
| 9 | `b7b56ea7-35eb-432d-9ad6-20ab48212ddb` | `:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0<\|eot_id\|>` |
|
||
| 14 | `e767dcea-b0e6-4969-a213-42b0f1eedba3` | `:0e6-4969-a213-42b0f1eedba3<\|eot_id\|>` |
|
||
| 24 | `59e4b671-4774-4c58-85f8-bc16f7860b50` | `:4774:4c58:85f8:bc16f7860b50<\|eot_id\|>` |
|
||
| 25 | `54c63cd8-8945-4f27-97fa-2d8dfb2ca025` | `: 54c63c63cd8-8945-4f27-97fa-2d8dfb2ca025.<\|eot_id\|>` |
|
||
| 29 | `006ed6e3-6fa1-4735-b572-f3d00b5cea6a` | `:6e3-6fa1-4735-b572-f3d00b5cea6a<\|eot_id\|>` |
|
||
| 31 | `e6697833-b841-40a0-9fe7-71d6d9178793` | `: e6697837837833-b841-40a0-9fe7-71d6d9178793.<\|eot_id\|>` |
|
||
| 43 | `d92c9227-eadf-4085-bfcb-75468eb22579` | `: d92c922c9227-eadf-4085-bfcb-75468eb22579.<\|eot_id\|>` |
|
||
|
||
---
|
||
|
||
### niah_multikey_1 (16 errors)
|
||
|
||
**错误 Sample IDs**: `20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74`
|
||
|
||
| Index | Expected | Actual Output |
|
||
|-------|----------|---------------|
|
||
| 20 | `2171218` | `: 2171212181212181212181218<\|eot_id\|>` |
|
||
| 31 | `9333700` | `:<\|eot_id\|>` |
|
||
| 32 | `7121355` | `:9651<\|eot_id\|>` |
|
||
| 40 | `3112652` | `:285<\|eot_id\|>` |
|
||
| 41 | `3427461` | `:<\|eot_id\|>` |
|
||
| 45 | `8217547` | `:<\|eot_id\|>` |
|
||
| 51 | `1514340` | `: 1514343403361.<\|eot_id\|>` |
|
||
| 54 | `8212753` | `:<\|eot_id\|>` |
|
||
| 59 | `6587964` | `:<\|eot_id\|>` |
|
||
| 63 | `1688246` | `:<\|eot_id\|>` |
|
||
| 64 | `8344365` | `: 834436, but it is not explicitly mentioned.<\|eot_id\|>` |
|
||
| 65 | `6614484` | `: 4367.<\|eot_id\|>` |
|
||
| 67 | `6510922` | `:7780<\|eot_id\|>` |
|
||
| 69 | `6649968` | `: 43610.<\|eot_id\|>` |
|
||
| 71 | `9437374` | `:<\|eot_id\|>` |
|
||
| 74 | `6625238` | `:1472908<\|eot_id\|>` |
|
||
|
||
---
|
||
|
||
### niah_multikey_2 (30 errors)
|
||
|
||
**错误 Sample IDs**: `2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65`
|
||
|
||
| Index | Expected | Actual Output |
|
||
|-------|----------|---------------|
|
||
| 2 | `1535573` | `: 8651665.<\|eot_id\|>` |
|
||
| 13 | `2794159` | `: 5261593<\|eot_id\|>` |
|
||
| 21 | `8970232` | `:168<\|eot_id\|>` |
|
||
| 22 | `9134051` | `: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 38...` |
|
||
| 23 | `9696620` | `: 969662620969662, which is: 969662920, 96966220 is not actually me...` |
|
||
| 24 | `7071187` | ` 055055055.<\|eot_id\|>` |
|
||
| 25 | `5572782` | `: 5342494<\|eot_id\|>` |
|
||
| 28 | `4953027` | `:1687719<\|eot_id\|>` |
|
||
| 32 | `4259234` | `: 425923521250, but not found is: 425923751572250, however is: 4259...` |
|
||
| 34 | `3643022` | `: 3957500<\|eot_id\|>` |
|
||
| 38 | `2031469` | `: the text.<\|eot_id\|>` |
|
||
| 39 | `8740362` | `: 8740364 8740364 8740364 8740364 is: is: is: is: 874036...` |
|
||
| 40 | `7041770` | `:1682<\|eot_id\|>` |
|
||
| 41 | `1986258` | `:086.<\|eot_id\|>` |
|
||
| 42 | `5668574` | `:055.<\|eot_id\|>` |
|
||
| 43 | `8560471` | `:067<\|eot_id\|>` |
|
||
| 45 | `9973767` | `: 8420273<\|eot_id\|>` |
|
||
| 46 | `3960211` | `:0<\|eot_id\|>` |
|
||
| 47 | `8003271` | `: 60870870870870870870870870870870870870870870870870870870870870870...` |
|
||
| 49 | `8632309` | ` 303640 is640 640 640 640 640 640 640 640 640 640 640 640 640 640 640...` |
|
||
| 50 | `2318630` | `: 7780552.<\|eot_id\|>` |
|
||
| 53 | `3405052` | `:<\|eot_id\|>` |
|
||
| 54 | `5364945` | `: 536494, which is: 536494, which is: 536494494494494494494494494...` |
|
||
| 56 | `7319214` | `:7607607607607607607607607607607607607607607607607607607607607607607...` |
|
||
| 57 | `9206104` | `:7607607607607607607607607607607607607607607607607607607607607607607...` |
|
||
| 59 | `9555385` | `:7095<\|eot_id\|>` |
|
||
| 60 | `5727554` | `: 572755755755755755755755755755755755755755755755755755755755 is: 572...` |
|
||
| 63 | `1090767` | `:7607607607607607607607607607607607607607607607607607607607607607607...` |
|
||
| 64 | `6791240` | `:<\|eot_id\|>` |
|
||
| 65 | `7275999` | `:7607607607607607607607607607607607607607607607607607607607607607607...` |
|
||
|
||
---
|
||
|
||
### niah_multikey_3 (24 errors)
|
||
|
||
**错误 Sample IDs**: `11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52`
|
||
|
||
| Index | Expected | Actual Output |
|
||
|-------|----------|---------------|
|
||
| 11 | `c73ed342-6523-4d4b-aa33-beb1c9007315` | `: 1d28b88b-b6a8-46ba-8e8f-56cbafbfd897.<\|eot_id\|>` |
|
||
| 18 | `87b8a762-1d1f-4e85-a5d1-caf284c95aa6` | `: 429a6676-5295-4ea2-a694-6aa949f48e31.<\|eot_id\|>` |
|
||
| 20 | `cce29702-134a-460c-979b-6f7ee7895280` | `:<\|eot_id\|>` |
|
||
| 23 | `ed344bfe-983f-4a21-af44-722e2517244c` | `: aec431e7d880a8dce2c023de24 is: aec43163-061a-4afe-b80a-f5bfb5e3c9...` |
|
||
| 24 | `4712ef99-a8d1-4388-8ca7-b08dd3505d77` | `:<\|eot_id\|>` |
|
||
| 25 | `46969ce7-0da0-49f8-87b2-845e7b8ef100` | `:<\|eot_id\|>` |
|
||
| 26 | `7cff3c66-6860-49e6-8ba5-002162c250c0` | `:4c7e-946b-30812edf965e<\|eot_id\|>` |
|
||
| 27 | `b63b4988-40bc-44b2-bf1c-ca95adbca4e9` | `:<\|eot_id\|>` |
|
||
| 29 | `6d94011c-f28a-4b0b-a2e2-fe34bb8b19a1` | `: 6d6d6d6d4b0e-52ce-44d9-a0f6-1ae405825615<\|eot_id\|>` |
|
||
| 30 | `7c33bb00-4ab4-4e4f-a78e-39f8f06d63eb` | ` d7a2-4b23-a2c0-8c859cb1fa96<\|eot_id\|>` |
|
||
| 33 | `b7c6b586-713a-4907-ad24-5c4f25aeb769` | `:1-4d2c-b42b-933ded2633d6<\|eot_id\|>` |
|
||
| 35 | `ac8a317b-a6bb-4327-90db-2a01622cb723` | `: d2f2f2f2f2f2f2f2d2d2f2d2d2d3d2f6b3d2f- is: d2dab is: is: is: i...` |
|
||
| 37 | `b187b337-3132-4376-a500-9340102092ae` | `:<\|eot_id\|>` |
|
||
| 40 | `2559fa56-dd0a-48d4-ba82-3ae2bf0a4b33` | `:358fe0e3-724e-4cfc-9ae0-d0873162626b.<\|eot_id\|>` |
|
||
| 41 | `7842feb5-e758-44cd-b73b-8ae08aa33142` | `: 6c6adf83-36a9-4e41-9cbe-60a8c9ffba92.<\|eot_id\|>` |
|
||
| 42 | `a1196139-f6fa-4c18-b3da-b7bd50362ac7` | `: a1196131396131196131399a1196139a1196139a1196139a1196139f6a1196139...` |
|
||
| 44 | `7d3d40b2-4594-4573-b267-4c6270dd4425` | `: 613a9e-4e7d-8c9f-740a630e3c53<\|eot_id\|>` |
|
||
| 45 | `500b8a75-8f05-43f5-b9ad-46d47d4e33fc` | `: 500b8a5e0e0e0a500b is: 500b is: 500b-4 is: is: is: is: is: i...` |
|
||
| 46 | `86a867a7-6a98-4a02-b065-70a33bafafde` | `:6139a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a...` |
|
||
| 47 | `7c0f7fd2-237e-4c0f-b3f5-f43623551169` | ` 5fb71d2f0f0b4f0 is: 5fb71 is: 5fb71f-4f-4f-4f-4f-4f-4d7 is: is: ...` |
|
||
| 48 | `b0e1f3f5-6570-437e-b8a1-f1b3f654e257` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` |
|
||
| 49 | `0153722a-70a8-4ec0-9f03-2b0930937e60` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` |
|
||
| 50 | `0a1ead51-0c39-4eeb-ac87-d146acdb1d4a` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` |
|
||
| 52 | `ff686e85-3a9f-4635-95dd-f19e8ca68eb1` | ` ff686e686e686e686e686e686f686e6f686e6fb686f686f686f686f686f- is: f...` |
|
||
|
||
---
|
||
|
||
### 错误模式分析
|
||
|
||
| 错误类型 | 示例 | 可能原因 |
|
||
|----------|------|----------|
|
||
| **空输出** | `:<\|eot_id\|>` | KV cache 完全损坏,模型无法生成有效内容 |
|
||
| **数字重复** | `:361361361...` | Chunk 边界 merge 错误导致 decode 进入循环 |
|
||
| **部分正确** | `9874152` → `:151:52` | 前几个 chunk 正确,后面 chunk 损坏 |
|
||
| **数字插入** | `2344047` → `23440447` | Merge 过程中数值精度损失 |
|
||
| **完全错误** | 输出完全不相关的 UUID | 严重的 KV cache 或 position encoding 错误 |
|
||
|
||
---
|
||
|
||
### 验证脚本
|
||
|
||
用于验证修复效果的脚本:
|
||
|
||
```python
|
||
# verify_samples.py
|
||
import json
|
||
|
||
ERROR_SAMPLES = {
|
||
'niah_single_1': [28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83],
|
||
'niah_single_2': [16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93],
|
||
'niah_single_3': [7, 9, 14, 24, 25, 29, 31, 43],
|
||
'niah_multikey_1': [20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74],
|
||
'niah_multikey_2': [2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65],
|
||
'niah_multikey_3': [11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52],
|
||
}
|
||
|
||
def verify_fix(result_dir: str) -> dict:
|
||
"""验证修复效果"""
|
||
fixed = {}
|
||
for task, error_ids in ERROR_SAMPLES.items():
|
||
pred_file = f"{result_dir}/{task}.jsonl"
|
||
with open(pred_file) as f:
|
||
predictions = [json.loads(line) for line in f]
|
||
|
||
fixed_count = 0
|
||
still_error = []
|
||
for idx in error_ids:
|
||
pred = predictions[idx]
|
||
expected = pred['answer']
|
||
actual = pred['model_output']
|
||
if expected in actual:
|
||
fixed_count += 1
|
||
else:
|
||
still_error.append(idx)
|
||
|
||
fixed[task] = {
|
||
'total_errors': len(error_ids),
|
||
'fixed': fixed_count,
|
||
'still_error': still_error,
|
||
'fix_rate': f"{fixed_count/len(error_ids)*100:.1f}%"
|
||
}
|
||
|
||
return fixed
|
||
```
|
||
|
||
---
|
||
|
||
**Author**: Zijie Tian
|
||
**Created**: 2026-01-20
|