# Chunked Attention 准确性问题解决方案 **Status**: 🟡 IN PROGRESS **Related Issue**: [`ruler_32k_chunked_offload_issue.md`](ruler_32k_chunked_offload_issue.md) **Created**: 2026-01-20 **Author**: Zijie Tian --- ## 概述 本文档基于对 chunked attention 代码的深入分析,详细描述了导致 RULER 32K 测试 20% 错误率的潜在代码问题,以及对应的解决方案。 **核心问题**: Chunked prefill 机制在处理 32K 长序列时,由于多个因素的累积效应,导致输出准确性下降。 --- ## 问题 1: `flash_attn_with_lse` 的 LSE 获取方式 ### 问题描述 **文件**: `nanovllm/ops/chunked_attention.py` **位置**: 第 264-269 行 ```python def flash_attn_with_lse( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: from flash_attn.flash_attn_interface import flash_attn_func # ... # 使用 return_attn_probs=True 来获取 LSE out, lse, _ = flash_attn_func( q, k, v, softmax_scale=softmax_scale, causal=causal, return_attn_probs=True, # ⚠️ 这个 API 不是专门用于 LSE 输出的 ) # lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded] lse = lse[:, :, :seqlen_q] return out, lse ``` ### 问题分析 1. **API 设计目的不匹配**: `return_attn_probs=True` 的主要目的是返回 attention probabilities 用于调试,不是为了高精度 LSE 输出 2. **潜在精度损失**: Flash Attention 内部可能对 LSE 进行了某些优化,这些优化可能牺牲精度 3. **S_dmask 返回值**: 第三个返回值 `S_dmask` 在某些版本中可能包含随机 dropout mask,影响结果 ### 验证方法 ```python def verify_lse_accuracy(): """验证 flash_attn 返回的 LSE 是否准确""" import torch from flash_attn.flash_attn_interface import flash_attn_func # 小规模数据用于精确计算 batch, seqlen, nheads, headdim = 1, 128, 8, 64 q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16) k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16) v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16) # Flash Attention LSE _, lse_flash, _ = flash_attn_func(q, k, v, return_attn_probs=True) # 手动计算 LSE (使用 PyTorch) scale = 1.0 / (headdim ** 0.5) qk = torch.einsum('bqhd,bkhd->bhqk', q, k) * scale lse_manual = torch.logsumexp(qk, dim=-1) # [batch, heads, seqlen_q] # 比较 diff = (lse_flash[:, :, :seqlen] - lse_manual).abs() print(f"LSE max diff: {diff.max().item():.6f}") print(f"LSE mean diff: {diff.mean().item():.6f}") return diff.max().item() < 1e-3 ``` ### 解决方案 #### 方案 A: 使用 FlashInfer 的精确 LSE 接口 (推荐) FlashInfer 提供了专门用于 chunked attention 的接口: ```python def flash_attn_with_lse_flashinfer( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """使用 FlashInfer 的接口获取精确 LSE""" import flashinfer batch, seqlen_q, nheads_q, headdim = q.shape _, seqlen_k, nheads_kv, _ = k.shape if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(headdim) # FlashInfer 的 single_prefill_with_kv_cache_return_lse # 专门设计用于返回精确的 LSE out, lse = flashinfer.single_prefill_with_kv_cache_return_lse( q.view(batch * seqlen_q, nheads_q, headdim), k.view(batch * seqlen_k, nheads_kv, headdim), v.view(batch * seqlen_k, nheads_kv, headdim), causal=causal, sm_scale=softmax_scale, ) # Reshape outputs out = out.view(batch, seqlen_q, nheads_q, headdim) lse = lse.view(batch, seqlen_q, nheads_q).transpose(1, 2) # [batch, nheads, seqlen_q] return out, lse ``` #### 方案 B: 自定义 Triton Kernel 如果 FlashInfer 不可用,使用自定义的 Triton kernel: ```python # 已经在 chunked_attention.py 中实现: _fwd_kernel_with_lse # 这个 kernel 直接在 Triton 中计算并输出 LSE # 优点: 完全控制精度 # 缺点: 可能比 Flash Attention 慢 ``` #### 方案 C: 混合方法 ```python def flash_attn_with_lse_hybrid( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, use_triton_kernel: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ 根据情况选择最佳实现: - 短序列 (< 4K): 使用 Flash Attention (快) - 长序列 (>= 4K): 使用 Triton kernel (精确) """ seqlen = q.shape[1] if seqlen < 4096 and not use_triton_kernel: return flash_attn_with_lse(q, k, v, softmax_scale, causal) else: return triton_flash_attn_with_lse(q, k, v, softmax_scale, causal) ``` ### 预期改进 - 修复 LSE 精度问题后,预计错误率降低 5-10% --- ## 问题 2: 累积 Merge 次数过多导致误差累积 ### 问题描述 **文件**: `nanovllm/kvcache/sparse/full_policy.py` **位置**: 第 136-137 行, 163-164 行 ```python # compute_chunked_prefill 中的 merge 循环 for block_idx in range(num_blocks): # ... 加载 block ... prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=softmax_scale, causal=False, ) if o_acc is None: o_acc, lse_acc = prev_o, prev_lse else: # ⚠️ 每个 block 都执行一次 merge o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) ``` ### 问题分析 对于 32K context (1024 block size),需要执行 ~32 次 merge。每次 merge 的数学公式: ``` max_lse = max(lse1, lse2) exp1 = exp(lse1 - max_lse) exp2 = exp(lse2 - max_lse) o_merged = (o1 * exp1 + o2 * exp2) / (exp1 + exp2) lse_merged = max_lse + log(exp1 + exp2) ``` **误差来源**: 1. `exp()` 和 `log()` 操作在边界情况下精度损失 2. 除法操作 `/ (exp1 + exp2)` 可能有精度问题 3. 32 次累积后,误差可达到显著水平 ### 数学分析 假设每次 merge 引入相对误差 ε ≈ 10^-6 (fp32): - 32 次累积后: 总误差 ≈ 32 × ε = 3.2 × 10^-5 - 对于 BFloat16 (7-bit mantissa): ε ≈ 10^-2, 总误差 ≈ 0.32 (32%) ### 验证方法 ```python def measure_merge_error_accumulation(): """测量 merge 误差随 chunk 数量的变化""" import torch from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs from flash_attn.flash_attn_interface import flash_attn_func torch.manual_seed(42) batch, seqlen, nheads, headdim = 1, 32768, 32, 128 q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) # Reference: full attention out_ref = flash_attn_func(q, k, v, causal=False) results = [] for num_chunks in [4, 8, 16, 32, 64]: chunk_size = seqlen // num_chunks o_acc, lse_acc = None, None for i in range(num_chunks): start = i * chunk_size end = (i + 1) * chunk_size k_chunk = k[:, start:end, :, :] v_chunk = v[:, start:end, :, :] chunk_o, chunk_lse = flash_attn_with_lse(q, k_chunk, v_chunk, causal=False) if o_acc is None: o_acc, lse_acc = chunk_o, chunk_lse else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse) diff = (out_ref - o_acc).abs() results.append({ 'num_chunks': num_chunks, 'max_diff': diff.max().item(), 'mean_diff': diff.mean().item(), }) print(f"chunks={num_chunks:3d}, max_diff={diff.max().item():.6f}, mean_diff={diff.mean().item():.8f}") return results ``` ### 解决方案 #### 方案 A: N-way Merge (推荐) 代替二叉树式的 pairwise merge,实现一次性 merge 多个 chunk: ```python def merge_attention_outputs_nway( outputs: List[torch.Tensor], # List of [batch, seqlen_q, nheads, headdim] lses: List[torch.Tensor], # List of [batch, nheads, seqlen_q] ) -> Tuple[torch.Tensor, torch.Tensor]: """ N-way merge: 一次性 merge 所有 chunk,减少累积误差 数学原理: - max_lse = max(lse_1, lse_2, ..., lse_n) - sum_exp = sum(exp(lse_i - max_lse) for i in 1..n) - o_merged = sum(o_i * exp(lse_i - max_lse) for i in 1..n) / sum_exp - lse_merged = max_lse + log(sum_exp) """ n = len(outputs) if n == 0: raise ValueError("Need at least one output to merge") if n == 1: return outputs[0], lses[0] # Stack for batch processing o_stack = torch.stack(outputs, dim=0) # [n, batch, seqlen, heads, dim] lse_stack = torch.stack(lses, dim=0) # [n, batch, heads, seqlen] # Compute max_lse across all chunks (in fp32 for precision) lse_fp32 = lse_stack.float() max_lse = lse_fp32.max(dim=0).values # [batch, heads, seqlen] # Compute exp(lse_i - max_lse) for each chunk exp_weights = torch.exp(lse_fp32 - max_lse.unsqueeze(0)) # [n, batch, heads, seqlen] # Normalize weights sum_exp = exp_weights.sum(dim=0) # [batch, heads, seqlen] normalized_weights = exp_weights / sum_exp.unsqueeze(0) # [n, batch, heads, seqlen] # Weighted sum of outputs # outputs: [n, batch, seqlen, heads, dim] # weights: [n, batch, heads, seqlen] -> [n, batch, seqlen, heads, 1] weights_expanded = normalized_weights.permute(0, 1, 3, 2).unsqueeze(-1) o_merged = (o_stack.float() * weights_expanded).sum(dim=0).to(outputs[0].dtype) # Compute merged LSE lse_merged = (max_lse + torch.log(sum_exp)).to(lses[0].dtype) return o_merged, lse_merged ``` **使用方式**: ```python # 原来的逐个 merge for block_idx in range(num_blocks): chunk_o, chunk_lse = flash_attn_with_lse(...) if o_acc is None: o_acc, lse_acc = chunk_o, chunk_lse else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse) # 改为: 收集所有 chunk,然后一次性 merge chunk_outputs = [] chunk_lses = [] for block_idx in range(num_blocks): chunk_o, chunk_lse = flash_attn_with_lse(...) chunk_outputs.append(chunk_o) chunk_lses.append(chunk_lse) o_acc, lse_acc = merge_attention_outputs_nway(chunk_outputs, chunk_lses) ``` #### 方案 B: Kahan Summation 使用 Kahan 求和算法减少浮点累积误差: ```python def merge_attention_outputs_kahan( o1: torch.Tensor, lse1: torch.Tensor, o2: torch.Tensor, lse2: torch.Tensor, compensation: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ 使用 Kahan 补偿算法的 merge,保持累积误差补偿项 Returns: o_merged, lse_merged, (output_comp, lse_comp) """ # ... 标准 merge 计算 ... max_lse = torch.maximum(lse1.float(), lse2.float()) exp1 = torch.exp(lse1.float() - max_lse) exp2 = torch.exp(lse2.float() - max_lse) sum_exp = exp1 + exp2 o_merged_raw = (o1.float() * exp1.unsqueeze(-1) + o2.float() * exp2.unsqueeze(-1)) / sum_exp.unsqueeze(-1) lse_merged_raw = max_lse + torch.log(sum_exp) # Kahan compensation if compensation is not None: output_comp, lse_comp = compensation # Apply compensation y = o_merged_raw - output_comp t = o_merged_raw # We need to track the actual result output_comp_new = (t - o_merged_raw) + y # ... similar for lse ... else: output_comp_new = torch.zeros_like(o_merged_raw) lse_comp_new = torch.zeros_like(lse_merged_raw) return o_merged_raw.to(o1.dtype), lse_merged_raw.to(lse1.dtype), (output_comp_new, lse_comp_new) ``` #### 方案 C: 分层 Merge (折中方案) 将 32 个 chunk 分组,先组内 merge,再组间 merge: ``` 原来: c1 -> c2 -> c3 -> ... -> c32 (31 次 merge) 改为: Group1: c1-c4 -> g1 (3 次) Group2: c5-c8 -> g2 (3 次) ... Group8: c29-c32 -> g8 (3 次) Final: g1-g8 -> result (7 次) 总计: 3*8 + 7 = 31 次 merge,但最大累积深度从 31 降到 7 ``` ```python def merge_attention_outputs_hierarchical( outputs: List[torch.Tensor], lses: List[torch.Tensor], group_size: int = 4, ) -> Tuple[torch.Tensor, torch.Tensor]: """ 分层 merge: 先分组 merge,再合并组 减少累积深度从 O(n) 到 O(log(n)) """ from nanovllm.ops.chunked_attention import merge_attention_outputs n = len(outputs) if n <= 1: return outputs[0] if n == 1 else (None, None) # Level 1: Merge within groups group_outputs = [] group_lses = [] for i in range(0, n, group_size): group = outputs[i:i+group_size] group_l = lses[i:i+group_size] o_g, lse_g = group[0], group_l[0] for j in range(1, len(group)): o_g, lse_g = merge_attention_outputs(o_g, lse_g, group[j], group_l[j]) group_outputs.append(o_g) group_lses.append(lse_g) # Recursively merge groups if len(group_outputs) > 1: return merge_attention_outputs_hierarchical(group_outputs, group_lses, group_size) else: return group_outputs[0], group_lses[0] ``` ### 预期改进 - N-way merge: 预计错误率降低 10-15% - 分层 merge: 预计错误率降低 5-8% --- ## 问题 3: Ring Buffer Slot 数量不足 ### 问题描述 **文件**: `nanovllm/kvcache/offload_engine.py` **位置**: Ring buffer 初始化部分 ```python # 2-slot 配置 (原始) num_gpu_blocks = 2 # 仅 2 个 GPU block 用于 ring buffer # 问题: # - Slot 0: 用于 decode # - Slot 1: 用于 prefill 加载 # - 32 个 chunk 共享 1 个 prefill slot # - 高频率的 CPU<->GPU 传输可能导致竞态条件 ``` ### 问题分析 1. **竞态条件**: 当 chunk N 的 compute 还在进行时,chunk N+1 的 transfer 可能覆盖同一个 slot 2. **同步开销**: 频繁的 `wait_slot_layer()` 调用增加延迟 3. **已验证**: 4-slot 配置显著改善但未完全解决问题 ### 验证结果 (已完成) | 配置 | niah_single_1 准确率 | niah_multikey_3 准确率 | |------|---------------------|----------------------| | 2-slot | 94% | 48% | | 4-slot | 98% | 56% | | 改进 | +4% | +8% | **关键发现**: Sample 40 在 4-slot 配置下仍然产生相同错误 (`6171717161711716`),说明 slot 数量不是唯一原因。 ### 解决方案 #### 方案 A: 增加 Slot 数量到 8 ```python # offload_engine.py 配置 class OffloadEngineConfig: # 原来 # num_ring_slots = 2 # 建议 num_ring_slots = 8 # 增加到 8 个 slots # 或者动态计算 @property def num_ring_slots(self): # 根据可用 GPU 内存动态调整 available_memory_gb = get_available_gpu_memory() slot_memory_gb = self.block_size * self.kv_cache_dtype_size * 2 / 1e9 max_slots = int(available_memory_gb * 0.3 / slot_memory_gb) # 使用 30% 内存 return min(max(max_slots, 4), 16) # 限制在 4-16 之间 ``` #### 方案 B: 改进同步机制 ```python def load_to_slot_layer_with_barrier(self, slot: int, layer_id: int, cpu_block_id: int): """ 带有显式 barrier 的加载,确保前一个操作完成 """ # 等待该 slot 上所有挂起的操作完成 if self.slot_events[slot] is not None: self.slot_events[slot].synchronize() # 执行传输 self._do_transfer(slot, layer_id, cpu_block_id) # 记录新的事件 self.slot_events[slot] = torch.cuda.Event() self.slot_events[slot].record(self.transfer_stream) ``` #### 方案 C: Double Buffering ```python class DoubleBufferedRingBuffer: """ 双缓冲设计: 每个逻辑 slot 有两个物理 buffer - 一个用于当前 compute - 一个用于下一个 transfer """ def __init__(self, num_logical_slots: int, ...): self.num_logical_slots = num_logical_slots self.num_physical_slots = num_logical_slots * 2 self.current_buffer = [0] * num_logical_slots # 每个 slot 当前使用的 buffer (0 or 1) def get_compute_slot(self, logical_slot: int) -> int: """获取用于 compute 的物理 slot""" return logical_slot * 2 + self.current_buffer[logical_slot] def get_transfer_slot(self, logical_slot: int) -> int: """获取用于 transfer 的物理 slot (另一个 buffer)""" return logical_slot * 2 + (1 - self.current_buffer[logical_slot]) def swap_buffer(self, logical_slot: int): """交换 buffer""" self.current_buffer[logical_slot] = 1 - self.current_buffer[logical_slot] ``` ### 预期改进 - 8-slot 配置: 预计错误率再降低 3-5% - Double buffering: 预计错误率再降低 2-3% --- ## 问题 4: Sparse Policy Mapping Bug ### 问题描述 **文件**: `COMPASS/eval/RULER/scripts/pred/call_api.py` (COMPASS 项目) **位置**: metric 到 policy 的映射 ```python metric_to_policy = { 'full': 'FULL', 'xattn': 'XATTN', # ❌ XATTN 不存在于 SparsePolicyType 'compass': 'XATTN', 'minfer': 'MINFERENCE', # ❌ MINFERENCE 不存在于 SparsePolicyType 'avgpool': 'FULL', 'flex': 'FLEXPREFILL', } sparse_policy_name = metric_to_policy.get(self.metric, 'FULL') sparse_policy = getattr(SparsePolicyType, sparse_policy_name, SparsePolicyType.FULL) # ⚠️ 静默回退到 FULL,没有警告! ``` ### 问题分析 1. **静默回退**: `getattr(..., SparsePolicyType.FULL)` 使得无效名称静默回退 2. **测试配置错误**: 用户以为在测试 `xattn`,实际上在测试 `FULL` 3. **结果不可靠**: 所有 sparse 方法的测试结果可能都是 FULL 的结果 ### 验证方法 ```python # 检查当前 SparsePolicyType 的有效值 from nanovllm.kvcache.sparse.policy import SparsePolicyType print([e.name for e in SparsePolicyType]) # 应该输出类似: ['FULL', 'QUEST', 'XATTN_BSA', ...] # 验证映射是否正确 for metric, policy_name in metric_to_policy.items(): has_policy = hasattr(SparsePolicyType, policy_name) print(f"{metric} -> {policy_name}: {'✓' if has_policy else '✗'}") ``` ### 解决方案 #### 方案: 修复映射并添加验证 ```python # 正确的映射 (需要根据 SparsePolicyType 定义更新) metric_to_policy = { 'full': 'FULL', 'xattn': 'XATTN_BSA', # ✓ 使用正确的枚举名 'compass': 'XATTN_BSA', 'minfer': 'MINFERENCE', # 需要确认是否存在 'avgpool': 'AVGPOOL', # 需要确认是否存在 'flex': 'FLEXPREFILL', # 需要确认是否存在 'quest': 'QUEST', # 添加 quest } def get_sparse_policy(metric: str) -> SparsePolicyType: """获取 sparse policy,带有验证""" policy_name = metric_to_policy.get(metric) if policy_name is None: raise ValueError(f"Unknown metric: {metric}. Valid options: {list(metric_to_policy.keys())}") if not hasattr(SparsePolicyType, policy_name): raise ValueError(f"SparsePolicyType.{policy_name} does not exist. " f"Available: {[e.name for e in SparsePolicyType]}") return getattr(SparsePolicyType, policy_name) ``` ### 预期改进 - 这个 bug 不影响 chunked offload 准确性 - 但修复后可以正确测试各种 sparse 方法的实际效果 --- ## 问题 5: Decode Buffer 读取边界条件 ### 问题描述 **文件**: `nanovllm/kvcache/sparse/full_policy.py` **位置**: 第 269-272 行 ```python def compute_chunked_decode(self, ...): # ... seq_len = len(seq) decode_pos_in_block = (seq_len - 1) % block_size decode_start_pos = kvcache_manager.get_decode_start_pos(seq) decode_start_pos_in_block = decode_start_pos % block_size num_accumulated = decode_pos_in_block - decode_start_pos_in_block + 1 # ⚠️ 潜在问题: 当 decode_pos_in_block < decode_start_pos_in_block 时 # num_accumulated 会是负数 ``` ### 问题分析 场景:当 decode 跨越 block 边界时: - 假设 block_size = 1024 - decode_start_pos = 1020 (在 block 边界附近) - 当前 seq_len = 1026 (跨越到下一个 block) - decode_pos_in_block = (1026 - 1) % 1024 = 1 - decode_start_pos_in_block = 1020 % 1024 = 1020 - num_accumulated = 1 - 1020 + 1 = -1018 ❌ ### 解决方案 ```python def compute_chunked_decode(self, ...): # ... seq_len = len(seq) decode_start_pos = kvcache_manager.get_decode_start_pos(seq) # 计算 decode buffer 中的实际 token 数量 # 使用绝对位置而非块内相对位置 num_accumulated = seq_len - decode_start_pos # 验证 assert num_accumulated > 0, f"Invalid decode range: seq_len={seq_len}, decode_start={decode_start_pos}" # 计算块内偏移用于读取 decode_start_pos_in_block = decode_start_pos % block_size decode_end_pos_in_block = (seq_len - 1) % block_size # 处理跨块情况 if decode_end_pos_in_block < decode_start_pos_in_block: # 跨块: 需要从两个位置读取 # 这种情况应该在 kvcache_manager 层面处理 raise NotImplementedError("Cross-block decode not yet supported") # 同块: 正常读取 decode_k = offload_engine.decode_k_buffer[ layer_id, decode_start_pos_in_block:decode_end_pos_in_block+1 ] decode_v = offload_engine.decode_v_buffer[ layer_id, decode_start_pos_in_block:decode_end_pos_in_block+1 ] ``` ### 预期改进 - 修复边界条件后,预计消除特定情况下的崩溃或错误输出 --- ## 问题 6: Merge Kernel Block Size 优化 ### 问题描述 **文件**: `nanovllm/ops/chunked_attention.py` **位置**: 第 406 行 ```python def merge_attention_outputs(...): # ... # Launch output merge kernel BLOCK_SIZE = 128 # 固定值 grid_output = (batch, seqlen_q, nheads) _merge_output_kernel[grid_output]( o1, o2, lse1, lse2, o_merged, batch, seqlen_q, nheads, headdim, BLOCK_SIZE=BLOCK_SIZE, ) ``` ### 问题分析 1. **固定 BLOCK_SIZE**: 128 可能不是所有 headdim 的最优值 2. **性能影响**: 不当的 BLOCK_SIZE 可能导致 GPU 占用率低 3. **精度影响**: 某些 BLOCK_SIZE 可能导致边界处理问题 ### 解决方案 ```python def merge_attention_outputs(...): batch, seqlen_q, nheads, headdim = o1.shape # 动态选择 BLOCK_SIZE # 原则: BLOCK_SIZE 应该是 headdim 的因子,且适合 GPU warp (32) if headdim <= 64: BLOCK_SIZE = 64 elif headdim <= 128: BLOCK_SIZE = 128 else: BLOCK_SIZE = 256 # 确保 BLOCK_SIZE 是 headdim 的因子 while headdim % BLOCK_SIZE != 0 and BLOCK_SIZE > 32: BLOCK_SIZE //= 2 # ... ``` ### 预期改进 - 这是性能优化,对准确性影响较小 --- ## 综合解决方案优先级 | 优先级 | 问题 | 解决方案 | 预期改进 | 实现难度 | |--------|------|----------|----------|----------| | **P0** | N-way Merge | 实现 `merge_attention_outputs_nway` | 10-15% | 中 | | **P0** | LSE 精度 | 使用 FlashInfer 或验证 flash_attn LSE | 5-10% | 低 | | **P1** | Ring Buffer Slots | 增加到 8 slots + double buffering | 3-5% | 中 | | **P1** | Policy Mapping | 修复 COMPASS 中的映射 bug | N/A | 低 | | **P2** | Decode 边界 | 添加边界检查和处理 | 1-2% | 低 | | **P2** | Merge Block Size | 动态选择 BLOCK_SIZE | <1% | 低 | --- ## 验证计划 ### 阶段 1: 误差追踪 (先执行) 在实施任何修改之前,添加详细的误差追踪: ```python # 在 full_policy.py 中添加 def compute_chunked_prefill_with_debug(self, ...): debug_info = { 'chunk_max_diffs': [], 'lse_values': [], 'merge_count': 0, } # ... 原有逻辑 ... for block_idx in range(num_blocks): # ... 计算 ... if o_acc is not None: # 记录 merge 前后的差异 pre_merge_o = o_acc.clone() o_acc, lse_acc = merge_attention_outputs(...) debug_info['chunk_max_diffs'].append((o_acc - pre_merge_o).abs().max().item()) debug_info['lse_values'].append(lse_acc.mean().item()) debug_info['merge_count'] += 1 # 保存 debug 信息 torch.save(debug_info, f'debug_prefill_layer{layer_id}.pt') ``` ### 阶段 2: 单一改进测试 每次只应用一个改进,测试效果: 1. 仅应用 N-way merge -> 测试准确率 2. 仅增加 ring buffer slots -> 测试准确率 3. 仅使用 FlashInfer LSE -> 测试准确率 ### 阶段 3: 组合测试 组合效果最好的改进,验证是否有叠加效果: ``` 最终配置 = N-way merge + 8 slots + FlashInfer LSE 目标: 错误率 < 10% (接近 xattn_stride8 的 8% baseline) ``` --- ## 附录: 相关文件索引 | 文件 | 关键函数/类 | 问题编号 | |------|-------------|----------| | `nanovllm/ops/chunked_attention.py` | `flash_attn_with_lse`, `merge_attention_outputs` | 1, 2, 6 | | `nanovllm/kvcache/sparse/full_policy.py` | `compute_chunked_prefill`, `compute_chunked_decode` | 2, 5 | | `nanovllm/kvcache/offload_engine.py` | Ring buffer 管理 | 3 | | `COMPASS/eval/RULER/scripts/pred/call_api.py` | `metric_to_policy` | 4 | --- ## 附录: 错误样本详情 以下是 RULER 32K 测试中出错的样本,用于后续验证修复效果。 ### 错误统计 | Task | Total Samples | Errors | Error Rate | |------|--------------|--------|------------| | niah_single_1 | 100 | 19 | 19% | | niah_single_2 | 100 | 23 | 23% | | niah_single_3 | 100 | 8 | 8% | | niah_multikey_1 | 100 | 16 | 16% | | niah_multikey_2 | 100 | 30 | 30% | | niah_multikey_3 | 100 | 24 | 24% | | **TOTAL** | **600** | **120** | **20%** | --- ### niah_single_1 (19 errors) **错误 Sample IDs**: `28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83` | Index | Expected | Actual Output | |-------|----------|---------------| | 28 | `9874152` | `:151:52<\|eot_id\|>` | | 33 | `9196204` | `:<\|eot_id\|>` | | 39 | `3484601` | `:<\|eot_id\|>` | | 40 | `6171716` | `: 17: 16<\|eot_id\|>` | | 41 | `4524499` | `:<\|eot_id\|>` | | 43 | `3726327` | `: 16: 7<\|eot_id\|>` | | 44 | `4009172` | `: 2<\|eot_id\|>` | | 49 | `4240180` | `:354:180<\|eot_id\|>` | | 51 | `9546409` | `:<\|eot_id\|>` | | 52 | `2935113` | `: 29351113.<\|eot_id\|>` | | 53 | `5453786` | `:354:678:90<\|eot_id\|>` | | 57 | `8315831` | `: 5831<\|eot_id\|>` | | 61 | `5960271` | `: 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,...<\|eot_id\|>` | | 63 | `6049101` | `: 5 0 4 9 1 0 1<\|eot_id\|>` | | 65 | `6406444` | `:361361361361361361361361361361361361361361361361361361361361361361361361361361...<\|eot_id\|>` | | 67 | `2422633` | `:31<\|eot_id\|>` | | 72 | `7442089` | ` 7953166<\|eot_id\|>` | | 77 | `8795419` | `:<\|eot_id\|>` | | 83 | `6363836` | `: 2<\|eot_id\|>` | --- ### niah_single_2 (23 errors) **错误 Sample IDs**: `16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93` | Index | Expected | Actual Output | |-------|----------|---------------| | 16 | `2344047` | `: 23440447.<\|eot_id\|>` | | 24 | `5449324` | `:<\|eot_id\|>` | | 30 | `5727085` | `:<\|eot_id\|>` | | 32 | `9196204` | `:<\|eot_id\|>` | | 40 | `4524499` | `:460<\|eot_id\|>` | | 41 | `7817881` | `:171.<\|eot_id\|>` | | 42 | `3726327` | `:<\|eot_id\|>` | | 50 | `9546409` | `:<\|eot_id\|>` | | 51 | `2935113` | `: 3: 5113<\|eot_id\|>` | | 52 | `5453786` | `:354<\|eot_id\|>` | | 55 | `4188992` | `: 418899189418899, but it is not explicitly stated...` | | 58 | `6266630` | `:5963<\|eot_id\|>` | | 60 | `5960271` | ` 0271<\|eot_id\|>` | | 62 | `6049101` | `:<\|eot_id\|>` | | 64 | `6406444` | `:<\|eot_id\|>` | | 66 | `2422633` | `:5313<\|eot_id\|>` | | 67 | `4940441` | `:5311<\|eot_id\|>` | | 68 | `3472189` | `:361.<\|eot_id\|>` | | 69 | `8971465` | `:361.<\|eot_id\|>` | | 77 | `8963715` | `: 0 8 9 7 1 5<\|eot_id\|>` | | 85 | `2044645` | `: 20446445.<\|eot_id\|>` | | 91 | `7783308` | `:<\|eot_id\|>` | | 93 | `1454696` | `:<\|eot_id\|>` | --- ### niah_single_3 (8 errors) **错误 Sample IDs**: `7, 9, 14, 24, 25, 29, 31, 43` | Index | Expected | Actual Output | |-------|----------|---------------| | 7 | `ee87905e-4ca4-45ea-8dfa-6a56d12dbc9a` | `: 2010-07-01T00:00:00Z<\|eot_id\|>` | | 9 | `b7b56ea7-35eb-432d-9ad6-20ab48212ddb` | `:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0<\|eot_id\|>` | | 14 | `e767dcea-b0e6-4969-a213-42b0f1eedba3` | `:0e6-4969-a213-42b0f1eedba3<\|eot_id\|>` | | 24 | `59e4b671-4774-4c58-85f8-bc16f7860b50` | `:4774:4c58:85f8:bc16f7860b50<\|eot_id\|>` | | 25 | `54c63cd8-8945-4f27-97fa-2d8dfb2ca025` | `: 54c63c63cd8-8945-4f27-97fa-2d8dfb2ca025.<\|eot_id\|>` | | 29 | `006ed6e3-6fa1-4735-b572-f3d00b5cea6a` | `:6e3-6fa1-4735-b572-f3d00b5cea6a<\|eot_id\|>` | | 31 | `e6697833-b841-40a0-9fe7-71d6d9178793` | `: e6697837837833-b841-40a0-9fe7-71d6d9178793.<\|eot_id\|>` | | 43 | `d92c9227-eadf-4085-bfcb-75468eb22579` | `: d92c922c9227-eadf-4085-bfcb-75468eb22579.<\|eot_id\|>` | --- ### niah_multikey_1 (16 errors) **错误 Sample IDs**: `20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74` | Index | Expected | Actual Output | |-------|----------|---------------| | 20 | `2171218` | `: 2171212181212181212181218<\|eot_id\|>` | | 31 | `9333700` | `:<\|eot_id\|>` | | 32 | `7121355` | `:9651<\|eot_id\|>` | | 40 | `3112652` | `:285<\|eot_id\|>` | | 41 | `3427461` | `:<\|eot_id\|>` | | 45 | `8217547` | `:<\|eot_id\|>` | | 51 | `1514340` | `: 1514343403361.<\|eot_id\|>` | | 54 | `8212753` | `:<\|eot_id\|>` | | 59 | `6587964` | `:<\|eot_id\|>` | | 63 | `1688246` | `:<\|eot_id\|>` | | 64 | `8344365` | `: 834436, but it is not explicitly mentioned.<\|eot_id\|>` | | 65 | `6614484` | `: 4367.<\|eot_id\|>` | | 67 | `6510922` | `:7780<\|eot_id\|>` | | 69 | `6649968` | `: 43610.<\|eot_id\|>` | | 71 | `9437374` | `:<\|eot_id\|>` | | 74 | `6625238` | `:1472908<\|eot_id\|>` | --- ### niah_multikey_2 (30 errors) **错误 Sample IDs**: `2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65` | Index | Expected | Actual Output | |-------|----------|---------------| | 2 | `1535573` | `: 8651665.<\|eot_id\|>` | | 13 | `2794159` | `: 5261593<\|eot_id\|>` | | 21 | `8970232` | `:168<\|eot_id\|>` | | 22 | `9134051` | `: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 38...` | | 23 | `9696620` | `: 969662620969662, which is: 969662920, 96966220 is not actually me...` | | 24 | `7071187` | ` 055055055.<\|eot_id\|>` | | 25 | `5572782` | `: 5342494<\|eot_id\|>` | | 28 | `4953027` | `:1687719<\|eot_id\|>` | | 32 | `4259234` | `: 425923521250, but not found is: 425923751572250, however is: 4259...` | | 34 | `3643022` | `: 3957500<\|eot_id\|>` | | 38 | `2031469` | `: the text.<\|eot_id\|>` | | 39 | `8740362` | `: 8740364 8740364 8740364 8740364 is: is: is: is: 874036...` | | 40 | `7041770` | `:1682<\|eot_id\|>` | | 41 | `1986258` | `:086.<\|eot_id\|>` | | 42 | `5668574` | `:055.<\|eot_id\|>` | | 43 | `8560471` | `:067<\|eot_id\|>` | | 45 | `9973767` | `: 8420273<\|eot_id\|>` | | 46 | `3960211` | `:0<\|eot_id\|>` | | 47 | `8003271` | `: 60870870870870870870870870870870870870870870870870870870870870870...` | | 49 | `8632309` | ` 303640 is640 640 640 640 640 640 640 640 640 640 640 640 640 640 640...` | | 50 | `2318630` | `: 7780552.<\|eot_id\|>` | | 53 | `3405052` | `:<\|eot_id\|>` | | 54 | `5364945` | `: 536494, which is: 536494, which is: 536494494494494494494494494...` | | 56 | `7319214` | `:7607607607607607607607607607607607607607607607607607607607607607607...` | | 57 | `9206104` | `:7607607607607607607607607607607607607607607607607607607607607607607...` | | 59 | `9555385` | `:7095<\|eot_id\|>` | | 60 | `5727554` | `: 572755755755755755755755755755755755755755755755755755755755 is: 572...` | | 63 | `1090767` | `:7607607607607607607607607607607607607607607607607607607607607607607...` | | 64 | `6791240` | `:<\|eot_id\|>` | | 65 | `7275999` | `:7607607607607607607607607607607607607607607607607607607607607607607...` | --- ### niah_multikey_3 (24 errors) **错误 Sample IDs**: `11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52` | Index | Expected | Actual Output | |-------|----------|---------------| | 11 | `c73ed342-6523-4d4b-aa33-beb1c9007315` | `: 1d28b88b-b6a8-46ba-8e8f-56cbafbfd897.<\|eot_id\|>` | | 18 | `87b8a762-1d1f-4e85-a5d1-caf284c95aa6` | `: 429a6676-5295-4ea2-a694-6aa949f48e31.<\|eot_id\|>` | | 20 | `cce29702-134a-460c-979b-6f7ee7895280` | `:<\|eot_id\|>` | | 23 | `ed344bfe-983f-4a21-af44-722e2517244c` | `: aec431e7d880a8dce2c023de24 is: aec43163-061a-4afe-b80a-f5bfb5e3c9...` | | 24 | `4712ef99-a8d1-4388-8ca7-b08dd3505d77` | `:<\|eot_id\|>` | | 25 | `46969ce7-0da0-49f8-87b2-845e7b8ef100` | `:<\|eot_id\|>` | | 26 | `7cff3c66-6860-49e6-8ba5-002162c250c0` | `:4c7e-946b-30812edf965e<\|eot_id\|>` | | 27 | `b63b4988-40bc-44b2-bf1c-ca95adbca4e9` | `:<\|eot_id\|>` | | 29 | `6d94011c-f28a-4b0b-a2e2-fe34bb8b19a1` | `: 6d6d6d6d4b0e-52ce-44d9-a0f6-1ae405825615<\|eot_id\|>` | | 30 | `7c33bb00-4ab4-4e4f-a78e-39f8f06d63eb` | ` d7a2-4b23-a2c0-8c859cb1fa96<\|eot_id\|>` | | 33 | `b7c6b586-713a-4907-ad24-5c4f25aeb769` | `:1-4d2c-b42b-933ded2633d6<\|eot_id\|>` | | 35 | `ac8a317b-a6bb-4327-90db-2a01622cb723` | `: d2f2f2f2f2f2f2f2d2d2f2d2d2d3d2f6b3d2f- is: d2dab is: is: is: i...` | | 37 | `b187b337-3132-4376-a500-9340102092ae` | `:<\|eot_id\|>` | | 40 | `2559fa56-dd0a-48d4-ba82-3ae2bf0a4b33` | `:358fe0e3-724e-4cfc-9ae0-d0873162626b.<\|eot_id\|>` | | 41 | `7842feb5-e758-44cd-b73b-8ae08aa33142` | `: 6c6adf83-36a9-4e41-9cbe-60a8c9ffba92.<\|eot_id\|>` | | 42 | `a1196139-f6fa-4c18-b3da-b7bd50362ac7` | `: a1196131396131196131399a1196139a1196139a1196139a1196139f6a1196139...` | | 44 | `7d3d40b2-4594-4573-b267-4c6270dd4425` | `: 613a9e-4e7d-8c9f-740a630e3c53<\|eot_id\|>` | | 45 | `500b8a75-8f05-43f5-b9ad-46d47d4e33fc` | `: 500b8a5e0e0e0a500b is: 500b is: 500b-4 is: is: is: is: is: i...` | | 46 | `86a867a7-6a98-4a02-b065-70a33bafafde` | `:6139a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a...` | | 47 | `7c0f7fd2-237e-4c0f-b3f5-f43623551169` | ` 5fb71d2f0f0b4f0 is: 5fb71 is: 5fb71f-4f-4f-4f-4f-4f-4d7 is: is: ...` | | 48 | `b0e1f3f5-6570-437e-b8a1-f1b3f654e257` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` | | 49 | `0153722a-70a8-4ec0-9f03-2b0930937e60` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` | | 50 | `0a1ead51-0c39-4eeb-ac87-d146acdb1d4a` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` | | 52 | `ff686e85-3a9f-4635-95dd-f19e8ca68eb1` | ` ff686e686e686e686e686e686f686e6f686e6fb686f686f686f686f686f- is: f...` | --- ### 错误模式分析 | 错误类型 | 示例 | 可能原因 | |----------|------|----------| | **空输出** | `:<\|eot_id\|>` | KV cache 完全损坏,模型无法生成有效内容 | | **数字重复** | `:361361361...` | Chunk 边界 merge 错误导致 decode 进入循环 | | **部分正确** | `9874152` → `:151:52` | 前几个 chunk 正确,后面 chunk 损坏 | | **数字插入** | `2344047` → `23440447` | Merge 过程中数值精度损失 | | **完全错误** | 输出完全不相关的 UUID | 严重的 KV cache 或 position encoding 错误 | --- ### 验证脚本 用于验证修复效果的脚本: ```python # verify_samples.py import json ERROR_SAMPLES = { 'niah_single_1': [28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83], 'niah_single_2': [16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93], 'niah_single_3': [7, 9, 14, 24, 25, 29, 31, 43], 'niah_multikey_1': [20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74], 'niah_multikey_2': [2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65], 'niah_multikey_3': [11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52], } def verify_fix(result_dir: str) -> dict: """验证修复效果""" fixed = {} for task, error_ids in ERROR_SAMPLES.items(): pred_file = f"{result_dir}/{task}.jsonl" with open(pred_file) as f: predictions = [json.loads(line) for line in f] fixed_count = 0 still_error = [] for idx in error_ids: pred = predictions[idx] expected = pred['answer'] actual = pred['model_output'] if expected in actual: fixed_count += 1 else: still_error.append(idx) fixed[task] = { 'total_errors': len(error_ids), 'fixed': fixed_count, 'still_error': still_error, 'fix_rate': f"{fixed_count/len(error_ids)*100:.1f}%" } return fixed ``` --- **Author**: Zijie Tian **Created**: 2026-01-20