Files
nano-vllm/docs/chunked_attention_solutions.md
Zijie Tian 6180055ed8 📝 docs: add chunked attention solutions guide and update doc index
Add comprehensive documentation analyzing the 32K chunked offload
accuracy issues with proposed solutions covering LSE precision,
ring buffer state management, and position encoding validation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 04:48:20 +08:00

36 KiB
Raw Permalink Blame History

Chunked Attention 准确性问题解决方案

Status: 🟡 IN PROGRESS Related Issue: ruler_32k_chunked_offload_issue.md Created: 2026-01-20 Author: Zijie Tian


概述

本文档基于对 chunked attention 代码的深入分析,详细描述了导致 RULER 32K 测试 20% 错误率的潜在代码问题,以及对应的解决方案。

核心问题: Chunked prefill 机制在处理 32K 长序列时,由于多个因素的累积效应,导致输出准确性下降。


问题 1: flash_attn_with_lse 的 LSE 获取方式

问题描述

文件: nanovllm/ops/chunked_attention.py 位置: 第 264-269 行

def flash_attn_with_lse(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    from flash_attn.flash_attn_interface import flash_attn_func

    # ...

    # 使用 return_attn_probs=True 来获取 LSE
    out, lse, _ = flash_attn_func(
        q, k, v,
        softmax_scale=softmax_scale,
        causal=causal,
        return_attn_probs=True,  # ⚠️ 这个 API 不是专门用于 LSE 输出的
    )

    # lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
    lse = lse[:, :, :seqlen_q]
    return out, lse

问题分析

  1. API 设计目的不匹配: return_attn_probs=True 的主要目的是返回 attention probabilities 用于调试,不是为了高精度 LSE 输出
  2. 潜在精度损失: Flash Attention 内部可能对 LSE 进行了某些优化,这些优化可能牺牲精度
  3. S_dmask 返回值: 第三个返回值 S_dmask 在某些版本中可能包含随机 dropout mask影响结果

验证方法

def verify_lse_accuracy():
    """验证 flash_attn 返回的 LSE 是否准确"""
    import torch
    from flash_attn.flash_attn_interface import flash_attn_func

    # 小规模数据用于精确计算
    batch, seqlen, nheads, headdim = 1, 128, 8, 64
    q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)
    k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)
    v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.float16)

    # Flash Attention LSE
    _, lse_flash, _ = flash_attn_func(q, k, v, return_attn_probs=True)

    # 手动计算 LSE (使用 PyTorch)
    scale = 1.0 / (headdim ** 0.5)
    qk = torch.einsum('bqhd,bkhd->bhqk', q, k) * scale
    lse_manual = torch.logsumexp(qk, dim=-1)  # [batch, heads, seqlen_q]

    # 比较
    diff = (lse_flash[:, :, :seqlen] - lse_manual).abs()
    print(f"LSE max diff: {diff.max().item():.6f}")
    print(f"LSE mean diff: {diff.mean().item():.6f}")

    return diff.max().item() < 1e-3

解决方案

方案 A: 使用 FlashInfer 的精确 LSE 接口 (推荐)

FlashInfer 提供了专门用于 chunked attention 的接口:

