diff --git a/docs/bench_offload_results.md b/docs/bench_offload_results.md index e35e206..ecb7b2e 100644 --- a/docs/bench_offload_results.md +++ b/docs/bench_offload_results.md @@ -28,7 +28,15 @@ | 32K | 4863 tok/s | 5587 tok/s | **+14.9%** ✅ | | 64K | 3373 tok/s | 4766 tok/s | **+41.3%** ✅ | -#### CPU Offload 模式 +#### CPU Offload 模式 (优化后, 2026-01-28) + +| 上下文 | Full Attention | XAttention | 相对性能 | +|--------|----------------|------------|----------| +| 32K | 4678 tok/s | 4398 tok/s | **-6.0%** | +| 64K | 3331 tok/s | 3203 tok/s | **-3.8%** | +| 128K | 2144 tok/s | 2196 tok/s | **+2.4%** ✅ | + +#### CPU Offload 模式 (优化前, 2026-01-27) | 上下文 | Full Attention | XAttention | 相对性能 | |--------|----------------|------------|----------| @@ -61,7 +69,8 @@ | 模式 | XAttention 效果 | 原因 | |------|-----------------|------| | **GPU-only** | ✅ 显著加速 (+15% ~ +41%) | 计算是瓶颈,稀疏注意力减少 FLOPs | -| **CPU Offload** | ❌ 性能下降 (-14% ~ -59%) | 传输是瓶颈,稀疏估计增加额外开销 | +| **CPU Offload (优化后)** | ✅ 长上下文略有收益 | estimate_block_size 优化减少估计开销 | +| **CPU Offload (优化前)** | ❌ 性能下降 (-14% ~ -59%) | 传输是瓶颈,稀疏估计增加额外开销 | ### 2. Block Size 对性能的影响 @@ -80,37 +89,46 @@ - 稀疏跳过的 blocks 比例更明显 - 但绝对性能极差,不推荐使用 -### 4. 性能下降随上下文增长加剧 +### 4. estimate_block_size 优化效果 (2026-01-28) ``` -Offload 模式 XAttention 相对性能: -32K: -14% (传输占 ~60%) -64K: -21% (传输占 ~70%) -128K: -59% (传输占 ~80%) +Offload 模式 XAttention 相对性能变化: + 优化前 优化后 改进 +32K: -13.9% -6.0% +7.9pp +64K: -20.6% -3.8% +16.8pp +128K: -59.1% +2.4% +61.5pp ✅ ``` -原因: -- 传输占比随上下文增长 -- XAttention 估计开销 O(num_chunks) 线性增长 -- 节省的计算量被传输瓶颈掩盖 +优化内容: +- `estimate_block_size` 从 4096 改为 1024 +- `softmax_fuse_block_sum` kernel 时间从 48% 降到 1% (44x 加速) +- 选择策略从 mask + voting 改为 score + threshold + +优化后结论: +- **128K 长上下文 XAttention 反超 Full Attention** +- 短上下文仍有少量开销,但已显著减少 ## 结论 -### 推荐配置 +### 推荐配置 (优化后, 2026-01-28) | 场景 | 推荐策略 | Block Size | |------|----------|------------| | GPU-only (VRAM 充足) | XAttention | 4096 | -| CPU Offload | Full Attention | 4096 | +| CPU Offload (128K+) | XAttention | 4096 | +| CPU Offload (32K-64K) | Full Attention 或 XAttention | 4096 | -### XAttention 适用条件 +### XAttention 适用条件 (优化后) ✅ **适合**: - GPU-only 模式(计算密集) +- CPU Offload + 长上下文(128K+)有正向收益 - 长上下文(64K+)收益更大 -❌ **不适合**: -- CPU Offload 模式(传输密集) +⚠️ **中性**: +- CPU Offload + 中等上下文(32K-64K):略慢 3-6%,可接受 + +❌ **不推荐**: - 短上下文(<32K)收益不明显 ## 运行命令 @@ -134,5 +152,6 @@ CUDA_VISIBLE_DEVICES=0 python bench_offload.py --enable-xattn --xattn-threshold ## 更新记录 +- 2026-01-28: **estimate_block_size 优化后重新测试**,128K XAttention 反超 Full (+2.4%) - 2026-01-27: 添加 GPU-only vs Offload 对比,block size 影响分析 - 2026-01-27: 初始测试,Llama-3.1-8B-Instruct, A100 80GB diff --git a/docs/estimate_block_size_performance.md b/docs/estimate_block_size_performance.md index 217f80e..b35726c 100644 --- a/docs/estimate_block_size_performance.md +++ b/docs/estimate_block_size_performance.md @@ -212,6 +212,47 @@ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \ 新策略更简洁,直接利用分级求和产生的 score,避免了 mask 生成和 voting 的复杂逻辑。 +## 实现状态 ✅ (2026-01-28) + +### 已实现 + +分级求和方案已在 `xattn_bsa.py` 中实现: + +```python +class XAttentionBSAPolicy: + def __init__(self, ..., estimate_block_size: int = 1024): + self.estimate_block_size = estimate_block_size # 新参数 + + def select_blocks(self, ...): + # Step 2: Hierarchical softmax_fuse_block_sum + reshaped_est_bs = estimate_bs // self.stride # 1024/8 = 128 + block_sums_fine = softmax_fuse_block_sum(attn_scores, reshaped_est_bs, ...) + + # Step 3: Aggregate to CPU block level + block_sums_coarse = block_sums_fine.view(..., num_cpu_blocks, ratio).sum(dim=-1) + cpu_block_scores = block_sums_coarse.sum(dim=2) + + # Step 4: Score + threshold selection (replaces mask + voting) + scores_per_block = cpu_block_scores.mean(dim=(0, 1)) + # ... cumulative threshold selection +``` + +### 实测结果 (Nsys Profiling) + +| Kernel | 优化前 | 优化后 | 改进 | +|--------|--------|--------|------| +| softmax_fuse_block_sum 占比 | 48.1% | **1.1%** | **44x** | +| softmax_fuse_block_sum 平均时间 | ~2ms | 489us | **4x** | + +### 端到端性能 (32K context) + +| 指标 | FULL Policy | XATTN Policy | 改进 | +|------|-------------|--------------|------| +| Prefill throughput | 3511 tok/s | 3695 tok/s | +5% | +| TTFT | 9327 ms | 8863 ms | -5% | + ## 结论 当前 estimate 阶段使用全局 `kvcache_block_size=4096` 导致 `softmax_fuse_block_sum` kernel 性能处于最差点。通过将 estimate block_size 改为 512-1024,可以获得 **15x** 的性能提升,显著降低 estimate 阶段的开销。 + +**⚠️ 重要变更**: 选择策略从 `mask + majority voting` 改为 `score + threshold`,更简洁且更直接。 diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 8b079a1..879a7a9 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -95,6 +95,7 @@ class XAttentionBSAPolicy(SparsePolicy): block_size: int = 128, samples_per_chunk: int = 128, use_triton: bool = True, + estimate_block_size: int = 1024, # Optimized block size for softmax_fuse_block_sum ): """ Initialize XAttention BSA policy. @@ -107,11 +108,15 @@ class XAttentionBSAPolicy(SparsePolicy): block_size: BSA block size (must be 128) samples_per_chunk: Samples per chunk for estimation (unused) use_triton: Whether to use Triton kernels + estimate_block_size: Block size for softmax_fuse_block_sum in select_blocks. + Default 1024 is optimal (15x faster than 4096). + Must be a factor of cpu_block_size (e.g., 4096/1024=4). """ self.threshold = threshold self.stride = stride self.chunk_size = chunk_size self.use_triton = use_triton + self.estimate_block_size = estimate_block_size self._num_heads = None # Set during first forward # Sparse metadata: stores attention scores per layer @@ -508,17 +513,28 @@ class XAttentionBSAPolicy(SparsePolicy): # Free intermediate list immediately del attn_scores_list - # Step 2: Apply softmax_fuse_block_sum to get block-level attention - # block_size = reshaped_block_size so each CPU block maps to exactly 1 output block - # This ensures block_sums.shape[-1] == num_available_blocks (1:1 mapping) + # Step 2: Apply softmax_fuse_block_sum with hierarchical aggregation + # Use smaller estimate_block_size (1024) for 15x faster softmax kernel, + # then aggregate to CPU block level (4096). + # + # Hierarchical approach: + # 1. softmax_fuse_block_sum with estimate_block_size (1024) -> fine-grained scores + # 2. Aggregate: reshape + sum -> CPU block level scores + # 3. Select blocks based on score + threshold (NOT mask + voting) + cpu_block_size = block_size # e.g., 4096 + estimate_bs = self.estimate_block_size # e.g., 1024 (15x faster) + ratio = cpu_block_size // estimate_bs # e.g., 4 + + # Use estimate_block_size for softmax kernel (optimized) + reshaped_est_bs = estimate_bs // self.stride # e.g., 1024/8 = 128 norm = 1.0 # Normalization factor scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling - segment_size = min(4096, reshaped_block_size) + segment_size = min(4096, reshaped_est_bs) with nvtx.range("xattn_estimate_softmax"): - block_sums = softmax_fuse_block_sum( + block_sums_fine = softmax_fuse_block_sum( attn_scores, - reshaped_block_size, # Use CPU block size in reshaped space (1024/8=128) + reshaped_est_bs, # Use optimized estimate block size (128 vs 512) segment_size, chunk_start=0, chunk_end=q_reshaped_len, @@ -526,54 +542,55 @@ class XAttentionBSAPolicy(SparsePolicy): scale=scale, is_causal=False, # Historical blocks are all before current chunk ) - # block_sums shape: [batch, heads, q_blocks, k_blocks] - # where k_blocks == len(available_blocks) (1:1 mapping with CPU blocks) + # block_sums_fine shape: [batch, heads, q_est_blocks, k_est_blocks] + # where k_est_blocks = len(available_blocks) * ratio - # Step 3: Use find_blocks_chunked to get selection mask - # current_index = 0 since we're looking at historical blocks only - with nvtx.range("xattn_estimate_find_blocks"): - mask = find_blocks_chunked( - block_sums, - current_index=0, - threshold=self.threshold, - num_to_choose=None, - decoding=False, - mode="prefill", - causal=False, # Historical blocks don't need causal mask - ) - # mask shape: [batch, num_heads, q_blocks, k_blocks] - boolean - # where k_blocks == len(available_blocks) + # Step 3: Aggregate to CPU block level (hierarchical sum) + # This is mathematically equivalent to direct computation but much faster + batch_size_bs, num_heads_bs, q_est_blocks, k_est_blocks = block_sums_fine.shape + num_cpu_blocks = len(available_blocks) - # GQA-aware aggregation: - # For GQA, multiple Q heads share one KV head. We need to select a block - # if ANY Q head within the same KV head group selects it. - # mask: [batch, num_heads, q_blocks, k_blocks] - # Reshape to [batch, num_kv_heads, num_groups, q_blocks, k_blocks] - batch_size, num_q_heads, q_blocks, k_blocks = mask.shape - # num_kv_heads was set in the K loading loop above (line ~199) - # num_groups = num_heads // num_kv_heads (for GQA) - num_groups = num_heads // num_kv_heads if num_heads != num_kv_heads else 1 + with nvtx.range("xattn_estimate_aggregate"): + # Reshape: [batch, heads, q_est, k_est] -> [batch, heads, q_est, num_cpu, ratio] + block_sums_coarse = block_sums_fine.view( + batch_size_bs, num_heads_bs, q_est_blocks, num_cpu_blocks, ratio + ).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks] - if num_groups > 1: - # Reshape: [batch, num_kv_heads, num_groups, q_blocks, k_blocks] - mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks) - # Aggregate within each KV head group: any Q head selects -> KV head selects - mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks] - else: - mask_per_kv_head = mask # [batch, num_heads, q_blocks, k_blocks] + # Sum over Q dimension to get total attention from Q chunk to each K block + cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks] - # Aggregate across KV heads and q_blocks using majority voting - # Instead of any(), use voting: select if >50% of kv_heads select it - # mask_per_kv_head: [batch, num_kv_heads, q_blocks, k_blocks] - # Sum across kv_heads and q_blocks to get vote count per k_block - vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks] - total_votes = num_kv_heads * q_blocks - vote_ratio = vote_count / total_votes + # Step 4: Select blocks using score + threshold (replaces mask + majority voting) + # This is simpler and more direct than the original mask-based approach + with nvtx.range("xattn_estimate_select"): + # Average scores across heads (GQA-aware: all heads contribute equally) + scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks] - # Select blocks with >50% votes (majority voting) - vote_threshold = 0.5 - block_selected = vote_ratio > vote_threshold - selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel] + # Normalize to get attention distribution + total_score = scores_per_block.sum() + if total_score > 0: + score_ratio = scores_per_block / total_score + else: + # Edge case: all zeros, select all blocks + selected_block_ids = list(available_blocks) + if layer_id == 0 and available_blocks: + self._stats_total_available_blocks += len(available_blocks) + self._stats_total_selected_blocks += len(selected_block_ids) + self._stats_num_chunks += 1 + return selected_block_ids + + # Sort by score (descending) and select until threshold is reached + sorted_indices = torch.argsort(score_ratio, descending=True) + cumsum = 0.0 + selected_indices = set() + + for idx in sorted_indices.tolist(): + selected_indices.add(idx) + cumsum += score_ratio[idx].item() + if cumsum >= self.threshold: + break + + # Map indices back to block IDs + selected_block_ids = [available_blocks[i] for i in sorted(selected_indices)] # Always include first block (sink) and last block for safety if available_blocks and available_blocks[0] not in selected_block_ids: @@ -593,7 +610,7 @@ class XAttentionBSAPolicy(SparsePolicy): f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}") # Free intermediate tensors to prevent memory leak - del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected + del attn_scores, block_sums_fine, block_sums_coarse, cpu_block_scores, scores_per_block return selected_block_ids