diff --git a/docs/issue_xattn_offload_gqa_buffer_oom.md b/docs/issue_xattn_offload_gqa_buffer_oom.md index a6511bc..a61de97 100644 --- a/docs/issue_xattn_offload_gqa_buffer_oom.md +++ b/docs/issue_xattn_offload_gqa_buffer_oom.md @@ -167,3 +167,43 @@ GPULIST=0 ./scripts/run_ruler.sh glm4-9b-xattn-nanovllm synthetic xattn --task n ## 优先级 **High** - 阻塞 9B+ 模型在 24GB 显存 GPU 上使用 XAttention + Offload 模式 + +--- + +## 修复状态 + +**✅ 已修复** (2026-02-05) + +### 修复内容 + +采用方案 1,在 offload 模式下跳过 GQA buffer 分配: + +1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数 +2. `nanovllm/kvcache/sparse/xattn_bsa.py`: 实现 offload 模式检查,跳过 GQA buffer +3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数 + +### 验证结果 + +```bash +# 64K offload 测试 +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ + python tests/test_ruler.py \ + --model ~/models/Llama-3.1-8B-Instruct \ + --data-dir tests/data/ruler_64k \ + --datasets niah_single_1 \ + --num-samples 1 \ + --max-model-len 72000 \ + --enable-offload \ + --sparse-policy XATTN_BSA +``` + +- ✅ 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers` +- ✅ 测试通过: 100% 准确率 +- ✅ 内存节省: ~16 GB (for 1M max_seq_len) + +### 内存对比 + +| 配置 | 修复前 | 修复后 | +|------|--------|--------| +| max_model_len=72K | +1.1 GB | 0 GB | +| max_model_len=1M | +16 GB | 0 GB | diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 8fb8708..f595860 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -227,9 +227,9 @@ class ModelRunner: device=torch.device("cuda"), ) - # GPU-only mode: pre-allocate policy metadata buffers - # This avoids dynamic GPU memory allocation during forward pass - # if not config.enable_cpu_offload: + # Pre-allocate policy metadata buffers + # - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats) + # - GPU-only mode: additionally allocate GQA expansion buffers num_heads = hf_config.num_attention_heads // self.world_size self.kvcache_manager.sparse_policy.alloc_policy_metadata( num_heads=num_heads, @@ -238,6 +238,7 @@ class ModelRunner: max_seq_len=config.max_model_len, dtype=hf_config.torch_dtype, device=torch.device("cuda"), + enable_cpu_offload=config.enable_cpu_offload, ) # Log policy info (handle both enum and None cases) diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index 1a51f87..f9615ce 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -116,13 +116,15 @@ class SparsePolicy(ABC): max_seq_len: int, dtype: torch.dtype, device: torch.device, + enable_cpu_offload: bool = False, ) -> None: """ Pre-allocate GPU buffers for policy computation. - Called by the framework after KV cache allocation, but ONLY for GPU-only - mode (not CPU offload mode). Override this to pre-allocate buffers that - would otherwise be dynamically allocated during forward pass. + Called by the framework after KV cache allocation. Implementations should + use enable_cpu_offload to decide which buffers to allocate: + - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats) + - GPU-only mode: additionally allocate GQA expansion buffers This is separate from initialize() which is used for CPU offload metadata. @@ -133,6 +135,7 @@ class SparsePolicy(ABC): max_seq_len: Maximum sequence length (for buffer sizing) dtype: Data type (typically float16/bfloat16) device: Target device (cuda) + enable_cpu_offload: Whether CPU offload is enabled """ pass diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 4cb1c90..cf70705 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -175,6 +175,7 @@ class XAttentionBSAPolicy(SparsePolicy): max_seq_len: int, dtype: torch.dtype, device: torch.device, + enable_cpu_offload: bool = False, ) -> None: """ Pre-allocate GQA expansion buffers for GPU-only mode. @@ -235,7 +236,14 @@ class XAttentionBSAPolicy(SparsePolicy): f"m/l shape={m_partial_shape} ({m_l_memory_mb:.1f} MB), " f"block_sums shape={block_sums_shape} ({block_sums_memory_mb:.1f} MB)") - # Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads) + # Skip GQA buffers in offload mode + # Chunked prefill uses compute_chunked_prefill() which handles GQA inline + if enable_cpu_offload: + logger.info("[XAttn] Offload mode: skipping GQA expansion buffers (saves ~16GB for 1M seq)") + return + + # GPU-only mode: pre-allocate GQA buffers for compute_prefill() + # Only allocate if GQA (num_heads != num_kv_heads) if num_heads == num_kv_heads: logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})") return