diff --git a/CLAUDE.md b/CLAUDE.md index ea099f6..c0f4621 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,7 +20,8 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L | [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) | | [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition | | [`docs/ruler_benchmark_results_32k.md`](docs/ruler_benchmark_results_32k.md) | RULER benchmark results (32K context): 13 tasks, 92.3% accuracy, CPU offload performance | -| [`docs/ruler_32k_chunked_offload_issue.md`](docs/ruler_32k_chunked_offload_issue.md) | ⚠️ OPEN ISSUE: 32K chunked offload accuracy problem (35% error rate in RULER) | +| [`docs/ruler_32k_chunked_offload_issue.md`](docs/ruler_32k_chunked_offload_issue.md) | ⚠️ OPEN ISSUE: 32K chunked offload accuracy problem (20% error rate in RULER) | +| [`docs/chunked_attention_solutions.md`](docs/chunked_attention_solutions.md) | 🔧 SOLUTIONS: Chunked attention 准确性问题的代码分析和解决方案 | ## GPU Mutex for Multi-Instance Debugging diff --git a/docs/chunked_attention_solutions.md b/docs/chunked_attention_solutions.md new file mode 100644 index 0000000..c84bb43 --- /dev/null +++ b/docs/chunked_attention_solutions.md @@ -0,0 +1,1078 @@ +# Chunked Attention 准确性问题解决方案 + +**Status**: 🟡 IN PROGRESS +**Related Issue**: [`ruler_32k_chunked_offload_issue.md`](ruler_32k_chunked_offload_issue.md) +**Created**: 2026-01-20 +**Author**: Zijie Tian + +--- + +## 概述 + +本文档基于对 chunked attention 代码的深入分析,详细描述了导致 RULER 32K 测试 20% 错误率的潜在代码问题,以及对应的解决方案。 + +**核心问题**: Chunked prefill 机制在处理 32K 长序列时,由于多个因素的累积效应,导致输出准确性下降。 + +--- + +## 问题 1: `flash_attn_with_lse` 的 LSE 获取方式 + +### 问题描述 + +**文件**: `nanovllm/ops/chunked_attention.py` +**位置**: 第 264-269 行 + +```python +def flash_attn_with_lse( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + from flash_attn.flash_attn_interface import flash_attn_func + + # ... + + # 使用 return_attn_probs=True 来获取 LSE + out, lse, _ = flash_attn_func( + q, k, v, + softmax_scale=softmax_scale, + causal=causal, + return_attn_probs=True, # ⚠️ 这个 API 不是专门用于 LSE 输出的 + ) + + # lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded] + lse = lse[:, :, :seqlen_q] + return out, lse +``` + +### 问题分析 + +1. **API 设计目的不匹配**: `return_attn_probs=True` 的主要目的是返回 attention probabilities 用于调试,不是为了高精度 LSE 输出 +2. **潜在精度损失**: Flash Attention 内部可能对 LSE 进行了某些优化,这些优化可能牺牲精度 +3. **S_dmask 返回值**: 第三个返回值 `S_dmask` 在某些版本中可能包含随机 dropout mask,影响结果 + +### 验证方法 + +```python +def verify_lse_accuracy(): + """验证 flash_attn 返回的 LSE 是否准确""" + import torch + from flash_attn.flash_attn_interface import flash_attn_func + + # 小规模数据用于精确计算 + batch, seqlen, nheads, headdim = 1, 128, 8, 64 + q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16) + k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16) + v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16) + + # Flash Attention LSE + _, lse_flash, _ = flash_attn_func(q, k, v, return_attn_probs=True) + + # 手动计算 LSE (使用 PyTorch) + scale = 1.0 / (headdim ** 0.5) + qk = torch.einsum('bqhd,bkhd->bhqk', q, k) * scale + lse_manual = torch.logsumexp(qk, dim=-1) # [batch, heads, seqlen_q] + + # 比较 + diff = (lse_flash[:, :, :seqlen] - lse_manual).abs() + print(f"LSE max diff: {diff.max().item():.6f}") + print(f"LSE mean diff: {diff.mean().item():.6f}") + + return diff.max().item() < 1e-3 +``` + +### 解决方案 + +#### 方案 A: 使用 FlashInfer 的精确 LSE 接口 (推荐) + +FlashInfer 提供了专门用于 chunked attention 的接口: + +```python +def flash_attn_with_lse_flashinfer( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """使用 FlashInfer 的接口获取精确 LSE""" + import flashinfer + + batch, seqlen_q, nheads_q, headdim = q.shape + _, seqlen_k, nheads_kv, _ = k.shape + + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(headdim) + + # FlashInfer 的 single_prefill_with_kv_cache_return_lse + # 专门设计用于返回精确的 LSE + out, lse = flashinfer.single_prefill_with_kv_cache_return_lse( + q.view(batch * seqlen_q, nheads_q, headdim), + k.view(batch * seqlen_k, nheads_kv, headdim), + v.view(batch * seqlen_k, nheads_kv, headdim), + causal=causal, + sm_scale=softmax_scale, + ) + + # Reshape outputs + out = out.view(batch, seqlen_q, nheads_q, headdim) + lse = lse.view(batch, seqlen_q, nheads_q).transpose(1, 2) # [batch, nheads, seqlen_q] + + return out, lse +``` + +#### 方案 B: 自定义 Triton Kernel + +如果 FlashInfer 不可用,使用自定义的 Triton kernel: + +```python +# 已经在 chunked_attention.py 中实现: _fwd_kernel_with_lse +# 这个 kernel 直接在 Triton 中计算并输出 LSE +# 优点: 完全控制精度 +# 缺点: 可能比 Flash Attention 慢 +``` + +#### 方案 C: 混合方法 + +```python +def flash_attn_with_lse_hybrid( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + use_triton_kernel: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + 根据情况选择最佳实现: + - 短序列 (< 4K): 使用 Flash Attention (快) + - 长序列 (>= 4K): 使用 Triton kernel (精确) + """ + seqlen = q.shape[1] + + if seqlen < 4096 and not use_triton_kernel: + return flash_attn_with_lse(q, k, v, softmax_scale, causal) + else: + return triton_flash_attn_with_lse(q, k, v, softmax_scale, causal) +``` + +### 预期改进 + +- 修复 LSE 精度问题后,预计错误率降低 5-10% + +--- + +## 问题 2: 累积 Merge 次数过多导致误差累积 + +### 问题描述 + +**文件**: `nanovllm/kvcache/sparse/full_policy.py` +**位置**: 第 136-137 行, 163-164 行 + +```python +# compute_chunked_prefill 中的 merge 循环 +for block_idx in range(num_blocks): + # ... 加载 block ... + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=softmax_scale, + causal=False, + ) + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + # ⚠️ 每个 block 都执行一次 merge + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) +``` + +### 问题分析 + +对于 32K context (1024 block size),需要执行 ~32 次 merge。每次 merge 的数学公式: + +``` +max_lse = max(lse1, lse2) +exp1 = exp(lse1 - max_lse) +exp2 = exp(lse2 - max_lse) +o_merged = (o1 * exp1 + o2 * exp2) / (exp1 + exp2) +lse_merged = max_lse + log(exp1 + exp2) +``` + +**误差来源**: +1. `exp()` 和 `log()` 操作在边界情况下精度损失 +2. 除法操作 `/ (exp1 + exp2)` 可能有精度问题 +3. 32 次累积后,误差可达到显著水平 + +### 数学分析 + +假设每次 merge 引入相对误差 ε ≈ 10^-6 (fp32): +- 32 次累积后: 总误差 ≈ 32 × ε = 3.2 × 10^-5 +- 对于 BFloat16 (7-bit mantissa): ε ≈ 10^-2, 总误差 ≈ 0.32 (32%) + +### 验证方法 + +```python +def measure_merge_error_accumulation(): + """测量 merge 误差随 chunk 数量的变化""" + import torch + from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs + from flash_attn.flash_attn_interface import flash_attn_func + + torch.manual_seed(42) + + batch, seqlen, nheads, headdim = 1, 32768, 32, 128 + q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) + k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) + v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) + + # Reference: full attention + out_ref = flash_attn_func(q, k, v, causal=False) + + results = [] + for num_chunks in [4, 8, 16, 32, 64]: + chunk_size = seqlen // num_chunks + o_acc, lse_acc = None, None + + for i in range(num_chunks): + start = i * chunk_size + end = (i + 1) * chunk_size + k_chunk = k[:, start:end, :, :] + v_chunk = v[:, start:end, :, :] + + chunk_o, chunk_lse = flash_attn_with_lse(q, k_chunk, v_chunk, causal=False) + + if o_acc is None: + o_acc, lse_acc = chunk_o, chunk_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse) + + diff = (out_ref - o_acc).abs() + results.append({ + 'num_chunks': num_chunks, + 'max_diff': diff.max().item(), + 'mean_diff': diff.mean().item(), + }) + print(f"chunks={num_chunks:3d}, max_diff={diff.max().item():.6f}, mean_diff={diff.mean().item():.8f}") + + return results +``` + +### 解决方案 + +#### 方案 A: N-way Merge (推荐) + +代替二叉树式的 pairwise merge,实现一次性 merge 多个 chunk: + +```python +def merge_attention_outputs_nway( + outputs: List[torch.Tensor], # List of [batch, seqlen_q, nheads, headdim] + lses: List[torch.Tensor], # List of [batch, nheads, seqlen_q] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + N-way merge: 一次性 merge 所有 chunk,减少累积误差 + + 数学原理: + - max_lse = max(lse_1, lse_2, ..., lse_n) + - sum_exp = sum(exp(lse_i - max_lse) for i in 1..n) + - o_merged = sum(o_i * exp(lse_i - max_lse) for i in 1..n) / sum_exp + - lse_merged = max_lse + log(sum_exp) + """ + n = len(outputs) + if n == 0: + raise ValueError("Need at least one output to merge") + if n == 1: + return outputs[0], lses[0] + + # Stack for batch processing + o_stack = torch.stack(outputs, dim=0) # [n, batch, seqlen, heads, dim] + lse_stack = torch.stack(lses, dim=0) # [n, batch, heads, seqlen] + + # Compute max_lse across all chunks (in fp32 for precision) + lse_fp32 = lse_stack.float() + max_lse = lse_fp32.max(dim=0).values # [batch, heads, seqlen] + + # Compute exp(lse_i - max_lse) for each chunk + exp_weights = torch.exp(lse_fp32 - max_lse.unsqueeze(0)) # [n, batch, heads, seqlen] + + # Normalize weights + sum_exp = exp_weights.sum(dim=0) # [batch, heads, seqlen] + normalized_weights = exp_weights / sum_exp.unsqueeze(0) # [n, batch, heads, seqlen] + + # Weighted sum of outputs + # outputs: [n, batch, seqlen, heads, dim] + # weights: [n, batch, heads, seqlen] -> [n, batch, seqlen, heads, 1] + weights_expanded = normalized_weights.permute(0, 1, 3, 2).unsqueeze(-1) + o_merged = (o_stack.float() * weights_expanded).sum(dim=0).to(outputs[0].dtype) + + # Compute merged LSE + lse_merged = (max_lse + torch.log(sum_exp)).to(lses[0].dtype) + + return o_merged, lse_merged +``` + +**使用方式**: + +```python +# 原来的逐个 merge +for block_idx in range(num_blocks): + chunk_o, chunk_lse = flash_attn_with_lse(...) + if o_acc is None: + o_acc, lse_acc = chunk_o, chunk_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse) + +# 改为: 收集所有 chunk,然后一次性 merge +chunk_outputs = [] +chunk_lses = [] +for block_idx in range(num_blocks): + chunk_o, chunk_lse = flash_attn_with_lse(...) + chunk_outputs.append(chunk_o) + chunk_lses.append(chunk_lse) + +o_acc, lse_acc = merge_attention_outputs_nway(chunk_outputs, chunk_lses) +``` + +#### 方案 B: Kahan Summation + +使用 Kahan 求和算法减少浮点累积误差: + +```python +def merge_attention_outputs_kahan( + o1: torch.Tensor, lse1: torch.Tensor, + o2: torch.Tensor, lse2: torch.Tensor, + compensation: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + 使用 Kahan 补偿算法的 merge,保持累积误差补偿项 + + Returns: + o_merged, lse_merged, (output_comp, lse_comp) + """ + # ... 标准 merge 计算 ... + max_lse = torch.maximum(lse1.float(), lse2.float()) + exp1 = torch.exp(lse1.float() - max_lse) + exp2 = torch.exp(lse2.float() - max_lse) + sum_exp = exp1 + exp2 + + o_merged_raw = (o1.float() * exp1.unsqueeze(-1) + o2.float() * exp2.unsqueeze(-1)) / sum_exp.unsqueeze(-1) + lse_merged_raw = max_lse + torch.log(sum_exp) + + # Kahan compensation + if compensation is not None: + output_comp, lse_comp = compensation + # Apply compensation + y = o_merged_raw - output_comp + t = o_merged_raw # We need to track the actual result + output_comp_new = (t - o_merged_raw) + y + # ... similar for lse ... + else: + output_comp_new = torch.zeros_like(o_merged_raw) + lse_comp_new = torch.zeros_like(lse_merged_raw) + + return o_merged_raw.to(o1.dtype), lse_merged_raw.to(lse1.dtype), (output_comp_new, lse_comp_new) +``` + +#### 方案 C: 分层 Merge (折中方案) + +将 32 个 chunk 分组,先组内 merge,再组间 merge: + +``` +原来: c1 -> c2 -> c3 -> ... -> c32 (31 次 merge) +改为: + Group1: c1-c4 -> g1 (3 次) + Group2: c5-c8 -> g2 (3 次) + ... + Group8: c29-c32 -> g8 (3 次) + Final: g1-g8 -> result (7 次) + 总计: 3*8 + 7 = 31 次 merge,但最大累积深度从 31 降到 7 +``` + +```python +def merge_attention_outputs_hierarchical( + outputs: List[torch.Tensor], + lses: List[torch.Tensor], + group_size: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + 分层 merge: 先分组 merge,再合并组 + 减少累积深度从 O(n) 到 O(log(n)) + """ + from nanovllm.ops.chunked_attention import merge_attention_outputs + + n = len(outputs) + if n <= 1: + return outputs[0] if n == 1 else (None, None) + + # Level 1: Merge within groups + group_outputs = [] + group_lses = [] + + for i in range(0, n, group_size): + group = outputs[i:i+group_size] + group_l = lses[i:i+group_size] + + o_g, lse_g = group[0], group_l[0] + for j in range(1, len(group)): + o_g, lse_g = merge_attention_outputs(o_g, lse_g, group[j], group_l[j]) + + group_outputs.append(o_g) + group_lses.append(lse_g) + + # Recursively merge groups + if len(group_outputs) > 1: + return merge_attention_outputs_hierarchical(group_outputs, group_lses, group_size) + else: + return group_outputs[0], group_lses[0] +``` + +### 预期改进 + +- N-way merge: 预计错误率降低 10-15% +- 分层 merge: 预计错误率降低 5-8% + +--- + +## 问题 3: Ring Buffer Slot 数量不足 + +### 问题描述 + +**文件**: `nanovllm/kvcache/offload_engine.py` +**位置**: Ring buffer 初始化部分 + +```python +# 2-slot 配置 (原始) +num_gpu_blocks = 2 # 仅 2 个 GPU block 用于 ring buffer + +# 问题: +# - Slot 0: 用于 decode +# - Slot 1: 用于 prefill 加载 +# - 32 个 chunk 共享 1 个 prefill slot +# - 高频率的 CPU<->GPU 传输可能导致竞态条件 +``` + +### 问题分析 + +1. **竞态条件**: 当 chunk N 的 compute 还在进行时,chunk N+1 的 transfer 可能覆盖同一个 slot +2. **同步开销**: 频繁的 `wait_slot_layer()` 调用增加延迟 +3. **已验证**: 4-slot 配置显著改善但未完全解决问题 + +### 验证结果 (已完成) + +| 配置 | niah_single_1 准确率 | niah_multikey_3 准确率 | +|------|---------------------|----------------------| +| 2-slot | 94% | 48% | +| 4-slot | 98% | 56% | +| 改进 | +4% | +8% | + +**关键发现**: Sample 40 在 4-slot 配置下仍然产生相同错误 (`6171717161711716`),说明 slot 数量不是唯一原因。 + +### 解决方案 + +#### 方案 A: 增加 Slot 数量到 8 + +```python +# offload_engine.py 配置 +class OffloadEngineConfig: + # 原来 + # num_ring_slots = 2 + + # 建议 + num_ring_slots = 8 # 增加到 8 个 slots + + # 或者动态计算 + @property + def num_ring_slots(self): + # 根据可用 GPU 内存动态调整 + available_memory_gb = get_available_gpu_memory() + slot_memory_gb = self.block_size * self.kv_cache_dtype_size * 2 / 1e9 + max_slots = int(available_memory_gb * 0.3 / slot_memory_gb) # 使用 30% 内存 + return min(max(max_slots, 4), 16) # 限制在 4-16 之间 +``` + +#### 方案 B: 改进同步机制 + +```python +def load_to_slot_layer_with_barrier(self, slot: int, layer_id: int, cpu_block_id: int): + """ + 带有显式 barrier 的加载,确保前一个操作完成 + """ + # 等待该 slot 上所有挂起的操作完成 + if self.slot_events[slot] is not None: + self.slot_events[slot].synchronize() + + # 执行传输 + self._do_transfer(slot, layer_id, cpu_block_id) + + # 记录新的事件 + self.slot_events[slot] = torch.cuda.Event() + self.slot_events[slot].record(self.transfer_stream) +``` + +#### 方案 C: Double Buffering + +```python +class DoubleBufferedRingBuffer: + """ + 双缓冲设计: 每个逻辑 slot 有两个物理 buffer + - 一个用于当前 compute + - 一个用于下一个 transfer + """ + def __init__(self, num_logical_slots: int, ...): + self.num_logical_slots = num_logical_slots + self.num_physical_slots = num_logical_slots * 2 + self.current_buffer = [0] * num_logical_slots # 每个 slot 当前使用的 buffer (0 or 1) + + def get_compute_slot(self, logical_slot: int) -> int: + """获取用于 compute 的物理 slot""" + return logical_slot * 2 + self.current_buffer[logical_slot] + + def get_transfer_slot(self, logical_slot: int) -> int: + """获取用于 transfer 的物理 slot (另一个 buffer)""" + return logical_slot * 2 + (1 - self.current_buffer[logical_slot]) + + def swap_buffer(self, logical_slot: int): + """交换 buffer""" + self.current_buffer[logical_slot] = 1 - self.current_buffer[logical_slot] +``` + +### 预期改进 + +- 8-slot 配置: 预计错误率再降低 3-5% +- Double buffering: 预计错误率再降低 2-3% + +--- + +## 问题 4: Sparse Policy Mapping Bug + +### 问题描述 + +**文件**: `COMPASS/eval/RULER/scripts/pred/call_api.py` (COMPASS 项目) +**位置**: metric 到 policy 的映射 + +```python +metric_to_policy = { + 'full': 'FULL', + 'xattn': 'XATTN', # ❌ XATTN 不存在于 SparsePolicyType + 'compass': 'XATTN', + 'minfer': 'MINFERENCE', # ❌ MINFERENCE 不存在于 SparsePolicyType + 'avgpool': 'FULL', + 'flex': 'FLEXPREFILL', +} + +sparse_policy_name = metric_to_policy.get(self.metric, 'FULL') +sparse_policy = getattr(SparsePolicyType, sparse_policy_name, SparsePolicyType.FULL) +# ⚠️ 静默回退到 FULL,没有警告! +``` + +### 问题分析 + +1. **静默回退**: `getattr(..., SparsePolicyType.FULL)` 使得无效名称静默回退 +2. **测试配置错误**: 用户以为在测试 `xattn`,实际上在测试 `FULL` +3. **结果不可靠**: 所有 sparse 方法的测试结果可能都是 FULL 的结果 + +### 验证方法 + +```python +# 检查当前 SparsePolicyType 的有效值 +from nanovllm.kvcache.sparse.policy import SparsePolicyType +print([e.name for e in SparsePolicyType]) +# 应该输出类似: ['FULL', 'QUEST', 'XATTN_BSA', ...] + +# 验证映射是否正确 +for metric, policy_name in metric_to_policy.items(): + has_policy = hasattr(SparsePolicyType, policy_name) + print(f"{metric} -> {policy_name}: {'✓' if has_policy else '✗'}") +``` + +### 解决方案 + +#### 方案: 修复映射并添加验证 + +```python +# 正确的映射 (需要根据 SparsePolicyType 定义更新) +metric_to_policy = { + 'full': 'FULL', + 'xattn': 'XATTN_BSA', # ✓ 使用正确的枚举名 + 'compass': 'XATTN_BSA', + 'minfer': 'MINFERENCE', # 需要确认是否存在 + 'avgpool': 'AVGPOOL', # 需要确认是否存在 + 'flex': 'FLEXPREFILL', # 需要确认是否存在 + 'quest': 'QUEST', # 添加 quest +} + +def get_sparse_policy(metric: str) -> SparsePolicyType: + """获取 sparse policy,带有验证""" + policy_name = metric_to_policy.get(metric) + if policy_name is None: + raise ValueError(f"Unknown metric: {metric}. Valid options: {list(metric_to_policy.keys())}") + + if not hasattr(SparsePolicyType, policy_name): + raise ValueError(f"SparsePolicyType.{policy_name} does not exist. " + f"Available: {[e.name for e in SparsePolicyType]}") + + return getattr(SparsePolicyType, policy_name) +``` + +### 预期改进 + +- 这个 bug 不影响 chunked offload 准确性 +- 但修复后可以正确测试各种 sparse 方法的实际效果 + +--- + +## 问题 5: Decode Buffer 读取边界条件 + +### 问题描述 + +**文件**: `nanovllm/kvcache/sparse/full_policy.py` +**位置**: 第 269-272 行 + +```python +def compute_chunked_decode(self, ...): + # ... + seq_len = len(seq) + decode_pos_in_block = (seq_len - 1) % block_size + decode_start_pos = kvcache_manager.get_decode_start_pos(seq) + decode_start_pos_in_block = decode_start_pos % block_size + num_accumulated = decode_pos_in_block - decode_start_pos_in_block + 1 + + # ⚠️ 潜在问题: 当 decode_pos_in_block < decode_start_pos_in_block 时 + # num_accumulated 会是负数 +``` + +### 问题分析 + +场景:当 decode 跨越 block 边界时: +- 假设 block_size = 1024 +- decode_start_pos = 1020 (在 block 边界附近) +- 当前 seq_len = 1026 (跨越到下一个 block) +- decode_pos_in_block = (1026 - 1) % 1024 = 1 +- decode_start_pos_in_block = 1020 % 1024 = 1020 +- num_accumulated = 1 - 1020 + 1 = -1018 ❌ + +### 解决方案 + +```python +def compute_chunked_decode(self, ...): + # ... + seq_len = len(seq) + decode_start_pos = kvcache_manager.get_decode_start_pos(seq) + + # 计算 decode buffer 中的实际 token 数量 + # 使用绝对位置而非块内相对位置 + num_accumulated = seq_len - decode_start_pos + + # 验证 + assert num_accumulated > 0, f"Invalid decode range: seq_len={seq_len}, decode_start={decode_start_pos}" + + # 计算块内偏移用于读取 + decode_start_pos_in_block = decode_start_pos % block_size + decode_end_pos_in_block = (seq_len - 1) % block_size + + # 处理跨块情况 + if decode_end_pos_in_block < decode_start_pos_in_block: + # 跨块: 需要从两个位置读取 + # 这种情况应该在 kvcache_manager 层面处理 + raise NotImplementedError("Cross-block decode not yet supported") + + # 同块: 正常读取 + decode_k = offload_engine.decode_k_buffer[ + layer_id, + decode_start_pos_in_block:decode_end_pos_in_block+1 + ] + decode_v = offload_engine.decode_v_buffer[ + layer_id, + decode_start_pos_in_block:decode_end_pos_in_block+1 + ] +``` + +### 预期改进 + +- 修复边界条件后,预计消除特定情况下的崩溃或错误输出 + +--- + +## 问题 6: Merge Kernel Block Size 优化 + +### 问题描述 + +**文件**: `nanovllm/ops/chunked_attention.py` +**位置**: 第 406 行 + +```python +def merge_attention_outputs(...): + # ... + # Launch output merge kernel + BLOCK_SIZE = 128 # 固定值 + grid_output = (batch, seqlen_q, nheads) + _merge_output_kernel[grid_output]( + o1, o2, lse1, lse2, o_merged, + batch, seqlen_q, nheads, headdim, + BLOCK_SIZE=BLOCK_SIZE, + ) +``` + +### 问题分析 + +1. **固定 BLOCK_SIZE**: 128 可能不是所有 headdim 的最优值 +2. **性能影响**: 不当的 BLOCK_SIZE 可能导致 GPU 占用率低 +3. **精度影响**: 某些 BLOCK_SIZE 可能导致边界处理问题 + +### 解决方案 + +```python +def merge_attention_outputs(...): + batch, seqlen_q, nheads, headdim = o1.shape + + # 动态选择 BLOCK_SIZE + # 原则: BLOCK_SIZE 应该是 headdim 的因子,且适合 GPU warp (32) + if headdim <= 64: + BLOCK_SIZE = 64 + elif headdim <= 128: + BLOCK_SIZE = 128 + else: + BLOCK_SIZE = 256 + + # 确保 BLOCK_SIZE 是 headdim 的因子 + while headdim % BLOCK_SIZE != 0 and BLOCK_SIZE > 32: + BLOCK_SIZE //= 2 + + # ... +``` + +### 预期改进 + +- 这是性能优化,对准确性影响较小 + +--- + +## 综合解决方案优先级 + +| 优先级 | 问题 | 解决方案 | 预期改进 | 实现难度 | +|--------|------|----------|----------|----------| +| **P0** | N-way Merge | 实现 `merge_attention_outputs_nway` | 10-15% | 中 | +| **P0** | LSE 精度 | 使用 FlashInfer 或验证 flash_attn LSE | 5-10% | 低 | +| **P1** | Ring Buffer Slots | 增加到 8 slots + double buffering | 3-5% | 中 | +| **P1** | Policy Mapping | 修复 COMPASS 中的映射 bug | N/A | 低 | +| **P2** | Decode 边界 | 添加边界检查和处理 | 1-2% | 低 | +| **P2** | Merge Block Size | 动态选择 BLOCK_SIZE | <1% | 低 | + +--- + +## 验证计划 + +### 阶段 1: 误差追踪 (先执行) + +在实施任何修改之前,添加详细的误差追踪: + +```python +# 在 full_policy.py 中添加 +def compute_chunked_prefill_with_debug(self, ...): + debug_info = { + 'chunk_max_diffs': [], + 'lse_values': [], + 'merge_count': 0, + } + + # ... 原有逻辑 ... + + for block_idx in range(num_blocks): + # ... 计算 ... + + if o_acc is not None: + # 记录 merge 前后的差异 + pre_merge_o = o_acc.clone() + o_acc, lse_acc = merge_attention_outputs(...) + debug_info['chunk_max_diffs'].append((o_acc - pre_merge_o).abs().max().item()) + debug_info['lse_values'].append(lse_acc.mean().item()) + debug_info['merge_count'] += 1 + + # 保存 debug 信息 + torch.save(debug_info, f'debug_prefill_layer{layer_id}.pt') +``` + +### 阶段 2: 单一改进测试 + +每次只应用一个改进,测试效果: + +1. 仅应用 N-way merge -> 测试准确率 +2. 仅增加 ring buffer slots -> 测试准确率 +3. 仅使用 FlashInfer LSE -> 测试准确率 + +### 阶段 3: 组合测试 + +组合效果最好的改进,验证是否有叠加效果: + +``` +最终配置 = N-way merge + 8 slots + FlashInfer LSE +目标: 错误率 < 10% (接近 xattn_stride8 的 8% baseline) +``` + +--- + +## 附录: 相关文件索引 + +| 文件 | 关键函数/类 | 问题编号 | +|------|-------------|----------| +| `nanovllm/ops/chunked_attention.py` | `flash_attn_with_lse`, `merge_attention_outputs` | 1, 2, 6 | +| `nanovllm/kvcache/sparse/full_policy.py` | `compute_chunked_prefill`, `compute_chunked_decode` | 2, 5 | +| `nanovllm/kvcache/offload_engine.py` | Ring buffer 管理 | 3 | +| `COMPASS/eval/RULER/scripts/pred/call_api.py` | `metric_to_policy` | 4 | + +--- + +## 附录: 错误样本详情 + +以下是 RULER 32K 测试中出错的样本,用于后续验证修复效果。 + +### 错误统计 + +| Task | Total Samples | Errors | Error Rate | +|------|--------------|--------|------------| +| niah_single_1 | 100 | 19 | 19% | +| niah_single_2 | 100 | 23 | 23% | +| niah_single_3 | 100 | 8 | 8% | +| niah_multikey_1 | 100 | 16 | 16% | +| niah_multikey_2 | 100 | 30 | 30% | +| niah_multikey_3 | 100 | 24 | 24% | +| **TOTAL** | **600** | **120** | **20%** | + +--- + +### niah_single_1 (19 errors) + +**错误 Sample IDs**: `28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83` + +| Index | Expected | Actual Output | +|-------|----------|---------------| +| 28 | `9874152` | `:151:52<\|eot_id\|>` | +| 33 | `9196204` | `:<\|eot_id\|>` | +| 39 | `3484601` | `:<\|eot_id\|>` | +| 40 | `6171716` | `: 17: 16<\|eot_id\|>` | +| 41 | `4524499` | `:<\|eot_id\|>` | +| 43 | `3726327` | `: 16: 7<\|eot_id\|>` | +| 44 | `4009172` | `: 2<\|eot_id\|>` | +| 49 | `4240180` | `:354:180<\|eot_id\|>` | +| 51 | `9546409` | `:<\|eot_id\|>` | +| 52 | `2935113` | `: 29351113.<\|eot_id\|>` | +| 53 | `5453786` | `:354:678:90<\|eot_id\|>` | +| 57 | `8315831` | `: 5831<\|eot_id\|>` | +| 61 | `5960271` | `: 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,...<\|eot_id\|>` | +| 63 | `6049101` | `: 5 0 4 9 1 0 1<\|eot_id\|>` | +| 65 | `6406444` | `:361361361361361361361361361361361361361361361361361361361361361361361361361361...<\|eot_id\|>` | +| 67 | `2422633` | `:31<\|eot_id\|>` | +| 72 | `7442089` | ` 7953166<\|eot_id\|>` | +| 77 | `8795419` | `:<\|eot_id\|>` | +| 83 | `6363836` | `: 2<\|eot_id\|>` | + +--- + +### niah_single_2 (23 errors) + +**错误 Sample IDs**: `16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93` + +| Index | Expected | Actual Output | +|-------|----------|---------------| +| 16 | `2344047` | `: 23440447.<\|eot_id\|>` | +| 24 | `5449324` | `:<\|eot_id\|>` | +| 30 | `5727085` | `:<\|eot_id\|>` | +| 32 | `9196204` | `:<\|eot_id\|>` | +| 40 | `4524499` | `:460<\|eot_id\|>` | +| 41 | `7817881` | `:171.<\|eot_id\|>` | +| 42 | `3726327` | `:<\|eot_id\|>` | +| 50 | `9546409` | `:<\|eot_id\|>` | +| 51 | `2935113` | `: 3: 5113<\|eot_id\|>` | +| 52 | `5453786` | `:354<\|eot_id\|>` | +| 55 | `4188992` | `: 418899189418899, but it is not explicitly stated...` | +| 58 | `6266630` | `:5963<\|eot_id\|>` | +| 60 | `5960271` | ` 0271<\|eot_id\|>` | +| 62 | `6049101` | `:<\|eot_id\|>` | +| 64 | `6406444` | `:<\|eot_id\|>` | +| 66 | `2422633` | `:5313<\|eot_id\|>` | +| 67 | `4940441` | `:5311<\|eot_id\|>` | +| 68 | `3472189` | `:361.<\|eot_id\|>` | +| 69 | `8971465` | `:361.<\|eot_id\|>` | +| 77 | `8963715` | `: 0 8 9 7 1 5<\|eot_id\|>` | +| 85 | `2044645` | `: 20446445.<\|eot_id\|>` | +| 91 | `7783308` | `:<\|eot_id\|>` | +| 93 | `1454696` | `:<\|eot_id\|>` | + +--- + +### niah_single_3 (8 errors) + +**错误 Sample IDs**: `7, 9, 14, 24, 25, 29, 31, 43` + +| Index | Expected | Actual Output | +|-------|----------|---------------| +| 7 | `ee87905e-4ca4-45ea-8dfa-6a56d12dbc9a` | `: 2010-07-01T00:00:00Z<\|eot_id\|>` | +| 9 | `b7b56ea7-35eb-432d-9ad6-20ab48212ddb` | `:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0<\|eot_id\|>` | +| 14 | `e767dcea-b0e6-4969-a213-42b0f1eedba3` | `:0e6-4969-a213-42b0f1eedba3<\|eot_id\|>` | +| 24 | `59e4b671-4774-4c58-85f8-bc16f7860b50` | `:4774:4c58:85f8:bc16f7860b50<\|eot_id\|>` | +| 25 | `54c63cd8-8945-4f27-97fa-2d8dfb2ca025` | `: 54c63c63cd8-8945-4f27-97fa-2d8dfb2ca025.<\|eot_id\|>` | +| 29 | `006ed6e3-6fa1-4735-b572-f3d00b5cea6a` | `:6e3-6fa1-4735-b572-f3d00b5cea6a<\|eot_id\|>` | +| 31 | `e6697833-b841-40a0-9fe7-71d6d9178793` | `: e6697837837833-b841-40a0-9fe7-71d6d9178793.<\|eot_id\|>` | +| 43 | `d92c9227-eadf-4085-bfcb-75468eb22579` | `: d92c922c9227-eadf-4085-bfcb-75468eb22579.<\|eot_id\|>` | + +--- + +### niah_multikey_1 (16 errors) + +**错误 Sample IDs**: `20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74` + +| Index | Expected | Actual Output | +|-------|----------|---------------| +| 20 | `2171218` | `: 2171212181212181212181218<\|eot_id\|>` | +| 31 | `9333700` | `:<\|eot_id\|>` | +| 32 | `7121355` | `:9651<\|eot_id\|>` | +| 40 | `3112652` | `:285<\|eot_id\|>` | +| 41 | `3427461` | `:<\|eot_id\|>` | +| 45 | `8217547` | `:<\|eot_id\|>` | +| 51 | `1514340` | `: 1514343403361.<\|eot_id\|>` | +| 54 | `8212753` | `:<\|eot_id\|>` | +| 59 | `6587964` | `:<\|eot_id\|>` | +| 63 | `1688246` | `:<\|eot_id\|>` | +| 64 | `8344365` | `: 834436, but it is not explicitly mentioned.<\|eot_id\|>` | +| 65 | `6614484` | `: 4367.<\|eot_id\|>` | +| 67 | `6510922` | `:7780<\|eot_id\|>` | +| 69 | `6649968` | `: 43610.<\|eot_id\|>` | +| 71 | `9437374` | `:<\|eot_id\|>` | +| 74 | `6625238` | `:1472908<\|eot_id\|>` | + +--- + +### niah_multikey_2 (30 errors) + +**错误 Sample IDs**: `2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65` + +| Index | Expected | Actual Output | +|-------|----------|---------------| +| 2 | `1535573` | `: 8651665.<\|eot_id\|>` | +| 13 | `2794159` | `: 5261593<\|eot_id\|>` | +| 21 | `8970232` | `:168<\|eot_id\|>` | +| 22 | `9134051` | `: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 38...` | +| 23 | `9696620` | `: 969662620969662, which is: 969662920, 96966220 is not actually me...` | +| 24 | `7071187` | ` 055055055.<\|eot_id\|>` | +| 25 | `5572782` | `: 5342494<\|eot_id\|>` | +| 28 | `4953027` | `:1687719<\|eot_id\|>` | +| 32 | `4259234` | `: 425923521250, but not found is: 425923751572250, however is: 4259...` | +| 34 | `3643022` | `: 3957500<\|eot_id\|>` | +| 38 | `2031469` | `: the text.<\|eot_id\|>` | +| 39 | `8740362` | `: 8740364 8740364 8740364 8740364 is: is: is: is: 874036...` | +| 40 | `7041770` | `:1682<\|eot_id\|>` | +| 41 | `1986258` | `:086.<\|eot_id\|>` | +| 42 | `5668574` | `:055.<\|eot_id\|>` | +| 43 | `8560471` | `:067<\|eot_id\|>` | +| 45 | `9973767` | `: 8420273<\|eot_id\|>` | +| 46 | `3960211` | `:0<\|eot_id\|>` | +| 47 | `8003271` | `: 60870870870870870870870870870870870870870870870870870870870870870...` | +| 49 | `8632309` | ` 303640 is640 640 640 640 640 640 640 640 640 640 640 640 640 640 640...` | +| 50 | `2318630` | `: 7780552.<\|eot_id\|>` | +| 53 | `3405052` | `:<\|eot_id\|>` | +| 54 | `5364945` | `: 536494, which is: 536494, which is: 536494494494494494494494494...` | +| 56 | `7319214` | `:7607607607607607607607607607607607607607607607607607607607607607607...` | +| 57 | `9206104` | `:7607607607607607607607607607607607607607607607607607607607607607607...` | +| 59 | `9555385` | `:7095<\|eot_id\|>` | +| 60 | `5727554` | `: 572755755755755755755755755755755755755755755755755755755755 is: 572...` | +| 63 | `1090767` | `:7607607607607607607607607607607607607607607607607607607607607607607...` | +| 64 | `6791240` | `:<\|eot_id\|>` | +| 65 | `7275999` | `:7607607607607607607607607607607607607607607607607607607607607607607...` | + +--- + +### niah_multikey_3 (24 errors) + +**错误 Sample IDs**: `11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52` + +| Index | Expected | Actual Output | +|-------|----------|---------------| +| 11 | `c73ed342-6523-4d4b-aa33-beb1c9007315` | `: 1d28b88b-b6a8-46ba-8e8f-56cbafbfd897.<\|eot_id\|>` | +| 18 | `87b8a762-1d1f-4e85-a5d1-caf284c95aa6` | `: 429a6676-5295-4ea2-a694-6aa949f48e31.<\|eot_id\|>` | +| 20 | `cce29702-134a-460c-979b-6f7ee7895280` | `:<\|eot_id\|>` | +| 23 | `ed344bfe-983f-4a21-af44-722e2517244c` | `: aec431e7d880a8dce2c023de24 is: aec43163-061a-4afe-b80a-f5bfb5e3c9...` | +| 24 | `4712ef99-a8d1-4388-8ca7-b08dd3505d77` | `:<\|eot_id\|>` | +| 25 | `46969ce7-0da0-49f8-87b2-845e7b8ef100` | `:<\|eot_id\|>` | +| 26 | `7cff3c66-6860-49e6-8ba5-002162c250c0` | `:4c7e-946b-30812edf965e<\|eot_id\|>` | +| 27 | `b63b4988-40bc-44b2-bf1c-ca95adbca4e9` | `:<\|eot_id\|>` | +| 29 | `6d94011c-f28a-4b0b-a2e2-fe34bb8b19a1` | `: 6d6d6d6d4b0e-52ce-44d9-a0f6-1ae405825615<\|eot_id\|>` | +| 30 | `7c33bb00-4ab4-4e4f-a78e-39f8f06d63eb` | ` d7a2-4b23-a2c0-8c859cb1fa96<\|eot_id\|>` | +| 33 | `b7c6b586-713a-4907-ad24-5c4f25aeb769` | `:1-4d2c-b42b-933ded2633d6<\|eot_id\|>` | +| 35 | `ac8a317b-a6bb-4327-90db-2a01622cb723` | `: d2f2f2f2f2f2f2f2d2d2f2d2d2d3d2f6b3d2f- is: d2dab is: is: is: i...` | +| 37 | `b187b337-3132-4376-a500-9340102092ae` | `:<\|eot_id\|>` | +| 40 | `2559fa56-dd0a-48d4-ba82-3ae2bf0a4b33` | `:358fe0e3-724e-4cfc-9ae0-d0873162626b.<\|eot_id\|>` | +| 41 | `7842feb5-e758-44cd-b73b-8ae08aa33142` | `: 6c6adf83-36a9-4e41-9cbe-60a8c9ffba92.<\|eot_id\|>` | +| 42 | `a1196139-f6fa-4c18-b3da-b7bd50362ac7` | `: a1196131396131196131399a1196139a1196139a1196139a1196139f6a1196139...` | +| 44 | `7d3d40b2-4594-4573-b267-4c6270dd4425` | `: 613a9e-4e7d-8c9f-740a630e3c53<\|eot_id\|>` | +| 45 | `500b8a75-8f05-43f5-b9ad-46d47d4e33fc` | `: 500b8a5e0e0e0a500b is: 500b is: 500b-4 is: is: is: is: is: i...` | +| 46 | `86a867a7-6a98-4a02-b065-70a33bafafde` | `:6139a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a...` | +| 47 | `7c0f7fd2-237e-4c0f-b3f5-f43623551169` | ` 5fb71d2f0f0b4f0 is: 5fb71 is: 5fb71f-4f-4f-4f-4f-4f-4d7 is: is: ...` | +| 48 | `b0e1f3f5-6570-437e-b8a1-f1b3f654e257` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` | +| 49 | `0153722a-70a8-4ec0-9f03-2b0930937e60` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` | +| 50 | `0a1ead51-0c39-4eeb-ac87-d146acdb1d4a` | `: 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...` | +| 52 | `ff686e85-3a9f-4635-95dd-f19e8ca68eb1` | ` ff686e686e686e686e686e686f686e6f686e6fb686f686f686f686f686f- is: f...` | + +--- + +### 错误模式分析 + +| 错误类型 | 示例 | 可能原因 | +|----------|------|----------| +| **空输出** | `:<\|eot_id\|>` | KV cache 完全损坏,模型无法生成有效内容 | +| **数字重复** | `:361361361...` | Chunk 边界 merge 错误导致 decode 进入循环 | +| **部分正确** | `9874152` → `:151:52` | 前几个 chunk 正确,后面 chunk 损坏 | +| **数字插入** | `2344047` → `23440447` | Merge 过程中数值精度损失 | +| **完全错误** | 输出完全不相关的 UUID | 严重的 KV cache 或 position encoding 错误 | + +--- + +### 验证脚本 + +用于验证修复效果的脚本: + +```python +# verify_samples.py +import json + +ERROR_SAMPLES = { + 'niah_single_1': [28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83], + 'niah_single_2': [16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93], + 'niah_single_3': [7, 9, 14, 24, 25, 29, 31, 43], + 'niah_multikey_1': [20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74], + 'niah_multikey_2': [2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65], + 'niah_multikey_3': [11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52], +} + +def verify_fix(result_dir: str) -> dict: + """验证修复效果""" + fixed = {} + for task, error_ids in ERROR_SAMPLES.items(): + pred_file = f"{result_dir}/{task}.jsonl" + with open(pred_file) as f: + predictions = [json.loads(line) for line in f] + + fixed_count = 0 + still_error = [] + for idx in error_ids: + pred = predictions[idx] + expected = pred['answer'] + actual = pred['model_output'] + if expected in actual: + fixed_count += 1 + else: + still_error.append(idx) + + fixed[task] = { + 'total_errors': len(error_ids), + 'fixed': fixed_count, + 'still_error': still_error, + 'fix_rate': f"{fixed_count/len(error_ids)*100:.1f}%" + } + + return fixed +``` + +--- + +**Author**: Zijie Tian +**Created**: 2026-01-20