Add comprehensive documentation analyzing the 32K chunked offload accuracy issues with proposed solutions covering LSE precision, ring buffer state management, and position encoding validation. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
36 KiB
Chunked Attention 准确性问题解决方案
Status: 🟡 IN PROGRESS
Related Issue: ruler_32k_chunked_offload_issue.md
Created: 2026-01-20
Author: Zijie Tian
概述
本文档基于对 chunked attention 代码的深入分析,详细描述了导致 RULER 32K 测试 20% 错误率的潜在代码问题,以及对应的解决方案。
核心问题: Chunked prefill 机制在处理 32K 长序列时,由于多个因素的累积效应,导致输出准确性下降。
问题 1: flash_attn_with_lse 的 LSE 获取方式
问题描述
文件: nanovllm/ops/chunked_attention.py
位置: 第 264-269 行
def flash_attn_with_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
from flash_attn.flash_attn_interface import flash_attn_func
# ...
# 使用 return_attn_probs=True 来获取 LSE
out, lse, _ = flash_attn_func(
q, k, v,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=True, # ⚠️ 这个 API 不是专门用于 LSE 输出的
)
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
lse = lse[:, :, :seqlen_q]
return out, lse
问题分析
- API 设计目的不匹配:
return_attn_probs=True的主要目的是返回 attention probabilities 用于调试,不是为了高精度 LSE 输出 - 潜在精度损失: Flash Attention 内部可能对 LSE 进行了某些优化,这些优化可能牺牲精度
- S_dmask 返回值: 第三个返回值
S_dmask在某些版本中可能包含随机 dropout mask,影响结果
验证方法
def verify_lse_accuracy():
"""验证 flash_attn 返回的 LSE 是否准确"""
import torch
from flash_attn.flash_attn_interface import flash_attn_func
# 小规模数据用于精确计算
batch, seqlen, nheads, headdim = 1, 128, 8, 64
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)
# Flash Attention LSE
_, lse_flash, _ = flash_attn_func(q, k, v, return_attn_probs=True)
# 手动计算 LSE (使用 PyTorch)
scale = 1.0 / (headdim ** 0.5)
qk = torch.einsum('bqhd,bkhd->bhqk', q, k) * scale
lse_manual = torch.logsumexp(qk, dim=-1) # [batch, heads, seqlen_q]
# 比较
diff = (lse_flash[:, :, :seqlen] - lse_manual).abs()
print(f"LSE max diff: {diff.max().item():.6f}")
print(f"LSE mean diff: {diff.mean().item():.6f}")
return diff.max().item() < 1e-3
解决方案
方案 A: 使用 FlashInfer 的精确 LSE 接口 (推荐)
FlashInfer 提供了专门用于 chunked attention 的接口:
def flash_attn_with_lse_flashinfer(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""使用 FlashInfer 的接口获取精确 LSE"""
import flashinfer
batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
# FlashInfer 的 single_prefill_with_kv_cache_return_lse
# 专门设计用于返回精确的 LSE
out, lse = flashinfer.single_prefill_with_kv_cache_return_lse(
q.view(batch * seqlen_q, nheads_q, headdim),
k.view(batch * seqlen_k, nheads_kv, headdim),
v.view(batch * seqlen_k, nheads_kv, headdim),
causal=causal,
sm_scale=softmax_scale,
)
# Reshape outputs
out = out.view(batch, seqlen_q, nheads_q, headdim)
lse = lse.view(batch, seqlen_q, nheads_q).transpose(1, 2) # [batch, nheads, seqlen_q]
return out, lse
方案 B: 自定义 Triton Kernel
如果 FlashInfer 不可用,使用自定义的 Triton kernel:
# 已经在 chunked_attention.py 中实现: _fwd_kernel_with_lse
# 这个 kernel 直接在 Triton 中计算并输出 LSE
# 优点: 完全控制精度
# 缺点: 可能比 Flash Attention 慢
方案 C: 混合方法
def flash_attn_with_lse_hybrid(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
use_triton_kernel: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
根据情况选择最佳实现:
- 短序列 (< 4K): 使用 Flash Attention (快)
- 长序列 (>= 4K): 使用 Triton kernel (精确)
"""
seqlen = q.shape[1]
if seqlen < 4096 and not use_triton_kernel:
return flash_attn_with_lse(q, k, v, softmax_scale, causal)
else:
return triton_flash_attn_with_lse(q, k, v, softmax_scale, causal)
预期改进
- 修复 LSE 精度问题后,预计错误率降低 5-10%
问题 2: 累积 Merge 次数过多导致误差累积
问题描述
文件: nanovllm/kvcache/sparse/full_policy.py
位置: 第 136-137 行, 163-164 行
# compute_chunked_prefill 中的 merge 循环
for block_idx in range(num_blocks):
# ... 加载 block ...
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
# ⚠️ 每个 block 都执行一次 merge
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
问题分析
对于 32K context (1024 block size),需要执行 ~32 次 merge。每次 merge 的数学公式:
max_lse = max(lse1, lse2)
exp1 = exp(lse1 - max_lse)
exp2 = exp(lse2 - max_lse)
o_merged = (o1 * exp1 + o2 * exp2) / (exp1 + exp2)
lse_merged = max_lse + log(exp1 + exp2)
误差来源:
exp()和log()操作在边界情况下精度损失- 除法操作
/ (exp1 + exp2)可能有精度问题 - 32 次累积后,误差可达到显著水平
数学分析
假设每次 merge 引入相对误差 ε ≈ 10^-6 (fp32):
- 32 次累积后: 总误差 ≈ 32 × ε = 3.2 × 10^-5
- 对于 BFloat16 (7-bit mantissa): ε ≈ 10^-2, 总误差 ≈ 0.32 (32%)
验证方法
def measure_merge_error_accumulation():
"""测量 merge 误差随 chunk 数量的变化"""
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
from flash_attn.flash_attn_interface import flash_attn_func
torch.manual_seed(42)
batch, seqlen, nheads, headdim = 1, 32768, 32, 128
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
# Reference: full attention
out_ref = flash_attn_func(q, k, v, causal=False)
results = []
for num_chunks in [4, 8, 16, 32, 64]:
chunk_size = seqlen // num_chunks
o_acc, lse_acc = None, None
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
k_chunk = k[:, start:end, :, :]
v_chunk = v[:, start:end, :, :]
chunk_o, chunk_lse = flash_attn_with_lse(q, k_chunk, v_chunk, causal=False)
if o_acc is None:
o_acc, lse_acc = chunk_o, chunk_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)
diff = (out_ref - o_acc).abs()
results.append({
'num_chunks': num_chunks,
'max_diff': diff.max().item(),
'mean_diff': diff.mean().item(),
})
print(f"chunks={num_chunks:3d}, max_diff={diff.max().item():.6f}, mean_diff={diff.mean().item():.8f}")
return results
解决方案
方案 A: N-way Merge (推荐)
代替二叉树式的 pairwise merge,实现一次性 merge 多个 chunk:
def merge_attention_outputs_nway(
outputs: List[torch.Tensor], # List of [batch, seqlen_q, nheads, headdim]
lses: List[torch.Tensor], # List of [batch, nheads, seqlen_q]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
N-way merge: 一次性 merge 所有 chunk,减少累积误差
数学原理:
- max_lse = max(lse_1, lse_2, ..., lse_n)
- sum_exp = sum(exp(lse_i - max_lse) for i in 1..n)
- o_merged = sum(o_i * exp(lse_i - max_lse) for i in 1..n) / sum_exp
- lse_merged = max_lse + log(sum_exp)
"""
n = len(outputs)
if n == 0:
raise ValueError("Need at least one output to merge")
if n == 1:
return outputs[0], lses[0]
# Stack for batch processing
o_stack = torch.stack(outputs, dim=0) # [n, batch, seqlen, heads, dim]
lse_stack = torch.stack(lses, dim=0) # [n, batch, heads, seqlen]
# Compute max_lse across all chunks (in fp32 for precision)
lse_fp32 = lse_stack.float()
max_lse = lse_fp32.max(dim=0).values # [batch, heads, seqlen]
# Compute exp(lse_i - max_lse) for each chunk
exp_weights = torch.exp(lse_fp32 - max_lse.unsqueeze(0)) # [n, batch, heads, seqlen]
# Normalize weights
sum_exp = exp_weights.sum(dim=0) # [batch, heads, seqlen]
normalized_weights = exp_weights / sum_exp.unsqueeze(0) # [n, batch, heads, seqlen]
# Weighted sum of outputs
# outputs: [n, batch, seqlen, heads, dim]
# weights: [n, batch, heads, seqlen] -> [n, batch, seqlen, heads, 1]
weights_expanded = normalized_weights.permute(0, 1, 3, 2).unsqueeze(-1)
o_merged = (o_stack.float() * weights_expanded).sum(dim=0).to(outputs[0].dtype)
# Compute merged LSE
lse_merged = (max_lse + torch.log(sum_exp)).to(lses[0].dtype)
return o_merged, lse_merged
使用方式:
# 原来的逐个 merge
for block_idx in range(num_blocks):
chunk_o, chunk_lse = flash_attn_with_lse(...)
if o_acc is None:
o_acc, lse_acc = chunk_o, chunk_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)
# 改为: 收集所有 chunk,然后一次性 merge
chunk_outputs = []
chunk_lses = []
for block_idx in range(num_blocks):
chunk_o, chunk_lse = flash_attn_with_lse(...)
chunk_outputs.append(chunk_o)
chunk_lses.append(chunk_lse)
o_acc, lse_acc = merge_attention_outputs_nway(chunk_outputs, chunk_lses)
方案 B: Kahan Summation
使用 Kahan 求和算法减少浮点累积误差:
def merge_attention_outputs_kahan(
o1: torch.Tensor, lse1: torch.Tensor,
o2: torch.Tensor, lse2: torch.Tensor,
compensation: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
使用 Kahan 补偿算法的 merge,保持累积误差补偿项
Returns:
o_merged, lse_merged, (output_comp, lse_comp)
"""
# ... 标准 merge 计算 ...
max_lse = torch.maximum(lse1.float(), lse2.float())
exp1 = torch.exp(lse1.float() - max_lse)
exp2 = torch.exp(lse2.float() - max_lse)
sum_exp = exp1 + exp2
o_merged_raw = (o1.float() * exp1.unsqueeze(-1) + o2.float() * exp2.unsqueeze(-1)) / sum_exp.unsqueeze(-1)
lse_merged_raw = max_lse + torch.log(sum_exp)
# Kahan compensation
if compensation is not None:
output_comp, lse_comp = compensation
# Apply compensation
y = o_merged_raw - output_comp
t = o_merged_raw # We need to track the actual result
output_comp_new = (t - o_merged_raw) + y
# ... similar for lse ...
else:
output_comp_new = torch.zeros_like(o_merged_raw)
lse_comp_new = torch.zeros_like(lse_merged_raw)
return o_merged_raw.to(o1.dtype), lse_merged_raw.to(lse1.dtype), (output_comp_new, lse_comp_new)
方案 C: 分层 Merge (折中方案)
将 32 个 chunk 分组,先组内 merge,再组间 merge:
原来: c1 -> c2 -> c3 -> ... -> c32 (31 次 merge)
改为:
Group1: c1-c4 -> g1 (3 次)
Group2: c5-c8 -> g2 (3 次)
...
Group8: c29-c32 -> g8 (3 次)
Final: g1-g8 -> result (7 次)
总计: 3*8 + 7 = 31 次 merge,但最大累积深度从 31 降到 7
def merge_attention_outputs_hierarchical(
outputs: List[torch.Tensor],
lses: List[torch.Tensor],
group_size: int = 4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
分层 merge: 先分组 merge,再合并组
减少累积深度从 O(n) 到 O(log(n))
"""
from nanovllm.ops.chunked_attention import merge_attention_outputs
n = len(outputs)
if n <= 1:
return outputs[0] if n == 1 else (None, None)
# Level 1: Merge within groups
group_outputs = []
group_lses = []
for i in range(0, n, group_size):
group = outputs[i:i+group_size]
group_l = lses[i:i+group_size]
o_g, lse_g = group[0], group_l[0]
for j in range(1, len(group)):
o_g, lse_g = merge_attention_outputs(o_g, lse_g, group[j], group_l[j])
group_outputs.append(o_g)
group_lses.append(lse_g)
# Recursively merge groups
if len(group_outputs) > 1:
return merge_attention_outputs_hierarchical(group_outputs, group_lses, group_size)
else:
return group_outputs[0], group_lses[0]
预期改进
- N-way merge: 预计错误率降低 10-15%
- 分层 merge: 预计错误率降低 5-8%
问题 3: Ring Buffer Slot 数量不足
问题描述
文件: nanovllm/kvcache/offload_engine.py
位置: Ring buffer 初始化部分
# 2-slot 配置 (原始)
num_gpu_blocks = 2 # 仅 2 个 GPU block 用于 ring buffer
# 问题:
# - Slot 0: 用于 decode
# - Slot 1: 用于 prefill 加载
# - 32 个 chunk 共享 1 个 prefill slot
# - 高频率的 CPU<->GPU 传输可能导致竞态条件
问题分析
- 竞态条件: 当 chunk N 的 compute 还在进行时,chunk N+1 的 transfer 可能覆盖同一个 slot
- 同步开销: 频繁的
wait_slot_layer()调用增加延迟 - 已验证: 4-slot 配置显著改善但未完全解决问题
验证结果 (已完成)
| 配置 | niah_single_1 准确率 | niah_multikey_3 准确率 |
|---|---|---|
| 2-slot | 94% | 48% |
| 4-slot | 98% | 56% |
| 改进 | +4% | +8% |
关键发现: Sample 40 在 4-slot 配置下仍然产生相同错误 (6171717161711716),说明 slot 数量不是唯一原因。
解决方案
方案 A: 增加 Slot 数量到 8
# offload_engine.py 配置
class OffloadEngineConfig:
# 原来
# num_ring_slots = 2
# 建议
num_ring_slots = 8 # 增加到 8 个 slots
# 或者动态计算
@property
def num_ring_slots(self):
# 根据可用 GPU 内存动态调整
available_memory_gb = get_available_gpu_memory()
slot_memory_gb = self.block_size * self.kv_cache_dtype_size * 2 / 1e9
max_slots = int(available_memory_gb * 0.3 / slot_memory_gb) # 使用 30% 内存
return min(max(max_slots, 4), 16) # 限制在 4-16 之间
方案 B: 改进同步机制
def load_to_slot_layer_with_barrier(self, slot: int, layer_id: int, cpu_block_id: int):
"""
带有显式 barrier 的加载,确保前一个操作完成
"""
# 等待该 slot 上所有挂起的操作完成
if self.slot_events[slot] is not None:
self.slot_events[slot].synchronize()
# 执行传输
self._do_transfer(slot, layer_id, cpu_block_id)
# 记录新的事件
self.slot_events[slot] = torch.cuda.Event()
self.slot_events[slot].record(self.transfer_stream)
方案 C: Double Buffering
class DoubleBufferedRingBuffer:
"""
双缓冲设计: 每个逻辑 slot 有两个物理 buffer
- 一个用于当前 compute
- 一个用于下一个 transfer
"""
def __init__(self, num_logical_slots: int, ...):
self.num_logical_slots = num_logical_slots
self.num_physical_slots = num_logical_slots * 2
self.current_buffer = [0] * num_logical_slots # 每个 slot 当前使用的 buffer (0 or 1)
def get_compute_slot(self, logical_slot: int) -> int:
"""获取用于 compute 的物理 slot"""
return logical_slot * 2 + self.current_buffer[logical_slot]
def get_transfer_slot(self, logical_slot: int) -> int:
"""获取用于 transfer 的物理 slot (另一个 buffer)"""
return logical_slot * 2 + (1 - self.current_buffer[logical_slot])
def swap_buffer(self, logical_slot: int):
"""交换 buffer"""
self.current_buffer[logical_slot] = 1 - self.current_buffer[logical_slot]
预期改进
- 8-slot 配置: 预计错误率再降低 3-5%
- Double buffering: 预计错误率再降低 2-3%
问题 4: Sparse Policy Mapping Bug
问题描述
文件: COMPASS/eval/RULER/scripts/pred/call_api.py (COMPASS 项目)
位置: metric 到 policy 的映射
metric_to_policy = {
'full': 'FULL',
'xattn': 'XATTN', # ❌ XATTN 不存在于 SparsePolicyType
'compass': 'XATTN',
'minfer': 'MINFERENCE', # ❌ MINFERENCE 不存在于 SparsePolicyType
'avgpool': 'FULL',
'flex': 'FLEXPREFILL',
}
sparse_policy_name = metric_to_policy.get(self.metric, 'FULL')
sparse_policy = getattr(SparsePolicyType, sparse_policy_name, SparsePolicyType.FULL)
# ⚠️ 静默回退到 FULL,没有警告!
问题分析
- 静默回退:
getattr(..., SparsePolicyType.FULL)使得无效名称静默回退 - 测试配置错误: 用户以为在测试
xattn,实际上在测试FULL - 结果不可靠: 所有 sparse 方法的测试结果可能都是 FULL 的结果
验证方法
# 检查当前 SparsePolicyType 的有效值
from nanovllm.kvcache.sparse.policy import SparsePolicyType
print([e.name for e in SparsePolicyType])
# 应该输出类似: ['FULL', 'QUEST', 'XATTN_BSA', ...]
# 验证映射是否正确
for metric, policy_name in metric_to_policy.items():
has_policy = hasattr(SparsePolicyType, policy_name)
print(f"{metric} -> {policy_name}: {'✓' if has_policy else '✗'}")
解决方案
方案: 修复映射并添加验证
# 正确的映射 (需要根据 SparsePolicyType 定义更新)
metric_to_policy = {
'full': 'FULL',
'xattn': 'XATTN_BSA', # ✓ 使用正确的枚举名
'compass': 'XATTN_BSA',
'minfer': 'MINFERENCE', # 需要确认是否存在
'avgpool': 'AVGPOOL', # 需要确认是否存在
'flex': 'FLEXPREFILL', # 需要确认是否存在
'quest': 'QUEST', # 添加 quest
}
def get_sparse_policy(metric: str) -> SparsePolicyType:
"""获取 sparse policy,带有验证"""
policy_name = metric_to_policy.get(metric)
if policy_name is None:
raise ValueError(f"Unknown metric: {metric}. Valid options: {list(metric_to_policy.keys())}")
if not hasattr(SparsePolicyType, policy_name):
raise ValueError(f"SparsePolicyType.{policy_name} does not exist. "
f"Available: {[e.name for e in SparsePolicyType]}")
return getattr(SparsePolicyType, policy_name)
预期改进
- 这个 bug 不影响 chunked offload 准确性
- 但修复后可以正确测试各种 sparse 方法的实际效果
问题 5: Decode Buffer 读取边界条件
问题描述
文件: nanovllm/kvcache/sparse/full_policy.py
位置: 第 269-272 行
def compute_chunked_decode(self, ...):
# ...
seq_len = len(seq)
decode_pos_in_block = (seq_len - 1) % block_size
decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
decode_start_pos_in_block = decode_start_pos % block_size
num_accumulated = decode_pos_in_block - decode_start_pos_in_block + 1
# ⚠️ 潜在问题: 当 decode_pos_in_block < decode_start_pos_in_block 时
# num_accumulated 会是负数
问题分析
场景:当 decode 跨越 block 边界时:
- 假设 block_size = 1024
- decode_start_pos = 1020 (在 block 边界附近)
- 当前 seq_len = 1026 (跨越到下一个 block)
- decode_pos_in_block = (1026 - 1) % 1024 = 1
- decode_start_pos_in_block = 1020 % 1024 = 1020
- num_accumulated = 1 - 1020 + 1 = -1018 ❌
解决方案
def compute_chunked_decode(self, ...):
# ...
seq_len = len(seq)
decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
# 计算 decode buffer 中的实际 token 数量
# 使用绝对位置而非块内相对位置
num_accumulated = seq_len - decode_start_pos
# 验证
assert num_accumulated > 0, f"Invalid decode range: seq_len={seq_len}, decode_start={decode_start_pos}"
# 计算块内偏移用于读取
decode_start_pos_in_block = decode_start_pos % block_size
decode_end_pos_in_block = (seq_len - 1) % block_size
# 处理跨块情况
if decode_end_pos_in_block < decode_start_pos_in_block:
# 跨块: 需要从两个位置读取
# 这种情况应该在 kvcache_manager 层面处理
raise NotImplementedError("Cross-block decode not yet supported")
# 同块: 正常读取
decode_k = offload_engine.decode_k_buffer[
layer_id,
decode_start_pos_in_block:decode_end_pos_in_block+1
]
decode_v = offload_engine.decode_v_buffer[
layer_id,
decode_start_pos_in_block:decode_end_pos_in_block+1
]
预期改进
- 修复边界条件后,预计消除特定情况下的崩溃或错误输出
问题 6: Merge Kernel Block Size 优化
问题描述
文件: nanovllm/ops/chunked_attention.py
位置: 第 406 行
def merge_attention_outputs(...):
# ...
# Launch output merge kernel
BLOCK_SIZE = 128 # 固定值
grid_output = (batch, seqlen_q, nheads)
_merge_output_kernel[grid_output](
o1, o2, lse1, lse2, o_merged,
batch, seqlen_q, nheads, headdim,
BLOCK_SIZE=BLOCK_SIZE,
)
问题分析
- 固定 BLOCK_SIZE: 128 可能不是所有 headdim 的最优值
- 性能影响: 不当的 BLOCK_SIZE 可能导致 GPU 占用率低
- 精度影响: 某些 BLOCK_SIZE 可能导致边界处理问题
解决方案
def merge_attention_outputs(...):
batch, seqlen_q, nheads, headdim = o1.shape
# 动态选择 BLOCK_SIZE
# 原则: BLOCK_SIZE 应该是 headdim 的因子,且适合 GPU warp (32)
if headdim <= 64:
BLOCK_SIZE = 64
elif headdim <= 128:
BLOCK_SIZE = 128
else:
BLOCK_SIZE = 256
# 确保 BLOCK_SIZE 是 headdim 的因子
while headdim % BLOCK_SIZE != 0 and BLOCK_SIZE > 32:
BLOCK_SIZE //= 2
# ...
预期改进
- 这是性能优化,对准确性影响较小
综合解决方案优先级
| 优先级 | 问题 | 解决方案 | 预期改进 | 实现难度 |
|---|---|---|---|---|
| P0 | N-way Merge | 实现 merge_attention_outputs_nway |
10-15% | 中 |
| P0 | LSE 精度 | 使用 FlashInfer 或验证 flash_attn LSE | 5-10% | 低 |
| P1 | Ring Buffer Slots | 增加到 8 slots + double buffering | 3-5% | 中 |
| P1 | Policy Mapping | 修复 COMPASS 中的映射 bug | N/A | 低 |
| P2 | Decode 边界 | 添加边界检查和处理 | 1-2% | 低 |
| P2 | Merge Block Size | 动态选择 BLOCK_SIZE | <1% | 低 |
验证计划
阶段 1: 误差追踪 (先执行)
在实施任何修改之前,添加详细的误差追踪:
# 在 full_policy.py 中添加
def compute_chunked_prefill_with_debug(self, ...):
debug_info = {
'chunk_max_diffs': [],
'lse_values': [],
'merge_count': 0,
}
# ... 原有逻辑 ...
for block_idx in range(num_blocks):
# ... 计算 ...
if o_acc is not None:
# 记录 merge 前后的差异
pre_merge_o = o_acc.clone()
o_acc, lse_acc = merge_attention_outputs(...)
debug_info['chunk_max_diffs'].append((o_acc - pre_merge_o).abs().max().item())
debug_info['lse_values'].append(lse_acc.mean().item())
debug_info['merge_count'] += 1
# 保存 debug 信息
torch.save(debug_info, f'debug_prefill_layer{layer_id}.pt')
阶段 2: 单一改进测试
每次只应用一个改进,测试效果:
- 仅应用 N-way merge -> 测试准确率
- 仅增加 ring buffer slots -> 测试准确率
- 仅使用 FlashInfer LSE -> 测试准确率
阶段 3: 组合测试
组合效果最好的改进,验证是否有叠加效果:
最终配置 = N-way merge + 8 slots + FlashInfer LSE
目标: 错误率 < 10% (接近 xattn_stride8 的 8% baseline)
附录: 相关文件索引
| 文件 | 关键函数/类 | 问题编号 |
|---|---|---|
nanovllm/ops/chunked_attention.py |
flash_attn_with_lse, merge_attention_outputs |
1, 2, 6 |
nanovllm/kvcache/sparse/full_policy.py |
compute_chunked_prefill, compute_chunked_decode |
2, 5 |
nanovllm/kvcache/offload_engine.py |
Ring buffer 管理 | 3 |
COMPASS/eval/RULER/scripts/pred/call_api.py |
metric_to_policy |
4 |
附录: 错误样本详情
以下是 RULER 32K 测试中出错的样本,用于后续验证修复效果。
错误统计
| Task | Total Samples | Errors | Error Rate |
|---|---|---|---|
| niah_single_1 | 100 | 19 | 19% |
| niah_single_2 | 100 | 23 | 23% |
| niah_single_3 | 100 | 8 | 8% |
| niah_multikey_1 | 100 | 16 | 16% |
| niah_multikey_2 | 100 | 30 | 30% |
| niah_multikey_3 | 100 | 24 | 24% |
| TOTAL | 600 | 120 | 20% |
niah_single_1 (19 errors)
错误 Sample IDs: 28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83
| Index | Expected | Actual Output |
|---|---|---|
| 28 | 9874152 |
:151:52<|eot_id|> |
| 33 | 9196204 |
:<|eot_id|> |
| 39 | 3484601 |
:<|eot_id|> |
| 40 | 6171716 |
: 17: 16<|eot_id|> |
| 41 | 4524499 |
:<|eot_id|> |
| 43 | 3726327 |
: 16: 7<|eot_id|> |
| 44 | 4009172 |
: 2<|eot_id|> |
| 49 | 4240180 |
:354:180<|eot_id|> |
| 51 | 9546409 |
:<|eot_id|> |
| 52 | 2935113 |
: 29351113.<|eot_id|> |
| 53 | 5453786 |
:354:678:90<|eot_id|> |
| 57 | 8315831 |
: 5831<|eot_id|> |
| 61 | 5960271 |
: 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,...<|eot_id|> |
| 63 | 6049101 |
: 5 0 4 9 1 0 1<|eot_id|> |
| 65 | 6406444 |
:361361361361361361361361361361361361361361361361361361361361361361361361361361...<|eot_id|> |
| 67 | 2422633 |
:31<|eot_id|> |
| 72 | 7442089 |
7953166<|eot_id|> |
| 77 | 8795419 |
:<|eot_id|> |
| 83 | 6363836 |
: 2<|eot_id|> |
niah_single_2 (23 errors)
错误 Sample IDs: 16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93
| Index | Expected | Actual Output |
|---|---|---|
| 16 | 2344047 |
: 23440447.<|eot_id|> |
| 24 | 5449324 |
:<|eot_id|> |
| 30 | 5727085 |
:<|eot_id|> |
| 32 | 9196204 |
:<|eot_id|> |
| 40 | 4524499 |
:460<|eot_id|> |
| 41 | 7817881 |
:171.<|eot_id|> |
| 42 | 3726327 |
:<|eot_id|> |
| 50 | 9546409 |
:<|eot_id|> |
| 51 | 2935113 |
: 3: 5113<|eot_id|> |
| 52 | 5453786 |
:354<|eot_id|> |
| 55 | 4188992 |
: 418899189418899, but it is not explicitly stated... |
| 58 | 6266630 |
:5963<|eot_id|> |
| 60 | 5960271 |
0271<|eot_id|> |
| 62 | 6049101 |
:<|eot_id|> |
| 64 | 6406444 |
:<|eot_id|> |
| 66 | 2422633 |
:5313<|eot_id|> |
| 67 | 4940441 |
:5311<|eot_id|> |
| 68 | 3472189 |
:361.<|eot_id|> |
| 69 | 8971465 |
:361.<|eot_id|> |
| 77 | 8963715 |
: 0 8 9 7 1 5<|eot_id|> |
| 85 | 2044645 |
: 20446445.<|eot_id|> |
| 91 | 7783308 |
:<|eot_id|> |
| 93 | 1454696 |
:<|eot_id|> |
niah_single_3 (8 errors)
错误 Sample IDs: 7, 9, 14, 24, 25, 29, 31, 43
| Index | Expected | Actual Output |
|---|---|---|
| 7 | ee87905e-4ca4-45ea-8dfa-6a56d12dbc9a |
: 2010-07-01T00:00:00Z<|eot_id|> |
| 9 | b7b56ea7-35eb-432d-9ad6-20ab48212ddb |
:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0<|eot_id|> |
| 14 | e767dcea-b0e6-4969-a213-42b0f1eedba3 |
:0e6-4969-a213-42b0f1eedba3<|eot_id|> |
| 24 | 59e4b671-4774-4c58-85f8-bc16f7860b50 |
:4774:4c58:85f8:bc16f7860b50<|eot_id|> |
| 25 | 54c63cd8-8945-4f27-97fa-2d8dfb2ca025 |
: 54c63c63cd8-8945-4f27-97fa-2d8dfb2ca025.<|eot_id|> |
| 29 | 006ed6e3-6fa1-4735-b572-f3d00b5cea6a |
:6e3-6fa1-4735-b572-f3d00b5cea6a<|eot_id|> |
| 31 | e6697833-b841-40a0-9fe7-71d6d9178793 |
: e6697837837833-b841-40a0-9fe7-71d6d9178793.<|eot_id|> |
| 43 | d92c9227-eadf-4085-bfcb-75468eb22579 |
: d92c922c9227-eadf-4085-bfcb-75468eb22579.<|eot_id|> |
niah_multikey_1 (16 errors)
错误 Sample IDs: 20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74
| Index | Expected | Actual Output |
|---|---|---|
| 20 | 2171218 |
: 2171212181212181212181218<|eot_id|> |
| 31 | 9333700 |
:<|eot_id|> |
| 32 | 7121355 |
:9651<|eot_id|> |
| 40 | 3112652 |
:285<|eot_id|> |
| 41 | 3427461 |
:<|eot_id|> |
| 45 | 8217547 |
:<|eot_id|> |
| 51 | 1514340 |
: 1514343403361.<|eot_id|> |
| 54 | 8212753 |
:<|eot_id|> |
| 59 | 6587964 |
:<|eot_id|> |
| 63 | 1688246 |
:<|eot_id|> |
| 64 | 8344365 |
: 834436, but it is not explicitly mentioned.<|eot_id|> |
| 65 | 6614484 |
: 4367.<|eot_id|> |
| 67 | 6510922 |
:7780<|eot_id|> |
| 69 | 6649968 |
: 43610.<|eot_id|> |
| 71 | 9437374 |
:<|eot_id|> |
| 74 | 6625238 |
:1472908<|eot_id|> |
niah_multikey_2 (30 errors)
错误 Sample IDs: 2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65
| Index | Expected | Actual Output |
|---|---|---|
| 2 | 1535573 |
: 8651665.<|eot_id|> |
| 13 | 2794159 |
: 5261593<|eot_id|> |
| 21 | 8970232 |
:168<|eot_id|> |
| 22 | 9134051 |
: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 38... |
| 23 | 9696620 |
: 969662620969662, which is: 969662920, 96966220 is not actually me... |
| 24 | 7071187 |
055055055.<|eot_id|> |
| 25 | 5572782 |
: 5342494<|eot_id|> |
| 28 | 4953027 |
:1687719<|eot_id|> |
| 32 | 4259234 |
: 425923521250, but not found is: 425923751572250, however is: 4259... |
| 34 | 3643022 |
: 3957500<|eot_id|> |
| 38 | 2031469 |
: the text.<|eot_id|> |
| 39 | 8740362 |
: 8740364 8740364 8740364 8740364 is: is: is: is: 874036... |
| 40 | 7041770 |
:1682<|eot_id|> |
| 41 | 1986258 |
:086.<|eot_id|> |
| 42 | 5668574 |
:055.<|eot_id|> |
| 43 | 8560471 |
:067<|eot_id|> |
| 45 | 9973767 |
: 8420273<|eot_id|> |
| 46 | 3960211 |
:0<|eot_id|> |
| 47 | 8003271 |
: 60870870870870870870870870870870870870870870870870870870870870870... |
| 49 | 8632309 |
303640 is640 640 640 640 640 640 640 640 640 640 640 640 640 640 640... |
| 50 | 2318630 |
: 7780552.<|eot_id|> |
| 53 | 3405052 |
:<|eot_id|> |
| 54 | 5364945 |
: 536494, which is: 536494, which is: 536494494494494494494494494... |
| 56 | 7319214 |
:7607607607607607607607607607607607607607607607607607607607607607607... |
| 57 | 9206104 |
:7607607607607607607607607607607607607607607607607607607607607607607... |
| 59 | 9555385 |
:7095<|eot_id|> |
| 60 | 5727554 |
: 572755755755755755755755755755755755755755755755755755755755 is: 572... |
| 63 | 1090767 |
:7607607607607607607607607607607607607607607607607607607607607607607... |
| 64 | 6791240 |
:<|eot_id|> |
| 65 | 7275999 |
:7607607607607607607607607607607607607607607607607607607607607607607... |
niah_multikey_3 (24 errors)
错误 Sample IDs: 11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52
| Index | Expected | Actual Output |
|---|---|---|
| 11 | c73ed342-6523-4d4b-aa33-beb1c9007315 |
: 1d28b88b-b6a8-46ba-8e8f-56cbafbfd897.<|eot_id|> |
| 18 | 87b8a762-1d1f-4e85-a5d1-caf284c95aa6 |
: 429a6676-5295-4ea2-a694-6aa949f48e31.<|eot_id|> |
| 20 | cce29702-134a-460c-979b-6f7ee7895280 |
:<|eot_id|> |
| 23 | ed344bfe-983f-4a21-af44-722e2517244c |
: aec431e7d880a8dce2c023de24 is: aec43163-061a-4afe-b80a-f5bfb5e3c9... |
| 24 | 4712ef99-a8d1-4388-8ca7-b08dd3505d77 |
:<|eot_id|> |
| 25 | 46969ce7-0da0-49f8-87b2-845e7b8ef100 |
:<|eot_id|> |
| 26 | 7cff3c66-6860-49e6-8ba5-002162c250c0 |
:4c7e-946b-30812edf965e<|eot_id|> |
| 27 | b63b4988-40bc-44b2-bf1c-ca95adbca4e9 |
:<|eot_id|> |
| 29 | 6d94011c-f28a-4b0b-a2e2-fe34bb8b19a1 |
: 6d6d6d6d4b0e-52ce-44d9-a0f6-1ae405825615<|eot_id|> |
| 30 | 7c33bb00-4ab4-4e4f-a78e-39f8f06d63eb |
d7a2-4b23-a2c0-8c859cb1fa96<|eot_id|> |
| 33 | b7c6b586-713a-4907-ad24-5c4f25aeb769 |
:1-4d2c-b42b-933ded2633d6<|eot_id|> |
| 35 | ac8a317b-a6bb-4327-90db-2a01622cb723 |
: d2f2f2f2f2f2f2f2d2d2f2d2d2d3d2f6b3d2f- is: d2dab is: is: is: i... |
| 37 | b187b337-3132-4376-a500-9340102092ae |
:<|eot_id|> |
| 40 | 2559fa56-dd0a-48d4-ba82-3ae2bf0a4b33 |
:358fe0e3-724e-4cfc-9ae0-d0873162626b.<|eot_id|> |
| 41 | 7842feb5-e758-44cd-b73b-8ae08aa33142 |
: 6c6adf83-36a9-4e41-9cbe-60a8c9ffba92.<|eot_id|> |
| 42 | a1196139-f6fa-4c18-b3da-b7bd50362ac7 |
: a1196131396131196131399a1196139a1196139a1196139a1196139f6a1196139... |
| 44 | 7d3d40b2-4594-4573-b267-4c6270dd4425 |
: 613a9e-4e7d-8c9f-740a630e3c53<|eot_id|> |
| 45 | 500b8a75-8f05-43f5-b9ad-46d47d4e33fc |
: 500b8a5e0e0e0a500b is: 500b is: 500b-4 is: is: is: is: is: i... |
| 46 | 86a867a7-6a98-4a02-b065-70a33bafafde |
:6139a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a... |
| 47 | 7c0f7fd2-237e-4c0f-b3f5-f43623551169 |
5fb71d2f0f0b4f0 is: 5fb71 is: 5fb71f-4f-4f-4f-4f-4f-4d7 is: is: ... |
| 48 | b0e1f3f5-6570-437e-b8a1-f1b3f654e257 |
: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ... |
| 49 | 0153722a-70a8-4ec0-9f03-2b0930937e60 |
: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ... |
| 50 | 0a1ead51-0c39-4eeb-ac87-d146acdb1d4a |
: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ... |
| 52 | ff686e85-3a9f-4635-95dd-f19e8ca68eb1 |
ff686e686e686e686e686e686f686e6f686e6fb686f686f686f686f686f- is: f... |
错误模式分析
| 错误类型 | 示例 | 可能原因 |
|---|---|---|
| 空输出 | :<|eot_id|> |
KV cache 完全损坏,模型无法生成有效内容 |
| 数字重复 | :361361361... |
Chunk 边界 merge 错误导致 decode 进入循环 |
| 部分正确 | 9874152 → :151:52 |
前几个 chunk 正确,后面 chunk 损坏 |
| 数字插入 | 2344047 → 23440447 |
Merge 过程中数值精度损失 |
| 完全错误 | 输出完全不相关的 UUID | 严重的 KV cache 或 position encoding 错误 |
验证脚本
用于验证修复效果的脚本:
# verify_samples.py
import json
ERROR_SAMPLES = {
'niah_single_1': [28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83],
'niah_single_2': [16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93],
'niah_single_3': [7, 9, 14, 24, 25, 29, 31, 43],
'niah_multikey_1': [20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74],
'niah_multikey_2': [2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65],
'niah_multikey_3': [11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52],
}
def verify_fix(result_dir: str) -> dict:
"""验证修复效果"""
fixed = {}
for task, error_ids in ERROR_SAMPLES.items():
pred_file = f"{result_dir}/{task}.jsonl"
with open(pred_file) as f:
predictions = [json.loads(line) for line in f]
fixed_count = 0
still_error = []
for idx in error_ids:
pred = predictions[idx]
expected = pred['answer']
actual = pred['model_output']
if expected in actual:
fixed_count += 1
else:
still_error.append(idx)
fixed[task] = {
'total_errors': len(error_ids),
'fixed': fixed_count,
'still_error': still_error,
'fix_rate': f"{fixed_count/len(error_ids)*100:.1f}%"
}
return fixed
Author: Zijie Tian Created: 2026-01-20