def flash_attn_with_lse_flashinfer(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """使用 FlashInfer 的接口获取精确 LSE"""
    import flashinfer

    batch, seqlen_q, nheads_q, headdim = q.shape
    _, seqlen_k, nheads_kv, _ = k.shape

    if softmax_scale is None:
        softmax_scale = 1.0 / math.sqrt(headdim)

    # FlashInfer 的 single_prefill_with_kv_cache_return_lse
    # 专门设计用于返回精确的 LSE
    out, lse = flashinfer.single_prefill_with_kv_cache_return_lse(
        q.view(batch * seqlen_q, nheads_q, headdim),
        k.view(batch * seqlen_k, nheads_kv, headdim),
        v.view(batch * seqlen_k, nheads_kv, headdim),
        causal=causal,
        sm_scale=softmax_scale,
    )

    # Reshape outputs
    out = out.view(batch, seqlen_q, nheads_q, headdim)
    lse = lse.view(batch, seqlen_q, nheads_q).transpose(1, 2)  # [batch, nheads, seqlen_q]

    return out, lse

方案 B: 自定义 Triton Kernel

如果 FlashInfer 不可用,使用自定义的 Triton kernel

# 已经在 chunked_attention.py 中实现: _fwd_kernel_with_lse
# 这个 kernel 直接在 Triton 中计算并输出 LSE
# 优点: 完全控制精度
# 缺点: 可能比 Flash Attention 慢

方案 C: 混合方法

def flash_attn_with_lse_hybrid(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    use_triton_kernel: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    根据情况选择最佳实现:
    - 短序列 (< 4K): 使用 Flash Attention (快)
    - 长序列 (>= 4K): 使用 Triton kernel (精确)
    """
    seqlen = q.shape[1]

    if seqlen < 4096 and not use_triton_kernel:
        return flash_attn_with_lse(q, k, v, softmax_scale, causal)
    else:
        return triton_flash_attn_with_lse(q, k, v, softmax_scale, causal)

预期改进

  • 修复 LSE 精度问题后,预计错误率降低 5-10%

问题 2: 累积 Merge 次数过多导致误差累积

问题描述

文件: nanovllm/kvcache/sparse/full_policy.py 位置: 第 136-137 行, 163-164 行

# compute_chunked_prefill 中的 merge 循环
for block_idx in range(num_blocks):
    # ... 加载 block ...
    prev_o, prev_lse = flash_attn_with_lse(
        q_batched, prev_k, prev_v,
        softmax_scale=softmax_scale,
        causal=False,
    )
    if o_acc is None:
        o_acc, lse_acc = prev_o, prev_lse
    else:
        # ⚠️ 每个 block 都执行一次 merge
        o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)

问题分析

对于 32K context (1024 block size),需要执行 ~32 次 merge。每次 merge 的数学公式:

max_lse = max(lse1, lse2)
exp1 = exp(lse1 - max_lse)
exp2 = exp(lse2 - max_lse)
o_merged = (o1 * exp1 + o2 * exp2) / (exp1 + exp2)
lse_merged = max_lse + log(exp1 + exp2)

误差来源:

  1. exp()log() 操作在边界情况下精度损失
  2. 除法操作 / (exp1 + exp2) 可能有精度问题
  3. 32 次累积后,误差可达到显著水平

数学分析

假设每次 merge 引入相对误差 ε ≈ 10^-6 (fp32):

  • 32 次累积后: 总误差 ≈ 32 × ε = 3.2 × 10^-5
  • 对于 BFloat16 (7-bit mantissa): ε ≈ 10^-2, 总误差 ≈ 0.32 (32%)

验证方法

def measure_merge_error_accumulation():
    """测量 merge 误差随 chunk 数量的变化"""
    import torch
    from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
    from flash_attn.flash_attn_interface import flash_attn_func

    torch.manual_seed(42)

    batch, seqlen, nheads, headdim = 1, 32768, 32, 128
    q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
    k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
    v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)

    # Reference: full attention
    out_ref = flash_attn_func(q, k, v, causal=False)

    results = []
    for num_chunks in [4, 8, 16, 32, 64]:
        chunk_size = seqlen // num_chunks
        o_acc, lse_acc = None, None

        for i in range(num_chunks):
            start = i * chunk_size
            end = (i + 1) * chunk_size
            k_chunk = k[:, start:end, :, :]
            v_chunk = v[:, start:end, :, :]

            chunk_o, chunk_lse = flash_attn_with_lse(q, k_chunk, v_chunk, causal=False)

            if o_acc is None:
                o_acc, lse_acc = chunk_o, chunk_lse
            else:
                o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)

        diff = (out_ref - o_acc).abs()
        results.append({
            'num_chunks': num_chunks,
            'max_diff': diff.max().item(),
            'mean_diff': diff.mean().item(),
        })
        print(f"chunks={num_chunks:3d}, max_diff={diff.max().item():.6f}, mean_diff={diff.mean().item():.8f}")

    return results

解决方案

方案 A: N-way Merge (推荐)

代替二叉树式的 pairwise merge实现一次性 merge 多个 chunk

def merge_attention_outputs_nway(
    outputs: List[torch.Tensor],  # List of [batch, seqlen_q, nheads, headdim]
    lses: List[torch.Tensor],     # List of [batch, nheads, seqlen_q]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    N-way merge: 一次性 merge 所有 chunk减少累积误差

    数学原理:
    - max_lse = max(lse_1, lse_2, ..., lse_n)
    - sum_exp = sum(exp(lse_i - max_lse) for i in 1..n)
    - o_merged = sum(o_i * exp(lse_i - max_lse) for i in 1..n) / sum_exp
    - lse_merged = max_lse + log(sum_exp)
    """
    n = len(outputs)
    if n == 0:
        raise ValueError("Need at least one output to merge")
    if n == 1:
        return outputs[0], lses[0]

    # Stack for batch processing
    o_stack = torch.stack(outputs, dim=0)  # [n, batch, seqlen, heads, dim]
    lse_stack = torch.stack(lses, dim=0)    # [n, batch, heads, seqlen]

    # Compute max_lse across all chunks (in fp32 for precision)
    lse_fp32 = lse_stack.float()
    max_lse = lse_fp32.max(dim=0).values  # [batch, heads, seqlen]

    # Compute exp(lse_i - max_lse) for each chunk
    exp_weights = torch.exp(lse_fp32 - max_lse.unsqueeze(0))  # [n, batch, heads, seqlen]

    # Normalize weights
    sum_exp = exp_weights.sum(dim=0)  # [batch, heads, seqlen]
    normalized_weights = exp_weights / sum_exp.unsqueeze(0)  # [n, batch, heads, seqlen]

    # Weighted sum of outputs
    # outputs: [n, batch, seqlen, heads, dim]
    # weights: [n, batch, heads, seqlen] -> [n, batch, seqlen, heads, 1]
    weights_expanded = normalized_weights.permute(0, 1, 3, 2).unsqueeze(-1)
    o_merged = (o_stack.float() * weights_expanded).sum(dim=0).to(outputs[0].dtype)

    # Compute merged LSE
    lse_merged = (max_lse + torch.log(sum_exp)).to(lses[0].dtype)

    return o_merged, lse_merged

使用方式:

# 原来的逐个 merge
for block_idx in range(num_blocks):
    chunk_o, chunk_lse = flash_attn_with_lse(...)
    if o_acc is None:
        o_acc, lse_acc = chunk_o, chunk_lse
    else:
        o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse)

# 改为: 收集所有 chunk然后一次性 merge
chunk_outputs = []
chunk_lses = []
for block_idx in range(num_blocks):
    chunk_o, chunk_lse = flash_attn_with_lse(...)
    chunk_outputs.append(chunk_o)
    chunk_lses.append(chunk_lse)

o_acc, lse_acc = merge_attention_outputs_nway(chunk_outputs, chunk_lses)

方案 B: Kahan Summation

使用 Kahan 求和算法减少浮点累积误差:

def merge_attention_outputs_kahan(
    o1: torch.Tensor, lse1: torch.Tensor,
    o2: torch.Tensor, lse2: torch.Tensor,
    compensation: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    使用 Kahan 补偿算法的 merge保持累积误差补偿项

    Returns:
        o_merged, lse_merged, (output_comp, lse_comp)
    """
    # ... 标准 merge 计算 ...
    max_lse = torch.maximum(lse1.float(), lse2.float())
    exp1 = torch.exp(lse1.float() - max_lse)
    exp2 = torch.exp(lse2.float() - max_lse)
    sum_exp = exp1 + exp2

    o_merged_raw = (o1.float() * exp1.unsqueeze(-1) + o2.float() * exp2.unsqueeze(-1)) / sum_exp.unsqueeze(-1)
    lse_merged_raw = max_lse + torch.log(sum_exp)

    # Kahan compensation
    if compensation is not None:
        output_comp, lse_comp = compensation
        # Apply compensation
        y = o_merged_raw - output_comp
        t = o_merged_raw  # We need to track the actual result
        output_comp_new = (t - o_merged_raw) + y
        # ... similar for lse ...
    else:
        output_comp_new = torch.zeros_like(o_merged_raw)
        lse_comp_new = torch.zeros_like(lse_merged_raw)

    return o_merged_raw.to(o1.dtype), lse_merged_raw.to(lse1.dtype), (output_comp_new, lse_comp_new)

方案 C: 分层 Merge (折中方案)

将 32 个 chunk 分组,先组内 merge再组间 merge

原来: c1 -> c2 -> c3 -> ... -> c32 (31 次 merge)
改为:
  Group1: c1-c4 -> g1 (3 次)
  Group2: c5-c8 -> g2 (3 次)
  ...
  Group8: c29-c32 -> g8 (3 次)
  Final: g1-g8 -> result (7 次)
  总计: 3*8 + 7 = 31 次 merge但最大累积深度从 31 降到 7
def merge_attention_outputs_hierarchical(
    outputs: List[torch.Tensor],
    lses: List[torch.Tensor],
    group_size: int = 4,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    分层 merge: 先分组 merge再合并组
    减少累积深度从 O(n) 到 O(log(n))
    """
    from nanovllm.ops.chunked_attention import merge_attention_outputs

    n = len(outputs)
    if n <= 1:
        return outputs[0] if n == 1 else (None, None)

    # Level 1: Merge within groups
    group_outputs = []
    group_lses = []

    for i in range(0, n, group_size):
        group = outputs[i:i+group_size]
        group_l = lses[i:i+group_size]

        o_g, lse_g = group[0], group_l[0]
        for j in range(1, len(group)):
            o_g, lse_g = merge_attention_outputs(o_g, lse_g, group[j], group_l[j])

        group_outputs.append(o_g)
        group_lses.append(lse_g)

    # Recursively merge groups
    if len(group_outputs) > 1:
        return merge_attention_outputs_hierarchical(group_outputs, group_lses, group_size)
    else:
        return group_outputs[0], group_lses[0]

预期改进

  • N-way merge: 预计错误率降低 10-15%
  • 分层 merge: 预计错误率降低 5-8%

问题 3: Ring Buffer Slot 数量不足

问题描述

文件: nanovllm/kvcache/offload_engine.py 位置: Ring buffer 初始化部分

# 2-slot 配置 (原始)
num_gpu_blocks = 2  # 仅 2 个 GPU block 用于 ring buffer

# 问题:
# - Slot 0: 用于 decode
# - Slot 1: 用于 prefill 加载
# - 32 个 chunk 共享 1 个 prefill slot
# - 高频率的 CPU<->GPU 传输可能导致竞态条件

问题分析

  1. 竞态条件: 当 chunk N 的 compute 还在进行时chunk N+1 的 transfer 可能覆盖同一个 slot
  2. 同步开销: 频繁的 wait_slot_layer() 调用增加延迟
  3. 已验证: 4-slot 配置显著改善但未完全解决问题

验证结果 (已完成)

配置 niah_single_1 准确率 niah_multikey_3 准确率
2-slot 94% 48%
4-slot 98% 56%
改进 +4% +8%

关键发现: Sample 40 在 4-slot 配置下仍然产生相同错误 (6171717161711716),说明 slot 数量不是唯一原因。

解决方案

方案 A: 增加 Slot 数量到 8

# offload_engine.py 配置
class OffloadEngineConfig:
    # 原来
    # num_ring_slots = 2

    # 建议
    num_ring_slots = 8  # 增加到 8 个 slots

    # 或者动态计算
    @property
    def num_ring_slots(self):
        # 根据可用 GPU 内存动态调整
        available_memory_gb = get_available_gpu_memory()
        slot_memory_gb = self.block_size * self.kv_cache_dtype_size * 2 / 1e9
        max_slots = int(available_memory_gb * 0.3 / slot_memory_gb)  # 使用 30% 内存
        return min(max(max_slots, 4), 16)  # 限制在 4-16 之间

方案 B: 改进同步机制

def load_to_slot_layer_with_barrier(self, slot: int, layer_id: int, cpu_block_id: int):
    """
    带有显式 barrier 的加载,确保前一个操作完成
    """
    # 等待该 slot 上所有挂起的操作完成
    if self.slot_events[slot] is not None:
        self.slot_events[slot].synchronize()

    # 执行传输
    self._do_transfer(slot, layer_id, cpu_block_id)

    # 记录新的事件
    self.slot_events[slot] = torch.cuda.Event()
    self.slot_events[slot].record(self.transfer_stream)

方案 C: Double Buffering

class DoubleBufferedRingBuffer:
    """
    双缓冲设计: 每个逻辑 slot 有两个物理 buffer
    - 一个用于当前 compute
    - 一个用于下一个 transfer
    """
    def __init__(self, num_logical_slots: int, ...):
        self.num_logical_slots = num_logical_slots
        self.num_physical_slots = num_logical_slots * 2
        self.current_buffer = [0] * num_logical_slots  # 每个 slot 当前使用的 buffer (0 or 1)

    def get_compute_slot(self, logical_slot: int) -> int:
        """获取用于 compute 的物理 slot"""
        return logical_slot * 2 + self.current_buffer[logical_slot]

    def get_transfer_slot(self, logical_slot: int) -> int:
        """获取用于 transfer 的物理 slot (另一个 buffer)"""
        return logical_slot * 2 + (1 - self.current_buffer[logical_slot])

    def swap_buffer(self, logical_slot: int):
        """交换 buffer"""
        self.current_buffer[logical_slot] = 1 - self.current_buffer[logical_slot]

预期改进

  • 8-slot 配置: 预计错误率再降低 3-5%
  • Double buffering: 预计错误率再降低 2-3%

问题 4: Sparse Policy Mapping Bug

问题描述

文件: COMPASS/eval/RULER/scripts/pred/call_api.py (COMPASS 项目) 位置: metric 到 policy 的映射

metric_to_policy = {
    'full': 'FULL',
    'xattn': 'XATTN',       # ❌ XATTN 不存在于 SparsePolicyType
    'compass': 'XATTN',
    'minfer': 'MINFERENCE', # ❌ MINFERENCE 不存在于 SparsePolicyType
    'avgpool': 'FULL',
    'flex': 'FLEXPREFILL',
}

sparse_policy_name = metric_to_policy.get(self.metric, 'FULL')
sparse_policy = getattr(SparsePolicyType, sparse_policy_name, SparsePolicyType.FULL)
# ⚠️ 静默回退到 FULL没有警告!

问题分析

  1. 静默回退: getattr(..., SparsePolicyType.FULL) 使得无效名称静默回退
  2. 测试配置错误: 用户以为在测试 xattn,实际上在测试 FULL
  3. 结果不可靠: 所有 sparse 方法的测试结果可能都是 FULL 的结果

验证方法

# 检查当前 SparsePolicyType 的有效值
from nanovllm.kvcache.sparse.policy import SparsePolicyType
print([e.name for e in SparsePolicyType])
# 应该输出类似: ['FULL', 'QUEST', 'XATTN_BSA', ...]

# 验证映射是否正确
for metric, policy_name in metric_to_policy.items():
    has_policy = hasattr(SparsePolicyType, policy_name)
    print(f"{metric} -> {policy_name}: {'✓' if has_policy else '✗'}")

解决方案

方案: 修复映射并添加验证

# 正确的映射 (需要根据 SparsePolicyType 定义更新)
metric_to_policy = {
    'full': 'FULL',
    'xattn': 'XATTN_BSA',    # ✓ 使用正确的枚举名
    'compass': 'XATTN_BSA',
    'minfer': 'MINFERENCE',  # 需要确认是否存在
    'avgpool': 'AVGPOOL',    # 需要确认是否存在
    'flex': 'FLEXPREFILL',   # 需要确认是否存在
    'quest': 'QUEST',        # 添加 quest
}

def get_sparse_policy(metric: str) -> SparsePolicyType:
    """获取 sparse policy带有验证"""
    policy_name = metric_to_policy.get(metric)
    if policy_name is None:
        raise ValueError(f"Unknown metric: {metric}. Valid options: {list(metric_to_policy.keys())}")

    if not hasattr(SparsePolicyType, policy_name):
        raise ValueError(f"SparsePolicyType.{policy_name} does not exist. "
                        f"Available: {[e.name for e in SparsePolicyType]}")

    return getattr(SparsePolicyType, policy_name)

预期改进

  • 这个 bug 不影响 chunked offload 准确性
  • 但修复后可以正确测试各种 sparse 方法的实际效果

问题 5: Decode Buffer 读取边界条件

问题描述

文件: nanovllm/kvcache/sparse/full_policy.py 位置: 第 269-272 行

def compute_chunked_decode(self, ...):
    # ...
    seq_len = len(seq)
    decode_pos_in_block = (seq_len - 1) % block_size
    decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
    decode_start_pos_in_block = decode_start_pos % block_size
    num_accumulated = decode_pos_in_block - decode_start_pos_in_block + 1

    # ⚠️ 潜在问题: 当 decode_pos_in_block < decode_start_pos_in_block 时
    # num_accumulated 会是负数

问题分析

场景:当 decode 跨越 block 边界时:

  • 假设 block_size = 1024
  • decode_start_pos = 1020 (在 block 边界附近)
  • 当前 seq_len = 1026 (跨越到下一个 block)
  • decode_pos_in_block = (1026 - 1) % 1024 = 1
  • decode_start_pos_in_block = 1020 % 1024 = 1020
  • num_accumulated = 1 - 1020 + 1 = -1018

解决方案

def compute_chunked_decode(self, ...):
    # ...
    seq_len = len(seq)
    decode_start_pos = kvcache_manager.get_decode_start_pos(seq)

    # 计算 decode buffer 中的实际 token 数量
    # 使用绝对位置而非块内相对位置
    num_accumulated = seq_len - decode_start_pos

    # 验证
    assert num_accumulated > 0, f"Invalid decode range: seq_len={seq_len}, decode_start={decode_start_pos}"

    # 计算块内偏移用于读取
    decode_start_pos_in_block = decode_start_pos % block_size
    decode_end_pos_in_block = (seq_len - 1) % block_size

    # 处理跨块情况
    if decode_end_pos_in_block < decode_start_pos_in_block:
        # 跨块: 需要从两个位置读取
        # 这种情况应该在 kvcache_manager 层面处理
        raise NotImplementedError("Cross-block decode not yet supported")

    # 同块: 正常读取
    decode_k = offload_engine.decode_k_buffer[
        layer_id,
        decode_start_pos_in_block:decode_end_pos_in_block+1
    ]
    decode_v = offload_engine.decode_v_buffer[
        layer_id,
        decode_start_pos_in_block:decode_end_pos_in_block+1
    ]

预期改进

  • 修复边界条件后,预计消除特定情况下的崩溃或错误输出

问题 6: Merge Kernel Block Size 优化

问题描述

文件: nanovllm/ops/chunked_attention.py 位置: 第 406 行

def merge_attention_outputs(...):
    # ...
    # Launch output merge kernel
    BLOCK_SIZE = 128  # 固定值
    grid_output = (batch, seqlen_q, nheads)
    _merge_output_kernel[grid_output](
        o1, o2, lse1, lse2, o_merged,
        batch, seqlen_q, nheads, headdim,
        BLOCK_SIZE=BLOCK_SIZE,
    )

问题分析

  1. 固定 BLOCK_SIZE: 128 可能不是所有 headdim 的最优值
  2. 性能影响: 不当的 BLOCK_SIZE 可能导致 GPU 占用率低
  3. 精度影响: 某些 BLOCK_SIZE 可能导致边界处理问题

解决方案

def merge_attention_outputs(...):
    batch, seqlen_q, nheads, headdim = o1.shape

    # 动态选择 BLOCK_SIZE
    # 原则: BLOCK_SIZE 应该是 headdim 的因子,且适合 GPU warp (32)
    if headdim <= 64:
        BLOCK_SIZE = 64
    elif headdim <= 128:
        BLOCK_SIZE = 128
    else:
        BLOCK_SIZE = 256

    # 确保 BLOCK_SIZE 是 headdim 的因子
    while headdim % BLOCK_SIZE != 0 and BLOCK_SIZE > 32:
        BLOCK_SIZE //= 2

    # ...

预期改进

  • 这是性能优化,对准确性影响较小

综合解决方案优先级

优先级 问题 解决方案 预期改进 实现难度
P0 N-way Merge 实现 merge_attention_outputs_nway 10-15%
P0 LSE 精度 使用 FlashInfer 或验证 flash_attn LSE 5-10%
P1 Ring Buffer Slots 增加到 8 slots + double buffering 3-5%
P1 Policy Mapping 修复 COMPASS 中的映射 bug N/A
P2 Decode 边界 添加边界检查和处理 1-2%
P2 Merge Block Size 动态选择 BLOCK_SIZE <1%

验证计划

阶段 1: 误差追踪 (先执行)

在实施任何修改之前,添加详细的误差追踪:

# 在 full_policy.py 中添加
def compute_chunked_prefill_with_debug(self, ...):
    debug_info = {
        'chunk_max_diffs': [],
        'lse_values': [],
        'merge_count': 0,
    }

    # ... 原有逻辑 ...

    for block_idx in range(num_blocks):
        # ... 计算 ...

        if o_acc is not None:
            # 记录 merge 前后的差异
            pre_merge_o = o_acc.clone()
            o_acc, lse_acc = merge_attention_outputs(...)
            debug_info['chunk_max_diffs'].append((o_acc - pre_merge_o).abs().max().item())
            debug_info['lse_values'].append(lse_acc.mean().item())
            debug_info['merge_count'] += 1

    # 保存 debug 信息
    torch.save(debug_info, f'debug_prefill_layer{layer_id}.pt')

阶段 2: 单一改进测试

每次只应用一个改进,测试效果:

  1. 仅应用 N-way merge -> 测试准确率
  2. 仅增加 ring buffer slots -> 测试准确率
  3. 仅使用 FlashInfer LSE -> 测试准确率

阶段 3: 组合测试

组合效果最好的改进,验证是否有叠加效果:

最终配置 = N-way merge + 8 slots + FlashInfer LSE
目标: 错误率 < 10% (接近 xattn_stride8 的 8% baseline)

附录: 相关文件索引

文件 关键函数/类 问题编号
nanovllm/ops/chunked_attention.py flash_attn_with_lse, merge_attention_outputs 1, 2, 6
nanovllm/kvcache/sparse/full_policy.py compute_chunked_prefill, compute_chunked_decode 2, 5
nanovllm/kvcache/offload_engine.py Ring buffer 管理 3
COMPASS/eval/RULER/scripts/pred/call_api.py metric_to_policy 4

附录: 错误样本详情

以下是 RULER 32K 测试中出错的样本,用于后续验证修复效果。

错误统计

Task Total Samples Errors Error Rate
niah_single_1 100 19 19%
niah_single_2 100 23 23%
niah_single_3 100 8 8%
niah_multikey_1 100 16 16%
niah_multikey_2 100 30 30%
niah_multikey_3 100 24 24%
TOTAL 600 120 20%

niah_single_1 (19 errors)

错误 Sample IDs: 28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83

Index Expected Actual Output
28 9874152 :151:52<|eot_id|>
33 9196204 :<|eot_id|>
39 3484601 :<|eot_id|>
40 6171716 : 17: 16<|eot_id|>
41 4524499 :<|eot_id|>
43 3726327 : 16: 7<|eot_id|>
44 4009172 : 2<|eot_id|>
49 4240180 :354:180<|eot_id|>
51 9546409 :<|eot_id|>
52 2935113 : 29351113.<|eot_id|>
53 5453786 :354:678:90<|eot_id|>
57 8315831 : 5831<|eot_id|>
61 5960271 : 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,...<|eot_id|>
63 6049101 : 5 0 4 9 1 0 1<|eot_id|>
65 6406444 :361361361361361361361361361361361361361361361361361361361361361361361361361361...<|eot_id|>
67 2422633 :31<|eot_id|>
72 7442089 7953166<|eot_id|>
77 8795419 :<|eot_id|>
83 6363836 : 2<|eot_id|>

niah_single_2 (23 errors)

错误 Sample IDs: 16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93

Index Expected Actual Output
16 2344047 : 23440447.<|eot_id|>
24 5449324 :<|eot_id|>
30 5727085 :<|eot_id|>
32 9196204 :<|eot_id|>
40 4524499 :460<|eot_id|>
41 7817881 :171.<|eot_id|>
42 3726327 :<|eot_id|>
50 9546409 :<|eot_id|>
51 2935113 : 3: 5113<|eot_id|>
52 5453786 :354<|eot_id|>
55 4188992 : 418899189418899, but it is not explicitly stated...
58 6266630 :5963<|eot_id|>
60 5960271 0271<|eot_id|>
62 6049101 :<|eot_id|>
64 6406444 :<|eot_id|>
66 2422633 :5313<|eot_id|>
67 4940441 :5311<|eot_id|>
68 3472189 :361.<|eot_id|>
69 8971465 :361.<|eot_id|>
77 8963715 : 0 8 9 7 1 5<|eot_id|>
85 2044645 : 20446445.<|eot_id|>
91 7783308 :<|eot_id|>
93 1454696 :<|eot_id|>

niah_single_3 (8 errors)

错误 Sample IDs: 7, 9, 14, 24, 25, 29, 31, 43

Index Expected Actual Output
7 ee87905e-4ca4-45ea-8dfa-6a56d12dbc9a : 2010-07-01T00:00:00Z<|eot_id|>
9 b7b56ea7-35eb-432d-9ad6-20ab48212ddb :0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0<|eot_id|>
14 e767dcea-b0e6-4969-a213-42b0f1eedba3 :0e6-4969-a213-42b0f1eedba3<|eot_id|>
24 59e4b671-4774-4c58-85f8-bc16f7860b50 :4774:4c58:85f8:bc16f7860b50<|eot_id|>
25 54c63cd8-8945-4f27-97fa-2d8dfb2ca025 : 54c63c63cd8-8945-4f27-97fa-2d8dfb2ca025.<|eot_id|>
29 006ed6e3-6fa1-4735-b572-f3d00b5cea6a :6e3-6fa1-4735-b572-f3d00b5cea6a<|eot_id|>
31 e6697833-b841-40a0-9fe7-71d6d9178793 : e6697837837833-b841-40a0-9fe7-71d6d9178793.<|eot_id|>
43 d92c9227-eadf-4085-bfcb-75468eb22579 : d92c922c9227-eadf-4085-bfcb-75468eb22579.<|eot_id|>

niah_multikey_1 (16 errors)

错误 Sample IDs: 20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74

Index Expected Actual Output
20 2171218 : 2171212181212181212181218<|eot_id|>
31 9333700 :<|eot_id|>
32 7121355 :9651<|eot_id|>
40 3112652 :285<|eot_id|>
41 3427461 :<|eot_id|>
45 8217547 :<|eot_id|>
51 1514340 : 1514343403361.<|eot_id|>
54 8212753 :<|eot_id|>
59 6587964 :<|eot_id|>
63 1688246 :<|eot_id|>
64 8344365 : 834436, but it is not explicitly mentioned.<|eot_id|>
65 6614484 : 4367.<|eot_id|>
67 6510922 :7780<|eot_id|>
69 6649968 : 43610.<|eot_id|>
71 9437374 :<|eot_id|>
74 6625238 :1472908<|eot_id|>

niah_multikey_2 (30 errors)

错误 Sample IDs: 2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65

Index Expected Actual Output
2 1535573 : 8651665.<|eot_id|>
13 2794159 : 5261593<|eot_id|>
21 8970232 :168<|eot_id|>
22 9134051 : 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 381:055: 38...
23 9696620 : 969662620969662, which is: 969662920, 96966220 is not actually me...
24 7071187 055055055.<|eot_id|>
25 5572782 : 5342494<|eot_id|>
28 4953027 :1687719<|eot_id|>
32 4259234 : 425923521250, but not found is: 425923751572250, however is: 4259...
34 3643022 : 3957500<|eot_id|>
38 2031469 : the text.<|eot_id|>
39 8740362 : 8740364 8740364 8740364 8740364 is: is: is: is: 874036...
40 7041770 :1682<|eot_id|>
41 1986258 :086.<|eot_id|>
42 5668574 :055.<|eot_id|>
43 8560471 :067<|eot_id|>
45 9973767 : 8420273<|eot_id|>
46 3960211 :0<|eot_id|>
47 8003271 : 60870870870870870870870870870870870870870870870870870870870870870...
49 8632309 303640 is640 640 640 640 640 640 640 640 640 640 640 640 640 640 640...
50 2318630 : 7780552.<|eot_id|>
53 3405052 :<|eot_id|>
54 5364945 : 536494, which is: 536494, which is: 536494494494494494494494494...
56 7319214 :7607607607607607607607607607607607607607607607607607607607607607607...
57 9206104 :7607607607607607607607607607607607607607607607607607607607607607607...
59 9555385 :7095<|eot_id|>
60 5727554 : 572755755755755755755755755755755755755755755755755755755755 is: 572...
63 1090767 :7607607607607607607607607607607607607607607607607607607607607607607...
64 6791240 :<|eot_id|>
65 7275999 :7607607607607607607607607607607607607607607607607607607607607607607...

niah_multikey_3 (24 errors)

错误 Sample IDs: 11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52

Index Expected Actual Output
11 c73ed342-6523-4d4b-aa33-beb1c9007315 : 1d28b88b-b6a8-46ba-8e8f-56cbafbfd897.<|eot_id|>
18 87b8a762-1d1f-4e85-a5d1-caf284c95aa6 : 429a6676-5295-4ea2-a694-6aa949f48e31.<|eot_id|>
20 cce29702-134a-460c-979b-6f7ee7895280 :<|eot_id|>
23 ed344bfe-983f-4a21-af44-722e2517244c : aec431e7d880a8dce2c023de24 is: aec43163-061a-4afe-b80a-f5bfb5e3c9...
24 4712ef99-a8d1-4388-8ca7-b08dd3505d77 :<|eot_id|>
25 46969ce7-0da0-49f8-87b2-845e7b8ef100 :<|eot_id|>
26 7cff3c66-6860-49e6-8ba5-002162c250c0 :4c7e-946b-30812edf965e<|eot_id|>
27 b63b4988-40bc-44b2-bf1c-ca95adbca4e9 :<|eot_id|>
29 6d94011c-f28a-4b0b-a2e2-fe34bb8b19a1 : 6d6d6d6d4b0e-52ce-44d9-a0f6-1ae405825615<|eot_id|>
30 7c33bb00-4ab4-4e4f-a78e-39f8f06d63eb d7a2-4b23-a2c0-8c859cb1fa96<|eot_id|>
33 b7c6b586-713a-4907-ad24-5c4f25aeb769 :1-4d2c-b42b-933ded2633d6<|eot_id|>
35 ac8a317b-a6bb-4327-90db-2a01622cb723 : d2f2f2f2f2f2f2f2d2d2f2d2d2d3d2f6b3d2f- is: d2dab is: is: is: i...
37 b187b337-3132-4376-a500-9340102092ae :<|eot_id|>
40 2559fa56-dd0a-48d4-ba82-3ae2bf0a4b33 :358fe0e3-724e-4cfc-9ae0-d0873162626b.<|eot_id|>
41 7842feb5-e758-44cd-b73b-8ae08aa33142 : 6c6adf83-36a9-4e41-9cbe-60a8c9ffba92.<|eot_id|>
42 a1196139-f6fa-4c18-b3da-b7bd50362ac7 : a1196131396131196131399a1196139a1196139a1196139a1196139f6a1196139...
44 7d3d40b2-4594-4573-b267-4c6270dd4425 : 613a9e-4e7d-8c9f-740a630e3c53<|eot_id|>
45 500b8a75-8f05-43f5-b9ad-46d47d4e33fc : 500b8a5e0e0e0a500b is: 500b is: 500b-4 is: is: is: is: is: i...
46 86a867a7-6a98-4a02-b065-70a33bafafde :6139a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a9a...
47 7c0f7fd2-237e-4c0f-b3f5-f43623551169 5fb71d2f0f0b4f0 is: 5fb71 is: 5fb71f-4f-4f-4f-4f-4f-4d7 is: is: ...
48 b0e1f3f5-6570-437e-b8a1-f1b3f654e257 : 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...
49 0153722a-70a8-4ec0-9f03-2b0930937e60 : 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...
50 0a1ead51-0c39-4eeb-ac87-d146acdb1d4a : 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b 500b ...
52 ff686e85-3a9f-4635-95dd-f19e8ca68eb1 ff686e686e686e686e686e686f686e6f686e6fb686f686f686f686f686f- is: f...

错误模式分析

错误类型 示例 可能原因
空输出 :<|eot_id|> KV cache 完全损坏,模型无法生成有效内容
数字重复 :361361361... Chunk 边界 merge 错误导致 decode 进入循环
部分正确 9874152:151:52 前几个 chunk 正确,后面 chunk 损坏
数字插入 234404723440447 Merge 过程中数值精度损失
完全错误 输出完全不相关的 UUID 严重的 KV cache 或 position encoding 错误

验证脚本

用于验证修复效果的脚本:

# verify_samples.py
import json

ERROR_SAMPLES = {
    'niah_single_1': [28, 33, 39, 40, 41, 43, 44, 49, 51, 52, 53, 57, 61, 63, 65, 67, 72, 77, 83],
    'niah_single_2': [16, 24, 30, 32, 40, 41, 42, 50, 51, 52, 55, 58, 60, 62, 64, 66, 67, 68, 69, 77, 85, 91, 93],
    'niah_single_3': [7, 9, 14, 24, 25, 29, 31, 43],
    'niah_multikey_1': [20, 31, 32, 40, 41, 45, 51, 54, 59, 63, 64, 65, 67, 69, 71, 74],
    'niah_multikey_2': [2, 13, 21, 22, 23, 24, 25, 28, 32, 34, 38, 39, 40, 41, 42, 43, 45, 46, 47, 49, 50, 53, 54, 56, 57, 59, 60, 63, 64, 65],
    'niah_multikey_3': [11, 18, 20, 23, 24, 25, 26, 27, 29, 30, 33, 35, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52],
}

def verify_fix(result_dir: str) -> dict:
    """验证修复效果"""
    fixed = {}
    for task, error_ids in ERROR_SAMPLES.items():
        pred_file = f"{result_dir}/{task}.jsonl"
        with open(pred_file) as f:
            predictions = [json.loads(line) for line in f]

        fixed_count = 0
        still_error = []
        for idx in error_ids:
            pred = predictions[idx]
            expected = pred['answer']
            actual = pred['model_output']
            if expected in actual:
                fixed_count += 1
            else:
                still_error.append(idx)

        fixed[task] = {
            'total_errors': len(error_ids),
            'fixed': fixed_count,
            'still_error': still_error,
            'fix_rate': f"{fixed_count/len(error_ids)*100:.1f}%"
        }

    return fixed

Author: Zijie Tian Created: 2026-01-20