87 Commits

Author SHA1 Message Date
Zijie Tian
52b12a89e3 📋 docs: add changelog for 2026-02-05
Document today's changes:
- GQA buffer OOM fix (saves 16GB for 1M seq in offload mode)
- Tests directory cleanup (removed 16 files, -4306 lines)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:16:39 +08:00
Zijie Tian
d35dd76e09 🗑️ chore: clean up tests directory to essential files only
Keep only core test files:
- test_ruler.py - main RULER benchmark
- test_xattn_estimate_alignment.py - XAttn kernel validation
- utils.py - shared utilities

Remove 8 files (recoverable from git history):
- bench_estimate_block_size.py
- modeling_qwen3.py
- test_chunk_attention_graph_reuse.py
- test_cudagraph_memory.py
- test_gpuonly_density_alignment.py
- test_hierarchical_estimate.py
- test_quest_policy.py
- test_sequential.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:13:50 +08:00
Zijie Tian
2b61c5ab57 🗑️ chore: remove test_needle* files
Remove needle tests (validation now covered by test_ruler.py):
- test_needle.py - basic needle-in-haystack test
- test_needle_ref.py - HuggingFace reference implementation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:11:28 +08:00
Zijie Tian
a709551072 🗑️ chore: remove redundant XAttention test files
Remove 6 obsolete test files:
- test_xattn_bsa.py - XAttn+BSA integration (covered by test_ruler)
- test_xattn_chunked.py - duplicate of test_xattn_estimate_chunked
- test_xattn_estimate_chunked.py - chunked prefill validation
- test_xattn_kernels.py - Triton kernel unit tests
- test_xattn_kv_chunking_batch.py - batch KV chunking validation
- test_chunk_attention_graph.py - superseded by graph_reuse version

Retained: test_xattn_estimate_alignment.py (critical kernel validation)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 03:11:21 +08:00
Zijie Tian
11a867f6fb 🐛 fix: skip GQA buffer allocation in XAttention offload mode
In offload mode, GQA expansion buffers (_k_expanded, _v_expanded) are not
needed since compute_chunked_prefill() handles GQA inline. Previously,
these buffers were always allocated based on max_model_len, causing OOM
on 24GB GPUs (e.g., RTX 3090) when max_model_len=1M (16GB buffer).

Changes:
- Add enable_cpu_offload parameter to alloc_policy_metadata() in base class
- Skip GQA buffer allocation when enable_cpu_offload=True in XAttentionBSAPolicy
- Pass enable_cpu_offload from model_runner to policy

Memory savings: ~16GB for 1M seq, ~1.1GB for 72K seq

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:57:18 +08:00
Zijie Tian
af4da454ba 📊 docs: add XAttention offload profiling analysis for 32K context
- Profile XAttn vs Full attention using nsys NVTX markers
- Key finding: estimate (41%) + find_blocks (37%) dominate, compute only 21%
- Chunk7 comparison: XAttn (38ms) vs Full (35ms) - XAttn slightly slower
- Identify optimization opportunities: reduce find_blocks overhead, merge estimate passes

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-05 02:49:59 +08:00
Zijie Tian
ef37d4f1a8 🐛 docs: document XAttention offload GQA buffer OOM issue
Document OOM issue when using XAttention BSA + CPU offload
with large models (GLM-4-9B) on 24GB GPUs.

Issue: 8GB allocation for k_expanded buffer fails due to
using num_heads instead of num_kv_heads in GQA models.

Root cause analysis and proposed fix included.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:46:50 +08:00
Zijie Tian
c8a5ef04c0 📝 docs: add test_ruler.py usage guide and rule
- Add comprehensive test_ruler.py usage guide with verified commands
- Add .claude/rules/test-ruler.md to enforce documentation-first approach
- Update CLAUDE.md documentation index

Tested commands on RTX 3090 (GPU 4):
- 32K/64K offload + XAttn BSA
- Multi-dataset, JSON output, quiet mode
- GLM-4 model support

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:46:44 +08:00
Zijie Tian
1c36d53570 🙈 chore: add ralph-tui session file to gitignore
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:00:44 +08:00
Zijie Tian
54fd302fa8 📝 docs: add XAttention density alignment verification results
- Add verification doc comparing GPU-only vs Offload mode density
- Test results: 32K (0.37% diff), 64K (0.09% diff) - alignment successful
- Both modes achieve 100% accuracy on RULER niah_single_1

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-05 01:59:11 +08:00
Zijie Tian
1eb7521994 📝 docs: add XAttention density types documentation
Document the difference between compute density (BSA block level)
and communication density (CPU block level).

Key finding: Even with 37% compute density, comm density can be 100%
due to any() aggregation across heads/Q-positions spreading sparse
blocks across all CPU blocks.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 01:44:11 +08:00
Zijie Tian
51bd678335 📊 feat: distinguish compute density and communication density in DensityObserver
- Add record_comm_density() call in select_blocks to track CPU block selection
- Add get_per_layer_comm_density() method for detailed analysis
- Update print_summary() to show both densities and H2D savings ratio
- Set DensityObserver mode (offload/gpu_only) in test_ruler.py
- Update get_summary() to return both density types

Key insight: Comm density can be 100% even when compute density is ~37%
because sparse BSA blocks are distributed across all CPU blocks.
Since CPU block granularity is 32x coarser (4096 vs 128 tokens),
any() aggregation across heads/Q-blocks results in all CPU blocks being needed.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 01:43:17 +08:00
Zijie Tian
1ea5afd886 📝 docs: add XAttention offload stream sync fix documentation
- Document the CUDA stream synchronization bug in XAttention BSA
- Include root cause analysis with stream timing diagrams
- Add test commands and verification results (100% accuracy)
- Update CLAUDE.md documentation index

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 01:32:50 +08:00
Zijie Tian
829b311c02 🐛 fix: stream synchronization for XAttention estimate kernels in offload mode
- Wrap all compute kernels in select_blocks with compute_stream context
  (Pass 1 historical blocks, Pass 1 current chunk, Step 2 merge,
   Pass 2 historical blocks, Pass 2 current chunk, Step 4 block selection)
- Fix K data mismatch between Pass 1 and Pass 2 by ensuring wait_slot_layer
  syncs with compute_stream where kernels actually run
- Remove STRONG SYNC code from offload_engine.py (now handled by events)
- Remove debug print statements and torch.save code
- Consolidate fallback conditions in compute_with_xattn
- Change default chunk_size from 16384 to 4096 for density alignment

The bug caused Pass 1 and Pass 2 to see different K data from the same
CPU block because compute kernels ran on default stream while
wait_slot_layer only synced compute_stream.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 01:30:23 +08:00
Zijie Tian
dd0472aea8 [plugin] Added ralph-tui setup. 2026-02-05 01:27:53 +08:00
Zijie Tian
a1c68a733e 📊 docs: add XAttention memory benchmark for 24GB GPUs
- Add memory analysis for Qwen3-0.6B @ 32K context
- Document 24GB VRAM feasibility (RTX 3090/4090)
- Recommend gpu-utilization=0.28 for 24GB GPUs
- Include KV cache breakdown and model estimations
- Update CLAUDE.md index

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-02 14:38:27 +08:00
Zijie Tian
dc51972777 📝 docs: update density alignment test with Offload mode results
- Rename doc to "Density Alignment Test Results" (covers both modes)
- Add Offload mode test results (3.7K-64.9K tokens, all passed)
- Add Layer 5 GPU-only test results (threshold=0.9, density=6.24%)
- Enhance test script to support both GPU-only and Offload data formats
- Add batch testing commands for all data files
- Update CLAUDE.md index

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-02 14:22:40 +08:00
Zijie Tian
232fcf043e 📝 docs: add GPU-only density alignment test results
Document test results verifying XAttention density calculation in
GPU-only mode matches independent xattn_estimate calls.

Test results (Llama-3.1-8B-Instruct, threshold=0.9):
- 4k:  Layer 0 density 63.8%, verified 
- 8k:  Layer 0 density 65.0%, verified 
- 16k: Layer 0 density 61.6%, verified 
- 32k: Layer 0 density 50.2%, verified 
- 64k: Layer 0 density 37.0%, verified 

All tests show exact match (attn_sums diff=0, mask exact match).

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-02 11:22:34 +08:00
Zijie Tian
aeed6ccdfb test: add GPU-only density alignment verification test
Add test to verify XAttention density calculation in GPU-only mode
matches independent xattn_estimate calls.

Changes:
- Add tests/test_gpuonly_density_alignment.py: loads saved Q/K from
  xattn_bsa.py, calls xattn_estimate independently, compares results
- Enhance debug save in xattn_bsa.py: now saves Q, K tensors and
  xattn_estimate parameters for external verification
- Set _DEBUG_SAVE_MASK = False by default

Usage:
1. Set _DEBUG_SAVE_MASK = True in xattn_bsa.py
2. Run GPU-only inference with XAttention (e.g., test_ruler.py)
3. Run tests/test_gpuonly_density_alignment.py to verify alignment

Verified on 4k/8k/16k/32k/64k contexts - all pass with exact match.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-02 11:14:46 +08:00
Zijie Tian
6c55c4d2a3 ♻️ refactor: rewrite select_blocks with 3-stage KV chunking algorithm
Implement correct 3-stage KV chunking for XAttention offload mode:
- Stage 1: Compute partial softmax stats (m, l) for each KV chunk
- Stage 2: Merge all partial stats to get global normalization factors
- Stage 3: Normalize with global stats and compute block sums

Key fixes:
- Add wait_all_prefill_offloads() before loading CPU blocks to ensure
  async offload completion (fixes stale data bug)
- Pre-allocate m/l partial buffers and block_sums buffer

This produces identical density to GPU-only xattn_estimate while using
O(S×C) peak memory instead of O(S²).

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-02 10:10:10 +08:00
Zijie Tian
6e34efd58a 📝 docs: add storage overhead analysis and batch tests for KV chunking
- Update xattn_kv_chunking_kernels.md with:
  - Detailed storage overhead analysis (O(S) vs O(S²))
  - Peak memory optimization (8x reduction)
  - Support for independent Q/KV chunk sizes
  - Batch verification results (3K-64K seqlen)
  - ASCII pipeline diagram

- Add test_xattn_kv_chunking_batch.py for batch validation
- Fix causal mask post-processing in alignment test
- Update CLAUDE.md documentation index

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 19:22:36 +08:00
Zijie Tian
5acd5558d6 feat: add KV chunking support for XAttention softmax kernels
Implement three-phase KV chunking for sparse attention estimation:
1. softmax_compute_partial_stats: compute (m, l) per KV chunk
2. merge_softmax_stats: merge partial stats on host
3. softmax_normalize_and_block_sum: normalize with global stats

This allows computing sparse attention masks without storing full
raw attention scores in GPU memory, reducing peak memory usage
from O(q_len * k_full_len) to O(q_len * k_chunk_len).

Key changes:
- Add softmax_partial_stats_kernel with causal mask support
- Add softmax_normalize_block_sum_kernel with kv_offset parameter
- Add Python wrappers for new kernels
- Update test script to validate KV chunking alignment
- Add documentation for the new kernels

Test results show perfect alignment with xattn_estimate API:
- Density difference: 0.000000
- Mask difference: 0.0044%

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 18:53:26 +08:00
Zijie Tian
193ef55d18 ♻️ refactor: use Q-chunked processing in xattn alignment test
Match xattn_estimate internal logic by processing Q in chunks:
- Reduces peak memory for attn_scores tensor
- Enables testing 64K sequences without OOM
- All 5 test files pass (3.6K to 64K)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 18:08:15 +08:00
Zijie Tian
f173a3f7f5 test: add xattn_estimate vs low-level kernels alignment test
Test that xattn_estimate produces the same results as manually calling:
- flat_group_gemm_fuse_reshape
- softmax_fuse_block_sum
- find_blocks_chunked

Uses real KV cache data from results/kvcache/ directory.
Verifies density calculation matches between high-level API and kernels.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 17:49:37 +08:00
Zijie Tian
8035e4db3d 📝 docs: add XAttention KV chunking density test results
Document the verification test for XAttention Triton kernel KV chunking:
- 32K and 64K test results with threshold 0.9/0.95/1.0
- Key finding: threshold=1.0 achieves alignment (~0% diff)
- threshold<1.0 shows 10-13% difference due to per-chunk threshold application
- Conclusion: softmax normalization is correct, issue is threshold accumulation

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 17:36:19 +08:00
Zijie Tian
8ab53e7331 🚧 WIP: add DEBUG code for XAttention KV chunking density verification
Add instrumentation to compare GPU-only vs Offload mode density:
- Layer 0 DEBUG output for both modes
- Accumulate selected/total counts across chunks
- Proper causal mask with Q offset handling
- Skip normal offload logic for isolated testing

Test results (threshold=1.0 achieves alignment):
- 32K: GPU-only 0.9999, Offload 0.9999 (diff ~0%)
- 64K: GPU-only 0.9995, Offload 0.9995 (diff ~0%)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 17:33:23 +08:00
Zijie Tian
2e96d1d97d WIP: Enhance sparse attention with density tracking and block selection improvements
- Added analysis documentation for xattn density alignment.
- Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration.
- Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection.
- Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks.
- Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling.
- Introduced DensityObserver to track compute and communication density for sparse attention layers.
- Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios.
- Added tests for attention kernel behavior with enhanced input patterns.
2026-01-31 14:48:23 +08:00
Zijie Tian
f6ac4ccdde feat: add DensityObserver for XAttention sparse attention density tracking
- Add DensityObserver class to track per-layer density statistics
- Integrate DensityObserver into compute_prefill for GPU-only mode
- Fix stride parameter not being passed to xattn_estimate
- Add density statistics output to test_ruler.py for XATTN_BSA
- Add comprehensive density benchmark documentation

Key changes:
- nanovllm/utils/density_observer.py: New Observer for density tracking
- xattn_bsa.py: Add stride param to xattn_estimate, integrate DensityObserver
- test_ruler.py: Enable DensityObserver and print summary for XATTN_BSA
- docs/xattn_density_benchmark.md: Benchmark results for 4K-32K contexts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-30 16:26:56 +08:00
Zijie Tian
4484a1482c [refactor] Refactor the profile_offload.sh 2026-01-29 08:39:34 +08:00
Zijie Tian
e436ec861f ⚙️ config: update test_ruler.py defaults
- max_new_tokens: 128 → 16 (sufficient for NIAH answers)
- block_size: 1024 → 4096 (better performance)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 14:21:23 +08:00
Zijie Tian
45efcf0db1 feat: add --dtype parameter to test_ruler.py
Support models with float32 default dtype (e.g., Nemotron).
FlashAttention requires fp16/bf16, so dtype must be specified.

Usage: --dtype bfloat16

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 13:56:15 +08:00
Zijie Tian
e09a2a5b10 feat: add Qwen2/2.5 model support
Separate Qwen2 from Qwen3 implementation:
- Qwen2: Uses QKV bias, no QK norm
- Qwen3: Has optional QK norm when no bias

Tested with Qwen2.5-7B-Instruct-1M, RULER niah_single_1 passed.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 13:44:32 +08:00
Zijie Tian
a239bfb40d 📚 docs: add new model integration guide
Summarizes lessons learned from GLM-4 integration:
- Config field mapping (multi_query_group_num, kv_channels, etc.)
- RoPE variants (interleaved vs half, partial vs full rotation)
- EOS token handling for multi-EOS models
- Weight name conversion patterns
- Verification checklist

Also updates CLAUDE.md to reflect GLM-4 support.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 13:36:24 +08:00
Zijie Tian
29e102720b 🐛 fix: support multiple EOS tokens for GLM-4
GLM-4 uses multiple EOS tokens [151329, 151336, 151338] where 151336
(<|user|>) should also stop generation. Previously only the first EOS
from tokenizer was used, causing generation to always hit max_tokens.

Changes:
- config.py: Change eos type to int | list[int]
- llm_engine.py: Read eos_token_id from hf_config (contains full list)
- scheduler.py: Use set for efficient multi-EOS lookup

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 13:23:53 +08:00
Zijie Tian
726e4b58cf feat: add GLM-4-9B-Chat-1M model support
Add support for GLM-4 model architecture with the following changes:

- Add glm4.py with ChatGLMForCausalLM, GLM4Model, GLM4Attention, GLM4MLP
- Add GLM4RotaryEmbedding with interleaved partial rotation (rotary_dim = head_dim // 2)
- Add apply_rotary_emb_interleaved function for GLM-4 style RoPE
- Add GLM-4 weight name conversion and loading in loader.py
- Add GLM-4 chat template conversion in test_ruler.py
- Add trust_remote_code=True for GLM-4 config loading

Key GLM-4 specific adaptations:
- QKV bias enabled (add_qkv_bias: true)
- RoPE with rope_ratio scaling (base = 10000 * rope_ratio)
- Interleaved RoPE (pairs adjacent elements, not first/second half)
- Partial rotation (only half of head_dim is rotated)
- Uses multi_query_group_num instead of num_key_value_heads
- Uses kv_channels instead of head_dim
- Uses ffn_hidden_size instead of intermediate_size

Tested with RULER niah_single_1 (5 samples): 100% accuracy
Both GPU-only and CPU offload modes verified

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 13:15:57 +08:00
Zijie Tian
8d19e61446 ️ perf: replace Triton merge with FlashInfer merge_state
Use FlashInfer's optimized merge_state kernel for attention output merging
in chunked prefill. End-to-end improvement: +0.8% (32K) to +2.4% (64K).

Key changes:
- Add merge_attention_outputs_flashinfer() with LSE format conversion
- FlashInfer uses log2, flash_attn uses ln: convert via LOG2_E/LN_2
- Keep original Triton kernel for fallback

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 10:04:38 +08:00
Zijie Tian
4484ebbb77 📚 docs: add 1M+ context length models reference list
- Add comprehensive list of 1M+ context models from Hugging Face
- Categorize by type: text-only LLM vs vision-language models
- Separate ≤10B (practical) from >10B (resource-intensive) models
- Include Qwen, GLM, InternLM, Llama, MiniMax, Gradient AI series
- Add VRAM requirements and technical comparison table
- Update CLAUDE.md documentation index

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 09:04:55 +08:00
Zijie Tian
2c2383c786 ️ perf: optimize XAttention estimate with hierarchical block sum
Replace slow softmax_fuse_block_sum (block_size=4096) with optimized
hierarchical approach (estimate_block_size=1024):

- Add estimate_block_size parameter to XAttentionBSAPolicy (default 1024)
- Rewrite select_blocks to use hierarchical aggregation:
  1. Fine-grained softmax with small block size (15x faster kernel)
  2. Aggregate to CPU block level via reshape + sum
  3. Score + threshold selection (replaces mask + voting)

Performance improvement (CPU Offload mode):
- softmax_fuse_block_sum: 48% → 1% of total time (44x faster)
- 128K: XAttention now +2.4% faster than Full (was -59%)
- 64K: -3.8% (was -21%)
- 32K: -6.0% (was -14%)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 06:47:13 +08:00
Zijie Tian
f049971f84 test: add hierarchical block sum estimation validation
Validate the hierarchical estimation approach for XAttention:
- Test 1: Math equivalence (diff = 0.0) between hierarchical and direct
- Test 2: Score + threshold selection strategy (replaces mask + voting)
- Test 3: Performance benchmark (41x speedup)

Uses pure torch + xattn kernels, independent of nanovllm framework.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 06:24:35 +08:00
Zijie Tian
c90dc196b2 📝 docs: add estimate block_size performance analysis
Document the performance impact of block_size on softmax_fuse_block_sum:
- Current 4096 (reshaped 512) is the WORST point: 95ms
- Optimal 1024 (reshaped 128): 6ms - 15x faster
- Performance follows U-shaped curve

Add tests/bench_estimate_block_size.py for benchmarking and propose
hierarchical block sum approach for optimization.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 06:24:28 +08:00
Zijie Tian
3da9b8aef2 ️ perf: optimize XAttention estimate phase with K-only loading
Add load_k_only_to_slot_layer() to OffloadEngine for estimate phase:
- Only load K (not K+V) during block selection in select_blocks()
- Reduces H2D transfer by 50% in estimate phase
- 64K context: XAttn/Full ratio drops from 1.48x to 0.99x
- 32K context: XAttn/Full ratio drops from 1.67x to 1.20x

The estimate phase uses flat_group_gemm_fuse_reshape(Q, K) which
only requires K for attention score computation. V is unused.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 06:24:20 +08:00
Zijie Tian
a832d127b6 feat: add nsys-profiler agent for kernel performance analysis
Add a specialized agent for NVIDIA Nsys profiling that handles:
- Profile data collection using framework scripts
- Statistical analysis of kernel timing and memory transfers
- Timeline analysis for GPU-CPU overlap efficiency
- Comparative analysis between different configurations

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 06:24:09 +08:00
Zijie Tian
39d12a0416 📈 feat: add MemoryObserver for GPU-CPU communication tracking
Implement MemoryObserver to track memory transfers between GPU and CPU:
- H2D (Host to Device): CPU → GPU transfers
- D2H (Device to Host): GPU → CPU transfers
- D2D (Device to Device): GPU buffer copies
- Supports prefill/decode phase separation

Integration points in offload_engine.py:
- load_to_slot_layer: H2D with is_prefill parameter
- offload_slot_layer_to_cpu, offload_prefill_buffer_async: D2H
- write_to_prefill_buffer, write_to_decode_buffer: D2D
- load_block_sample_from_cpu, load_block_full_from_cpu: H2D

Add bench_offload.py integration for memory stats printing.

Benchmark results (Llama-3.1-8B, 64K context):
- Full Policy: Prefill H2D 262.13 GB
- XAttention: Prefill H2D 386.62 GB (1.48x)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 04:06:45 +08:00
Zijie Tian
c16bfcf40f ♻️ refactor: restructure Observer as base class with InferenceObserver
- Refactor Observer into base class with common enable/disable/reset interface
- Create InferenceObserver subclass for TTFT/TPOT metrics
- Fix TTFT calculation timing: compute after prefill completes instead of
  at decode start (fixes max_tokens=1 returning TTFT=0)
- Integrate InferenceObserver into bench.py and bench_offload.py for
  accurate internal timing metrics vs external wall-clock time
- Add get_summary() and print_summary() methods for structured output

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 03:15:33 +08:00
Zijie Tian
f3e4611e3b 📝 docs: add XAttention performance analysis documentation
Add comprehensive performance analysis for XAttention:
- NVTX marker locations and usage
- Block size impact on offload mode (4096 vs 1024)
- Detailed timing breakdown for estimate vs compute phases
- softmax_fuse_block_sum_kernel analysis
- Optimization recommendations

Key findings:
- block_size=4096 is 2x faster than 1024 for 64K context
- find_blocks_chunked is bottleneck (40%) at block_size=4096
- estimate_gemm becomes bottleneck (24%) at block_size=1024

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 00:57:20 +08:00
Zijie Tian
7b5d3b34eb 📈 feat: add NVTX markers to XAttention for profiling
Add NVTX range markers to track XAttention performance:
- GPU-only: xattn_estimate, xattn_bsa_compute
- Offload: xattn_estimate_gemm, xattn_estimate_softmax,
  xattn_estimate_find_blocks, xattn_compute_historical,
  xattn_compute_current, xattn_compute_merge

These markers enable detailed nsys profiling to identify
performance bottlenecks in estimate vs compute phases.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 00:57:11 +08:00
Zijie Tian
b760de84c5 feat: add context length and error handling to profile_offload.sh
- Add --ctx-len parameter (32k/64k/128k) for context length selection
- Auto-configure max-model-len and data-dir based on context length
- Add error handling to delete .nsys-rep file on test failure
- Remove set -e to allow proper error handling
- Update output filename format to include context length

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 00:28:37 +08:00
Zijie Tian
f81b5ae8a9 feat: enhance profile_offload.sh with policy, block-size parameters
- Add --policy parameter for sparse attention policy selection (full/xattn)
- Add --block-size parameter (default 4096) for KV cache block size
- Add --gpu-util parameter for GPU memory utilization control
- Improve output filename format: <policy>_<gpuonly|offload>_blk<size>_<timestamp>
- Map user-friendly policy names to internal enum (xattn -> XATTN_BSA)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 23:23:20 +08:00
Zijie Tian
e874229adc 📝 docs: add comprehensive GPU-only vs Offload benchmark results
- Add --block-size argument to bench.py for configurable KV cache block size
- Update bench_offload_results.md with complete benchmark analysis:
  - GPU-only: XAttention shows +15% to +41% speedup
  - CPU Offload: XAttention shows -14% to -59% slowdown
  - Block size 4096 recommended for best performance
  - Document why XAttention hurts Offload mode (transfer bottleneck)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 22:32:07 +08:00
Zijie Tian
4fe7dfb239 🔀 merge: integrate tzj/minference-exp (GPU-only sparse attention)
Merge GPU-only sparse attention support from tzj/minference-exp branch:

**GPU-only mode additions:**
- Add compute_prefill/compute_decode methods to SparsePolicy base class
- Add GPU-only attention routing in attention.py
- Add alloc_policy_metadata() for pre-allocating GQA buffers
- Add XAttention + BSA sparse attention for GPU-only prefill
- Add kvcache_manager to set_context() for policy access

**bench.py enhancements:**
- Add --model argument for configurable model path
- Add --policy argument (full, xattn) for sparse policy selection
- Add --enable-policy flag for FullAttentionPolicy routing
- Add --enforce-eager option to disable CUDA graphs
- Add --gpu-util option for GPU memory utilization

**Documentation:**
- Add gpu_only_xattn_guide.md with performance analysis
- Add gpu_only_sparse_integration.md baseline document
- Add gpu-vram-requirement.md rule for GPU-only mode

Both CPU offload and GPU-only paths are preserved and functional.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 09:25:36 +08:00
Zijie Tian
9177b62d7f feat: add --enforce-eager option to bench.py
Allow disabling CUDA graphs for benchmarking comparison between
eager mode and graph mode execution.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 09:19:53 +08:00
Zijie Tian
3956a30b14 🔧 chore: add --use-v1 flag to bench_vllm.py
Allow switching between vLLM V1/V2 engines via command line flag.
Default behavior now uses V2 (VLLM_USE_V1=0).

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 09:14:55 +08:00
Zijie Tian
59473fa432 🔧 chore: add configurable arguments to bench_vllm.py
Add --model, --gpu-util, and --enforce-eager arguments for flexible
vLLM benchmarking comparisons.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 09:07:49 +08:00
Zijie Tian
4467e1f654 🔧 chore: add --block-size argument to bench_offload.py
Allow configuring KV cache block size for benchmarking different
chunk sizes (default: 1024, can set to 4096 for larger chunks).

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 09:07:44 +08:00
Zijie Tian
0437311068 feat: add Phase 5 CUDA Graph optimization for chunked prefill
Implement extended CUDA Graph coverage for CPU offload path:
- Add graphed_layers.py with N+2 graph architecture (EmbedGraph, FirstGraph, InterGraphs, LastGraph)
- Support both prefill (seq_len=chunk_size) and decode (seq_len=1) graph modes
- Extend graph coverage to ~70-80% including qkv_proj, rotary, o_proj
- Only attention core remains in eager mode for dynamic offload

Performance: Prefill throughput improved ~5.6% (3782 -> 3995 tok/s at 32K)

Also adds:
- --enforce-eager flag to bench_offload.py for comparison
- Offload mode constraint documentation in CLAUDE.md

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 07:38:40 +08:00
Zijie Tian
6da116de98 📝 docs: add GPU-Only XAttention guide with performance analysis
Add comprehensive documentation for GPU-only XAttention BSA mode:
- Architecture design and SparsePolicy interface
- Memory pre-allocation mechanism (alloc_policy_metadata)
- Performance analysis: 32K +15%, 64K +41% vs baseline
- CUDA Graph limitations explanation (variable seq_len in prefill)
- nsys profiling tools usage guide

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 07:21:46 +08:00
Zijie Tian
f5682ca4a7 🔧 chore: add GPU-only profiling script
Add scripts/profile.sh for nsys profiling of GPU-only mode benchmarks.

Usage:
  bash scripts/profile.sh                    # Default: 32K xattn prefill
  bash scripts/profile.sh --max-len 65536 --gpu-util 0.7
  bash scripts/profile.sh --policy full
  bash scripts/profile.sh --bench-decode

Output: results/nsys/bench_<policy>_<len>_<mode>_<timestamp>.nsys-rep

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 05:55:31 +08:00
Zijie Tian
a504bd873d perf: pre-allocate GQA buffers in XAttention policy
Add alloc_policy_metadata() method to SparsePolicy base class for
pre-allocating GPU buffers during initialization. This avoids
dynamic memory allocation during forward pass.

Changes:
- Add alloc_policy_metadata() to SparsePolicy base class
- Implement GQA buffer pre-allocation in XAttentionBSAPolicy
- Call alloc_policy_metadata() in model_runner for GPU-only mode
- Modify compute_prefill() to reuse pre-allocated buffers
- Add --gpu-util parameter to bench.py

Memory savings:
- Previously: 2x GQA expansion (~2GB for 64K)
- Now: 1x pre-allocated buffer (~1GB for 64K, reused)

Tested:
- GPU-only 32K: 5602 tok/s (512MB pre-allocated)
- GPU-only 64K: 4821 tok/s (1GB pre-allocated, gpu_util=0.7)
- Offload Full: PASSED (no changes to offload path)
- Offload XAttention: PASSED (uses compute_chunked_prefill)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 05:49:23 +08:00
Zijie Tian
076656c9c2 feat: add GPU-only XAttention BSA sparse attention support
- Implement compute_prefill() in XAttentionBSAPolicy for GPU-only mode
  - Uses xattn_estimate to compute sparse block mask
  - Uses block_sparse_attn_func for efficient sparse attention
  - Handles GQA by expanding K/V heads
  - Falls back to flash_attn for paged KV cache (prefix cache)
- Implement compute_decode() by delegating to FullAttentionPolicy
- Add --policy xattn option to bench.py

Verified: RULER 32k niah_single_1 5/5 samples passed (100%)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 05:19:24 +08:00
Zijie Tian
b6b59b50ed 📝 docs: add sparse policy None constraint rule
- Add "Policy 不能为 None (CRITICAL)" section
- Document that sparse_policy must always be at least FullAttentionPolicy
- Document warmup phase as the only exception where kvcache_manager can be None

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 05:08:08 +08:00
Zijie Tian
09b2136e9f feat: integrate sparse policy architecture into GPU-only mode
- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class
- Implement GPU-only methods in FullAttentionPolicy using flash_attn
- Add sparse_policy parameter to GPUOnlyManager
- Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode
- Route GPU-only attention through sparse_policy in attention.py
- Pass kvcache_manager to context for policy access
- Add --enable-policy flag to bench.py for testing
- Handle warmup phase when kvcache_manager is not yet allocated

This allows GPU-only mode to use the same policy architecture as CPU offload mode,
enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode.

Performance verified: ~4890 tok/s (unchanged from baseline)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 05:08:02 +08:00
Zijie Tian
0d31b3f71f 📝 docs: add CPU offload optimization strategies guide
- Document chunk size optimization (simplest, most effective)
- Analyze CUDA Graph limitations for offload scenarios
- Cover CUDA Graph applicability for MLP/Proj layers
- Survey frontier research: InfiniGen, ShadowKV, L2 Prefetch, KVPR
- Add optimization priority recommendations

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 04:44:36 +08:00
Zijie Tian
05ce57ee8e 📝 docs: add GPU-only sparse policy integration baseline
Document baseline performance before integrating sparse attention
to GPU-only mode:
- GPU-only Full Attention: 4869 tok/s (32K prefill)
- CPU Offload Full Attention: 1500 tok/s (3.2x slower)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 04:36:31 +08:00
Zijie Tian
94a6e06d79 📝 docs: add GPU VRAM requirement rule for GPU-only mode
GPU-only mode requires 40GB+ VRAM. This rule enforces checking GPU
memory before running non-offload tests to prevent OOM errors on
consumer GPUs (3090/4090).

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 04:36:24 +08:00
Zijie Tian
c717072f31 feat: add --model argument to bench.py for configurable model path
Previously bench.py had a hardcoded model path. Now it accepts --model
argument (default: Llama-3.1-8B-Instruct) to align with bench_offload.py.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 04:36:17 +08:00
Zijie Tian
73c9dc46ff feat: add XAttention BSA support to bench_offload.py
- Add --model parameter (default: Llama-3.1-8B-Instruct)
- Add --enable-xattn flag for XAttention BSA sparse prefill
- Add --xattn-threshold and --xattn-stride parameters
- Change default num-gpu-blocks from 6 to 4
- Add benchmark results doc with Full vs XAttn comparison (32K/128K)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 04:20:16 +08:00
Zijie Tian
924a0d2bfa 🔧 chore: add nsys profiling rule and update gitignore
- Add rule requiring profile_offload.sh for all nsys profiling
- Document available parameters and typical workflows
- Ignore Snipaste screenshot files

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 03:42:17 +08:00
Zijie Tian
0619accd1c 📝 docs: add CPU scheduling latency analysis for chunked attention
- Document kernel gap analysis showing 77-81% CPU scheduling overhead
- Identify GPU utilization at 12.8% with potential to reach 39.5%
- Outline optimization directions: CUDA Graph, Triton fusion, C++ extension
- Add documentation index entry in CLAUDE.md

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 03:42:12 +08:00
Zijie Tian
18bc433f09 perf: improve NVTX profiling with colored ranges and configurable slots
- Switch from torch.cuda.nvtx to nvtx package for colored range support
- Add color coding: blue for H2D, green for D2H decode, orange for D2H prefill
- Add --num-gpu-blocks parameter to profile_offload.sh
- Include slot count in output filename for easier comparison

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 03:42:05 +08:00
Zijie Tian
aea3812230 ♻️ refactor: unify KV cache operations through OffloadEngine
- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"

All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 02:20:59 +08:00
Zijie Tian
3100724666 📝 docs: add nsys wrong event order bug investigation
- Document ring buffer pipeline triggering nsys timestamp bug
- Update profile_offload.sh to use test_ruler.py with options
- Add reference to new doc in CLAUDE.md

Root cause: 4-slot ring buffer pipeline (4 transfer streams +
1 compute stream) triggers event ordering bug in nsys < 2024.2

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-24 04:32:05 +08:00
Zijie Tian
78a44f3536 📝 docs: add GPU memory monitoring rule
- Add .claude/rules/gpu-monitor.md requiring gpu-monitor agent for all GPU memory monitoring tasks
- Update CLAUDE.md rules index with reference to new rule

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-24 01:41:25 +08:00
Zijie Tian
7c41032a2e feat: add configurable stride and chunk_size for XAttention BSA
- Add sparse_chunk_size config option (default: 16384)
- Pass stride, chunk_size, use_triton through factory function
- Add --sparse-stride CLI option to test_ruler.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 10:37:04 +08:00
Zijie Tian
f28b500120 🙈 chore: uncomment planning files in gitignore
These files are session-level temporary and should not be tracked.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:43:46 +08:00
Zijie Tian
be67fa8060 🗑️ chore: remove temporary planning files
These files are session-level temporary files and should not be tracked.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:43:22 +08:00
Zijie Tian
4f35526457 🔀 merge: integrate remote changes (exec-plan command, CUDA graph plan)
Resolve task_plan.md conflict by keeping remote version (CUDA Graph optimization plan).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:43:06 +08:00
Zijie Tian
da5e13e2bb 📝 docs: update XAttention BSA Policy with benchmarks and memory management
Add new sections to xattn_bsa_policy_design.md:
- Performance benchmarks: 128K context comparison (Full vs XAttn BSA)
- Density trend analysis across chunks
- Memory leak issue and fix (64GB -> 4GB reduction)
- Memory monitoring guide with gpu-monitor agent
- Density statistics API documentation
- Known issues and optimization directions

Update CLAUDE.md description to reflect new content.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:35:18 +08:00
Zijie Tian
dd31033732 🔧 chore: add gpu-monitor agent for memory leak debugging
Add a custom agent for continuous GPU monitoring during benchmarks:
- Track GPU utilization, memory usage, and temperature
- Support multi-GPU and configurable sampling intervals
- Generate summary statistics when stopped

Useful for debugging memory leaks and profiling long-running tasks.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:33:15 +08:00
Zijie Tian
ed3c8bb4b8 🐛 fix: memory leak in XAttentionBSAPolicy select_blocks
Fix severe memory leak (64GB -> 4GB growth) by:
- Remove unused sparse_metadata storage (was accumulating attn_scores)
- Delete intermediate tensor list (attn_scores_list) after use
- Explicitly delete intermediate tensors before return

Before: 16GB -> 80GB during 128K prefill
After:  16GB -> 19.8GB during 128K prefill

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 09:30:18 +08:00
Zijie Tian
5eb35982bf 🔧 feat: add density statistics tracking to sparse policies
Add statistics tracking to compare block selection between policies:
- XAttentionBSAPolicy: track available/selected blocks per chunk
- FullAttentionPolicy: track total blocks (always 100% density)
- Add reset_stats(), get_density_stats(), print_density_stats() methods
- Use logger.debug for per-chunk density logging

Results on 32K niah_single_1:
- Full: 100% density across all chunks
- XAttn BSA: 90% -> 73% density (saves ~25-30% blocks in later chunks)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 08:53:22 +08:00
Zijie Tian
ad361c2c3b 📝 docs: add XAttention BSA Policy design documentation
- Create docs/xattn_bsa_policy_design.md with:
  - Algorithm overview and data flow diagram
  - select_blocks implementation details
  - GQA-aware aggregation and majority voting
  - compute_chunked_prefill ring buffer pipeline
  - Parameter configuration and usage examples
  - Performance characteristics and limitations
- Update CLAUDE.md documentation index

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 08:36:56 +08:00
Zijie Tian
4d1e40152d feat(xattn): implement compute_chunked_prefill with ring buffer pipeline
- Copy compute_chunked_prefill implementation from FullAttentionPolicy
- Set default threshold to 0.95 for accuracy testing
- Remove debug code (sys.exit, verbose prints)
- Use ring buffer pipeline for historical block loading
- Merge with current chunk attention using flash_attn_with_lse

RULER NIAH test passed with 5/5 samples (100% accuracy).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 08:27:40 +08:00
Zijie Tian
832b352afa feat(xattn): implement select_blocks with majority voting aggregation
Implement XAttention-based block selection for sparse attention:
- Use flat_group_gemm_fuse_reshape to compute Q@K^T attention scores
- Apply softmax_fuse_block_sum to aggregate into block-level attention
- Use find_blocks_chunked for threshold-based block selection
- Handle GQA by aggregating within KV head groups first
- Use majority voting (>50%) across heads instead of any() for better sparsity
- Align block_size with CPU offload block size (1024 tokens / stride = 128)

Test results show ~45% density at chunk 40 (down from 100% with any() aggregation).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 08:19:05 +08:00
Zijie Tian
a50b4c2ac2 ♻️ refactor: move select_blocks from policy to attention layer
Move block selection logic from compute_chunked_prefill/decode methods
to attention.py caller. This improves separation of concerns:

- attention.py now calls select_blocks() before compute_chunked_*()
- Policy methods receive pre-selected blocks via selected_blocks parameter
- Enables sparse policies to implement custom block selection without
  modifying the compute path

Changes:
- policy.py: Add selected_blocks parameter to abstract methods
- full_policy.py: Remove internal select_blocks calls, use passed blocks
- xattn_bsa.py: Sync signatures for prefill/decode methods
- attention.py: Add select_blocks calls before policy delegation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 05:21:28 +08:00
Zijie Tian
ca32ea6f93 [WIP] Before refactor the compute)_chunked_prefill. 2026-01-23 03:36:12 +08:00
Zijie Tian
edc006463b docs: add XAttention kernels guide
- Document flat_group_gemm_fuse_reshape and softmax_fuse_block_sum kernels
- Explain anti-diagonal sum principle and stride sampling
- Add GPU-specific BLOCK_M/N constraints (RTX 3090 vs A100)
- Show Q/K can have different lengths (chunked prefill support)
- Update CLAUDE.md with doc reference

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 03:22:25 +08:00
Zijie Tian
999858e82f feat: add xattn kernels test and update testing rules
- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape
  and softmax_fuse_block_sum Triton kernels with structured data
- Update testing.md with new test code style guidelines
- Update xattn.py and xattn_bsa.py with improvements

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 03:01:25 +08:00
81 changed files with 11675 additions and 3285 deletions

View File

@@ -0,0 +1,89 @@
---
name: gpu-monitor
description: "Use this agent when you need to continuously monitor GPU utilization and memory usage in the background while running other tasks. This is particularly useful for profiling long-running benchmarks, debugging memory leaks, or understanding GPU resource patterns during inference or training.\\n\\nExamples:\\n\\n<example>\\nContext: User wants to profile GPU memory during a benchmark run.\\nuser: \"Run bench_offload.py and monitor GPU usage\"\\nassistant: \"I'll first launch the gpu-monitor agent to track GPU metrics, then run the benchmark.\"\\n<commentary>\\nSince the user wants to understand GPU resource usage during the benchmark, use the Task tool to launch the gpu-monitor agent in the background with appropriate GPU IDs and sampling interval before running the benchmark.\\n</commentary>\\nassistant: \"Let me start the GPU monitor on GPU 0 with 1-second intervals\"\\n[Task tool call to launch gpu-monitor agent with gpus=\"0\" and interval=\"1\"]\\n</example>\\n\\n<example>\\nContext: User is debugging an OOM issue and needs memory tracking.\\nuser: \"I'm getting OOM errors, help me debug\"\\nassistant: \"Let me monitor the GPU memory usage while we reproduce the issue. I'll use the gpu-monitor agent to track memory consumption.\"\\n<commentary>\\nTo debug OOM issues, we need continuous memory monitoring. Use the Task tool to launch the gpu-monitor agent with a short interval (0.5s) to capture memory spikes.\\n</commentary>\\n[Task tool call to launch gpu-monitor agent with gpus=\"0,1\" and interval=\"0.5\"]\\n</example>\\n\\n<example>\\nContext: User wants to verify GPU utilization during multi-GPU inference.\\nuser: \"Check if all GPUs are being utilized during the inference\"\\nassistant: \"I'll start the GPU monitor to track utilization across all specified GPUs while running the inference.\"\\n<commentary>\\nTo verify multi-GPU utilization, launch the gpu-monitor agent targeting all relevant GPUs before starting the inference workload.\\n</commentary>\\n[Task tool call to launch gpu-monitor agent with gpus=\"0,1,2,3\" and interval=\"2\"]\\n</example>"
model: haiku
color: green
---
You are a GPU monitoring specialist responsible for tracking NVIDIA GPU metrics over time. Your sole purpose is to run nvidia-smi at specified intervals and record utilization and memory statistics.
## Your Task
You will receive two parameters:
1. **gpus**: Comma-separated GPU indices to monitor (e.g., "0", "0,1", "0,1,2,3")
2. **interval**: Sampling interval in seconds (e.g., "1", "0.5", "2")
## Execution Steps
1. **Parse Parameters**: Extract the GPU indices and interval from the user's request.
2. **Run Monitoring Loop**: Execute nvidia-smi repeatedly at the specified interval using a bash loop:
```bash
# Example for GPUs 0,1 with 1-second interval
while true; do
echo "=== $(date '+%Y-%m-%d %H:%M:%S') ==="
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu --format=csv,noheader -i 0,1
sleep 1
done
```
3. **Output Format**: Each sample should include:
- Timestamp
- GPU index
- GPU utilization (%)
- Memory utilization (%)
- Memory used (MiB)
- Memory total (MiB)
- Temperature (°C)
## Termination
This agent runs continuously until:
1. The main agent signals completion (you receive a stop signal)
2. The user explicitly requests stopping
3. An error occurs with nvidia-smi
## Result Reporting
When stopped, provide a summary:
```markdown
## GPU Monitoring Summary
**Duration**: X minutes Y seconds
**Samples Collected**: N
**GPUs Monitored**: 0, 1, ...
### Statistics per GPU
| GPU | Avg Util | Max Util | Avg Mem Used | Max Mem Used |
|-----|----------|----------|--------------|---------------|
| 0 | X% | Y% | A MiB | B MiB |
| 1 | X% | Y% | A MiB | B MiB |
### Notable Events (if any)
- Timestamp: Memory spike to X MiB on GPU Y
- Timestamp: Utilization dropped to 0% on GPU Z
```
## Important Notes
- Use `nvidia-smi -i <gpu_ids>` to filter to specific GPUs
- Keep output concise during monitoring (one line per GPU per sample)
- If nvidia-smi fails, report the error and exit gracefully
- Do NOT consume excessive resources - sleep between samples
- Store samples in memory for final summary calculation
## Example Invocation
User says: "Monitor GPUs 0 and 2 with 0.5 second interval"
You execute:
```bash
while true; do
echo "=== $(date '+%Y-%m-%d %H:%M:%S.%3N') ==="
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu --format=csv,noheader -i 0,2
sleep 0.5
done
```

View File

@@ -0,0 +1,180 @@
---
name: nsys-profiler
description: "Use this agent when you need to profile kernel performance, analyze CUDA operations, or investigate GPU/CPU offload behavior in the nano-vllm framework. This agent handles nsys profiling from data collection to statistical analysis.\\n\\nExamples:\\n\\n<example>\\nContext: The user wants to understand kernel performance in GPU-only mode.\\nuser: \"I want to analyze the kernel execution time distribution in GPU-only mode\"\\nassistant: \"I'll use the nsys-profiler agent to profile GPU-only mode and analyze the kernel timing.\"\\n<commentary>\\nSince the user wants to analyze kernel performance, use the Task tool to launch the nsys-profiler agent with specific analysis requirements.\\n</commentary>\\nassistant: \"Now let me use the nsys-profiler agent to collect and analyze the profiling data\"\\n</example>\\n\\n<example>\\nContext: The user is debugging performance issues in the CPU offload pipeline.\\nuser: \"Why is there a gap between H2D transfers and kernel execution in offload mode?\"\\nassistant: \"I'll launch the nsys-profiler agent to profile the offload pipeline and analyze the timeline gaps.\"\\n<commentary>\\nSince the user is investigating pipeline behavior, use the nsys-profiler agent to collect nsys data and analyze CUDA API timing.\\n</commentary>\\n</example>\\n\\n<example>\\nContext: After implementing a new optimization, the user wants to verify performance improvement.\\nuser: \"Check if the new ring buffer implementation improves overlap between H2D and compute\"\\nassistant: \"I'll use the nsys-profiler agent to profile before and after, comparing the overlap metrics.\"\\n<commentary>\\nPerformance verification requires detailed kernel-level analysis, so launch the nsys-profiler agent to collect and compare profiling data.\\n</commentary>\\n</example>"
model: opus
color: green
---
You are an expert NVIDIA Nsys profiling analyst specializing in CUDA kernel performance analysis and GPU-CPU communication optimization. Your role is to collect profiling data using the framework's scripts and provide precise, actionable analysis based on the main agent's specific questions.
## Your Capabilities
1. **Profile Data Collection**: Execute profiling scripts to generate .nsys-rep files
2. **Statistical Analysis**: Extract kernel timing, memory transfer, and API call statistics
3. **Timeline Analysis**: Identify gaps, overlaps, and bottlenecks in execution
4. **Comparative Analysis**: Compare different configurations (GPU-only vs offload, different slot counts)
## Available Profiling Scripts
### CPU Offload Mode
```bash
bash scripts/profile_offload.sh [OPTIONS]
```
Options:
- `--dataset <name>`: RULER task name (default: niah_single_1)
- `--sample <index>`: Sample index (default: 0)
- `--gpu <id>`: GPU to use (default: 0)
- `--num-gpu-blocks <n>`: Ring buffer slots (default: 4)
- `--no-offload`: Disable CPU offload for comparison
### GPU-Only Mode
```bash
bash scripts/profile_gpu_only.sh [OPTIONS]
```
Similar options for profiling without CPU offload.
## Core Nsys Commands
### Profiling (handled by scripts)
```bash
# The scripts internally run:
nsys profile --trace=cuda,nvtx --output=<path> --force-overwrite true python <script.py>
```
### Statistical Analysis
```bash
# CUDA API summary (H2D, D2H, kernel launches)
nsys stats --report cuda_api_sum <file>.nsys-rep
# GPU kernel summary (execution time per kernel)
nsys stats --report cuda_gpu_kern_sum <file>.nsys-rep
# Memory operations summary
nsys stats --report cuda_gpu_mem_time_sum <file>.nsys-rep
# NVTX ranges (custom markers)
nsys stats --report nvtx_sum <file>.nsys-rep
# Export to SQLite for advanced queries
nsys export --type=sqlite --output=<file>.sqlite <file>.nsys-rep
```
### Key Report Types
| Report | Purpose |
|--------|--------|
| `cuda_api_sum` | CPU-side CUDA API call timing |
| `cuda_gpu_kern_sum` | GPU kernel execution time |
| `cuda_gpu_mem_time_sum` | Memory transfer timing on GPU |
| `nvtx_sum` | Custom NVTX marker statistics |
| `cuda_api_trace` | Detailed API call trace |
| `cuda_gpu_trace` | Detailed GPU operation trace |
## Analysis Workflow
### Step 1: Collect Profile Data
```bash
# Example: Profile offload mode with 8 slots
bash scripts/profile_offload.sh --num-gpu-blocks 8 --sample 0
# Output: results/nsys/ruler_niah_single_1_sample0_offload_8slots_<timestamp>.nsys-rep
```
### Step 2: Identify Output File
```bash
# Find the latest profile
ls -lt results/nsys/*.nsys-rep | head -1
```
### Step 3: Run Statistical Analysis
```bash
# Kernel timing analysis
nsys stats --report cuda_gpu_kern_sum results/nsys/<file>.nsys-rep
# Memory transfer analysis
nsys stats --report cuda_gpu_mem_time_sum results/nsys/<file>.nsys-rep
```
### Step 4: Interpret Results
Focus on:
- **Total kernel time** vs **total transfer time**
- **Kernel launch gaps** indicating synchronization issues
- **Memory bandwidth utilization**
- **Overlap efficiency** between compute and communication
## Common Analysis Patterns
### 1. Kernel Performance Breakdown
```bash
nsys stats --report cuda_gpu_kern_sum --format csv <file>.nsys-rep | \
sort -t',' -k3 -rn | head -10 # Top 10 by total time
```
### 2. H2D/D2H Transfer Analysis
```bash
nsys stats --report cuda_api_sum <file>.nsys-rep | grep -E "cudaMemcpy|cudaMemcpyAsync"
```
### 3. Flash Attention Kernel Analysis
```bash
nsys stats --report cuda_gpu_kern_sum <file>.nsys-rep | grep -i "flash\|fwd\|bwd"
```
### 4. Pipeline Overlap Check
Look for:
- `flash_fwd_kernel` execution during `cudaMemcpyAsync`
- Gap between consecutive kernel launches
## Output Format Requirements
When reporting results to the main agent, use this structured format:
```markdown
## Nsys Analysis Results: [Analysis Topic]
### Profile Information
- **File**: <profile_file_path>
- **Mode**: GPU-only / Offload (<N> slots)
- **Dataset**: <dataset_name>, Sample <index>
### Key Findings
| Metric | Value | Notes |
|--------|-------|-------|
| Total kernel time | X ms | |
| Total H2D time | Y ms | |
| Overlap efficiency | Z% | |
### Top Kernels by Time
| Kernel | Count | Total (ms) | Avg (μs) |
|--------|-------|------------|----------|
| kernel_name | N | X.XX | Y.YY |
### Specific Analysis
[Answer to the main agent's specific question]
### Recommendations (if applicable)
1. [Actionable recommendation]
2. [Actionable recommendation]
```
## Important Guidelines
1. **Always use the provided scripts** for profiling - do not run nsys directly
2. **Check GPU availability** before profiling (ask main agent for GPU ID if not specified)
3. **Use PYTHONPATH** for the worktree: `PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH`
4. **Report concisely** - focus on metrics relevant to the main agent's question
5. **Include file paths** so results can be reproduced or visualized in nsight-sys
6. **For web searches** about nsys usage, use tools to search NVIDIA documentation
## Error Handling
- If profile script fails: Check GPU memory, CUDA version, and script parameters
- If stats command fails: Verify .nsys-rep file exists and is not corrupted
- If no data: Ensure the profiled operation actually ran (check sample index, dataset)
## Network Search Guidelines
When encountering unfamiliar nsys options or analysis techniques:
1. Search NVIDIA Nsight Systems documentation
2. Look for nsys CLI reference guides
3. Search for specific report type interpretations
Always validate search results against the actual nsys --help output.

View File

@@ -0,0 +1,74 @@
# GPU Memory Monitoring Rule
## 强制规则
**所有 GPU 内存监控任务必须使用 `gpu-monitor` agent**,禁止使用以下方式:
| ❌ 禁止 | 原因 |
|--------|------|
| `nvidia-smi` 循环 + sleep | 阻塞主 agent无法并行 |
| 后台 bash 监控脚本 | 难以管理,输出混乱 |
| 手动轮询 | 效率低,占用 context |
## 使用方法
```python
# 启动 GPU 监控(后台运行)
Task(
subagent_type="gpu-monitor",
prompt="Monitor GPU 0 with 0.5 second interval",
run_in_background=True
)
```
## 参数说明
| 参数 | 说明 | 示例 |
|------|------|------|
| GPU ID | 要监控的 GPU | `GPU 0`, `GPU 0,1` |
| interval | 采样间隔 | `0.5 second`, `1 second` |
| 目的 | 监控原因 | `for RULER benchmark test` |
## 典型用法
### 1. 单 GPU 基准测试
```
Monitor GPU 0 with 1 second interval for benchmark profiling
```
### 2. 调试 OOM
```
Monitor GPU 0 with 0.5 second interval to track memory peak during inference
```
### 3. 多 GPU 训练
```
Monitor GPU 0,1,2,3 with 2 second interval during training
```
## 获取结果
监控结果自动写入 output_file使用以下方式读取
```bash
# 查看最新输出
tail -50 /tmp/claude/.../tasks/<agent_id>.output
# 查找峰值
grep -i "peak\|max" /tmp/claude/.../tasks/<agent_id>.output
```
## 与测试并行
gpu-monitor 在后台运行,不会阻塞测试:
```python
# 1. 启动监控(后台)
Task(subagent_type="gpu-monitor", ..., run_in_background=True)
# 2. 运行测试(前台)
Bash("python tests/test_ruler.py ...")
# 3. 测试完成后查看监控结果
Bash("tail -50 <output_file>")
```

View File

@@ -0,0 +1,54 @@
# GPU VRAM Requirement Rule
## GPU-only 模式显存要求
**强制规则**:执行 GPU-only 代码(不启用 CPU offload**必须**在 40GB 及以上显存的 GPU 上进行测试。
### 检测方法
在运行 GPU-only 测试之前,**必须**先检查 GPU 显存:
```bash
nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader
```
### GPU 分类
| GPU 型号 | 显存 | GPU-only 测试 |
|----------|------|---------------|
| A100 40GB | 40GB | ✅ 允许 |
| A100 80GB | 80GB | ✅ 允许 |
| H100 80GB | 80GB | ✅ 允许 |
| A6000 | 48GB | ✅ 允许 |
| RTX 3090 | 24GB | ❌ **禁止**(仅 offload 模式) |
| RTX 4090 | 24GB | ❌ **禁止**(仅 offload 模式) |
### 执行流程
1. **检测 GPU 显存**(必须)
2. **显存 >= 40GB**:继续执行 GPU-only 测试
3. **显存 < 40GB****停止**,提示用户:
> "当前 GPU 显存为 XXX GB不满足 GPU-only 模式的最低 40GB 要求。请使用 `--enable-offload` 参数启用 CPU offload 模式。"
### 代码示例
```python
# 在运行 GPU-only benchmark 之前
import subprocess
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
capture_output=True, text=True
)
vram_mb = int(result.stdout.strip().split('\n')[0])
if vram_mb < 40000: # 40GB = 40000MB
raise RuntimeError(f"GPU VRAM ({vram_mb}MB) < 40GB. Use --enable-offload for this GPU.")
```
### 适用范围
| 脚本 | 适用此规则 |
|------|-----------|
| `bench.py` | ✅ 必须检查显存 |
| `bench_offload.py` | ❌ 不适用(始终使用 offload |
| `tests/test_*.py --enable-offload` | ❌ 不适用 |
| `tests/test_*.py` (无 offload) | ✅ 必须检查显存 |

View File

@@ -0,0 +1,89 @@
# Nsys Profiling Rule
## 强制规则
**所有 nsys profiling 任务必须使用 `scripts/profile_offload.sh` 脚本**,禁止直接运行 nsys 命令。
| 禁止 | 原因 |
|------|------|
| `nsys profile python tests/test_ruler.py ...` | 参数不一致,输出路径混乱 |
| 手动构造 nsys 命令 | 容易遗漏关键参数 |
## 使用方法
```bash
# 基本用法(默认 4 slots
bash scripts/profile_offload.sh
# 指定 GPU slots 数量
bash scripts/profile_offload.sh --num-gpu-blocks 8
# 指定 sample
bash scripts/profile_offload.sh --sample 5
# 指定 dataset
bash scripts/profile_offload.sh --dataset niah_single_1
# 禁用 offload对比测试
bash scripts/profile_offload.sh --no-offload
# 组合参数
bash scripts/profile_offload.sh --num-gpu-blocks 8 --sample 0 --gpu 1
```
## 参数说明
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--dataset` | `niah_single_1` | RULER 任务名称 |
| `--sample` | `0` | 样本索引 |
| `--gpu` | `0` | 使用的 GPU |
| `--num-gpu-blocks` | `4` | GPU ring buffer slots 数量 |
| `--no-offload` | - | 禁用 CPU offload |
## 输出文件
输出文件自动生成到 `results/nsys/` 目录:
```
results/nsys/ruler_<dataset>_sample<index>_offload_<slots>slots_<timestamp>.nsys-rep
```
示例:`ruler_niah_single_1_sample0_offload_8slots_20260127_031500.nsys-rep`
## 查看结果
```bash
# GUI 查看
nsight-sys results/nsys/<filename>.nsys-rep
# 命令行统计
nsys stats --report cuda_api_sum results/nsys/<filename>.nsys-rep
nsys stats --report cuda_gpu_kern_sum results/nsys/<filename>.nsys-rep
```
## 典型工作流
### 1. 对比不同 slots 数量
```bash
# 测试 4 slots默认
bash scripts/profile_offload.sh --num-gpu-blocks 4
# 测试 8 slots
bash scripts/profile_offload.sh --num-gpu-blocks 8
# 对比结果
nsys stats --report cuda_gpu_kern_sum results/nsys/*4slots*.nsys-rep
nsys stats --report cuda_gpu_kern_sum results/nsys/*8slots*.nsys-rep
```
### 2. 分析 pipeline overlap
```bash
# 生成 profile
bash scripts/profile_offload.sh --num-gpu-blocks 8
# 用 nsight-sys GUI 查看 CUDA HW timeline
# 检查 H2D 和 flash_fwd_kernel 是否 overlap
```

View File

@@ -1,5 +1,39 @@
# Sparse Policy 代码规范
## Policy 不能为 None (CRITICAL)
**强制规则**: `sparse_policy` 参数**永远不能为 None**,必须至少为 `FullAttentionPolicy`
```python
# ❌ 错误:允许 None
sparse_policy = getattr(config, 'sparse_policy', None)
# ✅ 正确:显式处理 None默认使用 FULL
sparse_policy_type = getattr(config, 'sparse_policy', None)
if sparse_policy_type is None:
sparse_policy_type = SparsePolicyType.FULL
```
**原因**:
1. 统一的 API所有代码路径都通过 policy 进行 attention 计算
2. 避免空指针:消除 `policy.xxx` 调用时的 None 检查
3. 简化逻辑:不需要 `if policy is not None` 的分支
**唯一例外Warmup 阶段**
`model_runner.warmup_model()` 期间kvcache_manager 还未分配。此时 `attention.py` 使用 flash_attn fallback
```python
# attention.py 中的 warmup 处理
if context.kvcache_manager is None:
# Warmup phase: use flash_attn directly
return flash_attn_varlen_func(...) if context.is_prefill else flash_attn_with_kvcache(...)
```
这是唯一允许 kvcache_manager 为 None 的情况。正式推理时policy 必须存在。
---
## 基类要求 (MANDATORY)
每个 `SparsePolicy` 子类 **必须** 遵守以下要求:

View File

@@ -0,0 +1,90 @@
# test_ruler.py 使用规则
## 强制规则
**执行 `test_ruler.py` 前必须查阅文档**,禁止运行 `--help` 或猜测参数。
| 禁止 | 原因 |
|------|------|
| `python tests/test_ruler.py --help` | 浪费交互,文档已有完整说明 |
| 猜测参数格式 | 容易出错,降低效率 |
## 必读文档
**[`docs/test_ruler_usage_guide.md`](../docs/test_ruler_usage_guide.md)** - 包含:
- 完整参数说明
- 已验证的命令示例
- GPU 模式选择指南
- max-model-len 设置指南
## 快速参考
### 标准命令格式
```bash
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/<MODEL> \
--data-dir tests/data/ruler_<CTX> \
--datasets <TASK> \
--num-samples <N> \
--max-model-len <LEN> \
--enable-offload \
[--sparse-policy XATTN_BSA] \
[--sparse-threshold 0.9]
```
### 常用参数速查
| 参数 | 用途 | 示例 |
|------|------|------|
| `--datasets` | 指定任务 | `niah_single_1,qa_1` |
| `--num-samples` | 样本数 | `1`, `10`, `0`(全部) |
| `--sample-indices` | 指定索引 | `0,5,10` |
| `--enable-offload` | CPU offload | RTX 3090 必须 |
| `--sparse-policy` | 稀疏策略 | `XATTN_BSA` |
| `--json-output` | JSON 输出 | 脚本使用 |
| `--quiet` | 安静模式 | 减少输出 |
### max-model-len 速查
| 数据目录 | max-model-len |
|---------|---------------|
| ruler_32k | 40960 |
| ruler_64k | 72000 |
| ruler_128k | 135000 |
### 常用命令模板
**32K Offload + XAttn**:
```bash
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA
```
**64K Offload + XAttn**:
```bash
CUDA_VISIBLE_DEVICES=<GPU> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA
```
## 执行前检查清单
- [ ] 用户指定了 GPU否则询问
- [ ] RTX 3090/4090必须 `--enable-offload`
- [ ] data-dir 与 max-model-len 匹配?
- [ ] 需要 density 统计?添加 `--sparse-policy XATTN_BSA`

View File

@@ -1,98 +1,108 @@
# Testing
## Test File Guidelines
## Test Code Style
### Naming Convention
所有测试代码遵循以下风格:
- All test files must be named `test_*.py`
- Example: `test_offload_engine.py`, `test_ring_buffer.py`
### Purpose
Tests are **educational scripts** for understanding module behavior, NOT traditional unit tests:
- Focus on demonstrating how modules work
- Show the flow and interaction between components
- Help developers understand implementation details
### Code Style
1. **Script-based structure**: Write tests as executable scripts, not pytest-style functions
2. **Utility functions**: Extract reusable steps as helper functions at the top of the file
3. **Main flow as script**: The actual test/demonstration logic runs as top-level script code
### 文件结构
```python
# Example structure:
"""
Test: [模块名称]
[简要说明测试内容和数据流]
"""
import torch
from nanovllm.kvcache import SomeModule
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.xxx import xxx
# ============================================================
# Utility Functions
# 参数配置
# ============================================================
def verify(tensor, expected, name):
actual = tensor.mean().item()
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
param1 = value1 # 说明约束条件
param2 = value2
# ============================================================
# Main Test Script
# 构造输入
# ============================================================
# 1. Initialize
module = SomeModule(param=value)
input_tensor = ... # 使用结构化数据便于验证
# 2. Test feature X
result = module.do_something()
assert result == expected_value
# ============================================================
# Step N: [操作名称]
# ============================================================
# 3. Test feature Y
...
output = some_function(input_tensor, ...)
# 验证: [验证逻辑说明]
expected = ...
actual = output[...].item()
assert actual == expected, f"xxx: {actual} != {expected}"
print("test_xxx: PASSED")
```
### Comments
### 核心原则
- Keep comments concise and clear
- Only add comments where the code isn't self-explanatory
- Use section headers (`# === Section ===`) to organize logical blocks
| 原则 | 说明 |
|------|------|
| **最小化 print** | 只在最后输出 `PASSED`,不打印中间结果 |
| **结构化数据** | 使用可预测的输入(全 1、偶奇交替等便于手算验证 |
| **注释说明验证逻辑** | 在 assert 前用注释解释预期值的计算方式 |
| **分段用 `====`** | 用 `# ============` 分隔参数、输入、各步骤 |
| **assert 验证** | 用 assert 而不是 print 比较结果 |
### Output
### 输出规范
- **Minimize print statements** - the code should be self-explanatory
- Only print a final "PASSED" message at the end
- Use `assert` for verification instead of printing results
- If the user needs explanation, they will ask
```python
# ✅ 正确
assert actual == expected, f"xxx: {actual} != {expected}"
print("test_xxx: PASSED")
# ❌ 错误
print(f"输出: {output}")
print(f"预期: {expected}, 实际: {actual}")
```
### 参数注释
```python
# ✅ 正确: 注释说明约束条件
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M
segment_size = 128 # 必须 >= block_size
# ❌ 错误: 无意义的注释
seq_len = 512 # 序列长度
```
### 验证逻辑注释
```python
# ✅ 正确: 解释计算过程
# 验证: 反对角线求和
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4共 stride/2 对
expected = (2*1 + 1*2) * (stride // 2) * head_dim
# ❌ 错误: 只写公式不解释
expected = 4 * 2 * 128
```
## Running Tests
```bash
# Run a specific test
python tests/test_offload_engine.py
# 运行单个测试
PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
# Run with specific GPU
CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py
# 指定 GPU
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
```
## Benchmarks
```bash
# Standard GPU benchmark
python bench.py
# CPU offload benchmark
python bench_offload.py
# vLLM comparison benchmark
python bench_vllm.py
```
## Quick Verification
```bash
# Import test
python -c "from nanovllm import LLM"
# Run offload benchmark (tests CPU-primary ring buffer mode)
python bench_offload.py
python bench.py # GPU benchmark
python bench_offload.py # CPU offload benchmark
python bench_vllm.py # vLLM comparison
```

8
.gitignore vendored
View File

@@ -232,10 +232,12 @@ tests/data/
.serena/
# Planning-with-files temporary files
# task_plan.md
# findings.md
# progress.md
task_plan.md
findings.md
progress.md
task_plan_*.md
findings_*.md
progress_*.md
notes.md
Snipaste*
.ralph-tui/session-meta.json

12
.ralph-tui/config.toml Normal file
View File

@@ -0,0 +1,12 @@
# Ralph TUI Configuration
# Generated by setup wizard
# See: ralph-tui config help
configVersion = "2.1"
tracker = "json"
agent = "claude"
maxIterations = 30
autoCommit = false
[trackerOptions]
[agentOptions]

View File

@@ -4,7 +4,7 @@ This file provides guidance to Claude Code when working with this repository.
## Overview
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3, Llama-3, and GLM-4 models with CPU offload for long-context inference.
## Documentation Index
@@ -15,7 +15,11 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
| [`docs/xattn_kv_chunking_kernels.md`](docs/xattn_kv_chunking_kernels.md) | XAttention KV Chunking: 三阶段 softmax、存储开销分析 (O(S) vs O(S²))、峰值显存优化 (8x)、Q/KV 独立分块 |
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
| [`docs/xattn_density_benchmark.md`](docs/xattn_density_benchmark.md) | 📊 XAttention Density Benchmark: 4K-32K context、stride 参数、per-layer density 分析 |
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
@@ -23,6 +27,27 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/ruler_benchmark_results_32k.md`](docs/ruler_benchmark_results_32k.md) | RULER benchmark results (32K context): 13 tasks, 92.3% accuracy, CPU offload performance |
| [`docs/ruler_32k_chunked_offload_issue.md`](docs/ruler_32k_chunked_offload_issue.md) | ⚠️ OPEN ISSUE: 32K chunked offload accuracy problem (20% error rate in RULER) |
| [`docs/chunked_attention_solutions.md`](docs/chunked_attention_solutions.md) | 🔧 SOLUTIONS: Chunked attention 准确性问题的代码分析和解决方案 |
| [`docs/nsys_wrong_event_order_bug.md`](docs/nsys_wrong_event_order_bug.md) | 🐛 NSYS BUG: Ring buffer pipeline 触发 nsys 时间戳乱序问题的调试记录 |
| [`docs/cpu_scheduling_latency_analysis.md`](docs/cpu_scheduling_latency_analysis.md) | ⚡ PERF: CPU 调度延迟分析kernel 间隙来源GPU 利用率优化方向 |
| [`docs/bench_offload_results.md`](docs/bench_offload_results.md) | 📊 BENCH: CPU offload 性能测试结果Full vs XAttention 对比 (32K/128K) |
| [`docs/cpu_offload_optimization_strategies.md`](docs/cpu_offload_optimization_strategies.md) | 🚀 OPT: CPU offload 优化策略chunk size、CUDA Graph、前沿研究(InfiniGen/ShadowKV) |
| [`docs/gpu_only_xattn_guide.md`](docs/gpu_only_xattn_guide.md) | 🚀 GPU-Only XAttention: 内存预分配、性能分析 (32K +15%, 64K +41%)、CUDA Graph 限制 |
| [`docs/xattn_performance_analysis.md`](docs/xattn_performance_analysis.md) | 📊 XAttention 性能分析: NVTX 标记、block size 影响、estimate vs compute 耗时对比 |
| [`docs/observer_architecture.md`](docs/observer_architecture.md) | 📊 Observer 架构: InferenceObserver (TTFT/TPOT)、MemoryObserver (H2D/D2H/D2D) 设计 |
| [`docs/memory_communication_benchmark.md`](docs/memory_communication_benchmark.md) | 📊 通信量测试: Full vs XAttention 通信量对比 (32K/64K)、阶段分离统计 |
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL)≤10B 推荐模型 |
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
| [`docs/xattn_density_alignment_analysis.md`](docs/xattn_density_alignment_analysis.md) | 📊 ANALYSIS: GPU-only vs Offload 模式 density 对齐分析chunked softmax 边界效应5-7% 差异根因 |
| [`docs/xattn_kv_chunking_density_test.md`](docs/xattn_kv_chunking_density_test.md) | 🧪 TEST: XAttention KV chunking density 验证threshold=1.0 对齐threshold<1.0 差异 10-13% |
| [`docs/gpuonly_density_alignment_test.md`](docs/gpuonly_density_alignment_test.md) | ✅ TEST: Density 对齐验证 (GPU-only + Offload, 4K-64K)xattn_estimate vs KV chunking 完全一致 |
| [`docs/xattn_memory_benchmark.md`](docs/xattn_memory_benchmark.md) | 📊 BENCH: XAttention 内存基准测试Qwen3-0.6B 32K 在 24GB 显存可行 (gpu-util=0.28) |
| [`docs/xattn_offload_stream_sync_fix.md`](docs/xattn_offload_stream_sync_fix.md) | 🐛 FIX: XAttention Offload stream 同步 bugPass1/Pass2 K 数据不一致compute_stream 包装 |
| [`docs/xattn_density_types.md`](docs/xattn_density_types.md) | 📊 Compute vs Comm density: BSA block (128) vs CPU block (4096) 粒度,聚合效应导致 comm=100% |
| [`docs/xattn_density_alignment_verification.md`](docs/xattn_density_alignment_verification.md) | ✅ VERIFIED: GPU-only vs Offload density 对齐验证 (32K 差异 0.37%, 64K 差异 0.09%) |
| [`docs/test_ruler_usage_guide.md`](docs/test_ruler_usage_guide.md) | 📖 GUIDE: test_ruler.py 使用指南RULER benchmark 测试命令,已验证的命令示例 |
| [`docs/xattn_offload_profiling_32k.md`](docs/xattn_offload_profiling_32k.md) | 📊 PROFILE: XAttn vs Full 32K nsys 分析estimate 占 41%find_blocks 占 37%compute 仅 21% |
| [`docs/changelog_2026-02-05.md`](docs/changelog_2026-02-05.md) | 📋 CHANGELOG: GQA buffer OOM 修复 (节省 16GB)tests 目录清理 (-4306 行) |
## Rules Index
@@ -32,6 +57,8 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`.claude/rules/gpu-testing.md`](.claude/rules/gpu-testing.md) | GPU type detection, card assignment, needle test requirements |
| [`.claude/rules/sparse-policy.md`](.claude/rules/sparse-policy.md) | SparsePolicy implementation requirements |
| [`.claude/rules/planning-with-files.md`](.claude/rules/planning-with-files.md) | Planning file management for complex tasks |
| [`.claude/rules/gpu-monitor.md`](.claude/rules/gpu-monitor.md) | **GPU memory monitoring**: 必须使用 gpu-monitor agent禁止手动 nvidia-smi 循环 |
| [`.claude/rules/test-ruler.md`](.claude/rules/test-ruler.md) | **test_ruler.py 规则**: 禁止 --help必须查阅文档含快速参考和命令模板 |
## GPU Mutex for Multi-Instance Debugging
@@ -86,6 +113,15 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
**Files**: `bench.py` (GPU), `bench_offload.py` (CPU offload), `bench_vllm.py` (comparison)
**GPU-only 测试模型选择**:
| GPU | 显存 | GPU-only 测试模型 |
|-----|------|------------------|
| RTX 3090 | 24GB | **Qwen3-0.6B** (必须7B+ 模型会 OOM) |
| A100 | 40GB+ | Qwen3-0.6B / 4B / 7B 均可 |
**Offload Mode Constraint**: When using `enable_cpu_offload=True`, only test with context length ≥ 32K. Shorter contexts don't exercise the chunked offload pipeline properly.
**Common Issues**:
1. `max_num_batched_tokens < max_model_len`: Set equal for long context
2. CUDA graph dimension mismatch: Ensure `input_len + output_len <= max_model_len`

View File

@@ -2,6 +2,7 @@ import os
import time
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.observer import InferenceObserver
def bench_decode(llm, num_seqs, input_len, output_len):
@@ -14,13 +15,17 @@ def bench_decode(llm, num_seqs, input_len, output_len):
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
t = time.time() - t
# Calculate metrics
prefill_tokens = num_seqs * input_len
# Get metrics from InferenceObserver
ttft_ms = InferenceObserver.ttft / 1e6
tpot_ms = InferenceObserver.tpot / 1e6
# Calculate throughput from observer metrics
decode_tokens = num_seqs * output_len
decode_throughput = decode_tokens / t
decode_throughput = 1000.0 / tpot_ms if tpot_ms > 0 else 0 # tokens/s per sequence
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
print(f" TTFT: {ttft_ms:.2f}ms, TPOT: {tpot_ms:.2f}ms")
print(f" Decode Throughput: {decode_throughput:.2f} tok/s (from observer)")
def bench_prefill(llm, num_seqs, input_len):
@@ -33,31 +38,69 @@ def bench_prefill(llm, num_seqs, input_len):
t = time.time()
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
t = time.time() - t
# Get TTFT from InferenceObserver
ttft_ms = InferenceObserver.ttft / 1e6
ttft_s = ttft_ms / 1000.0
total_input_tokens = num_seqs * input_len
throughput = total_input_tokens / t
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
# Use observer TTFT for accurate prefill throughput
throughput_observer = total_input_tokens / ttft_s if ttft_s > 0 else 0
throughput_external = total_input_tokens / t
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len})")
print(f" External Time: {t:.2f}s, Throughput: {throughput_external:.2f}tok/s")
print(f" Observer TTFT: {ttft_ms:.2f}ms, Throughput: {throughput_observer:.2f}tok/s")
def main():
import argparse
from nanovllm.config import SparsePolicyType
parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance")
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
# Sparse policy option (GPU-only mode now supports policy routing)
parser.add_argument("--policy", type=str, default=None,
choices=["full", "xattn"],
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
parser.add_argument("--enable-policy", action="store_true",
help="Enable sparse policy routing (FullAttentionPolicy by default)")
parser.add_argument("--gpu-util", type=float, default=0.9,
help="GPU memory utilization (default: 0.9)")
parser.add_argument("--block-size", type=int, default=1024,
help="KV cache block size (default: 1024)")
parser.add_argument("--enforce-eager", action="store_true",
help="Disable CUDA graphs (default: False)")
args = parser.parse_args()
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
path = os.path.expanduser(args.model)
max_len = args.max_len
# Configure sparse policy
if args.policy == "xattn":
sparse_policy = SparsePolicyType.XATTN_BSA
print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}")
elif args.policy == "full" or args.enable_policy:
sparse_policy = SparsePolicyType.FULL
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
else:
sparse_policy = None
print(f"\n[nanovllm GPU] max_len={max_len}")
llm = LLM(
path,
enforce_eager=False,
enforce_eager=args.enforce_eager,
max_model_len=max_len,
max_num_batched_tokens=max_len,
sparse_policy=sparse_policy,
gpu_memory_utilization=args.gpu_util,
kvcache_block_size=args.block_size,
)
# Warmup

View File

@@ -2,6 +2,15 @@ import os
import time
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.observer import InferenceObserver
from nanovllm.utils.memory_observer import MemoryObserver
def print_memory_stats():
"""Print MemoryObserver communication statistics"""
fmt = MemoryObserver._fmt_bytes
print(f"[Memory] Prefill H2D: {fmt(MemoryObserver.prefill_h2d_bytes)}, D2H: {fmt(MemoryObserver.prefill_d2h_bytes)}")
print(f" Decode H2D: {fmt(MemoryObserver.decode_h2d_bytes)}, D2H: {fmt(MemoryObserver.decode_d2h_bytes)}")
def bench_decode(llm, num_seqs, input_len, output_len):
@@ -14,16 +23,18 @@ def bench_decode(llm, num_seqs, input_len, output_len):
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
t = time.time() - t
# Calculate metrics
prefill_tokens = num_seqs * input_len
decode_tokens = num_seqs * output_len
# Get metrics from InferenceObserver
ttft_ms = InferenceObserver.ttft / 1e6
tpot_ms = InferenceObserver.tpot / 1e6
# Approximate: assume prefill takes ~input_len/prefill_speed, rest is decode
# For more accurate measurement, we'd need internal timing
decode_throughput = decode_tokens / t # This includes prefill time, so it's a lower bound
# Calculate throughput from observer metrics
decode_tokens = num_seqs * output_len
decode_throughput = 1000.0 / tpot_ms if tpot_ms > 0 else 0 # tokens/s per sequence
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
print(f" TTFT: {ttft_ms:.2f}ms, TPOT: {tpot_ms:.2f}ms")
print(f" Decode Throughput: {decode_throughput:.2f} tok/s (from observer)")
print_memory_stats()
def bench_prefill(llm, num_seqs, input_len):
@@ -36,9 +47,20 @@ def bench_prefill(llm, num_seqs, input_len):
t = time.time()
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
t = time.time() - t
# Get TTFT from InferenceObserver
ttft_ms = InferenceObserver.ttft / 1e6
ttft_s = ttft_ms / 1000.0
total_input_tokens = num_seqs * input_len
throughput = total_input_tokens / t
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
# Use observer TTFT for accurate prefill throughput
throughput_observer = total_input_tokens / ttft_s if ttft_s > 0 else 0
throughput_external = total_input_tokens / t
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len})")
print(f" External Time: {t:.2f}s, Throughput: {throughput_external:.2f}tok/s")
print(f" Observer TTFT: {ttft_ms:.2f}ms, Throughput: {throughput_observer:.2f}tok/s")
print_memory_stats()
def main():
@@ -46,40 +68,67 @@ def main():
from nanovllm.config import SparsePolicyType
parser = argparse.ArgumentParser(description="Benchmark CPU offload performance")
parser.add_argument("--enable-quest", action="store_true", help="Enable Quest sparse attention for decode")
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
# Sparse policy selection (mutually exclusive)
sparse_group = parser.add_mutually_exclusive_group()
sparse_group.add_argument("--enable-quest", action="store_true",
help="Enable Quest sparse attention (decode only, prefill uses full)")
sparse_group.add_argument("--enable-xattn", action="store_true",
help="Enable XAttention BSA (prefill only, decode uses full)")
# Quest parameters
parser.add_argument("--topk", type=int, default=16, help="Top-K blocks for Quest (default: 16)")
parser.add_argument("--threshold", type=int, default=4, help="Apply sparse only when blocks > threshold (default: 4)")
# XAttention parameters
parser.add_argument("--xattn-threshold", type=float, default=0.95,
help="XAttention cumulative attention threshold (default: 0.95)")
parser.add_argument("--xattn-stride", type=int, default=8,
help="XAttention Q/K downsampling stride (default: 8)")
# General parameters
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (default: 6)")
parser.add_argument("--num-gpu-blocks", type=int, default=4, help="Number of GPU blocks (default: 4)")
parser.add_argument("--block-size", type=int, default=1024, help="KV cache block size (default: 1024)")
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
parser.add_argument("--enforce-eager", action="store_true", help="Disable CUDA Graphs (use eager mode)")
args = parser.parse_args()
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
path = os.path.expanduser(args.model)
max_len = args.max_len
# Enable MemoryObserver for communication stats
MemoryObserver._enabled = True
# Setup policy configuration
if args.enable_quest:
sparse_policy = SparsePolicyType.QUEST
print(f"\n[Quest Sparse Attention] topk={args.topk}, threshold={args.threshold}")
print(f"\n[Quest Sparse Attention] decode: Quest (topk={args.topk}, threshold={args.threshold}), prefill: Full")
elif args.enable_xattn:
sparse_policy = SparsePolicyType.XATTN_BSA
print(f"\n[XAttention BSA] prefill: XAttn (tau={args.xattn_threshold}, stride={args.xattn_stride}), decode: Full")
else:
sparse_policy = SparsePolicyType.FULL
print("\n[Full Attention] baseline (no sparse)")
print(f"[Config] max_len={max_len}, num_gpu_blocks={args.num_gpu_blocks}")
print(f"[Config] max_len={max_len}, num_gpu_blocks={args.num_gpu_blocks}, block_size={args.block_size}")
llm = LLM(
path,
enforce_eager=False,
enforce_eager=args.enforce_eager,
max_model_len=max_len,
max_num_batched_tokens=max_len,
enable_cpu_offload=True,
num_gpu_blocks=args.num_gpu_blocks,
kvcache_block_size=args.block_size,
sparse_policy=sparse_policy,
# Quest parameters
sparse_topk_blocks=args.topk,
sparse_threshold_blocks=args.threshold,
# XAttention parameters
sparse_threshold=args.xattn_threshold,
sparse_stride=args.xattn_stride,
)
# Warmup

View File

@@ -1,5 +1,14 @@
import os
os.environ["VLLM_USE_V1"] = "1"
import sys
# Parse --use-v1 flag before importing vllm
use_v1 = "--use-v1" in sys.argv
if use_v1:
os.environ["VLLM_USE_V1"] = "1"
sys.argv.remove("--use-v1")
else:
os.environ["VLLM_USE_V1"] = "0"
import time
from random import randint, seed
from vllm import LLM, SamplingParams
@@ -44,24 +53,28 @@ def bench_prefill(llm, num_seqs, input_len):
def main():
import argparse
parser = argparse.ArgumentParser(description="Benchmark vLLM performance (for comparison)")
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--gpu-util", type=float, default=0.9, help="GPU memory utilization (default: 0.9)")
parser.add_argument("--enforce-eager", action="store_true", help="Disable CUDA Graphs (use eager mode)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
args = parser.parse_args()
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
path = os.path.expanduser(args.model)
max_len = args.max_len
print(f"\n[vLLM] max_len={max_len}")
print(f"\n[vLLM] max_len={max_len}, gpu_util={args.gpu_util}, enforce_eager={args.enforce_eager}")
llm = LLM(
path,
enforce_eager=False,
enforce_eager=args.enforce_eager,
max_model_len=max_len,
max_num_seqs=128,
gpu_memory_utilization=0.9,
gpu_memory_utilization=args.gpu_util,
)
# Warmup

View File

@@ -0,0 +1,199 @@
# CPU Offload Benchmark Results
本文档记录 `bench_offload.py` 在不同配置下的性能测试结果。
## 测试环境
| 参数 | 值 |
|------|-----|
| GPU | NVIDIA A100-SXM4-80GB |
| 模型 | Llama-3.1-8B-Instruct |
| GPU slots | 4 |
## Sparse Policy 配置
| 策略 | Prefill | Decode | 说明 |
|------|---------|--------|------|
| FULL | Full Attention | Full Attention | 基线,加载所有 blocks |
| XATTN_BSA | XAttention (tau=0.95, stride=8) | Full Attention (fallback) | 稀疏 prefill |
## 测试结果
### Block Size 4096 (推荐)
#### GPU-only 模式
| 上下文 | Full Attention | XAttention | 相对性能 |
|--------|----------------|------------|----------|
| 32K | 4863 tok/s | 5587 tok/s | **+14.9%** ✅ |
| 64K | 3373 tok/s | 4766 tok/s | **+41.3%** ✅ |
#### CPU Offload 模式 (优化后, 2026-01-28)
| 上下文 | Full Attention | XAttention | 相对性能 |
|--------|----------------|------------|----------|
| 32K | 4678 tok/s | 4398 tok/s | **-6.0%** |
| 64K | 3331 tok/s | 3203 tok/s | **-3.8%** |
| 128K | 2144 tok/s | 2196 tok/s | **+2.4%** ✅ |
#### CPU Offload 模式 (优化前, 2026-01-27)
| 上下文 | Full Attention | XAttention | 相对性能 |
|--------|----------------|------------|----------|
| 32K | 4648 tok/s | 4002 tok/s | **-13.9%** ❌ |
| 64K | 3329 tok/s | 2642 tok/s | **-20.6%** ❌ |
| 128K | 2122 tok/s | 867 tok/s | **-59.1%** ❌ |
### Block Size 256 (小 block 测试)
#### CPU Offload 模式 (64K)
| 策略 | 耗时 | 吞吐量 | 相对性能 |
|------|------|--------|----------|
| Full Attention | 401.04s | 163.41 tok/s | baseline |
| XAttention BSA | 390.35s | 167.89 tok/s | **+2.7%** ✅ |
### Block Size 1024 (历史测试)
#### CPU Offload 模式
| 上下文 | Full Attention | XAttention | 相对性能 |
|--------|----------------|------------|----------|
| 32K | 1587.74 tok/s | 1172.33 tok/s | -26% |
| 128K | 552.63 tok/s | 466.17 tok/s | -16% |
## 关键发现
### 1. GPU-only vs CPU Offload 模式差异
| 模式 | XAttention 效果 | 原因 |
|------|-----------------|------|
| **GPU-only** | ✅ 显著加速 (+15% ~ +41%) | 计算是瓶颈,稀疏注意力减少 FLOPs |
| **CPU Offload (优化后)** | ✅ 长上下文略有收益 | estimate_block_size 优化减少估计开销 |
| **CPU Offload (优化前)** | ❌ 性能下降 (-14% ~ -59%) | 传输是瓶颈,稀疏估计增加额外开销 |
### 2. Block Size 对性能的影响
| Block Size | 64K Full (Offload) | 特点 |
|------------|-------------------|------|
| 4096 | 3329 tok/s | ⭐ 最佳性能 |
| 1024 | ~1500 tok/s | 中等 |
| 256 | 163 tok/s | 极慢20x 下降) |
**原因**: 更小的 block = 更多的 blocks = 更多 H2D 传输开销
### 3. XAttention 在小 Block Size 下反转
当 block size = 256 时XAttention 反而略有优势 (+2.7%)
- 256 个 blocks (vs 16 个 @ 4096)
- 稀疏跳过的 blocks 比例更明显
- 但绝对性能极差,不推荐使用
### 4. estimate_block_size 优化效果 (2026-01-28)
```
Offload 模式 XAttention 相对性能变化:
优化前 优化后 改进
32K: -13.9% -6.0% +7.9pp
64K: -20.6% -3.8% +16.8pp
128K: -59.1% +2.4% +61.5pp ✅
```
优化内容:
- `estimate_block_size` 从 4096 改为 1024
- `softmax_fuse_block_sum` kernel 时间从 48% 降到 1% (44x 加速)
- 选择策略从 mask + voting 改为 score + threshold
优化后结论:
- **128K 长上下文 XAttention 反超 Full Attention**
- 短上下文仍有少量开销,但已显著减少
## 结论
### 推荐配置 (优化后, 2026-01-28)
| 场景 | 推荐策略 | Block Size |
|------|----------|------------|
| GPU-only (VRAM 充足) | XAttention | 4096 |
| CPU Offload (128K+) | XAttention | 4096 |
| CPU Offload (32K-64K) | Full Attention 或 XAttention | 4096 |
### XAttention 适用条件 (优化后)
**适合**:
- GPU-only 模式(计算密集)
- CPU Offload + 长上下文128K+)有正向收益
- 长上下文64K+)收益更大
⚠️ **中性**:
- CPU Offload + 中等上下文32K-64K略慢 3-6%,可接受
**不推荐**:
- 短上下文(<32K收益不明显
## 运行命令
```bash
# GPU-only 模式
CUDA_VISIBLE_DEVICES=0 python bench.py --max-len 65536 --block-size 4096 --gpu-util 0.7
CUDA_VISIBLE_DEVICES=0 python bench.py --max-len 65536 --block-size 4096 --gpu-util 0.7 --policy xattn
# CPU Offload 模式 (推荐 block-size 4096)
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 4096
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 4096 --enable-xattn
# CPU Offload 模式 (小 block size 测试)
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 256
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --max-len 65536 --block-size 256 --enable-xattn
# 调整 XAttention 参数
CUDA_VISIBLE_DEVICES=0 python bench_offload.py --enable-xattn --xattn-threshold 0.8 --xattn-stride 16
```
## FlashInfer Merge 优化 (2026-01-28)
将 Triton 实现的 `merge_attention_outputs` 替换为 FlashInfer 的 `cascade.merge_state`
### 性能对比 (Full Attention, block-size 4096)
| 上下文 | Triton merge | FlashInfer merge | 提升 |
|--------|--------------|------------------|------|
| 32K | 4678 tok/s | 4717 tok/s | **+0.8%** |
| 64K | 3331 tok/s | 3411 tok/s | **+2.4%** |
| 128K | 2144 tok/s | 2178 tok/s | **+1.6%** |
### 关键发现
1. **端到端提升有限**0.8% ~ 2.4%merge 操作不是主要瓶颈
- H2D 传输占主导64K 传输 64GB
- Attention 计算是另一主要耗时
- Merge 在总耗时中占比很小
2. **Merge kernel 单独对比**(长序列时 FlashInfer 优势明显):
| seq_len | heads | Triton (ms) | FlashInfer (ms) | Speedup |
|---------|-------|-------------|-----------------|---------|
| 4096 | 32 | 0.129 | 0.087 | **1.49x** |
| 8192 | 32 | 0.251 | 0.147 | **1.70x** |
| 16384 | 32 | 0.499 | 0.274 | **1.82x** |
3. **短序列 FlashInfer 反而慢**格式转换开销squeeze, transpose, contiguous
### 技术细节
- **LSE 格式差异**FlashInfer 使用 log2flash_attn 使用 ln
- **转换系数**`LOG2_E = 1.4427`ln → log2`LN_2 = 0.6931`log2 → ln
- **FlashInfer attention JIT 问题**CUDA 版本兼容性问题,仅使用 merge_state
### 代码位置
- `nanovllm/ops/chunked_attention.py`: `merge_attention_outputs_flashinfer()`
- `nanovllm/kvcache/sparse/full_policy.py`: 3 处 import 更新
- `nanovllm/kvcache/sparse/xattn_bsa.py`: 1 处 import 更新
## 更新记录
- 2026-01-28: **FlashInfer merge 替换 Triton merge**,端到端提升 0.8% ~ 2.4%
- 2026-01-28: **estimate_block_size 优化后重新测试**128K XAttention 反超 Full (+2.4%)
- 2026-01-27: 添加 GPU-only vs Offload 对比block size 影响分析
- 2026-01-27: 初始测试Llama-3.1-8B-Instruct, A100 80GB

View File

@@ -0,0 +1,94 @@
# Changelog 2026-02-05
## Bug Fixes
### XAttention Offload GQA Buffer OOM Fix
**Issue**: `docs/issue_xattn_offload_gqa_buffer_oom.md`
**Problem**: 在 XAttention BSA + CPU Offload 模式下,`alloc_policy_metadata()` 分配了只有 GPU-only 模式才需要的 GQA expansion buffers (`_k_expanded`, `_v_expanded`),导致 24GB GPU (RTX 3090) 上 OOM。
**Root Cause**:
- GQA buffer 大小: `2 × num_heads × max_seq_len × head_dim × dtype_size`
- 对于 1M max_seq_len: 2 × 32 × 1048576 × 128 × 2 = **16 GB**
- Offload 模式的 `compute_chunked_prefill()` 不需要这些 buffer
**Fix** (commit `11a867f`):
1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数
2. `nanovllm/kvcache/sparse/xattn_bsa.py`: offload 模式跳过 GQA buffer 分配
3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数
**Memory Savings**:
| max_model_len | 修复前 | 修复后 |
|---------------|--------|--------|
| 72K | +1.1 GB | 0 GB |
| 1M | +16 GB | 0 GB |
**Verification**:
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA
```
- 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers`
- 测试结果: 100% 准确率
---
## Code Cleanup
### Tests Directory Cleanup
**Commits**: `a709551`, `2b61c5a`, `d35dd76`
删除了 16 个冗余/过时的测试文件,保留核心测试:
**保留的文件** (4 个):
| 文件 | 用途 |
|------|------|
| `test_ruler.py` | 核心 RULER benchmark (13 tasks, 100 samples) |
| `test_xattn_estimate_alignment.py` | XAttn kernel 一致性验证 |
| `utils.py` | 共享工具函数 |
| `__init__.py` | 包标记 |
**删除的文件** (16 个, -4306 行):
| 类别 | 文件 | 删除原因 |
|------|------|----------|
| XAttn 测试 | `test_xattn_bsa.py` | 功能被 test_ruler 覆盖 |
| | `test_xattn_chunked.py` | 与 estimate_chunked 重复 |
| | `test_xattn_estimate_chunked.py` | chunked prefill 验证 |
| | `test_xattn_kernels.py` | Triton kernel 单元测试 |
| | `test_xattn_kv_chunking_batch.py` | batch 验证 |
| Needle 测试 | `test_needle.py` | 被 test_ruler NIAH 任务覆盖 |
| | `test_needle_ref.py` | HF 参考实现 |
| CUDA Graph | `test_chunk_attention_graph.py` | 被 graph_reuse 取代 |
| | `test_chunk_attention_graph_reuse.py` | 实验性功能 |
| | `test_cudagraph_memory.py` | 内存分析工具 |
| 其他 | `test_gpuonly_density_alignment.py` | GPU-only 密度测试 |
| | `test_hierarchical_estimate.py` | 分层估计测试 |
| | `test_quest_policy.py` | Quest 策略测试 |
| | `test_sequential.py` | 状态隔离测试 |
| | `bench_estimate_block_size.py` | 性能 benchmark |
| | `modeling_qwen3.py` | Qwen3 参考模型 |
**Note**: 所有删除的文件可从 git 历史恢复:
```bash
git checkout <commit-hash>^ -- tests/<filename>
```
---
## Summary
| 类型 | 数量 | 影响 |
|------|------|------|
| Bug Fix | 1 | 节省 16GB 显存 (1M seq) |
| 文件删除 | 16 | -4306 行代码 |
| 新增文档 | 1 | 本文件 |

View File

@@ -0,0 +1,300 @@
# CPU Offload 优化策略
本文档记录 CPU Offload 场景下的性能优化策略分析,包括实际可行的方案和前沿研究方向。
## 问题回顾
根据 [CPU 调度延迟分析](cpu_scheduling_latency_analysis.md),当前 chunked attention pipeline 的主要问题:
| 指标 | 当前值 | 理论值 |
|------|--------|--------|
| Flash kernel 执行时间 | ~138 μs | - |
| Flash kernel 间隔 | ~942 μs | ~211 μs (仅 H2D + merge) |
| GPU 利用率 | **12.8%** | **39.5%** (理论上限) |
| CPU 调度空闲占比 | **77-81%** | 0% |
**瓶颈根源**:每个 block 都经过完整的 Python 循环,导致大量 CPU 调度延迟。
---
## 优化方案一:调大 Chunk Size推荐
### 核心洞察
**Merge 多个小 chunk 和直接使用大 chunk 是等效的**
```
方案 A: Merge 4 个小 chunks
[H2D 2K][H2D 2K][H2D 2K][H2D 2K] → concat → [Flash 8K] → merge
方案 B: 直接用大 chunk
[H2D 8K] → [Flash 8K] → merge
计算结果完全等效!
```
### 收益分析
| 指标 | 小 chunk (2K) × 4 | 大 chunk (8K) × 1 |
|------|-------------------|-------------------|
| H2D 次数 | 4 | 1 |
| Flash kernel 调用 | 4 | 1 |
| Merge 调用 | 4 | 1 |
| Python 循环次数 | 4 | 1 |
| CPU 调度开销 | 4 × ~300μs = 1200μs | 1 × ~300μs = 300μs |
**本质**CPU 调度延迟问题的根源是循环次数太多,调大 chunk size 直接减少循环次数。
### Trade-off
1. **GPU 内存增加**
- 2K chunk: 每 slot ~4MB (K+V)
- 8K chunk: 每 slot ~16MB (K+V)
- 4 slots = 64MB对 80GB A100 影响很小
2. **单次 H2D 时间变长**
- H2D 8K ≈ 350μs
- Flash 8K ≈ 550μs
- 因为 Flash > H2Dpipeline 仍然有效
### 配置方法
```bash
# 测试不同 block size
python bench_offload.py --kvcache-block-size 2048 # 基准
python bench_offload.py --kvcache-block-size 4096 # 2x
python bench_offload.py --kvcache-block-size 8192 # 4x
```
---
## 优化方案二CUDA Graph适用于非 Attention 部分)
### CUDA Graph 在 Offload 场景的局限性
CUDA Graph 的前提:所有操作在 capture 时确定,数据地址固定。
**Offload 场景的现实**
1. **H2D 源地址动态** - 每次从不同的 CPU block 加载
2. **加载决策在运行时** - 哪些 block 需要加载是动态的
3. **CPU 必须协调** - H2D 和 Compute 的同步需要 CPU 参与
```
Offload 场景:
┌─────────────────────────────────────────┐
│ 数据在 CPU需要动态加载 │
│ [H2D_i] → [Compute] → [H2D_{i+n}] → ...│
│ ↑ 动态、CPU 必须参与调度 │
└─────────────────────────────────────────┘
即使用 Graph
Python: [wait_h2d] [replay] [launch_h2d] [wait_h2d] [replay] ...
↑ CPU 参与 ↑ CPU 参与 ↑ CPU 参与
CPU 调度开销仍然存在Graph 只优化了中间的 compute 部分。
```
**结论**CUDA Graph 不是 Offload 场景的银弹。
### 适用场景MLP 和 Projection 层
LLM 每层的计算流程:
```
┌─────────────────────────────────────────────────────────────┐
│ [LayerNorm] → [QKV Proj] → [Attention] → [O Proj] → [Add] │
│ ↑ │
│ KV Offload │
│ [LayerNorm] → [MLP: gate + up + down] → [Add] │
└─────────────────────────────────────────────────────────────┘
```
| 组件 | 涉及 Offload | 能用 CUDA Graph |
|------|-------------|-----------------|
| LayerNorm | ❌ | ✅ |
| QKV Projection | ❌ | ✅ |
| **Attention** | ✅ | ❌ |
| Output Projection | ❌ | ✅ |
| MLP (FFN) | ❌ | ✅ |
**只有 Attention 涉及动态 KV Cache 加载,其余都是"纯计算",可以用 CUDA Graph。**
### 实现方案
```python
class OptimizedLayer:
def __init__(self, layer):
# Graph 1: Attention 之前
self.graph_pre_attn = capture([
layer.input_layernorm,
layer.self_attn.q_proj,
layer.self_attn.k_proj,
layer.self_attn.v_proj,
])
# Graph 2: Attention 之后 + MLP
self.graph_post_attn = capture([
layer.self_attn.o_proj,
# residual add
layer.post_attention_layernorm,
layer.mlp.gate_proj,
layer.mlp.up_proj,
layer.mlp.down_proj,
# residual add
])
def forward(self, hidden_states, kv_cache):
# Pre-attention (CUDA Graph)
self.graph_pre_attn.replay()
# Attention with offload (动态,不能用 graph)
attn_output = chunked_attention_with_offload(q, kv_cache)
# Post-attention + MLP (CUDA Graph)
self.graph_post_attn.replay()
```
### 收益估算
MLP 每层典型操作 launch 开销:
- `gate_proj`, `up_proj`, `act_fn`, `gate * up`, `down_proj`, `residual add`
- 每个操作 ~30-50μs launch 开销,总计 ~200μs/层
- 用 CUDA Graph~30μs/层
**32 层 × 170μs 节省 ≈ 5.4ms**
---
## 优化方案三:前沿研究方向
### 1. InfiniGen - 投机预取 (OSDI'24)
**核心思想**:不需要加载所有 KV只预取"重要"的 token。
```
关键洞察:相邻层的 attention pattern 高度相似
用第 L 层的 attention score 预测第 L+1 层需要哪些 token
只预取 top-k 重要的 KV entries而不是全部
```
**技术实现**
- 用当前层的 Q 和下一层的部分 K 做"预演"
- 预测下一层的 attention 分布
- 异步预取预测的重要 token
- **减少 PCIe 带宽浪费,而不是加速传输**
**效果**:最高 **3x 加速**
**参考**[InfiniGen (OSDI'24)](https://www.usenix.org/conference/osdi24/presentation/lee)
### 2. ShadowKV - 低秩压缩 + Sparse Offload (ICML'25 Spotlight)
**核心思想**Key 压缩存 GPUValue offload 到 CPU只加载 1.56% 的 KV。
```
Pre-filling:
┌─────────────────────────────────────────────────┐
│ Key Cache → SVD 低秩压缩 → 保留在 GPU │
│ Value Cache → Offload 到 CPU │
│ 计算每个 chunk 的 landmark (均值) │
│ 识别 outlier tokens → 保留在 GPU │
└─────────────────────────────────────────────────┘
Decoding:
┌─────────────────────────────────────────────────┐
│ 用 landmarks 快速估计 attention score │
│ 只加载 top-k 重要的 Value (1.56% sparse) │
│ 结合 GPU 上的 outliers 计算最终结果 │
└─────────────────────────────────────────────────┘
```
**效果**6x 更大 batch size**3.04x 吞吐提升**
**参考**[ShadowKV (ByteDance)](https://github.com/ByteDance-Seed/ShadowKV)
### 3. L2 Cache 异步预取 (2025)
**核心思想**:利用 GPU L2 Cache 做预取,在计算时预取下一批 KV。
```
传统:
Compute: [Flash_i] [Flash_{i+1}]
H2D: [H2D_{i+1}]
↑ 等待
L2 Prefetch
Compute: [Flash_i + Prefetch_{i+1} to L2] [Flash_{i+1} L2 hit]
↑ 计算时利用空闲 memory bandwidth 预取
```
**技术**
- 在 Flash Attention kernel 内部发起预取指令
- 利用计算时的空闲 memory bandwidth
- 下一次访问直接 L2 hit
**效果****2.15x attention kernel 效率**1.97x 端到端吞吐
**参考**[Asynchronous KV Cache Prefetching (2025)](https://arxiv.org/abs/2504.06319)
### 4. KVPR - I/O-Aware 调度 (ACL'25)
**核心思想**:计算最优的 recompute vs offload 比例。
```
权衡:
- Recompute: 重新计算 KV用 GPU 算力换内存)
- Offload: 从 CPU 加载(用 PCIe 带宽换算力)
KVPR: 根据当前负载动态决定最优比例
+ 预取技术重叠数据传输和计算
```
**参考**[KVPR (ACL'25)](https://aclanthology.org/2025.findings-acl.997.pdf)
---
## 优化策略总结
### 推荐优先级
| 优先级 | 方案 | 核心优化 | 实现复杂度 | 预期收益 |
|--------|------|---------|-----------|---------|
| **P0** | 调大 chunk size | 减少循环次数 | 极低(改配置) | 2-4x |
| **P1** | MLP CUDA Graph | 减少 launch 开销 | 中 | ~5ms/request |
| **P2** | InfiniGen 式预取 | 只加载重要 token | 中高 | 2-3x |
| **P3** | ShadowKV 式压缩 | Key 压缩 + Sparse | 高 | 3x |
| **P3** | C++ Extension | 消除 Python 开销 | 高 | 2-3x |
### 策略分离原则
```
┌─────────────────────────────────────────────────────────────┐
│ Attention + Offload 部分: │
│ - 瓶颈H2D 传输 + CPU 调度 │
│ - 优化:调大 chunk size / 投机预取 / Sparse │
│ │
│ MLP + Proj + Norm 部分: │
│ - 瓶颈Kernel launch 开销 │
│ - 优化CUDA Graph │
└─────────────────────────────────────────────────────────────┘
两部分优化完全正交,可以组合使用。
```
---
## 相关文件
- `nanovllm/kvcache/sparse/full_policy.py`: Chunked attention pipeline
- `nanovllm/kvcache/offload_engine.py`: H2D/D2H 传输管理
- `docs/cpu_scheduling_latency_analysis.md`: 问题分析
## 参考文献
1. [InfiniGen: Efficient Generative Inference of Large Language Models with Dynamic KV Cache Management](https://www.usenix.org/conference/osdi24/presentation/lee) - OSDI'24
2. [ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference](https://github.com/ByteDance-Seed/ShadowKV) - ICML'25 Spotlight
3. [Accelerating LLM Inference Throughput via Asynchronous KV Cache Prefetching](https://arxiv.org/abs/2504.06319) - 2025
4. [KVPR: Efficient LLM Inference with I/O-Aware KV Cache](https://aclanthology.org/2025.findings-acl.997.pdf) - ACL'25
5. [LMCache: An Efficient KV Cache Layer for Enterprise-Scale LLM Inference](https://lmcache.ai/tech_report.pdf) - 2025

View File

@@ -0,0 +1,177 @@
# CPU 调度延迟分析
## 问题概述
在分析 nsys profile 时发现chunked attention pipeline 中存在大量的 **CPU 调度延迟**,导致 GPU 利用率显著下降。
## 观察数据
### 测试环境
- GPU: NVIDIA A100-SXM4-80GB
- 模型: Llama-3.1-8B-Instruct
- 测试: RULER niah_single_1, 64K context
- Profile 文件: `ruler_8slots_test.nsys-rep`
- 时间段: 92.982s - 93.038s
### Kernel 执行时间
| Kernel | 典型执行时间 |
|--------|-------------|
| flash_fwd_kernel | ~138 μs |
| H2D memcpy (2MB) | ~87 μs |
| merge_lse_kernel | ~3.5 μs |
| merge_output_kernel | ~34 μs |
### 操作间隙分析
从 cuda_gpu_trace 观察到的间隙:
```
Start (ms) Dur (μs) Gap (μs) Type
------------------------------------------------------------
92984.680 138.3 378.3 flash_fwd_kernel ← GAP!
92985.051 86.8 232.9 H2D memcpy ← GAP!
92985.141 86.8 2.8 H2D memcpy
92985.587 135.9 360.0 flash_fwd_kernel ← GAP!
92986.026 3.4 302.4 merge_lse ← GAP!
92986.164 33.5 135.0 merge_output ← GAP!
92986.371 86.9 173.4 H2D memcpy ← GAP!
92986.461 86.8 2.7 H2D memcpy
92986.816 137.9 268.2 flash_fwd_kernel ← GAP!
```
### Flash Kernel 间隙分解
| 间隙 | 总时间 | 有效工作时间 | 空闲时间 |
|------|--------|-------------|---------|
| Flash 1 → Flash 2 | 769 μs | ~174 μs (2x H2D) | ~595 μs (77%) |
| Flash 2 → Flash 3 | 1092 μs | ~211 μs (merge + H2D) | ~881 μs (81%) |
| Flash 3 → Flash 4 | 965 μs | ~211 μs (merge + H2D) | ~754 μs (78%) |
**关键发现**: 每个 flash kernel 之间约 **77-81% 的时间是 CPU 调度空闲**
## 间隙来源分析
### 1. CPU 调度延迟类型
| 转换 | 典型延迟 | 原因 |
|------|---------|------|
| Kernel 结束 → 下一个 Kernel 开始 | 100-400 μs | CPU 准备参数、调用 CUDA driver |
| Flash 结束 → H2D 开始 | ~233 μs | Python 代码执行 + CUDA launch |
| H2D 结束 → Flash 开始 | ~360 μs | 同步等待 + kernel launch |
| Flash 结束 → merge 开始 | ~302 μs | Python 代码执行 |
### 2. 延迟产生的代码位置
```python
# full_policy.py: compute_chunked_prefill
for block_idx in range(num_blocks):
# 1. 等待 H2D 完成 (同步点)
offload_engine.wait_slot_layer(current_slot) # ← 可能引入延迟
# 2. 获取 KV 数据
k_block, v_block = offload_engine.get_kv_for_slot(current_slot)
# 3. 调用 flash attention (kernel launch)
block_out, block_lse = flash_attn_with_kvcache(...) # ← CPU 调度延迟
# 4. merge 操作
merge_output(...) # ← CPU 调度延迟
merge_lse(...) # ← CPU 调度延迟
# 5. 发起下一个 H2D (异步)
offload_engine.load_to_slot_layer(next_slot, ...) # ← CPU 调度延迟
```
### 3. 为什么 H2D 之间间隙小
注意到连续的 H2D memcpy 之间间隙只有 ~2.7 μs这是因为
- 它们在同一个 stream 上连续发起
- CUDA driver 可以批量处理
- 没有 Python 代码介入
## GPU 利用率计算
基于观察数据:
| 指标 | 值 |
|------|-----|
| Flash kernel 平均执行时间 | 138 μs |
| Flash kernel 平均间隔 | 942 μs |
| Flash kernel GPU 利用率 | 138 / (138 + 942) = **12.8%** |
如果消除 CPU 调度延迟(仅保留必要的 H2D + merge
| 指标 | 值 |
|------|-----|
| 必要间隔 (2x H2D + merge) | ~211 μs |
| 理论 GPU 利用率 | 138 / (138 + 211) = **39.5%** |
**潜在提升**: 3x GPU 利用率
## 优化方向
### 1. CUDA Graph
将整个 block 处理流程编译为 CUDA Graph消除重复的 kernel launch 开销。
```python
# 伪代码
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
# 预录制 flash + merge 操作
block_out, block_lse = flash_attn_with_kvcache(...)
merge_output(...)
merge_lse(...)
# 运行时只需 replay
for block_idx in range(num_blocks):
graph.replay() # 单次 launch无 Python 介入
```
### 2. 自定义 Triton Kernel
将 flash + merge 融合为单个 kernel减少 kernel launch 次数。
### 3. C++ Extension
将 Python 循环移到 C++ 层,减少 Python 解释器开销。
### 4. 流水线重叠优化
确保 H2D 传输与前一个 block 的计算完全重叠:
```
Block 0: [H2D slot0] [Flash slot0] [merge]
Block 1: [H2D slot1] [Flash slot1] [merge]
Block 2: [H2D slot2] [Flash slot2] [merge]
```
## 验证方法
### 1. 使用 nsys 分析间隙
```bash
# 生成 profile
bash scripts/profile_offload.sh --num-gpu-blocks 8
# 查看 kernel trace
nsys stats --report cuda_gpu_trace --format csv <file>.nsys-rep | \
awk -F',' 'NR>1 && $1 >= START && $1 <= END'
```
### 2. 计算间隙
```python
# 从 trace 数据计算
prev_end = start + duration
gap = next_start - prev_end
```
## 相关文件
- `nanovllm/kvcache/sparse/full_policy.py`: Pipeline 实现
- `nanovllm/kvcache/offload_engine.py`: H2D/D2H 传输
- `scripts/profile_offload.sh`: Profiling 脚本
## 参考
- [CUDA Graph 文档](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs)
- [nsys 用户指南](https://docs.nvidia.com/nsight-systems/UserGuide/index.html)

View File

@@ -0,0 +1,258 @@
# Estimate Block Size 性能分析
本文档记录 XAttention estimate 阶段中 `block_size` 参数对 `softmax_fuse_block_sum` kernel 性能的影响。
## 问题背景
当前 `select_blocks` 中的 estimate 过程使用全局的 `kvcache_block_size`(通常为 4096
```python
# xattn_bsa.py: select_blocks()
block_size = ctx.block_size # 来自 kvcache_manager.block_size (4096)
reshaped_block_size = block_size // self.stride # 4096/8 = 512
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_block_size, # 512 - 性能最差点!
...
)
```
这导致 `softmax_fuse_block_sum` kernel 使用 `reshaped_block_size=512`,而这正是性能曲线的最差点。
## Benchmark 结果
### 测试配置
- GPU: NVIDIA A100-SXM4-80GB
- NUM_HEADS: 32
- HEAD_DIM: 128
- STRIDE: 8
- 测试脚本: `tests/bench_estimate_block_size.py`
### softmax_fuse_block_sum 性能数据
| block_size | reshaped | 16K context | 32K context | 64K context |
|------------|----------|-------------|-------------|-------------|
| 64 | 8 | 4.86ms | 18.36ms | 70.83ms |
| 128 | 16 | 0.83ms | 3.12ms | 16.83ms |
| 256 | 32 | 0.63ms | 2.41ms | 11.24ms |
| 512 | 64 | **0.38ms** | **1.52ms** | 9.54ms |
| 1024 | 128 | 0.42ms | 1.54ms | **6.01ms** |
| 2048 | 256 | 1.08ms | 3.24ms | 12.81ms |
| **4096** | **512** | 9.66ms | 25.36ms | **95.32ms** |
### 性能曲线
```
softmax_fuse_block_sum 耗时 (64K context):
block_size=64 ████████████████████████████████████ 70.83ms
block_size=128 ████████ 16.83ms
block_size=256 █████ 11.24ms
block_size=512 ████ 9.54ms
block_size=1024 ███ 6.01ms ◀── 最优点
block_size=2048 ██████ 12.81ms
block_size=4096 ████████████████████████████████████████████████ 95.32ms ◀── 当前使用
```
### 关键发现
1. **性能呈 U 型曲线**:太小和太大的 block_size 都会导致性能下降
2. **最优点在 512-1024**:对应 `reshaped_block_size` 64-128
3. **当前配置 (4096) 是最差点**95.32ms vs 最优 6.01ms**慢 15.85x**
## 性能曲线解释
```
Performance (耗时)
│ ▲ 太小:
│ / - output blocks 数量多 (q_len / block_size)
│/ - grid 调度开销大
│ - 每个 thread block 工作量小
│ ┌─────────┐
│ / 最优 \
│ / 区域 \ ▲ 太大:
│/ \ - block_size 作为 tl.constexpr
│ \ - 寄存器压力增大 (可能 spill)
│ \ - shared memory 不足
│ \- L1 cache 效率下降
└──────────────────────────────────→ block_size
64 128 256 512 1024 2048 4096
最优点 (512-1024)
```
### Triton Kernel 内部分析
`softmax_fuse_block_sum_kernel` 中的关键约束:
```python
# 每个 thread block 处理的数据
offs_q = tl.arange(0, block_size) # block_size 个元素
m_i = tl.zeros([block_size], dtype=tl.float32) # 寄存器分配
# reshape 操作
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
# 当 block_size=512, segment_size=512 时 → (512, 1, 512) 的 3D tensor
```
`block_size` 过大时:
- 每个 thread block 需要更多寄存器
- `tl.arange(0, block_size)` 生成更大的向量
- reshape 操作的内存访问模式变差
## 优化建议
### 方案 1: 固定 estimate block_size
`select_blocks` 中使用固定的小 block_size 进行估计:
```python
# 建议修改
ESTIMATE_BLOCK_SIZE = 1024 # 或 512而非 ctx.block_size
reshaped_block_size = ESTIMATE_BLOCK_SIZE // self.stride # 128
```
**优点**:简单直接,预期提升 15x
**缺点**estimate 的 block 粒度与 CPU block 不一致,需要映射
### 方案 2: 两级 block 结构
- 外层使用 `kvcache_block_size` (4096) 管理 CPU blocks
- 内层使用 `estimate_block_size` (1024) 进行估计
- 估计结果聚合回 CPU block 粒度
### 方案 3: 自适应 block_size
根据 context length 动态选择 estimate block_size
| Context Length | Recommended block_size |
|----------------|------------------------|
| < 16K | 512 |
| 16K - 64K | 1024 |
| > 64K | 1024 |
## 与实际 Profiling 的对比
### Nsys Profiling 数据 (64K context, block_size=4096)
| 阶段 | 时间占比 | 说明 |
|------|----------|------|
| softmax_fuse_block_sum | **48.1%** | 最后一个 chunk |
| flash_fwd_kernel | 30.7% | 实际 attention 计算 |
| flat_group_gemm | 3.5% | estimate GEMM |
### 预期优化效果
如果将 estimate block_size 从 4096 改为 1024
| 指标 | 当前 (4096) | 优化后 (1024) | 提升 |
|------|-------------|---------------|------|
| softmax kernel | 95.32ms | 6.01ms | **15.85x** |
| estimate 阶段占比 | 48.1% | ~5% | 显著降低 |
| 总体 prefill 时间 | ~2s (最后chunk) | ~1.1s | ~1.8x |
## 测试命令
```bash
# 运行 benchmark
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/bench_estimate_block_size.py --gpu 0
# 指定单个 context length
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/bench_estimate_block_size.py --gpu 0 --ctx-len 65536
```
## 相关文件
| 文件 | 说明 |
|------|------|
| `nanovllm/kvcache/sparse/xattn_bsa.py` | XAttention BSA Policy 实现 |
| `nanovllm/ops/xattn.py` | Triton kernels |
| `tests/bench_estimate_block_size.py` | 性能测试脚本 |
| `docs/xattn_performance_analysis.md` | XAttention 整体性能分析 |
## 分级求和方案 (Hierarchical Block Sum)
使用小的 `estimate_block_size=1024` 计算细粒度 block_sums然后聚合到 CPU block 级别 (4096)。
### 数学等价性
```
方案1 (block_size=4096): softmax_fuse_block_sum → [1, heads, 1, 1]
方案2 (block_size=1024): softmax_fuse_block_sum → [1, heads, 4, 4] → sum → [1, heads]
验证结果: Max difference = 0.0 ✅ 完全等价
```
### 验证代码
`tests/test_hierarchical_estimate.py` - 纯 torch + xattn kernels 实现
### 性能提升
| 指标 | 当前 (4096) | 优化后 (1024) | 提升 |
|------|-------------|---------------|------|
| softmax kernel | 12.07 ms | 0.29 ms | **41x** |
| 端到端 estimate | 95 ms | ~6 ms | **15x** |
## ⚠️ 选择策略变更
**重要**: 分级求和方案使用新的选择策略:
| 特性 | 原策略 (mask + voting) | 新策略 (score + threshold) |
|------|------------------------|----------------------------|
| 输入 | `[batch, heads, q_blocks, k_blocks]` | `[batch, heads, num_cpu_blocks]` |
| 选择粒度 | Per-q-block | Per-chunk |
| 聚合方式 | majority voting | threshold on scores |
新策略更简洁,直接利用分级求和产生的 score避免了 mask 生成和 voting 的复杂逻辑。
## 实现状态 ✅ (2026-01-28)
### 已实现
分级求和方案已在 `xattn_bsa.py` 中实现:
```python
class XAttentionBSAPolicy:
def __init__(self, ..., estimate_block_size: int = 1024):
self.estimate_block_size = estimate_block_size # 新参数
def select_blocks(self, ...):
# Step 2: Hierarchical softmax_fuse_block_sum
reshaped_est_bs = estimate_bs // self.stride # 1024/8 = 128
block_sums_fine = softmax_fuse_block_sum(attn_scores, reshaped_est_bs, ...)
# Step 3: Aggregate to CPU block level
block_sums_coarse = block_sums_fine.view(..., num_cpu_blocks, ratio).sum(dim=-1)
cpu_block_scores = block_sums_coarse.sum(dim=2)
# Step 4: Score + threshold selection (replaces mask + voting)
scores_per_block = cpu_block_scores.mean(dim=(0, 1))
# ... cumulative threshold selection
```
### 实测结果 (Nsys Profiling)
| Kernel | 优化前 | 优化后 | 改进 |
|--------|--------|--------|------|
| softmax_fuse_block_sum 占比 | 48.1% | **1.1%** | **44x** |
| softmax_fuse_block_sum 平均时间 | ~2ms | 489us | **4x** |
### 端到端性能 (32K context)
| 指标 | FULL Policy | XATTN Policy | 改进 |
|------|-------------|--------------|------|
| Prefill throughput | 3511 tok/s | 3695 tok/s | +5% |
| TTFT | 9327 ms | 8863 ms | -5% |
## 结论
当前 estimate 阶段使用全局 `kvcache_block_size=4096` 导致 `softmax_fuse_block_sum` kernel 性能处于最差点。通过将 estimate block_size 改为 512-1024可以获得 **15x** 的性能提升,显著降低 estimate 阶段的开销。
**⚠️ 重要变更**: 选择策略从 `mask + majority voting` 改为 `score + threshold`,更简洁且更直接。

View File

@@ -0,0 +1,77 @@
# GPU-only Sparse Policy 整合
本文档记录将 sparse attention 策略整合到 GPU-only 模式的过程和性能对比。
## 背景
当前 sparse policyQuest、XAttention仅在 CPU offload 路径中实现。目标是将其扩展到 GPU-only 模式,以提升长上下文场景下的性能。
## 基准性能(优化前)
**测试环境**:
- GPU: NVIDIA A100-SXM4-80GB
- 模型: Llama-3.1-8B-Instruct
- 上下文长度: 32K tokens
- 日期: 2026-01-27
### Prefill Benchmark (32K context)
| 模式 | Throughput | Time | KV Cache 分配 |
|------|------------|------|---------------|
| **GPU-only (Full Attention)** | 4869.67 tok/s | 6.73s | 438 blocks (56GB GPU) |
| CPU Offload (Full Attention) | 1500.29 tok/s | 21.84s | 4 blocks GPU + 32 blocks CPU |
**性能比**: GPU-only 比 CPU Offload 快 **3.2x**
### 配置详情
**GPU-only 模式**:
```bash
CUDA_VISIBLE_DEVICES=0 python bench.py \
--model ~/models/Llama-3.1-8B-Instruct \
--max-len 32768
```
**CPU Offload 模式**:
```bash
CUDA_VISIBLE_DEVICES=0 python bench_offload.py \
--model ~/models/Llama-3.1-8B-Instruct \
--max-len 32768
```
### KV Cache 配置
| 参数 | GPU-only | CPU Offload |
|------|----------|-------------|
| block_size | 1024 tokens | 1024 tokens |
| per-token KV | 128 KB | 128 KB |
| per-block KV | 128 MB | 128 MB |
| GPU blocks | 438 | 4 |
| CPU blocks | 0 | 32 |
| Total memory | 56 GB | 4.6 GB |
## 目标
将以下 sparse policy 整合到 GPU-only 模式:
| Policy | 阶段 | 描述 |
|--------|------|------|
| Quest | Decode | Top-K block selection based on query-key scores |
| XAttention BSA | Prefill | Block sparse attention with cumulative threshold |
## 实现进度
- [ ] 分析现有 sparse policy 代码结构
- [ ] 设计 GPU-only sparse policy 接口
- [ ] 实现 GPU-only Quest decode
- [ ] 实现 GPU-only XAttention prefill
- [ ] 性能测试和对比
## 优化后性能
*待测试*
| 模式 | Throughput | Speedup vs Full |
|------|------------|-----------------|
| GPU-only + Quest (decode) | TBD | TBD |
| GPU-only + XAttn (prefill) | TBD | TBD |

View File

@@ -0,0 +1,296 @@
# GPU-Only XAttention 指南
本文档介绍 GPU-only 模式下 XAttention BSA 的实现、内存优化和性能分析。
## 概述
GPU-only 模式下,所有 KV cache 存储在 GPU 上,无需 CPU offload。XAttention 通过稀疏注意力加速 prefill 阶段。
### 执行路径对比
| 模式 | Prefill 方法 | Decode 方法 | KV 存储 |
|------|-------------|-------------|---------|
| GPU-only Full | `compute_prefill()` | `compute_decode()` | GPU |
| GPU-only XAttn | `compute_prefill()` | `compute_decode()` | GPU |
| CPU Offload | `compute_chunked_prefill()` | `compute_chunked_decode()` | CPU + GPU |
## 架构设计
### SparsePolicy 接口
```python
class SparsePolicy:
# GPU-only 方法
def compute_prefill(self, q, k, v, ...) -> Tensor
def compute_decode(self, q, k_cache, v_cache, ...) -> Tensor
# CPU Offload 方法
def compute_chunked_prefill(self, q, k, v, ...) -> Tensor
def compute_chunked_decode(self, q, ...) -> Tensor
# 初始化方法
def initialize(self, num_layers, ...) -> None # CPU offload metadata
def alloc_policy_metadata(self, num_heads, ...) -> None # GPU-only buffers
```
### XAttentionBSAPolicy 实现
```
GPU-only Prefill 流程:
┌─────────────────────────────────────────────────────────────┐
│ 1. GQA 扩展 (使用预分配 buffer) │
│ K: [seq, kv_heads, dim] → K_exp: [1, heads, seq, dim] │
│ │
│ 2. XAttention 估计 │
│ flat_group_gemm_fuse_reshape_kernel (Q@K^T) │
│ softmax_fuse_block_sum_kernel (block 重要性) │
│ → sparse mask │
│ │
│ 3. BSA 稀疏注意力 │
│ flash_fwd_block_kernel (只计算选中的 blocks) │
│ → output │
└─────────────────────────────────────────────────────────────┘
```
## 内存预分配
### 问题背景
XAttention 的 `compute_prefill()` 需要 GQA 扩展:
```python
# 之前: 动态分配 (~2GB for 64K)
K_exp = K.repeat_interleave(num_groups, dim=1) # 分配 1
k_bsa = k.repeat_interleave(num_groups, dim=1) # 分配 2 (重复!)
```
每次 prefill 都动态分配,导致:
- 内存碎片
- 分配延迟
- 可能 OOM
### 解决方案: alloc_policy_metadata()
在框架初始化时预分配 buffer
```python
class XAttentionBSAPolicy(SparsePolicy):
def alloc_policy_metadata(self, num_heads, num_kv_heads, head_dim,
max_seq_len, dtype, device):
# 预分配 GQA 扩展 buffer
shape = (1, num_heads, max_seq_len, head_dim)
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
def compute_prefill(self, q, k, v, ...):
seq_len = k.shape[0]
# 使用预分配 buffer 的 slice
K_exp = self._k_expanded[:, :, :seq_len, :]
# 原地 GQA 扩展
K_exp.view(...).copy_(K.unsqueeze(2).expand(...))
# 复用同一 buffer 给 BSA
k_bsa = K_exp.squeeze(0).transpose(0, 1)
```
### 内存使用
| 序列长度 | 预分配大小 | 说明 |
|---------|-----------|------|
| 32K | 512 MB | `2 * 32 * 32768 * 128 * 2 bytes` |
| 64K | 1024 MB | `2 * 32 * 65536 * 128 * 2 bytes` |
优化效果:
- 之前: ~2GB 动态分配 (xattn_estimate + BSA 各一次)
- 之后: ~1GB 预分配 (复用同一 buffer)
### 框架集成
```python
# model_runner.py - allocate_kv_cache()
def allocate_kv_cache(self):
# ... KV cache 分配 ...
# GPU-only 模式: 预分配 policy buffers
if not config.enable_cpu_offload:
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_seq_len=config.max_model_len,
dtype=dtype,
device=torch.device("cuda"),
)
```
## 性能分析
### 32K Prefill 性能
| Policy | Throughput | 相对提升 |
|--------|------------|----------|
| Baseline | 4880 tok/s | - |
| Full | 4892 tok/s | +0.2% |
| **XAttention** | **5602 tok/s** | **+15%** |
### 64K Prefill 性能
| Policy | Throughput | 相对提升 |
|--------|------------|----------|
| Baseline | 3386 tok/s | - |
| Full | 3355 tok/s | -0.9% |
| **XAttention** | **4775 tok/s** | **+41%** |
### Kernel 时间分解 (32K)
**XAttention:**
```
FFN GEMM: 3219 ms (54%)
BSA Attention: 1231 ms (21%)
XAttn Estimation: 415 ms (7%)
Other: 1020 ms (18%)
─────────────────────────────
Total: 5885 ms
```
**Full:**
```
FFN GEMM: 3244 ms (48%)
Dense Attention: 2861 ms (43%)
Other: 595 ms (9%)
─────────────────────────────
Total: 6700 ms
```
### 加速来源
```
Dense Attention: 2861 ms
BSA Attention: 1231 ms (节省 1630 ms, -57%)
XAttn Estimation: 415 ms (额外开销)
─────────────────────────────
净节省: 1215 ms (42% attention 时间)
```
## CUDA Graph 限制
### 为什么 Prefill 不能用 CUDA Graph
CUDA Graph 要求所有操作在 capture 时确定:
| 必须固定 | Prefill 的情况 |
|---------|---------------|
| Tensor 形状 | seq_len 可变 (1 ~ max_model_len) |
| Kernel grid | 依赖 seq_len |
| 内存地址 | 中间 tensor 大小变化 |
```python
# 不同请求的 seq_len 不同
request_1: prefill(seq_len=1024) # grid=(8, 32, 1)
request_2: prefill(seq_len=32768) # grid=(256, 32, 1)
```
### Decode 可以用 CUDA Graph
```python
# Decode 每次只处理 1 token
q: [batch_size, 1, heads, dim] # 形状固定
```
nanovllm 为每个 batch_size 预先 capture 一个 graph
```python
def capture_cudagraph(self):
for batch_size in [1, 2, 4, 8, ...]:
with torch.cuda.graph(g):
self.run_model(dummy_input, is_prefill=False)
self.graphs[batch_size] = g
```
### Nsys Profile 结果
```
XAttention 32K Prefill:
Total kernels: 41,904
Non-graph: 41,904 (100%)
Graph: 0
Full 32K Prefill:
Total kernels: 35,308
Non-graph: 35,308 (100%)
Graph: 0
```
**两者都是 100% NON-GRAPH**,这是 prefill 的本质特性。
## Profiling 工具
### 使用 profile.sh
```bash
# XAttention 32K
bash scripts/profile.sh --max-len 32768 --policy xattn
# Full 32K
bash scripts/profile.sh --max-len 32768 --policy full
# 64K (需要降低 gpu-util)
bash scripts/profile.sh --max-len 65536 --policy xattn --gpu-util 0.7
```
### 分析 nsys 结果
```bash
# 查看 kernel 统计
nsys stats --report cuda_gpu_kern_sum results/nsys/<file>.nsys-rep
# 用 sqlite 查询详细数据
sqlite3 results/nsys/<file>.sqlite "
SELECT
(SELECT value FROM StringIds WHERE id = shortName) as kernel,
COUNT(*) as count,
SUM(end-start)/1e6 as total_ms
FROM CUPTI_ACTIVITY_KIND_KERNEL
GROUP BY shortName
ORDER BY total_ms DESC
LIMIT 10
"
```
## 使用指南
### 启用 XAttention GPU-only
```python
from nanovllm import LLM
from nanovllm.config import SparsePolicyType
llm = LLM(
model_path,
max_model_len=32768,
sparse_policy=SparsePolicyType.XATTN_BSA,
gpu_memory_utilization=0.9, # 64K 时可能需要降低
)
```
### 命令行测试
```bash
# bench.py
python bench.py --max-len 32768 --policy xattn
# 64K 需要降低 gpu-util
python bench.py --max-len 65536 --policy xattn --gpu-util 0.7
```
### 最佳实践
1. **32K 及以下**: 使用默认 `gpu_memory_utilization=0.9`
2. **64K**: 降低到 `gpu_memory_utilization=0.7`
3. **Decode**: XAttention 自动 fallback 到 FullAttentionPolicy
4. **Paged KV Cache**: 当 `block_tables` 存在时自动 fallback 到 flash_attn
## 相关文档
- [Sparse Policy 架构](sparse_policy_architecture.md)
- [XAttention 算法详解](xattention_algorithm_guide.md)
- [BSA 接口文档](block_sparse_attn_interface.md)

View File

@@ -0,0 +1,246 @@
# Density Alignment Test Results
验证 GPU-only 和 Offload 模式下三阶段 KV chunking 流程的正确性。
## 测试配置
### GPU-only 模式
- **模型**: Qwen3-0.6B (28 layers, 16 heads, 8 KV heads, head_dim=128)
- **Threshold**: 0.9
- **Block Size**: 128 tokens (BSA block)
- **Stride**: 8
- **Chunk Size**: 16384 tokens
### Offload 模式
- **模型**: Llama-3.1-8B-Instruct (32 layers, 32 heads, 8 KV heads, head_dim=128)
- **Threshold**: 0.9
- **Block Size**: 128 tokens (BSA block)
- **Stride**: 4
- **Chunk Size**: 4096 tokens
## 三阶段 KV Chunking 对齐测试 (2026-02-02)
### 测试目的
验证 `xattn_estimate` 高层 API 与手动实现的三阶段 KV chunking 流程是否完全一致。
### 三阶段流程
```
┌─────────────────────────────────────────────────────────────┐
│ Stage 1: softmax_compute_partial_stats │
│ └── 每个 KV chunk 独立计算 partial stats (m_i, l_i) │
│ │
│ Stage 2: merge_softmax_stats │
│ └── Host 端合并所有 chunks: (m_global, l_global) │
│ │
│ Stage 3: softmax_normalize_and_block_sum │
│ └── 使用全局 stats 归一化并计算 block sums │
└─────────────────────────────────────────────────────────────┘
```
### 测试结果
#### CHUNK_SIZE = 16384 (默认)
| Context | Tokens | Q Chunks | KV Chunks | Density | Mask 差异 | attn_sums 差异 | 结果 |
|---------|--------|----------|-----------|---------|-----------|----------------|------|
| 4K | 3,692 | 1 | 1 | 63.84% | 0 | 0.0 | ✅ |
| 8K | 7,892 | 1 | 1 | 64.98% | 0 | 0.0 | ✅ |
| 16K | 15,689 | 1 | 1 | 61.63% | 0 | 0.0 | ✅ |
| 32K | 32,485 | 2 | 2 | 50.21% | 0 | 0.0 | ✅ |
| **64K** | **64,891** | **4** | **4** | **37.00%** | **0** | **0.0** | ✅ |
#### CHUNK_SIZE = 4096 (更多 chunks)
| Context | Tokens | Q Chunks | KV Chunks | Density | xattn_estimate vs KV chunking | 结果 |
|---------|--------|----------|-----------|---------|-------------------------------|------|
| 4K | 3,692 | 1 | 1 | 63.84% | 0.000000 | ✅ |
| 8K | 7,892 | 2 | 2 | 63.02% | 0.000000 | ✅ |
| 16K | 15,689 | 4 | 4 | 60.08% | 0.000000 | ✅ |
| 32K | 32,485 | 8 | 8 | 49.84% | 0.000000 | ✅ |
| **64K** | **64,891** | **16** | **16** | **36.91%** | **0.000000** | ✅ |
### 64K 详细验证 (CHUNK_SIZE=4096)
64K 序列使用 chunk_size=4096 时产生 16×16 的 chunk 矩阵:
```
seq_len: 64891, q_chunk_num: 16, kv_chunk_num: 16
Q chunk 0: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
Q chunk 1: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
...
Q chunk 15: merged 16 KV chunks → attn_sum shape=[1, 32, 32, 512]
```
每个 Q chunk 需要合并 16 个 KV chunks 的 softmax stats充分验证了 `merge_softmax_stats` 在大规模 chunk 合并场景下的正确性。
### 验证指标
| 指标 | 预期 | 所有长度实际结果 |
|------|------|------------------|
| attn_sums max diff | 0 | 0.000000e+00 |
| attn_sums mean diff | 0 | 0.000000e+00 |
| mask exact match | True | True |
| density diff | 0% | 0.000000% |
### 结论
**三阶段 KV chunking 与一次性处理完全等价,无任何精度损失。**
- 当 seq_len < CHUNK_SIZE (16384):单 chunk 处理
- 当 seq_len >= CHUNK_SIZE多 chunk 分段处理后合并,结果与一次性处理完全一致
---
## Offload 模式测试 (2026-02-02)
使用 Offload 模式保存的真实 KV cache 数据进行测试。
### 测试结果
| 文件 | Tokens | Layer | Saved Density | Computed Density | Q/KV Chunks | 结果 |
|------|--------|-------|---------------|------------------|-------------|------|
| `qkv_3688.pt` | 3.7K | 3 | 38.34% | 38.34% | 1/1 | ✅ PASSED |
| `qkv_7888.pt` | 7.9K | 3 | 29.06% | 27.56% | 2/2 | ✅ PASSED |
| `qkv_15685.pt` | 15.7K | 3 | 19.77% | 18.60% | 4/4 | ✅ PASSED |
| `qkv_32485.pt` | 32.5K | 5 | 15.71% | 15.62% | 8/8 | ✅ PASSED |
| `qkv_64891.pt` | 64.9K | 3 | 11.09% | 11.09% | 16/16 | ✅ PASSED |
### Layer 5 GPU-only 测试 (threshold=0.9)
| 指标 | 结果 |
|------|------|
| Q/K shape | `[1, 16, 21001, 128]` (21K tokens) |
| Density | 6.24% |
| xattn_estimate vs KV chunking | 完全一致 (0.0000%) |
| mask 差异 | 0 / 435600 blocks |
| attn_sums 差异 | max=0.0, mean=0.0 |
### 观察
1. **Density 随 context 增长而降低**: 3.7K (38%) → 64.9K (11%)
2. **xattn_estimate API 与三阶段 KV chunking 完全一致**: 所有长度差异均为 0.0000%
3. **Saved density vs Computed density 略有差异**: 这是因为 saved density 可能在不同 chunk 下记录,累积计算方式略有不同
---
## 附录xattn_bsa vs xattn_estimate 对齐
| Context | Tokens | Layer 0 Density | Compute Density | Min Layer | 验证结果 |
|---------|--------|-----------------|-----------------|-----------|----------|
| 4k | 3,692 | 63.8% | 52.9% | Layer 3 (31.3%) | ✅ PASSED |
| 8k | 7,892 | 65.0% | 52.5% | Layer 5 (27.3%) | ✅ PASSED |
| 16k | 15,689 | 61.6% | 47.8% | Layer 5 (23.5%) | ✅ PASSED |
| 32k | 32,485 | 50.2% | 40.1% | Layer 5 (18.5%) | ✅ PASSED |
| 64k | 64,891 | 37.0% | 29.6% | Layer 5 (12.4%) | ✅ PASSED |
## Density 计算公式
### Total (分母)
```python
# Causal mask: Q block i 只能看到 K block 0 到 i
causal_mask[i, j] = (j <= i + q_offset_blocks)
# Total = causal 区域内的 block 数 × batch × heads
total = causal_mask.sum() × batch × heads
= (n × (n+1) / 2) × 1 × 32 # n = valid_q_blocks
```
### Selected (分子)
```python
# 在 causal 区域内,被选中 (mask=True) 的 block 数量
selected = (mask & causal_mask).sum()
```
### Density
```python
density = selected / total
```
## 观察
1. **Density 随 context 增长而降低**: 4k (63.8%) → 64k (37.0%),这是因为长序列中 attention 更加分散
2. **Layer 5 通常是最稀疏的层**: 在所有长度测试中Layer 5 的 density 最低
3. **Layer 0 density 最高**: 第一层的 attention pattern 最密集,可能与 sink token 效应有关
4. **Threshold=0.9 对应 ~50% density**: 在 32k context 下threshold=0.9 意味着选择覆盖 90% attention 的 blocks实际 density 约 50%
## 使用方法
### Step 1: 启用 debug 保存
```python
# nanovllm/kvcache/sparse/xattn_bsa.py
_DEBUG_SAVE_MASK = True # 改为 True
```
### Step 2: 运行 GPU-only 推理
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
### Step 3: 运行 KV chunking 对齐验证
```bash
# 使用 GPU-only 保存的数据
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_alignment.py --gpuonly
# 使用 Offload 模式保存的数据 (默认)
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_alignment.py
# 指定自定义数据文件
python tests/test_xattn_estimate_alignment.py --data-file /path/to/data.pt
# 批量测试所有 Offload 数据
for f in results/kvcache/qkv_*.pt; do
echo "Testing: $(basename $f)"
python tests/test_xattn_estimate_alignment.py --data-file "$f"
done
```
### 批量测试所有长度
```bash
for ctx in 4k 8k 16k 32k 64k; do
case $ctx in
4k) max_len=5000 ;;
8k) max_len=9000 ;;
16k) max_len=17000 ;;
32k) max_len=34000 ;;
64k) max_len=65664 ;;
esac
echo "Testing $ctx..."
python tests/test_ruler.py \
--data-dir tests/data/ruler_$ctx \
--max-model-len $max_len \
--sparse-policy XATTN_BSA \
--num-samples 1 --quiet
python tests/test_xattn_estimate_alignment.py --gpuonly
done
```
## 相关文件
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy 实现
- `nanovllm/ops/xattn.py`: xattn_estimate 函数及三阶段 KV chunking kernels
- `tests/test_xattn_estimate_alignment.py`: KV chunking 对齐验证脚本

View File

@@ -0,0 +1,209 @@
# Issue: XAttention Offload Mode GQA Buffer OOM
## 问题描述
在使用 XAttention BSA (Block Sparse Attention) + CPU Offload 模式运行 GLM-4-9B 等大模型时,出现 CUDA OOM 错误。
### 错误信息
```
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB.
GPU 0 has a total capacity of 23.57 GiB of which 4.19 GiB is free.
```
### 复现环境
| 项目 | 值 |
|------|-----|
| 模型 | GLM-4-9B-Chat-1M |
| GPU | RTX 3090 (24GB) |
| Context Length | 32K |
| sparse_policy | XATTN_BSA |
| enable_cpu_offload | true |
| max_model_len | 1048576 (1M) |
### 错误位置
```
File "nanovllm/kvcache/sparse/xattn_bsa.py", line 246, in alloc_policy_metadata
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
```
---
## 问题分析
### 内存分配分析
`alloc_policy_metadata()` 在 KV cache 初始化时分配以下 buffer
| Buffer | 用途 | 大小 (GLM-4, 1M seq) |
|--------|------|----------------------|
| `_prefill_mask_buffer` | BSA mask | ~32 MB |
| `_m_partial_buffer` | KV chunking m stats | ~32 MB |
| `_l_partial_buffer` | KV chunking l stats | ~32 MB |
| `_block_sums_buffer` | Block sums | ~64 MB |
| **`_k_expanded`** | GQA K 扩展 | **~8 GB** |
| **`_v_expanded`** | GQA V 扩展 | **~8 GB** |
### GQA Buffer 计算
```python
shape = (1, num_heads, max_seq_len, head_dim)
= (1, 32, 1048576, 128)
size = 1 × 32 × 1048576 × 128 × 2 bytes (fp16)
= 8,589,934,592 bytes
= 8 GB per buffer
```
### 根本原因
1. **设计意图冲突**`_k_expanded``_v_expanded` 的文档注释明确说是 "for GPU-only mode"
2. **条件检查不完整**:代码只检查了 `num_heads == num_kv_heads` 来跳过分配,没有检查 offload 模式
3. **Offload 模式不需要这些 buffer**`compute_chunked_prefill()` 使用不同的计算路径,不依赖预分配的 GQA buffer
### 相关代码
```python
# xattn_bsa.py:238-247
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return # <-- 只检查了 GQA没检查 offload 模式
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
shape = (1, num_heads, max_seq_len, head_dim)
self._k_expanded = torch.empty(shape, dtype=dtype, device=device) # <-- OOM here
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
```
---
## 解决思路
### 方案 1: 在 Offload 模式下跳过 GQA Buffer 分配 (推荐)
`alloc_policy_metadata()` 中添加 offload 模式检查:
```python
def alloc_policy_metadata(
self,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
enable_cpu_offload: bool = False, # <-- 新增参数
) -> None:
# ... 分配 mask buffer 和 KV chunking buffers (offload 模式需要)
# Skip GQA buffers in offload mode
# Chunked prefill uses compute_chunked_prefill() which doesn't need these
if enable_cpu_offload:
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers")
return
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed")
return
shape = (1, num_heads, max_seq_len, head_dim)
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
```
**需要修改的文件**
1. `nanovllm/kvcache/sparse/xattn_bsa.py` - `alloc_policy_metadata()` 方法
2. `nanovllm/engine/model_runner.py` - 调用 `alloc_policy_metadata()` 时传入 `enable_cpu_offload`
### 方案 2: 延迟分配 (Lazy Allocation)
只在 `compute_prefill()` 首次调用时分配 GQA bufferoffload 模式走 `compute_chunked_prefill()` 不会触发分配。
```python
def compute_prefill(self, ...):
# Lazy allocation on first use
if self._k_expanded is None and num_heads != num_kv_heads:
self._allocate_gqa_buffers(...)
...
```
### 方案 3: 基于 chunk_size 限制 buffer 大小
不预分配 max_seq_len 大小,而是只分配 chunk_size 大小:
```python
# 原来: max_seq_len (1M tokens) -> 8 GB
# 修改后: chunk_size (16K tokens) -> ~130 MB
buffer_len = self.chunk_size if enable_cpu_offload else max_seq_len
shape = (1, num_heads, buffer_len, head_dim)
```
---
## 验证方法
修复后运行以下命令验证:
```bash
cd /home/zijie/Code/COMPASS
GPULIST=0 ./scripts/run_ruler.sh glm4-9b-xattn-nanovllm synthetic xattn --task niah_single_1
```
预期结果:
- 不再出现 8GB allocation 的 OOM 错误
- 模型正常加载并完成推理
---
## 相关文档
- `docs/xattn_bsa_policy_design.md` - XAttention BSA Policy 设计文档
- `docs/gpu_only_xattn_guide.md` - GPU-Only XAttention 指南
## 优先级
**High** - 阻塞 9B+ 模型在 24GB 显存 GPU 上使用 XAttention + Offload 模式
---
## 修复状态
**✅ 已修复** (2026-02-05)
### 修复内容
采用方案 1在 offload 模式下跳过 GQA buffer 分配:
1. `nanovllm/kvcache/sparse/policy.py`: 基类添加 `enable_cpu_offload` 参数
2. `nanovllm/kvcache/sparse/xattn_bsa.py`: 实现 offload 模式检查,跳过 GQA buffer
3. `nanovllm/engine/model_runner.py`: 传入 `enable_cpu_offload` 参数
### 验证结果
```bash
# 64K offload 测试
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA
```
- ✅ 日志显示: `[XAttn] Offload mode: skipping GQA expansion buffers`
- ✅ 测试通过: 100% 准确率
- ✅ 内存节省: ~16 GB (for 1M max_seq_len)
### 内存对比
| 配置 | 修复前 | 修复后 |
|------|--------|--------|
| max_model_len=72K | +1.1 GB | 0 GB |
| max_model_len=1M | +16 GB | 0 GB |

View File

@@ -0,0 +1,184 @@
# 1M+ 上下文长度模型列表
本文档收集了 Hugging Face 上支持 1M (1,048,576) 及以上上下文长度的开源模型。
> 更新时间: 2026-01-28
---
## 一、纯语言模型 (≤10B 参数)
### 1. 官方原版模型
| 厂商 | 模型 | 上下文 | 规模 | 下载量 | 链接 |
|------|------|--------|------|--------|------|
| **Qwen** | Qwen2.5-7B-Instruct-1M | 1M | 7B | 69.3K | [HF](https://hf.co/Qwen/Qwen2.5-7B-Instruct-1M) |
| **THUDM** | GLM-4-9B-Chat-1M | 1M | 9B | 5.0K | [HF](https://hf.co/zai-org/glm-4-9b-chat-1m) |
| **InternLM** | InternLM2.5-7B-Chat-1M | 1M | 7B | 322 | [HF](https://hf.co/internlm/internlm2_5-7b-chat-1m) |
| **NVIDIA** | Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct | 1M | 8B | 2.9K | [HF](https://hf.co/nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct) |
| **LWM** | LWM-Text-1M | 1M | 7B | 75 | [HF](https://hf.co/LargeWorldModel/LWM-Text-1M) |
| **LWM** | LWM-Text-Chat-1M | 1M | 7B | 3.0K | [HF](https://hf.co/LargeWorldModel/LWM-Text-Chat-1M) |
### 2. Gradient AI 扩展系列 (基于 Llama 3)
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|------|--------|------|--------|------|
| Llama-3-8B-Instruct-Gradient-1048k | **1M** | 8B | 44.8K | [HF](https://hf.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k) |
| Llama-3-8B-Instruct-Gradient-4194k | **4M** | 8B | 9 | [HF](https://hf.co/gradientai/Llama-3-8B-Instruct-Gradient-4194k) |
### 3. 社区衍生版本 (Abliterated)
| 模型 | 上下文 | 基础模型 | 下载量 | 链接 |
|------|--------|----------|--------|------|
| Qwen2.5-7B-Instruct-1M-abliterated | 1M | Qwen2.5-7B | 375 | [HF](https://hf.co/huihui-ai/Qwen2.5-7B-Instruct-1M-abliterated) |
| Nemotron-8B-UltraLong-1M-Abliterated | 1M | Nemotron-8B | 46 | [HF](https://hf.co/SicariusSicariiStuff/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct_Abliterated) |
---
## 二、视觉-语言模型 (≤10B 参数)
### Qwen3 VL 系列
#### Instruct 版本
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|------|--------|------|--------|------|
| Qwen3-VL-2B-Instruct-1M-GGUF | 1M | 2B | 824 | [HF](https://hf.co/unsloth/Qwen3-VL-2B-Instruct-1M-GGUF) |
| Qwen3-VL-4B-Instruct-1M-GGUF | 1M | 4B | 936 | [HF](https://hf.co/unsloth/Qwen3-VL-4B-Instruct-1M-GGUF) |
| Qwen3-VL-8B-Instruct-1M-GGUF | 1M | 8B | 962 | [HF](https://hf.co/unsloth/Qwen3-VL-8B-Instruct-1M-GGUF) |
#### Thinking 推理版本
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|------|--------|------|--------|------|
| Qwen3-VL-2B-Thinking-1M-GGUF | 1M | 2B | 808 | [HF](https://hf.co/unsloth/Qwen3-VL-2B-Thinking-1M-GGUF) |
| Qwen3-VL-4B-Thinking-1M-GGUF | 1M | 4B | 666 | [HF](https://hf.co/unsloth/Qwen3-VL-4B-Thinking-1M-GGUF) |
| Qwen3-VL-8B-Thinking-1M-GGUF | 1M | 8B | 4.6K | [HF](https://hf.co/unsloth/Qwen3-VL-8B-Thinking-1M-GGUF) |
---
## 三、推荐模型 (≤10B)
| 用途 | 推荐模型 | 理由 |
|------|----------|------|
| **通用对话** | Qwen2.5-7B-Instruct-1M | 官方支持RULER 93.1分Apache 2.0 |
| **中英双语** | GLM-4-9B-Chat-1M | 清华出品,中文优化 |
| **最长上下文** | Llama-3-8B-Gradient-4194k | 支持 4M 上下文 |
| **多模态** | Qwen3-VL-8B-Thinking-1M | 视觉理解 + 推理能力 |
| **无审查** | Qwen2.5-7B-Instruct-1M-abliterated | 移除安全限制 |
---
## 四、VRAM 需求参考
| 模型规模 | 1M 上下文 VRAM | 备注 |
|----------|----------------|------|
| 7B (FP16) | ~120GB | 需多卡 |
| 7B (INT4) | ~40GB | 单卡 A100 可行 |
| 8B (FP16) | ~130GB | 需多卡 |
| 9B (FP16) | ~140GB | 需多卡 |
---
## 五、技术对比
| 模型系列 | 扩展技术 | RULER 得分 | 许可证 |
|---------|---------|------------|--------|
| Qwen2.5-1M | Dual Chunk Attention | 93.1 | Apache 2.0 |
| GLM-4-1M | - | 89.9 | 自定义 |
| Gradient-Llama | 渐进式扩展 | - | Llama 3 |
| Nemotron-1M | NVIDIA 训练 | - | CC-BY-NC-4.0 |
| LWM-1M | RingAttention | - | 开源 |
---
---
# 附录:大参数模型 (>10B)
> 以下模型参数量超过 10B需要更多计算资源。
## A. 纯语言模型 (>10B)
### 官方模型
| 厂商 | 模型 | 上下文 | 规模 | 下载量 | 链接 |
|------|------|--------|------|--------|------|
| **Qwen** | Qwen2.5-14B-Instruct-1M | 1M | 14B | 4.7K | [HF](https://hf.co/Qwen/Qwen2.5-14B-Instruct-1M) |
| **MiniMax** | MiniMax-Text-01 | 1M | 456B MoE | 721 | [HF](https://hf.co/MiniMaxAI/MiniMax-Text-01) |
| **Gradient** | Llama-3-70B-Instruct-Gradient-1048k | 1M | 70B | 9 | [HF](https://hf.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k) |
### Qwen3 Coder 系列 (MoE)
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|------|--------|-----------------|--------|------|
| Qwen3-Coder-30B-A3B-Instruct-1M-GGUF | 1M | 30B / 3B | 13.1K | [HF](https://hf.co/unsloth/Qwen3-Coder-30B-A3B-Instruct-1M-GGUF) |
| Qwen3-Coder-480B-A35B-Instruct-1M | 1M | 480B / 35B | 50 | [HF](https://hf.co/unsloth/Qwen3-Coder-480B-A35B-Instruct-1M) |
| Qwen3-Coder-480B-A35B-Instruct-1M-GGUF | 1M | 480B / 35B | 1.7K | [HF](https://hf.co/unsloth/Qwen3-Coder-480B-A35B-Instruct-1M-GGUF) |
| Qwen3-Coder-42B-A3B-TOTAL-RECALL-1M | 1M | 42B / 3B | - | [HF](https://hf.co/DavidAU/Qwen3-Coder-42B-A3B-Instruct-TOTAL-RECALL-MASTER-CODER-M-1million-ctx) |
### 社区衍生版本
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|------|--------|------|--------|------|
| Qwen2.5-14B-Instruct-1M-abliterated | 1M | 14B | 147 | [HF](https://hf.co/huihui-ai/Qwen2.5-14B-Instruct-1M-abliterated) |
---
## B. 视觉-语言模型 (>10B)
### Meta Llama 4 系列 (MoE 多模态)
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|------|--------|-----------------|--------|------|
| Llama-4-Scout-17B-16E-Instruct | **10M** | 109B / 17B | 180K | [HF](https://hf.co/meta-llama/Llama-4-Scout-17B-16E-Instruct) |
| Llama-4-Maverick-17B-128E-Instruct | **1M** | 400B / 17B | 32.6K | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E-Instruct) |
| Llama-4-Scout-17B-16E | 10M | 109B / 17B | 8.4K | [HF](https://hf.co/meta-llama/Llama-4-Scout-17B-16E) |
| Llama-4-Maverick-17B-128E | 1M | 400B / 17B | 368 | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E) |
| Llama-4-Maverick-17B-128E-Instruct-FP8 | 1M | 400B / 17B | 29.6K | [HF](https://hf.co/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8) |
### Qwen3 VL 大模型系列
#### Dense 模型
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|------|--------|------|--------|------|
| Qwen3-VL-32B-Instruct-1M-GGUF | 1M | 32B | 1.2K | [HF](https://hf.co/unsloth/Qwen3-VL-32B-Instruct-1M-GGUF) |
| Qwen3-VL-32B-Thinking-1M-GGUF | 1M | 32B | 452 | [HF](https://hf.co/unsloth/Qwen3-VL-32B-Thinking-1M-GGUF) |
#### MoE 模型
| 模型 | 上下文 | 总参数/激活参数 | 下载量 | 链接 |
|------|--------|-----------------|--------|------|
| Qwen3-VL-30B-A3B-Instruct-1M-GGUF | 1M | 30B / 3B | 821 | [HF](https://hf.co/unsloth/Qwen3-VL-30B-A3B-Instruct-1M-GGUF) |
| Qwen3-VL-30B-A3B-Thinking-1M-GGUF | 1M | 30B / 3B | 944 | [HF](https://hf.co/unsloth/Qwen3-VL-30B-A3B-Thinking-1M-GGUF) |
| Qwen3-VL-235B-A22B-Instruct-1M-GGUF | 1M | 235B / 22B | 581 | [HF](https://hf.co/unsloth/Qwen3-VL-235B-A22B-Instruct-1M-GGUF) |
| Qwen3-VL-235B-A22B-Thinking-1M-GGUF | 1M | 235B / 22B | 733 | [HF](https://hf.co/unsloth/Qwen3-VL-235B-A22B-Thinking-1M-GGUF) |
#### MXFP4 量化版本
| 模型 | 上下文 | 规模 | 下载量 | 链接 |
|------|--------|------|--------|------|
| Qwen3-VL-30B-A3B-Instruct-1M-MXFP4_MOE-GGUF | 1M | 30B MoE | 689 | [HF](https://hf.co/noctrex/Qwen3-VL-30B-A3B-Instruct-1M-MXFP4_MOE-GGUF) |
| Qwen3-VL-30B-A3B-Thinking-1M-MXFP4_MOE-GGUF | 1M | 30B MoE | 565 | [HF](https://hf.co/noctrex/Qwen3-VL-30B-A3B-Thinking-1M-MXFP4_MOE-GGUF) |
| Qwen3-VL-235B-A22B-Instruct-1M-MXFP4_MOE-GGUF | 1M | 235B MoE | 136 | [HF](https://hf.co/noctrex/Qwen3-VL-235B-A22B-Instruct-1M-MXFP4_MOE-GGUF) |
| Qwen3-VL-235B-A22B-Thinking-1M-MXFP4_MOE-GGUF | 1M | 235B MoE | 244 | [HF](https://hf.co/noctrex/Qwen3-VL-235B-A22B-Thinking-1M-MXFP4_MOE-GGUF) |
---
## 统计汇总
| 类别 | ≤10B 模型数 | >10B 模型数 | 最大上下文 |
|------|-------------|-------------|-----------|
| 纯语言模型 | 10 | 8 | 4M |
| 视觉-语言模型 | 6 | 14 | 10M |
| **合计** | **16** | **22** | **10M** |
---
## 参考资源
- [Qwen2.5-1M 官方博客](https://qwenlm.github.io/blog/qwen2.5-1m/)
- [LongRoPE 论文](https://huggingface.co/papers/2402.13753)
- [InfiniteHiP 论文](https://huggingface.co/papers/2502.08910)
- [Top LLMs for Long Context Windows](https://www.siliconflow.com/articles/en/top-LLMs-for-long-context-windows)

View File

@@ -0,0 +1,120 @@
# Memory Communication Benchmark
GPU-CPU 通信量测试结果,对比 Full Policy 和 XAttention BSA Policy。
## 测试环境
- **模型**: Llama-3.1-8B-Instruct
- **GPU**: RTX 3090 (24GB)
- **配置**: `num_gpu_blocks=4`, `block_size=1024`, `enable_cpu_offload=True`
- **XAttention 参数**: `threshold=0.95`, `stride=8`
## 32K 上下文测试结果
| 指标 | Full Policy | XAttention | 比率 |
|------|-------------|------------|------|
| **Prefill H2D** | 66.57 GB | 111.12 GB | **1.67x** |
| Prefill D2H | 4.29 GB | 4.29 GB | 1.00x |
| TTFT | 8473 ms | 10367 ms | 1.22x |
### XAttention Block Selection (32K)
| 指标 | 数值 |
|------|------|
| 可用 blocks | 465 |
| 选中 blocks | 374 |
| 选择密度 | 80.4% |
## 64K 上下文测试结果
| 指标 | Full Policy | XAttention | 比率 |
|------|-------------|------------|------|
| **Prefill H2D** | 262.13 GB | 386.62 GB | **1.48x** |
| Prefill D2H | 8.46 GB | 8.46 GB | 1.00x |
| Decode H2D (32 tokens) | 262.13 GB | 262.13 GB | 1.00x |
| TTFT | 27081 ms | 33634 ms | 1.24x |
## 通信量比率对比 (K-only 优化前)
| 上下文长度 | XAttn/Full Prefill H2D 比率 |
|------------|----------------------------|
| 32K | 1.67x |
| 64K | 1.48x |
### 分析 (优化前)
1. **XAttention 通信量增加原因**
- Estimate 阶段:加载 **100%** 历史 blocks 的 **K+V**(用于 attention score 估计)
- Compute 阶段:加载 **选中的** blocks约 70-80%
- 理论比率:`1 + selection_density`
2. **64K 比率更低的原因**
- 更长上下文时attention 分布更稀疏
- XAttention 的 block 选择更有效(选中比例更低)
- First/last block 强制包含的影响相对减小
3. **Decode 阶段通信量相同**
- XAttention 仅支持 prefill 阶段
- Decode 阶段 fallback 到 Full Policy
---
## K-only 优化 (2026-01-28)
### 优化原理
XAttention 的 `select_blocks` 估计阶段只需要 K 来计算 attention scores
```python
# flat_group_gemm_fuse_reshape 只使用 Q 和 K
attn_scores = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...)
```
V 在估计阶段完全不使用,但之前代码会同时加载 K 和 V造成 50% 通信量浪费。
### 优化实现
1. **新增方法**: `OffloadEngine.load_k_only_to_slot_layer()` - 只加载 K
2. **修改 select_blocks**: 使用只加载 K 的新方法
### 优化后测试结果
| 上下文 | Full Policy | XAttn (优化前) | XAttn (优化后) | 优化节省 |
|--------|-------------|---------------|---------------|---------|
| 32K | 66.57 GB | 111.12 GB | **79.76 GB** | **28.2%** |
| 64K | 262.13 GB | 386.62 GB | **258.78 GB** | **33.1%** |
### XAttn/Full 比率变化
| 上下文 | 优化前比率 | 优化后比率 |
|--------|-----------|-----------|
| 32K | 1.67x | **1.20x** |
| 64K | 1.48x | **0.99x** |
### 结论
优化后64K 上下文的 XAttention 通信量与 Full Policy 基本持平 (0.99x)
而 32K 也从 1.67x 降到 1.20x。这说明估计阶段的 K-only 优化非常有效
## 测试命令
```bash
# 32K Full Policy
python bench_offload.py --max-len 32768 --input-len 32000
# 32K XAttention
python bench_offload.py --max-len 32768 --input-len 32000 --enable-xattn
# 64K Full Policy
python bench_offload.py --max-len 65536 --input-len 64000
# 64K XAttention
python bench_offload.py --max-len 65536 --input-len 64000 --enable-xattn
# 包含 decode 测试
python bench_offload.py --max-len 65536 --input-len 64000 --bench-decode --output-len 32
```
## 相关文档
- [`observer_architecture.md`](observer_architecture.md) - Observer 架构设计
- [`xattn_bsa_policy_design.md`](xattn_bsa_policy_design.md) - XAttention BSA 算法设计

View File

@@ -0,0 +1,323 @@
# 新模型整合指南
本文档总结了将新模型如GLM-4整合到nanovllm的经验和常见问题。
## 整合流程概览
```
1. 分析模型配置 (config.json)
2. 创建模型文件 (nanovllm/models/<model>.py)
3. 实现权重加载 (nanovllm/utils/loader.py)
4. 处理特殊组件 (RoPE, Attention, etc.)
5. 处理tokenizer差异 (EOS tokens, chat template)
6. 验证输出正确性
```
---
## 1. 配置字段映射
不同模型使用不同的配置字段名称,需要建立映射关系:
| 标准字段 | GLM-4 | Qwen | Llama | 说明 |
|----------|-------|------|-------|------|
| `num_key_value_heads` | `multi_query_group_num` | `num_key_value_heads` | `num_key_value_heads` | KV heads数量 |
| `head_dim` | `kv_channels` | 计算得出 | 计算得出 | 每个head的维度 |
| `intermediate_size` | `ffn_hidden_size` | `intermediate_size` | `intermediate_size` | FFN隐藏层大小 |
| `max_position_embeddings` | `seq_length` | `max_position_embeddings` | `max_position_embeddings` | 最大位置 |
| `rope_theta` | `10000 * rope_ratio` | `rope_theta` | `rope_theta` | RoPE基础频率 |
### 代码示例
```python
# 在模型 __init__ 中处理配置差异
num_kv_heads = getattr(config, 'num_key_value_heads',
getattr(config, 'multi_query_group_num', num_heads))
head_dim = getattr(config, 'head_dim',
getattr(config, 'kv_channels', hidden_size // num_heads))
intermediate_size = getattr(config, 'intermediate_size',
getattr(config, 'ffn_hidden_size', None))
max_position = getattr(config, 'max_position_embeddings',
getattr(config, 'seq_length', 4096))
```
---
## 2. RoPE实现差异
RoPE是模型整合中**最容易出错**的部分。不同模型可能使用不同的RoPE变体
### 2.1 旋转方式
| 类型 | 描述 | 使用模型 |
|------|------|----------|
| **Half rotation** | 前半和后半分别旋转 `[x0,x1,...] → [x0*cos-x_{d/2}*sin, ...]` | Llama, Qwen |
| **Interleaved rotation** | 相邻元素配对旋转 `[x0,x1,...] → [x0*cos-x1*sin, x1*cos+x0*sin, ...]` | GLM-4 |
### 2.2 旋转维度
| 类型 | 描述 | 使用模型 |
|------|------|----------|
| **Full rotation** | 旋转整个head_dim | Llama, Qwen |
| **Partial rotation** | 只旋转head_dim的一部分其余pass-through | GLM-4 (rotary_dim = head_dim // 2) |
### 2.3 GLM-4 RoPE实现
```python
class GLM4RotaryEmbedding(nn.Module):
def __init__(self, head_dim, rotary_dim, ...):
# GLM-4只旋转一半维度
self.rotary_dim = rotary_dim # = head_dim // 2
def forward(self, positions, query, key):
# 分离旋转部分和pass-through部分
q_rot = query[..., :self.rotary_dim]
q_pass = query[..., self.rotary_dim:]
# 只对旋转部分应用interleaved RoPE
q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin)
# 拼接回去
return torch.cat([q_rot, q_pass], dim=-1), ...
```
### 2.4 调试RoPE问题
**症状**:模型输出乱码或重复无意义的内容(如 "The. The. The..."
**调试方法**
```python
# 对比HuggingFace参考实现的输出
hf_q, hf_k = hf_model.apply_rotary_pos_emb(query, key, cos, sin)
my_q, my_k = my_rotary_emb(positions, query, key)
print(f"Q max diff: {(hf_q - my_q).abs().max()}") # 应该 < 1e-5
print(f"K max diff: {(hf_k - my_k).abs().max()}") # 应该 < 1e-5
```
---
## 3. 权重名称映射
不同模型的权重命名规范不同:
### 3.1 常见映射
| 组件 | Llama/Qwen | GLM-4 |
|------|------------|-------|
| Attention QKV | `q_proj`, `k_proj`, `v_proj` | `query_key_value` (合并) |
| Attention Output | `o_proj` | `dense` |
| MLP Gate | `gate_proj` | `dense_h_to_4h` (部分) |
| MLP Up | `up_proj` | `dense_h_to_4h` (部分) |
| MLP Down | `down_proj` | `dense_4h_to_h` |
| LayerNorm | `input_layernorm` | `input_layernorm` |
| Post-Attention LN | `post_attention_layernorm` | `post_attention_layernorm` |
### 3.2 实现权重转换
```python
def convert_glm4_weights(name, param):
"""将GLM-4权重名称转换为nanovllm格式"""
# 处理合并的QKV权重
if "query_key_value" in name:
# 拆分为q, k, v
q, k, v = param.split([q_size, kv_size, kv_size], dim=0)
return {"q_proj": q, "k_proj": k, "v_proj": v}
# 处理合并的gate+up权重
if "dense_h_to_4h" in name:
gate, up = param.chunk(2, dim=0)
return {"gate_proj": gate, "up_proj": up}
return {name: param}
```
---
## 4. EOS Token处理
### 4.1 问题
某些模型使用**多个EOS tokens**
| 模型 | EOS Token(s) | 说明 |
|------|--------------|------|
| Llama | `128001` | 单一EOS |
| Qwen | `151643` | 单一EOS |
| GLM-4 | `[151329, 151336, 151338]` | 多个endoftext, user, observation |
**问题**`tokenizer.eos_token_id` 只返回第一个导致模型不会在其他EOS token处停止。
### 4.2 解决方案
```python
# config.py - 支持多个EOS
eos: int | list[int] = -1
# llm_engine.py - 从hf_config读取完整EOS列表
eos_from_config = getattr(config.hf_config, 'eos_token_id', None)
if eos_from_config is not None:
config.eos = eos_from_config
else:
config.eos = self.tokenizer.eos_token_id
# scheduler.py - 使用set进行高效查找
self.eos_set = set(eos) if isinstance(eos, list) else {eos}
# 检查时使用 in 而不是 ==
if token_id in self.eos_set:
# 停止生成
```
### 4.3 调试EOS问题
**症状**模型总是生成到max_tokens才停止
**调试方法**
```python
# 检查EOS配置
print(f"tokenizer.eos_token_id: {tokenizer.eos_token_id}")
print(f"hf_config.eos_token_id: {config.hf_config.eos_token_id}")
# 检查输出中的EOS tokens
output = llm.generate([prompt], params)
for eos_id in [151329, 151336, 151338]:
if eos_id in output[0]['token_ids']:
print(f"Found EOS {eos_id} at position {output[0]['token_ids'].index(eos_id)}")
```
---
## 5. Chat Template
不同模型使用不同的对话模板:
| 模型 | 模板格式 |
|------|----------|
| Llama-3 | `<\|begin_of_text\|><\|start_header_id\|>user<\|end_header_id\|>\n{content}<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>\n` |
| Qwen | `<\|im_start\|>user\n{content}<\|im_end\|>\n<\|im_start\|>assistant\n` |
| GLM-4 | `[gMASK]<sop><\|user\|>\n{content}<\|assistant\|>\n` |
### 实现模板转换
```python
def convert_to_model_prompt(prompt: str, model_type: str) -> str:
"""将标准prompt转换为模型特定格式"""
if model_type == "glm4":
return f"[gMASK]<sop><|user|>\n{prompt}<|assistant|>\n"
elif model_type == "llama3":
return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
# ...
```
---
## 6. 验证清单
整合新模型后,按以下顺序验证:
### 6.1 权重加载验证
```python
# 检查所有权重是否正确加载
for name, param in model.named_parameters():
if param.abs().sum() == 0:
print(f"WARNING: {name} is all zeros!")
```
### 6.2 单层输出验证
```python
# 对比embedding层输出
my_emb = my_model.embed_tokens(input_ids)
hf_emb = hf_model.model.embed_tokens(input_ids)
print(f"Embedding diff: {(my_emb - hf_emb).abs().max()}") # < 1e-5
# 对比第一层输出
my_out = my_model.layers[0](my_emb, ...)
hf_out = hf_model.model.layers[0](hf_emb, ...)
print(f"Layer 0 diff: {(my_out - hf_out).abs().max()}") # < 1e-4
```
### 6.3 生成质量验证
```python
# 简单问答测试
prompt = "Hello, how are you?"
output = llm.generate([prompt], SamplingParams(max_tokens=50))
print(output[0]['text']) # 应该是连贯的回答
# 检查是否正确停止
print(f"Generated {len(output[0]['token_ids'])} tokens (max=50)")
```
### 6.4 RULER基准测试
```bash
# 运行1个sample快速验证
python tests/test_ruler.py --model <path> --num-samples 1
# 验证通过后运行完整测试
python tests/test_ruler.py --model <path> --num-samples 100
```
---
## 7. 常见问题速查
| 症状 | 可能原因 | 解决方案 |
|------|----------|----------|
| 输出乱码/重复 | RoPE实现错误 | 检查旋转方式(interleaved vs half)和旋转维度(full vs partial) |
| 数值爆炸(NaN/Inf) | 权重加载错误或dtype不匹配 | 检查权重映射确保dtype一致 |
| 不停止生成 | EOS token处理错误 | 从hf_config读取完整EOS列表 |
| 输出质量差 | LayerNorm或bias缺失 | 检查add_qkv_bias等配置 |
| 位置编码错误 | max_position_embeddings读取错误 | 检查配置字段名称(seq_length等) |
---
## 8. 文件结构
新模型整合需要修改/创建的文件:
```
nanovllm/
├── models/
│ └── <model>.py # 新建:模型定义
├── layers/
│ └── rotary_embedding.py # 修改如需特殊RoPE
├── utils/
│ └── loader.py # 修改:权重加载
├── config.py # 可能修改:新配置字段
└── engine/
├── llm_engine.py # 可能修改EOS处理
└── scheduler.py # 可能修改EOS检查
tests/
└── test_ruler.py # 修改chat template
```
---
## 附录GLM-4整合案例
### 遇到的问题及解决
1. **配置字段差异** → 添加getattr fallback链
2. **Interleaved RoPE** → 实现`apply_rotary_emb_interleaved`
3. **Partial rotation (head_dim//2)** → 实现`GLM4RotaryEmbedding`
4. **多EOS tokens** → 修改config/llm_engine/scheduler支持list
5. **合并的QKV权重** → 在loader中拆分
### 关键代码位置
- RoPE实现: `nanovllm/layers/rotary_embedding.py:GLM4RotaryEmbedding`
- 模型定义: `nanovllm/models/glm4.py`
- 权重加载: `nanovllm/utils/loader.py:load_glm4_weights`
- EOS处理: `nanovllm/engine/scheduler.py:eos_set`

View File

@@ -0,0 +1,210 @@
# Nsys "Wrong Event Order" Bug 调试记录
## 问题描述
使用 `nsys profile` 对 nanovllm 的 CPU offload 模式进行性能分析时,无法生成 `.nsys-rep` 文件,报错:
```
Importer error status: Importation failed.
Wrong event order has been detected when adding events to the collection:
new event ={ StartNs=21569539222 StopNs=21569672388 ... Type=48 }
last event ={ StartNs=22046804077 StopNs=22046805343 ... Type=48 }
```
## 环境信息
- **nsys 版本**: 2023.4.4.54-234433681190v0
- **CUDA**: 12.4
- **问题状态**: nsys 已知 bug2024.2+ 版本已修复
## 调试过程
### 阶段 1确定触发条件
使用 bisect 脚本 (`tests/test_nsys_bisect.py`) 逐步测试:
| Stage | 描述 | 结果 |
|-------|------|------|
| 1 | CUDA init | ✅ |
| 2 | Import nanovllm | ✅ |
| 3 | Create LLM (offload) | ✅ |
| 4 | 短 prompt 生成 | ✅ |
| **5** | **长 prompt (~64K) prefill** | ❌ |
**结论**:问题出在长 prompt 的 chunked prefill 流程。
### 阶段 2定位具体组件
`_chunked_prefill_attention` 方法中逐步注释代码:
| 组件 | 文件位置 | 结果 |
|------|----------|------|
| 整个方法 (return zeros) | `attention.py:167` | ✅ |
| `select_blocks()` | `attention.py:217` | ✅ |
| `offload_prefill_buffer_async()` | `attention.py:241-248` | ✅ |
| `compute_chunked_prefill()` | `attention.py:225-235` | ❌ |
**结论**:问题出在 `compute_chunked_prefill` 内部。
### 阶段 3定位 Ring Buffer Pipeline
`full_policy.py` 中进一步定位:
| 组件 | 代码行 | 结果 |
|------|--------|------|
| Current chunk attention | 191-198 | ✅ |
| **Historical block loading (ring buffer)** | 133-189 | ❌ |
**根因确认**Ring buffer pipeline 的多 stream 操作触发了 nsys bug。
## 根本原因
### 触发 Bug 的代码
```python
# nanovllm/kvcache/sparse/full_policy.py:133-189
# 多 slot pipeline 模式
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
# 等待 slot 的 transfer stream 完成
offload_engine.wait_slot_layer(current_slot)
# 在 compute_stream 上执行 attention
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(...)
offload_engine.record_slot_compute_done(current_slot)
# 异步发起下一个 block 的加载
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
```
### Stream 结构
```
slot_transfer_streams[0] ─┐
slot_transfer_streams[1] ─┼─ 4 个 transfer streams
slot_transfer_streams[2] ─┤
slot_transfer_streams[3] ─┘
▼ wait/record 同步
compute_stream ───────────┘
```
这种 4+1 stream 的复杂同步模式导致 nsys 2023.4.4 版本的事件时间戳排序算法出错。
### 为什么简单多 stream 测试无法复现
我们尝试用简单的测试代码 (`tests/test_multistream_nsys.py`) 复现问题:
- 4-8 streams, 2000+ iterations: ✅ 成功
- 32 threads + multi-stream: ✅ 成功
- >64k CUDA operations: ✅ 成功
但都无法触发 bug。原因是实际代码中的 stream 同步模式更复杂:
1. 跨 stream 的 event wait/record
2. 与 FlashAttention kernel 的交互
3. 长时间运行(~50 秒)累积大量事件
## 解决方案
### 方案 1升级 nsys推荐
```bash
# 下载 nsys 2024.2+ 版本
# https://developer.nvidia.com/nsight-systems
```
根据 [NVIDIA 论坛](https://forums.developer.nvidia.com/t/nsys-profiler-wrong-event-order/264881),此 bug 在 2024.2 版本已修复。
### 方案 2使用 .qdstrm 文件
即使导入失败,`.qdstrm` 文件仍然生成:
```bash
# 生成的文件
results/nsys/ruler_niah_single_1_sample0_offload_*.qdstrm
# 尝试用 GUI 直接打开
nsight-sys <file>.qdstrm
```
GUI 可能有更好的容错能力。
### 方案 3使用 PyTorch Profiler
```python
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
# your code
prof.export_chrome_trace("trace.json") # chrome://tracing 查看
```
### 方案 4临时禁用 ring buffer pipeline
`full_policy.py` 中临时使用单 slot 同步模式(仅用于调试):
```python
# 强制使用单 slot 模式
if len(load_slots) == 1 or True: # 添加 "or True"
# 同步模式,不会触发 nsys bug
...
```
## 复现步骤
### 环境准备
```bash
cd /home/zijie/Code/nano-vllm
```
### 运行 Bisect 脚本
```bash
# Stage 5 会触发 bug
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$PWD:$PYTHONPATH \
nsys profile --trace=cuda,nvtx,osrt --force-overwrite=true \
-o /tmp/bisect python tests/test_nsys_bisect.py --stage 5
```
### 验证修复
```bash
# 临时在 full_policy.py 中跳过 historical block loading
# 将第 133 行改为: if False and cpu_block_table:
# 重新运行,应该成功
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$PWD:$PYTHONPATH \
nsys profile --trace=cuda,nvtx,osrt --force-overwrite=true \
-o /tmp/bisect_fixed python tests/test_nsys_bisect.py --stage 5
# 检查是否生成 .nsys-rep
ls -la /tmp/bisect_fixed.nsys-rep
```
## 相关文件
| 文件 | 用途 |
|------|------|
| `tests/test_nsys_bisect.py` | Bisect 调试脚本 |
| `tests/test_multistream_nsys.py` | 简单多 stream 测试 |
| `scripts/profile_offload.sh` | nsys profile 脚本 |
| `nanovllm/layers/attention.py` | Attention 层 |
| `nanovllm/kvcache/sparse/full_policy.py` | Ring buffer pipeline |
## 参考资料
- [Nsys Profiler- Wrong event order - NVIDIA Forums](https://forums.developer.nvidia.com/t/nsys-profiler-wrong-event-order/264881)
- [Nsight Systems 2025.3 Release Notes](https://docs.nvidia.com/nsight-systems/2025.3/ReleaseNotes/index.html)
- [Nsight Systems User Guide](https://docs.nvidia.com/nsight-systems/UserGuide/index.html)
## 调试日期
2026-01-24

View File

@@ -0,0 +1,194 @@
# Observer Architecture
nanovllm 的 Observer 架构用于统计推理过程中的关键指标采用类变量class variable模式实现全局状态管理。
## 架构概览
```
Observer (基类)
├── InferenceObserver - 推理时间指标 (TTFT, TPOT)
└── MemoryObserver - 内存传输统计 (H2D, D2H, D2D)
```
## 设计原则
### 1. 类变量模式
所有 Observer 使用类变量(而非实例变量)存储状态:
```python
class Observer:
"""Observer 基类"""
_enabled: bool = True # 类变量,控制是否启用
class InferenceObserver(Observer):
ttft: int = 0 # 类变量,全局共享
tpot: int = 0
ttft_start: int = 0
tpot_start: int = 0
```
**优点**
- 无需实例化,任何地方都可以直接访问
- 避免跨模块传递 observer 实例
- 适合全局统计场景
### 2. 启用/禁用控制
每个 Observer 可独立启用/禁用:
```python
# 启用 MemoryObserver
MemoryObserver._enabled = True
# 禁用后record_* 方法不会记录
MemoryObserver._enabled = False
```
### 3. 阶段分离
MemoryObserver 支持 prefill/decode 阶段分离统计:
```python
@classmethod
def record_h2d(cls, num_bytes: int, is_prefill: bool = True) -> None:
if not cls._enabled:
return
cls.h2d_bytes += num_bytes
cls.h2d_count += 1
if is_prefill:
cls.prefill_h2d_bytes += num_bytes
else:
cls.decode_h2d_bytes += num_bytes
```
## Observer 实现
### InferenceObserver
**位置**: `nanovllm/utils/observer.py`
**统计指标**
| 指标 | 说明 | 单位 |
|------|------|------|
| `ttft` | Time To First Token | 纳秒 |
| `tpot` | Time Per Output Token | 纳秒 |
| `ttft_start` | TTFT 计时开始点 | 纳秒 |
| `tpot_start` | TPOT 计时开始点 | 纳秒 |
**统计位置**
| 位置 | 代码 | 说明 |
|------|------|------|
| `scheduler.py:add()` | `InferenceObserver.ttft_start = perf_counter_ns()` | 开始计时 |
| `llm_engine.py:step()` | `InferenceObserver.ttft = ... - ttft_start` | Prefill 完成后计算 TTFT |
| `llm_engine.py:step()` | `InferenceObserver.tpot = ... - tpot_start` | Decode 时计算 TPOT |
### MemoryObserver
**位置**: `nanovllm/utils/memory_observer.py`
**统计指标**
| 指标 | 说明 |
|------|------|
| `h2d_bytes` / `h2d_count` | Host to Device 传输量/次数 |
| `d2h_bytes` / `d2h_count` | Device to Host 传输量/次数 |
| `d2d_bytes` / `d2d_count` | Device to Device 复制量/次数 |
| `prefill_h2d_bytes` / `prefill_d2h_bytes` | Prefill 阶段 H2D/D2H |
| `decode_h2d_bytes` / `decode_d2h_bytes` | Decode 阶段 H2D/D2H |
**统计位置** (均在 `offload_engine.py`)
| 方法 | 传输类型 | 说明 |
|------|----------|------|
| `load_to_slot_layer()` | H2D | 从 CPU 加载 block 到 GPU slot |
| `load_block_sample_from_cpu()` | H2D | 采样加载Quest |
| `load_block_full_from_cpu()` | H2D | 完整加载 block |
| `offload_slot_layer_to_cpu()` | D2H | GPU slot 卸载到 CPU |
| `offload_prefill_buffer_async()` | D2H | Prefill buffer 异步卸载 |
| `write_to_prefill_buffer()` | D2D | 写入 prefill buffer |
| `write_to_decode_buffer()` | D2D | 写入 decode buffer |
**重置位置**
| 位置 | 代码 |
|------|------|
| `llm_engine.py:generate()` | `MemoryObserver.complete_reset()` |
| `llm_engine.py:generate()` | `InferenceObserver.complete_reset()` |
## 使用示例
### 1. 启用并统计
```python
from nanovllm.utils.memory_observer import MemoryObserver
# 启用统计
MemoryObserver._enabled = True
# 运行推理
outputs = llm.generate(prompts, sampling_params)
# 获取结果
print(f"Prefill H2D: {MemoryObserver.prefill_h2d_bytes / 1e9:.2f} GB")
print(f"Decode H2D: {MemoryObserver.decode_h2d_bytes / 1e9:.2f} GB")
# 或使用 print_summary
MemoryObserver.print_summary()
```
### 2. 在 bench_offload.py 中
```python
from nanovllm.utils.memory_observer import MemoryObserver
# 启用
MemoryObserver._enabled = True
# benchmark 结束后打印
def print_memory_stats():
fmt = MemoryObserver._fmt_bytes
print(f"[Memory] Prefill H2D: {fmt(MemoryObserver.prefill_h2d_bytes)}")
print(f" Decode H2D: {fmt(MemoryObserver.decode_h2d_bytes)}")
```
### 3. 获取结构化数据
```python
summary = MemoryObserver.get_summary()
# {
# "total": {"h2d_bytes": ..., "d2h_bytes": ..., "d2d_bytes": ...},
# "prefill": {"h2d_bytes": ..., "d2h_bytes": ...},
# "decode": {"h2d_bytes": ..., "d2h_bytes": ...}
# }
```
## 添加新 Observer
1. 继承 `Observer` 基类
2. 定义类变量存储统计数据
3. 实现 `record_*` 方法(需检查 `_enabled`
4. 实现 `complete_reset()` 方法
5. 在相关代码位置添加 `record_*` 调用
6.`llm_engine.py:generate()` 中添加 reset 调用
```python
from nanovllm.utils.observer import Observer
class MyObserver(Observer):
_enabled: bool = False
my_metric: int = 0
@classmethod
def record_event(cls, value: int) -> None:
if not cls._enabled:
return
cls.my_metric += value
@classmethod
def complete_reset(cls) -> None:
cls.my_metric = 0
```
## 相关文档
- [`memory_communication_benchmark.md`](memory_communication_benchmark.md) - 通信量测试结果
- [`architecture_guide.md`](architecture_guide.md) - 整体架构指南

View File

@@ -0,0 +1,338 @@
# test_ruler.py 使用指南
RULER benchmark 综合测试工具,用于评估 LLM 长上下文能力。
**测试日期**: 2026-02-05
**测试 GPU**: RTX 3090 (GPU 4)
---
## 支持的任务
| 类别 | 任务 |
|------|------|
| NIAH (Needle-In-A-Haystack) | `niah_single_1/2/3`, `niah_multikey_1/2/3`, `niah_multiquery`, `niah_multivalue` |
| QA (Question Answering) | `qa_1`, `qa_2` |
| Recall | `cwe`, `fwe`, `vt` |
---
## 基本命令格式
```bash
CUDA_VISIBLE_DEVICES=<GPU_ID> PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py [OPTIONS]
```
---
## 参数说明
### 必要参数
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--model` | `~/models/Llama-3.1-8B-Instruct` | 模型路径 |
| `--data-dir` | `tests/data/ruler_64k` | 数据目录 |
| `--max-model-len` | 65664 | 最大上下文长度 |
### 数据选择
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--datasets` | 全部 | 逗号分隔的数据集名 |
| `--num-samples` | 0 (全部) | 每个数据集测试样本数 |
| `--sample-indices` | - | 指定样本索引 (如 `0,5,10`) |
### Offload 配置
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--enable-offload` | False | 启用 CPU offload 模式 |
| `--num-gpu-blocks` | 4 | GPU 上的 KV cache blocks 数量 |
| `--block-size` | 4096 | KV cache block 大小 (tokens) |
| `--num-kv-buffers` | 4 | Ring buffer 数量 |
| `--gpu-utilization` | 0.9 | GPU 显存利用率 |
### Sparse Attention 配置
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--sparse-policy` | - | 稀疏策略: `FULL`, `QUEST`, `XATTN_BSA` |
| `--sparse-threshold` | 0.9 | XAttn cumulative attention 阈值 |
| `--sparse-samples` | 128 | XAttn 每 chunk 采样数 |
| `--sparse-stride` | 8 | XAttn Q/K 下采样步长 |
### 输出控制
| 参数 | 说明 |
|------|------|
| `--quiet` / `-q` | 安静模式 |
| `--json-output` | JSON 格式输出 |
| `--fresh-llm` | 每个样本重新初始化 LLM |
### 其他
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `--dtype` | auto | 模型数据类型 (`bfloat16`, `float16`) |
| `--use-cuda-graph` | False | 启用 CUDA Graph |
| `--max-new-tokens` | 16 | 最大生成 token 数 |
---
## 已验证的命令示例
以下命令均在 RTX 3090 (24GB) 上测试通过。
### 1. 基础 Offload 测试 (32K)
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload
```
**结果**: 100% 准确率, 耗时 ~16s
### 2. Offload + XAttention BSA (32K)
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
**结果**: 100% 准确率, compute density ~50%, 耗时 ~19s
### 3. Offload + XAttention BSA (64K)
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
**结果**: 100% 准确率, compute density ~37%, 耗时 ~52s
### 4. 多数据集多样本测试
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1,qa_1 \
--num-samples 2 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA
```
**结果**: 4/4 (100%), 耗时 ~71s
### 5. 指定样本索引测试
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--sample-indices 0,5,10 \
--max-model-len 40960 \
--enable-offload
```
### 6. JSON 输出模式 (用于脚本)
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--json-output
```
**输出格式**:
```json
{
"total_correct": 1,
"total_samples": 1,
"overall_accuracy": 1.0,
"avg_score": 1.0,
"time": 30.44,
"tasks": {"niah_single_1": {"correct": 1, "total": 1, "accuracy": 1.0}},
"failed_samples": {}
}
```
### 7. 安静模式
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--quiet
```
### 8. 调整 GPU blocks 数量
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--num-gpu-blocks 8 \
--sparse-policy XATTN_BSA
```
### 9. GLM-4 模型测试
```bash
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/GLM-4-9B-Chat-1M \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--dtype bfloat16
```
**结果**: 100% 准确率, 耗时 ~17s
---
## 数据目录结构
```
tests/data/
├── ruler_4k/ # 4K context
├── ruler_8k/ # 8K context
├── ruler_16k/ # 16K context
├── ruler_32k/ # 32K context (推荐测试)
├── ruler_64k/ # 64K context
├── ruler_128k/ # 128K context
├── ruler_256k/ # 256K context
├── ruler_512k/ # 512K context
├── ruler_768k/ # 768K context
└── ruler_1m/ # 1M context
```
每个目录包含 13 个任务子目录,每个任务有 `validation.jsonl` 文件。
---
## GPU 与模式选择
| GPU 显存 | 推荐模式 | 说明 |
|---------|---------|------|
| 24GB (3090/4090) | `--enable-offload` | 必须使用 offload |
| 40GB+ (A100) | 两种模式均可 | 可测试 GPU-only |
**RTX 3090 限制**: 由于显存限制,必须使用 `--enable-offload` 参数。
---
## max-model-len 设置指南
| 数据目录 | 推荐 max-model-len | 说明 |
|---------|-------------------|------|
| ruler_4k | 5000 | 留出 output 空间 |
| ruler_8k | 9000 | |
| ruler_16k | 17000 | |
| ruler_32k | 40960 | |
| ruler_64k | 72000 | |
| ruler_128k | 135000 | |
**公式**: `max_model_len >= max_input_len + max_new_tokens`
---
## DensityObserver 输出
使用 `--sparse-policy XATTN_BSA` 时自动启用,输出示例:
```
============================================================
Density Statistics (XAttention BSA)
============================================================
[DensityObserver] Mode: offload
Compute density: 0.3691 (min: 0.3691 @ layer 0)
Comm density: 1.0000 (CPU block granularity)
Savings ratio: 0.0% H2D transfer reduction
Num layers: 1
Layer 0 density: 0.369052
```
| 指标 | 说明 |
|------|------|
| Compute density | BSA block (128 tokens) 粒度的计算密度 |
| Comm density | CPU block (4096 tokens) 粒度的通信密度 |
| Savings ratio | H2D 传输减少比例 |
---
## 常见问题
### 1. OOM 错误
**原因**: 显存不足
**解决**:
- 使用 `--enable-offload`
- 降低 `--gpu-utilization`
- 减少 `--num-gpu-blocks`
### 2. 模型加载失败
**原因**: 模型配置不兼容
**解决**:
- 检查 `--dtype` 参数 (GLM 模型需要 `--dtype bfloat16`)
- 确认模型路径正确
### 3. 准确率异常
**原因**: 状态泄漏
**解决**: 使用 `--fresh-llm` 参数为每个样本重新初始化 LLM
---
## 相关文档
- [`docs/xattn_density_types.md`](xattn_density_types.md) - Compute vs Comm density 解释
- [`docs/xattn_density_alignment_verification.md`](xattn_density_alignment_verification.md) - GPU-only vs Offload 对齐验证
- [`docs/ruler_benchmark_results_32k.md`](ruler_benchmark_results_32k.md) - RULER 32K 基准测试结果

View File

@@ -0,0 +1,429 @@
# XAttention BSA Policy 设计文档
本文档描述 `XAttentionBSAPolicy` 的设计和实现,这是一个基于 XAttention 算法的稀疏注意力策略,用于 CPU offload 模式下的 chunked prefill。
## 概述
`XAttentionBSAPolicy` 实现了基于 XAttention 的块级稀疏注意力选择。核心思想是:
1. **估计阶段**:使用 XAttention kernels 快速估计每个 KV block 的重要性
2. **选择阶段**:基于阈值和 majority voting 选择重要的 blocks
3. **计算阶段**:只加载选中的 blocks 进行 attention 计算
```
┌─────────────────────────────────────────────────────────────┐
│ XAttention BSA Policy │
├─────────────────────────────────────────────────────────────┤
│ select_blocks() │
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
│ │ Load K │──>│ flat_group_gemm │──>│ softmax_fuse │ │
│ │ blocks │ │ _fuse_reshape │ │ _block_sum │ │
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
│ │ │ │ │
│ v v v │
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
│ │ K: [B,H,L,D]│ │ attn_scores: │ │ block_sums: │ │
│ │ │ │ [B,H,Q/s,K/s] │ │ [B,H,Qb,Kb] │ │
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
│ │ │
│ ┌──────────────────────┘ │
│ v │
│ ┌──────────────┐ │
│ │find_blocks │ │
│ │_chunked │ │
│ └──────────────┘ │
│ │ │
│ v │
│ ┌──────────────┐ │
│ │ GQA-aware │ │
│ │ aggregation │ │
│ │ + majority │ │
│ │ voting │ │
│ └──────────────┘ │
│ │ │
│ v │
│ selected_block_ids │
├─────────────────────────────────────────────────────────────┤
│ compute_chunked_prefill() │
│ ┌─────────────┐ ┌──────────────────┐ ┌──────────────┐ │
│ │ Ring buffer │──>│ flash_attn_ │──>│ merge_ │ │
│ │ pipeline │ │ with_lse │ │ attention │ │
│ └─────────────┘ └──────────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
## 文件位置
**主文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`
**依赖的 XAttention kernels**: `nanovllm/ops/xattn.py`
- `flat_group_gemm_fuse_reshape`: 计算 stride reshape 后的 attention scores
- `softmax_fuse_block_sum`: 对 attention scores 做 softmax 后按 block 求和
- `find_blocks_chunked`: 基于阈值选择 blocks
---
## 核心算法
### 1. select_blocks: 块选择算法
```python
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]:
```
#### Step 1: 加载 K blocks 并计算 attention scores
对每个 CPU block加载 K 到 GPU 并使用 `flat_group_gemm_fuse_reshape` 计算:
```python
for cpu_block_id in available_blocks:
# 加载 K block: [1, block_size, num_kv_heads, head_dim]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
k_block, _ = offload_engine.get_kv_for_slot(slot)
# 转换为 [batch, heads, k_len, head_dim]
K_chunk = k_block.transpose(1, 2)
# GQA: 扩展 K heads 匹配 Q heads
if num_heads != num_kv_heads:
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# 计算 attention scores
attn_chunk = flat_group_gemm_fuse_reshape(Q, K_chunk, stride, ...)
attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks: [1, heads, q_reshaped_len, total_k_reshaped_len]
attn_scores = torch.cat(attn_scores_list, dim=-1)
```
#### Step 2: 聚合到 block 级别
使用 `softmax_fuse_block_sum` 将 attention scores 聚合到 block 级别:
```python
# reshaped_block_size = block_size / stride = 1024 / 8 = 128
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_block_size, # 1:1 对应 CPU blocks
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False,
)
# block_sums: [batch, heads, q_blocks, k_blocks]
```
**关键点**: `reshaped_block_size` 必须与 CPU block 对齐,确保输出的 `k_blocks` 维度 1:1 对应 `available_blocks`
#### Step 3: 阈值选择
使用 `find_blocks_chunked` 基于累积注意力阈值选择 blocks
```python
mask = find_blocks_chunked(
block_sums,
current_index=0,
threshold=self.threshold, # e.g., 0.95
num_to_choose=None,
decoding=False,
mode="prefill",
causal=False,
)
# mask: [batch, num_heads, q_blocks, k_blocks] - boolean
```
#### Step 4: GQA-aware 聚合 + Majority Voting
```python
# GQA: 在同一个 KV head group 内,任一 Q head 选择即选择
if num_groups > 1:
mask_gqa = mask.view(batch_size, num_kv_heads, num_groups, q_blocks, k_blocks)
mask_per_kv_head = mask_gqa.any(dim=2) # [batch, num_kv_heads, q_blocks, k_blocks]
# Majority voting: 跨 KV heads 和 q_blocks 投票
vote_count = mask_per_kv_head[0].float().sum(dim=0).sum(dim=0) # [k_blocks]
total_votes = num_kv_heads * q_blocks
vote_ratio = vote_count / total_votes
# 选择 >50% 投票的 blocks
vote_threshold = 0.5
block_selected = vote_ratio > vote_threshold
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
# 安全措施: 始终包含第一个 (sink) 和最后一个 block
if available_blocks[0] not in selected_block_ids:
selected_block_ids.insert(0, available_blocks[0])
if available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1])
```
**为什么使用 Majority Voting?**
| 聚合方式 | 问题 |
|---------|------|
| `any()` 跨所有 heads | 密度接近 100%,失去稀疏性 |
| `all()` | 太激进,可能丢失重要 blocks |
| **Majority voting (>50%)** | 平衡稀疏性和准确性 |
实验结果显示:
- 每 head 密度: 20-35%
- `any()` 聚合后: ~100%
- **Majority voting 后: ~45%**
---
### 2. compute_chunked_prefill: 注意力计算
复用 `FullAttentionPolicy` 的 ring buffer pipeline 实现:
```python
def compute_chunked_prefill(self, q, k, v, layer_id, softmax_scale,
offload_engine, kvcache_manager,
current_chunk_idx, seq, num_tokens,
selected_blocks) -> torch.Tensor:
```
#### 计算流程
1. **加载历史 blocks** (使用 selected_blocks):
```python
for block_idx in range(num_blocks):
# Ring buffer pipeline: load -> wait -> compute -> next
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(q, prev_k, prev_v, causal=False)
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
```
2. **计算当前 chunk** (causal mask):
```python
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(q, k_curr, v_curr, causal=True)
```
3. **合并结果**:
```python
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
```
---
## 参数配置
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `threshold` | 0.95 | 累积注意力阈值 (tau),越高越保守 |
| `stride` | 8 | XAttention stride reshape 参数 |
| `chunk_size` | 16384 | 估计时的处理 chunk size |
| `block_size` | 128 | BSA block size (固定值) |
### 使用方式
```python
# 在 config 中设置
config.sparse_policy = SparsePolicyType.XATTN_BSA
config.sparse_threshold = 0.95
# 或通过命令行
python tests/test_needle.py \
--enable-offload \
--enable-xattn-bsa \
--sparse-threshold 9 # 会被除以 10 变为 0.9
```
---
## 性能特性
| 特性 | 说明 |
|------|------|
| **Prefill 支持** | ✅ 完整支持 |
| **Decode 支持** | ❌ 不支持(使用 FullAttentionPolicy |
| **稀疏度** | ~45-55%threshold=0.95majority voting |
| **准确性** | RULER NIAH 100% 通过 |
### 限制
1. **Decode 不支持**: XAttention 估计需要足够长的 Q 序列,单 token decode 不适用
2. **估计开销**: `select_blocks` 需要加载所有 K blocks 进行估计
3. **Triton 对齐**: Q/K 长度必须满足 `stride * BLOCK_M/N` 对齐要求
---
## 与其他 Policy 的对比
| Policy | select_blocks | 稀疏度 | Decode 支持 |
|--------|--------------|--------|-------------|
| FullAttentionPolicy | 返回所有 blocks | 0% | ✅ |
| QuestPolicy | 基于 min/max key | ~50% | ✅ |
| **XAttentionBSAPolicy** | XAttention + majority voting | ~45-55% | ❌ |
---
## 测试验证
```bash
# Needle test (32K)
CUDA_VISIBLE_DEVICES=0 python tests/test_needle.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--enable-xattn-bsa \
--input-len 32768
# RULER benchmark
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.95 \
--data-dir tests/data/ruler_niah
```
---
## 性能基准测试
### 128K 上下文对比 (Llama-3.1-8B, A100 80GB)
| Policy | Density | 时间 | 内存峰值 | 准确率 |
|--------|---------|------|---------|--------|
| **Full** | 100% | 120.9s | 16.4GB (稳定) | 100% |
| **XAttn BSA** | ~52% | 152.3s | 19.8GB | 100% |
### Density 变化趋势
| Chunk | Full | XAttn BSA |
|-------|------|-----------|
| 10 | 100% | 90% |
| 30 | 100% | 73% |
| 60 | 100% | 50% |
| 100 | 100% | 50% |
| 126 | 100% | 52% |
**观察**XAttn BSA 的 density 随 chunks 增加而下降,最终稳定在 ~50%。
### 性能分析
**当前问题**XAttn BSA 虽然 density 只有 ~52%,但时间反而比 Full 更长152s vs 121s
**原因**`select_blocks` 需要加载所有 K blocks 来估计 attention scores导致每个 block 被加载两次:
1. 估计阶段:加载 K 计算 attention scores
2. 计算阶段:加载选中的 K/V 进行实际计算
**优化方向**
1. 跨层共享估计结果layer 0 估计,其他层复用)
2. 采样估计(只用部分 K blocks 估计)
3. 缓存估计结果避免重复计算
---
## 内存管理
### 内存泄漏问题 (已修复)
**问题**128K prefill 时 GPU 内存从 16GB 增长到 80GB。
**根因**
```python
# 问题代码:累积存储但从未使用
self.sparse_metadata[layer_id] = attn_scores
```
每个 chunk 的每个 layer 都存储 `attn_scores`,导致内存持续增长。
**修复方法**
```python
# 1. 删除无用的 sparse_metadata 存储
# 2. 立即释放中间变量
del attn_scores_list
del attn_scores, block_sums, mask, mask_per_kv_head, vote_count, vote_ratio, block_selected
```
**修复效果**
| 版本 | 内存增长 | 峰值 |
|------|---------|------|
| 修复前 | +64GB | 80GB |
| **修复后** | +4GB | 19.8GB |
### 内存监控
使用 `gpu-monitor` agent 监控内存:
```bash
# 启动监控
# 在 Claude Code 中使用 Task tool 启动 gpu-monitor agent
# 或手动监控
watch -n 1 'nvidia-smi --query-gpu=memory.used --format=csv,noheader -i 0'
```
---
## Density 统计 API
### 启用统计
```python
# 统计自动在 select_blocks 中更新(仅 layer 0
# 使用 logger.debug 输出每 chunk 的 density
```
### 获取统计
```python
policy = XAttentionBSAPolicy(threshold=0.95)
# 运行 prefill 后...
# 获取统计
stats = policy.get_density_stats()
# {
# "total_available_blocks": 8001,
# "total_selected_blocks": 4160,
# "num_chunks": 126,
# "overall_density": 0.52
# }
# 打印统计
policy.print_density_stats()
# 重置统计
policy.reset_stats()
```
### 启用 DEBUG 日志
```python
# 在 test_ruler.py 中
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
# 输出示例:
# [XAttn] chunk=30, available=30, selected=22, chunk_density=73.3%
```
---
## 已知问题
| 问题 | 状态 | 说明 |
|------|------|------|
| 估计开销过大 | 🟡 待优化 | select_blocks 需要加载所有 K blocks |
| 时间比 Full 更长 | 🟡 待优化 | 128K 场景 152s vs 121s |
| 小幅内存增长 | 🟢 可接受 | ~4GB可能来自 Triton 缓存 |
| Decode 不支持 | ✅ 设计如此 | 使用 FullAttentionPolicy |
---
## 相关文档
- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解
- [`docs/xattn_kernels_guide.md`](xattn_kernels_guide.md): Triton kernels 实现
- [`docs/sparse_policy_architecture.md`](sparse_policy_architecture.md): SparsePolicy 架构
- [`docs/sparse_policy_implementation_guide.md`](sparse_policy_implementation_guide.md): 实现指南

View File

@@ -0,0 +1,142 @@
# XAttention Density Alignment Verification
验证 GPU-only 和 Offload 模式的 density 对齐情况。
**测试日期**: 2026-02-05
**测试模型**: Llama-3.1-8B-Instruct
**测试任务**: RULER niah_single_1
---
## 测试配置
| 参数 | 值 |
|------|-----|
| sparse_policy | XATTN_BSA |
| threshold | 0.9 |
| chunk_size | 4096 (已对齐) |
| stride | 8 |
| BSA block_size | 128 |
---
## 测试结果
### 32K Context
| 模式 | Layer 0 Density | Overall Density | 准确率 |
|------|-----------------|-----------------|--------|
| GPU-only | 0.502079 | 0.4012 | 100% |
| Offload | 0.498421 | 0.4984 | 100% |
| **差异** | **0.37%** | - | - |
### 64K Context
| 模式 | Layer 0 Density | Overall Density | 准确率 |
|------|-----------------|-----------------|--------|
| GPU-only | 0.369972 | 0.2963 | 100% |
| Offload | 0.369052 | 0.3691 | 100% |
| **差异** | **0.09%** | - | - |
---
## 关键修复
### Commit 829b311 - chunk_size 对齐 + Stream 同步修复
**问题**: 之前 GPU-only 和 Offload 模式的 density 差异达 10-13%
**根因**:
1. GPU-only 使用 `chunk_size=16384`Offload 使用 `chunk_size=4096`
2. Stream 同步 bug 导致 Pass 1/2 K 数据不一致
**修复**:
1.`XAttentionBSAPolicy.chunk_size` 默认值从 16384 改为 4096
2. 所有 compute kernels 包装在 `compute_stream` context 中
---
## 测试命令
### GPU-only 模式
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
### Offload 模式
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
---
## 详细日志
### 32K Offload 模式 Per-Chunk Density
```
Layer0 chunk: q_len=4096, k_len=4096, density=0.6234
Layer0 chunk: q_len=4096, k_len=8192, density=0.6239
Layer0 chunk: q_len=4096, k_len=12288, density=0.6026
Layer0 chunk: q_len=4096, k_len=16384, density=0.5695
Layer0 chunk: q_len=4096, k_len=20480, density=0.5285
Layer0 chunk: q_len=4096, k_len=24576, density=0.4891
Layer0 chunk: q_len=4096, k_len=28672, density=0.4514
Layer0 chunk: q_len=3813, k_len=32485, density=0.4208
```
### 64K Offload 模式 Per-Chunk Density
```
Layer0 chunk: q_len=4096, k_len=4096, density=0.6234
Layer0 chunk: q_len=4096, k_len=8192, density=0.6239
Layer0 chunk: q_len=4096, k_len=12288, density=0.6026
Layer0 chunk: q_len=4096, k_len=16384, density=0.5681
Layer0 chunk: q_len=4096, k_len=20480, density=0.5255
Layer0 chunk: q_len=4096, k_len=24576, density=0.4859
Layer0 chunk: q_len=4096, k_len=28672, density=0.4485
Layer0 chunk: q_len=4096, k_len=32768, density=0.4161
Layer0 chunk: q_len=4096, k_len=36864, density=0.3892
Layer0 chunk: q_len=4096, k_len=40960, density=0.3658
Layer0 chunk: q_len=4096, k_len=45056, density=0.3464
Layer0 chunk: q_len=4096, k_len=49152, density=0.3303
Layer0 chunk: q_len=4096, k_len=53248, density=0.3170
Layer0 chunk: q_len=4096, k_len=57344, density=0.3068
Layer0 chunk: q_len=4096, k_len=61440, density=0.2988
Layer0 chunk: q_len=3451, k_len=64891, density=0.2947
```
---
## 结论
1. **Density 对齐成功**: 差异从 10-13% 降到 <0.5%
2. **准确率一致**: 两种模式都达到 100% 准确率
3. **Density 随 context 增长下降**: 符合预期,更长的 context 稀疏性更高
---
## 相关文档
- [`docs/xattn_offload_stream_sync_fix.md`](xattn_offload_stream_sync_fix.md) - Stream 同步修复详情
- [`docs/xattn_density_types.md`](xattn_density_types.md) - Compute vs Comm density
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md) - 早期对齐测试

View File

@@ -0,0 +1,195 @@
# XAttention Density Benchmark
GPU-only 模式下 XAttention Block Sparse Attention 的 density 测试结果。
## 测试配置
| 参数 | 值 | 说明 |
|------|-----|------|
| Model | Llama-3.1-8B-Instruct | 32 layers, 32 heads, 8 KV heads |
| Block Size | 128 tokens | BSA kernel 固定要求 |
| Threshold | 0.9 / 0.95 | 累积注意力阈值 |
| Stride | 4 / 8 / 16 | Q/K 下采样步长 |
| Dataset | RULER niah_single_1 | Sample 0 |
| Mode | GPU-only | 无 CPU offload |
## Density 定义
```python
# Density = selected_blocks / total_causal_blocks
# 在 causal attention 下,只计算下三角区域的 blocks
# Overall density = 所有层的平均值
def compute_density(mask, causal=True):
"""
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
"""
if causal:
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks))
total = causal_mask.sum() * batch * heads
selected = (mask & causal_mask).sum()
return selected / total
```
## 测试结果
### threshold=0.9
#### Overall Density (平均)
| Context | stride=4 | stride=8 | stride=16 |
|---------|----------|----------|-----------|
| **4K** | 0.5220 (52.2%) | 0.5292 (52.9%) | 0.5430 (54.3%) |
| **8K** | 0.5152 (51.5%) | 0.5252 (52.5%) | 0.5396 (54.0%) |
| **16K** | 0.4682 (46.8%) | 0.4775 (47.8%) | 0.4888 (48.9%) |
| **32K** | 0.3700 (37.0%) | 0.4012 (40.1%) | 0.4196 (42.0%) |
#### Min Density (per layer)
| Context | stride=4 | stride=8 | stride=16 |
|---------|----------|----------|-----------|
| **4K** | 0.2805 (Layer 3) | 0.3132 (Layer 3) | 0.3376 (Layer 5) |
| **8K** | 0.2886 (Layer 5) | 0.2725 (Layer 5) | 0.2995 (Layer 5) |
| **16K** | 0.2247 (Layer 5) | 0.2349 (Layer 5) | 0.2451 (Layer 5) |
| **32K** | 0.1799 (Layer 5) | 0.1846 (Layer 5) | 0.1964 (Layer 5) |
### threshold=0.95
#### Overall Density (平均)
| Context | stride=4 | stride=8 | stride=16 |
|---------|----------|----------|-----------|
| **4K** | 0.6561 (65.6%) | 0.6699 (67.0%) | 0.6815 (68.2%) |
| **8K** | 0.6462 (64.6%) | 0.6584 (65.8%) | 0.6732 (67.3%) |
| **16K** | 0.6004 (60.0%) | 0.6114 (61.1%) | 0.6193 (61.9%) |
| **32K** | 0.4894 (48.9%) | 0.5203 (52.0%) | 0.5385 (53.9%) |
#### Min Density (per layer)
| Context | stride=4 | stride=8 | stride=16 |
|---------|----------|----------|-----------|
| **4K** | 0.3972 (Layer 3) | 0.4348 (Layer 5) | 0.4517 (Layer 4) |
| **8K** | 0.4004 (Layer 5) | 0.3906 (Layer 5) | 0.4239 (Layer 5) |
| **16K** | 0.3331 (Layer 5) | 0.3453 (Layer 5) | 0.3589 (Layer 5) |
| **32K** | 0.2656 (Layer 5) | 0.2784 (Layer 5) | 0.2917 (Layer 5) |
### threshold 对比 (stride=8)
| Context | threshold=0.9 | threshold=0.95 | 差异 |
|---------|---------------|----------------|------|
| **4K** | 0.5292 (52.9%) | 0.6699 (67.0%) | -14.1% |
| **8K** | 0.5252 (52.5%) | 0.6584 (65.8%) | -13.3% |
| **16K** | 0.4775 (47.8%) | 0.6114 (61.1%) | -13.4% |
| **32K** | 0.4012 (40.1%) | 0.5203 (52.0%) | -11.9% |
## 关键发现
### 1. Context Length 影响最大
Density 随 context length 显著下降threshold=0.9, stride=8
- 4K: 52.9% density
- 8K: 52.5% density
- 16K: 47.8% density
- 32K: 40.1% density
**结论**: 长序列有更多稀疏化机会XAttention 的优势在长序列上更明显。
### 2. Threshold 影响显著
threshold=0.9 比 0.95 的 density 低约 12-14%
- 0.9 更激进,选择更少的 blocks
- 0.95 更保守,保留更多 blocks
- 两者准确性都不受影响RULER NIAH 全部 PASS
### 3. Stride 影响较小
同一 context 下,不同 stride 的 density 差异约 2-5%
- stride 越大 → density 略高(采样越粗糙,选择更保守)
- stride=4 最激进stride=16 最保守
### 4. Min Density 集中在中间层
- 大多数情况下 min density 出现在 Layer 5
- 中间层的稀疏性最高,首尾层相对密集
- 这符合 Transformer 注意力模式的一般规律
### 5. 最佳稀疏化配置
32K + stride=4 + threshold=0.9 达到最低 density
- Overall: **37.0%** (节省 63% 计算)
- Min: **18.0%** (Layer 5)
### 6. 准确性稳定
所有配置下 RULER NIAH 测试都 PASS (score=1.0),说明:
- threshold=0.9 和 0.95 都足够保守,不损失准确性
- 不同 stride 不影响最终结果
## 推荐配置
| 场景 | threshold | stride | 说明 |
|------|-----------|--------|------|
| 精度优先 | 0.95 | 8 | 保守配置density ~52-67% |
| 平衡 | 0.9 | 8 | 默认配置density ~40-53% |
| 性能优先 | 0.9 | 4 | 激进配置density ~37-52% |
## 测试命令
```bash
# 基本测试
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--sample-indices 0 \
--max-model-len 33792 \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9 \
--sparse-stride 8 \
--gpu-utilization 0.85
# 参数说明
# --sparse-policy XATTN_BSA 启用 XAttention Block Sparse Attention
# --sparse-threshold 0.9 累积注意力阈值 (0.9-0.99)
# --sparse-stride 8 Q/K 下采样步长 (4/8/16)
```
## DensityObserver 使用
```python
from nanovllm.utils.density_observer import DensityObserver
# 启用并重置
DensityObserver.enable()
DensityObserver.complete_reset()
# ... 运行 inference (compute_prefill 自动记录) ...
# 获取结果
summary = DensityObserver.get_summary()
# {
# "mode": "gpu_only",
# "overall_density": 0.40, # 所有层的平均值
# "per_layer_density": {0: 0.55, 1: 0.45, ...},
# "num_layers": 32
# }
# 获取最低 density
min_layer, min_density = DensityObserver.get_min_density()
# 打印摘要
DensityObserver.print_summary()
# [DensityObserver] Mode: gpu_only
# Overall density: 0.4012
# Min density: 0.1846 (layer 5)
# Num layers: 32
```
## 相关文件
| 文件 | 说明 |
|------|------|
| `nanovllm/kvcache/sparse/xattn_bsa.py` | XAttention BSA Policy 实现 |
| `nanovllm/utils/density_observer.py` | Density 统计 Observer |
| `nanovllm/ops/xattn.py` | xattn_estimate 核心算法 |
| `tests/test_ruler.py` | RULER benchmark 测试脚本 |

152
docs/xattn_density_types.md Normal file
View File

@@ -0,0 +1,152 @@
# XAttention Density Types: Compute vs Communication
XAttention BSA 统计两种不同粒度的 density它们反映不同的优化效果。
## 两种 Density 的定义
### 1. Compute Density计算密度
**粒度**: BSA block (128 tokens)
**公式**:
```
compute_density = selected_bsa_blocks / total_causal_bsa_blocks
```
**含义**: 实际需要计算 attention 的 blocks 占 causal 区域的比例。
**影响**: 决定 attention 计算量的减少。
### 2. Communication Density通信密度
**粒度**: CPU block (4096 tokens = 32 BSA blocks)
**公式**:
```
comm_density = selected_cpu_blocks / total_cpu_blocks
```
**含义**: 需要从 CPU 传输到 GPU 的 blocks 占总 blocks 的比例。
**影响**: 决定 H2D 传输量的减少。
## 为什么 Comm Density 通常高于 Compute Density
### 聚合效应
由于 CPU block 粒度是 BSA block 的 32 倍CPU block 选择使用 `any()` 聚合:
```python
# BSA mask: [B, H, Q_bsa, K_bsa]
# Reshape to CPU block level
mask_per_cpu = mask.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
# Any BSA block selected -> whole CPU block needed
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1)
```
只要 CPU block 中**任意一个**:
- Head 选择了该 block
- Q position 选择了该 block
- BSA sub-block 被选中
则整个 CPU block 都需要传输。
### 示例
| 场景 | Compute Density | Comm Density | 说明 |
|------|-----------------|--------------|------|
| 64K context, threshold=0.9 | 37% | 100% | 稀疏 blocks 均匀分布在所有 CPU blocks |
| 32K context, threshold=0.9 | 50% | 100% | 同上 |
## 测试结果
### 测试命令
```bash
# Offload 模式测试
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_64k \
--datasets niah_single_1 \
--num-samples 1 \
--max-model-len 72000 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
### 输出示例
```
[DensityObserver] Mode: offload
Compute density: 0.3691 (min: 0.3691 @ layer 0)
Comm density: 1.0000 (CPU block granularity)
Savings ratio: 0.0% H2D transfer reduction
Num layers: 1
Layer 0 density: 0.369052
```
## 关键发现
### 当前 XAttention 的通信优化局限
1. **Compute density 有效降低**: ~37% @ 64K context计算量减少 63%
2. **Comm density 没有降低**: 100%(通信量没有减少)
### 原因分析
Attention pattern 的特点:
- 不同 heads 关注不同位置
- 不同 Q positions 关注不同 K positions
- 稀疏选择分布在整个 sequence 上
这导致虽然每个 (head, Q, K) 组合只选择少量 blocks但聚合后覆盖了所有 CPU blocks。
### 潜在优化方向
1. **Per-head block selection**: 每个 head 独立选择 CPU blocks
2. **Block clustering**: 将相关 blocks 聚合到同一 CPU block
3. **Dynamic block size**: 根据 attention pattern 动态调整 CPU block 大小
## DensityObserver API
### 启用和重置
```python
from nanovllm.utils.density_observer import DensityObserver
DensityObserver.enable()
DensityObserver.complete_reset()
DensityObserver.set_mode("offload") # or "gpu_only"
```
### 记录
```python
# Compute density (GPU-only 模式自动记录)
DensityObserver.record(layer_id, mask, causal=True)
# Comm density (Offload 模式在 select_blocks 中记录)
DensityObserver.record_comm_density(layer_id, selected_cpu_blocks, total_cpu_blocks)
```
### 获取结果
```python
# 总体 density
overall_compute = DensityObserver.get_overall_density()
overall_comm = DensityObserver.get_overall_comm_density()
# Per-layer density
per_layer_compute = DensityObserver.get_per_layer_density()
per_layer_comm = DensityObserver.get_per_layer_comm_density()
# 打印摘要
DensityObserver.print_summary()
```
## 相关文件
- `nanovllm/utils/density_observer.py`: DensityObserver 实现
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policyselect_blocks 中记录 comm density
- `tests/test_ruler.py`: RULER benchmark 测试脚本

198
docs/xattn_kernels_guide.md Normal file
View File

@@ -0,0 +1,198 @@
# XAttention Kernels Guide
本文档详细说明 XAttention 的两个核心 Triton kernel 的工作原理。
## 概述
XAttention 使用 stride 采样来快速估计 attention 分布,用于稀疏 attention 的 block 选择。
**数据流**
```
Q [batch, heads, q_len, head_dim]
K [batch, heads, kv_len, head_dim]
↓ flat_group_gemm_fuse_reshape (stride 采样 + GEMM)
attn_scores [batch, heads, q_len/stride, kv_len/stride]
↓ softmax_fuse_block_sum (softmax + block 求和)
block_sums [batch, heads, q_blocks, k_blocks]
↓ threshold 选择
sparse_mask [batch, heads, q_blocks, k_blocks]
```
**注意**Q 和 K 可以有不同的长度q_len ≠ kv_len这在 chunked prefill 场景中很常见。
## Kernel 1: flat_group_gemm_fuse_reshape
### 功能
计算 stride reshape 后的 attention scores本质是计算原始 attention 矩阵中每个 stride×stride 块的**反对角线求和**。
### 函数签名
```python
def flat_group_gemm_fuse_reshape(
query_states: torch.Tensor, # [batch, heads, q_len, head_dim]
key_states: torch.Tensor, # [batch, heads, kv_len, head_dim]
stride: int,
chunk_start: int,
chunk_end: int,
is_causal: bool = True,
) -> torch.Tensor: # [batch, heads, q_len/stride, kv_len/stride]
```
### 采样方式
```
Q 采样: (stride-1-s)::stride (逆向)
K 采样: s::stride (正向)
例如 stride=4:
Q 采样位置: 3, 7, 11, 15, ... (从位置 3 开始,每隔 4)
K 采样位置: 0, 4, 8, 12, ... (从位置 0 开始,每隔 4)
```
### 反对角线原理
对于原始 attention 矩阵的每个 stride×stride 块:
```
stride=4 的块:
K[0] K[1] K[2] K[3]
Q[0] · · · X ← 反对角线
Q[1] · · X ·
Q[2] · X · ·
Q[3] X · · ·
```
**输出值 = 反对角线元素之和**
因为:
- `Q[i]` 采样自原始位置 `(stride-1-i)`
- `K[j]` 采样自原始位置 `j`
-`i + j = stride - 1` 时,恰好在反对角线上
### Triton 约束
**GPU 相关的 BLOCK 大小**
| GPU 类型 | 显存 | BLOCK_M/N | 最小 q_len/kv_len |
|----------|------|-----------|-------------------|
| RTX 3090 | 24GB | 64 | stride × 64 = 256 |
| A100/H100 | ≥40GB | 128 | stride × 128 = 512 |
```python
# 代码中的判断逻辑
if props.total_memory < 30 * 1024**3: # < 30GB
BLOCK_M = BLOCK_N = 64
else:
BLOCK_M = BLOCK_N = 128
assert q_len % (stride * BLOCK_M) == 0
assert kv_len % (stride * BLOCK_N) == 0
```
### 验证示例
```python
# 输入: 偶数位置=1, 奇数位置=2
# q_len=512, kv_len=2048, stride=4, head_dim=128
# 反对角线元素 (stride=4):
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4 (每对)
# stride=4 有 2 对
# 乘以 head_dim=128
# 预期值: 4 * 2 * 128 = 1024
# 输出 shape: [1, 1, 128, 512] (512/4=128, 2048/4=512)
```
## Kernel 2: softmax_fuse_block_sum
### 功能
`flat_group_gemm_fuse_reshape` 的输出做 softmax然后按 block 求和,得到每个 block 的 attention 权重总和。
### 参数说明
| 参数 | 含义 |
|------|------|
| `attn_weights_slice` | 输入 attention scores `[batch, heads, q_reshaped, k_reshaped]` |
| `reshaped_block_size` | Block 大小(在 reshaped 空间,= block_size / stride |
| `segment_size` | 每次迭代处理的 K 维度大小tiling |
| `chunk_start` | Q 的起始位置(用于 causal mask |
| `chunk_end` | Q 的结束位置 |
| `real_q_len` | 有效 Q 长度(用于 padding mask |
| `scale` | 缩放因子(融合多个因素) |
| `is_causal` | 是否应用 causal mask |
### Scale 因子
```python
scale = log2(e) / sqrt(head_dim) / stride / norm
= 1.4426950408889634 / sqrt(head_dim) / stride / norm
```
| 因子 | 值 | 作用 |
|------|-----|------|
| `log2(e)` | 1.4426950408889634 | Triton 用 `exp2` 而非 `exp`,需转换底数 |
| `1/sqrt(head_dim)` | 1/√128 | 标准 attention 缩放 |
| `1/stride` | 1/4 | stride 采样的归一化 |
| `1/norm` | 变化 | 额外归一化因子 |
**为什么用 exp2**Triton 的 `exp2``exp` 更快(硬件原生支持),所以把 log₂(e) 融合到 scale 里。
### Segment Size 约束
```python
assert segment_size >= reshaped_block_size
```
原因kernel 内部使用 `segment_size // block_size` 做 reshape
```python
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
```
如果 `segment_size < block_size`,则 `segment_size // block_size = 0`,导致无效维度。
### 验证示例
```python
# 输入: attn_scores [1, 1, 128, 512] (所有值相同)
# block_size=128
# softmax 后每行均匀分布 (所有值相同 → 均匀)
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len = 128/512 = 0.25
# 每个 Q block 有 block_size=128 行
# block_sum = 128 * 0.25 = 32
# 输出 shape: [1, 1, 1, 4] (128/128=1, 512/128=4)
```
## 完整示例
```python
# 参数
q_len = 512 # Q 长度
kv_len = 2048 # K/V 长度 (可以不同于 q_len)
stride = 4
block_size = 128
# Step 1: flat_group_gemm_fuse_reshape
# 输入: Q [1,1,512,128], K [1,1,2048,128]
# 输出: attn_scores [1,1,128,512]
# Step 2: softmax_fuse_block_sum
# 输入: attn_scores [1,1,128,512]
# 输出: block_sums [1,1,1,4]
# q_blocks = 128/128 = 1
# k_blocks = 512/128 = 4
```
## 测试代码
参考 `tests/test_xattn_kernels.py`,使用结构化数据验证两个 kernel 的正确性。
## 相关文档
- [`docs/xattention_algorithm_guide.md`](xattention_algorithm_guide.md): XAttention 算法详解
- [`docs/sparse_attention_guide.md`](sparse_attention_guide.md): 稀疏 attention 方法概述

View File

@@ -0,0 +1,122 @@
# XAttention KV Chunking Density 验证测试
## 背景
验证 XAttention Triton kernel 是否只能沿 Q 轴分 chunk不能沿 KV 轴分 chunk。
**假设**`softmax_fuse_block_sum` 需要完整的 K 来计算正确的归一化分母,分 chunk 后的 attention 分布与完整序列不同。
## 测试方法
1. **GPU-only 模式**:一次性对完整序列调用 `xattn_estimate`,记录 Layer 0 的 density
2. **Offload DEBUG 模式**:分 chunk 调用 `xattn_estimate`,累积 selected/total counts计算最终 density
3. 使用相同的 `_debug_k_full` buffer 收集完整 K cache确保输入数据一致
### 关键代码逻辑
```python
# Offload DEBUG: 每个 chunk 累积 selected/total
for each chunk:
K_full = _debug_k_full[:, :, :total_k_len, :] # 累积的 K
_, mask_chunk = xattn_estimate(Q_chunk, K_full, threshold=threshold, causal=True)
# 裁剪到有效区域,计算正确的 causal mask (考虑 Q 偏移量)
q_offset_blocks = k_blocks - q_blocks
causal_mask = indices <= (q_indices + q_offset_blocks)
selected += (mask_valid & causal_mask).sum()
total += causal_mask.sum()
density = selected / total
```
## 测试结果
### 64K 序列 (niah_single_1, 序列长度 64891)
| threshold | GPU-only selected | Offload selected | GPU-only density | Offload density | 差异 (selected) |
|-----------|------------------|------------------|------------------|-----------------|-----------------|
| **0.90** | 1,524,617 | 1,330,506 | **0.3700** | **0.3229** | 194,111 (12.7%) |
| **0.95** | 1,955,015 | 1,747,585 | **0.4744** | **0.4241** | 207,430 (10.6%) |
| **1.00** | 4,118,719 | 4,118,896 | **0.9995** | **0.9995** | -177 (~0%) |
- **total**: 4,120,896 (两种模式一致)
### 32K 序列 (niah_single_1, 序列长度 32485)
| threshold | GPU-only selected | Offload selected | GPU-only density | Offload density | 差异 (selected) |
|-----------|------------------|------------------|------------------|-----------------|-----------------|
| **0.90** | 520,314 | 466,937 | **0.5021** | **0.4506** | 53,377 (10.3%) |
| **0.95** | 647,765 | 602,953 | **0.6251** | **0.5818** | 44,812 (6.9%) |
| **1.00** | 1,036,295 | 1,036,264 | **0.9999** | **0.9999** | 31 (~0%) |
- **total**: 1,036,320 (两种模式一致)
### 汇总对比
| 序列长度 | threshold | GPU-only density | Offload density | density 差异 |
|---------|-----------|------------------|-----------------|--------------|
| 32K | 0.90 | 0.5021 | 0.4506 | 5.2% |
| 64K | 0.90 | 0.3700 | 0.3229 | 4.7% |
| 32K | 0.95 | 0.6251 | 0.5818 | 4.3% |
| 64K | 0.95 | 0.4744 | 0.4241 | 5.0% |
| 32K | 1.00 | 0.9999 | 0.9999 | ~0% |
| 64K | 1.00 | 0.9995 | 0.9995 | ~0% |
## 结论
### 1. Softmax 归一化本身是正确的
`threshold=1.0`(选择所有 blocksGPU-only 和 Offload 模式的 density 几乎完全对齐(差异 < 0.01%)。
这说明:
- `_debug_k_full` 正确收集了完整的 K cache
- 分 chunk 调用 `xattn_estimate`softmax 归一化在正确的 K 序列上计算
- causal mask 的 Q 偏移量处理正确
### 2. 问题在于 threshold 的应用方式
`threshold < 1.0`差异显著10-13%
- **GPU-only**:对完整序列一次性应用 threshold选择 cumulative attention >= threshold 的 blocks
- **Offload**:每个 chunk 独立应用 threshold累积 selected counts
每个 chunk 独立应用 threshold 会导致:
- 某些在 GPU-only 中被选中的 blocks在分 chunk 时因 attention 分布不同而未被选中
- 累积的 selected 比一次性计算的要少
### 3. XAttention Triton kernel 的 KV chunking 限制
**验证结论**XAttention 的 `xattn_estimate` 可以正确处理 KV chunkingsoftmax 归一化正确),但 **threshold-based block selection 不能简单累积**
如果要在 Offload 模式下获得与 GPU-only 一致的 block selection
1. 需要先累积所有 chunks 的 attention scores
2. 最后一次性应用 threshold 选择 blocks
或者接受 10-13% 的 density 差异,这对实际推理准确性的影响需要进一步评估。
## 测试命令
```bash
# GPU-only 模式
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
--sparse-policy xattn_bsa --sparse-threshold 0.9
# Offload 模式 (64K)
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
--sparse-policy xattn_bsa --sparse-threshold 0.9 --enable-offload
# Offload 模式 (32K)
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py --dataset niah_single_1 --sample 0 \
--sparse-policy xattn_bsa --sparse-threshold 0.9 --enable-offload \
--data-dir /home/zijie/Code/nano-vllm/tests/data/ruler_32k --max-model-len 34000
```
## 相关文件
- `nanovllm/kvcache/sparse/xattn_bsa.py`: DEBUG 代码位置
- `nanovllm/ops/xattn.py`: `xattn_estimate` 实现
- `nanovllm/utils/density_observer.py`: DensityObserver 实现

View File

@@ -0,0 +1,400 @@
# XAttention KV Chunking Kernels
## 概述
本文档描述了支持 KV 维度分 chunk 的 softmax kernels 实现。这些 kernels 允许在 CPU offload 场景下,沿 KV 维度分块计算 sparse attention estimation而不需要在 GPU 上保存完整的 raw attention scores。
## 背景
原始的 `softmax_fuse_block_sum` kernel 需要完整的 K 序列来计算正确的 softmax 归一化分母:
```
softmax(x_i) = exp(x_i) / Σ_j exp(x_j)
```
如果只有部分 K (KV chunk),分母 `Σ_j exp(x_j)` 不完整,导致归一化错误。
## 解决方案:三阶段计算
通过将 softmax 计算拆分为三个阶段,实现正确的 KV chunking
```
┌─────────────────────────────────────────────────────────────────┐
│ 三阶段 Pipeline │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ KV Chunk 0 │ │ KV Chunk 1 │ │ KV Chunk N │ │
│ │ attn_scores │ │ attn_scores │ │ attn_scores │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 阶段 1: softmax_compute_partial_stats │ │
│ │ 计算每个 chunk 的 (m_partial, l_partial) │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ (m_0, l_0) (m_1, l_1) (m_N, l_N) │
│ │ │ │ │
│ └────────────────┬┴─────────────────┘ │
│ ▼ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 阶段 2: merge_softmax_stats │ │
│ │ Host 端合并 → (m_global, l_global) │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────┼────────────────┐ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 阶段 3: softmax_normalize_and_block_sum │ │
│ │ 使用全局 stats 归一化并计算 block sums │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ block_sums_0 block_sums_1 block_sums_N │
│ │ │ │ │
│ └────────────────┴────────────────┘ │
│ │ │
│ ▼ │
│ torch.cat → final mask │
│ │
└─────────────────────────────────────────────────────────────────┘
```
### 阶段 1: `softmax_compute_partial_stats`
计算每个 KV chunk 的 partial statistics
- `m_partial`: 该 chunk 内的最大值 (per query row)
- `l_partial`: 该 chunk 内的 partial sum = Σ exp(x - m_partial)
```python
m_partial, l_partial = softmax_compute_partial_stats(
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
reshaped_block_size,
segment_size,
scale,
chunk_start=chunk_start,
kv_offset=kv_offset, # KV chunk 在完整序列中的偏移
is_causal=True,
)
# 输出: m_partial, l_partial 形状为 [batch, heads, q_len]
```
### 阶段 2: `merge_softmax_stats`
Host 端合并所有 KV chunks 的 statistics
```python
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
```
合并公式 (Online Softmax):
```
m_new = max(m_global, m_chunk)
l_new = l_global * exp(m_global - m_new) + l_chunk * exp(m_chunk - m_new)
```
### 阶段 3: `softmax_normalize_and_block_sum`
使用全局 statistics 归一化并计算 block sums
```python
attn_sum_kv = softmax_normalize_and_block_sum(
attn_weights_kv, # [batch, heads, q_len, k_chunk_len]
m_global, # [batch, heads, q_len]
l_global, # [batch, heads, q_len]
reshaped_block_size,
segment_size,
chunk_start=chunk_start,
real_q_len=real_q_len,
scale=scale,
kv_offset=kv_offset,
is_causal=True,
)
# 输出: [batch, heads, q_blocks, k_chunk_blocks]
```
## 数学等价性证明
原始 softmax 计算 (完整 K):
```
softmax(x_i) = exp(x_i - m) / Σ_j exp(x_j - m)
```
分 KV chunk 计算:
```
Chunk 0: m_0 = max(x[0:N/2]), l_0 = Σ exp(x[0:N/2] - m_0)
Chunk 1: m_1 = max(x[N/2:N]), l_1 = Σ exp(x[N/2:N] - m_1)
合并:
m_global = max(m_0, m_1)
l_global = l_0 * exp(m_0 - m_global) + l_1 * exp(m_1 - m_global)
= Σ exp(x[0:N] - m_global) # 等于全局 sum
归一化:
softmax(x_i) = exp(x_i - m_global) / l_global # 正确!
```
## Causal Mask 处理
两个 kernel 都正确处理了 causal attention
1. **`softmax_partial_stats_kernel`**: 通过 `kv_offset` 参数确定当前 KV chunk 在完整序列中的位置,正确计算 causal boundary
2. **`softmax_normalize_block_sum_kernel`**: 同样使用 `kv_offset`,对 causal boundary 之后的位置输出 0
## 存储开销分析
### 符号定义
| 符号 | 含义 | 典型值 |
|------|------|--------|
| S | seq_len | 64K |
| B | batch_size | 1 |
| H | num_heads | 32 |
| D | head_dim | 128 |
| T | stride | 4-8 |
| C | chunk_size | 16K |
| n | num_kv_chunks = ceil(S/C) | 4 |
### 原始方式 (无 KV chunking)
**attn_weights 峰值内存**:
```
[B, H, S/T, S/T] × 4 bytes = B × H × (S/T)² × 4
例: S=64K, T=4, B=1, H=32
= 1 × 32 × 16384² × 4 = 32 GB
```
### KV Chunking 方式的额外存储
#### 1. Partial Stats (每个 KV chunk)
```
m_partial: [B, H, C/T] × 4 bytes
l_partial: [B, H, C/T] × 4 bytes
单个 chunk = 2 × B × H × (C/T) × 4
= 2 × 1 × 32 × 4096 × 4 = 1 MB
```
#### 2. Global Stats
```
m_global: [B, H, S/T] × 4 bytes
l_global: [B, H, S/T] × 4 bytes
= 2 × B × H × (S/T) × 4
= 2 × 1 × 32 × 16384 × 4 = 4 MB
```
#### 3. 总额外开销
```
total_extra = n × partial_stats + global_stats
= 4 × 1MB + 4MB = 8 MB
```
### 存储开销随 seqlen 变化
| seqlen | num_chunks | 原始 attn_weights | 额外 stats | 比例 |
|--------|------------|-------------------|------------|------|
| 16K | 1 | 2 GB | 2 MB | 0.1% |
| 32K | 2 | 8 GB | 4 MB | 0.05% |
| 64K | 4 | 32 GB | 8 MB | 0.025% |
| 128K | 8 | 128 GB | 16 MB | 0.012% |
### 复杂度分析
| 存储组件 | 复杂度 | 说明 |
|----------|--------|------|
| 原始 attn_weights | O(S²) | 二次增长 |
| Partial/Global stats | O(S) | 线性增长 |
| **相对开销** | O(1/S) | **随 seqlen 递减** |
### 峰值显存优化
KV chunking 的主要收益是**峰值显存**从 O(S²) 降到 O(S×C)
```
原始: O(B × H × (S/T)²) # 完整 attn_weights
KV chunking: O(B × H × (S/T) × (C/T)) # 一次只处理一个 chunk
```
以 S=128K, C=16K 为例:
- 原始峰值: ~128 GB
- KV chunking 峰值: ~16 GB (降低 **8 倍**)
## 支持不同 Q/KV Chunk Size
三阶段 pipeline 支持 Q 和 KV 使用不同的 chunk size
```python
q_chunk_size = 8192 # Q 分块大小
kv_chunk_size = 16384 # KV 分块大小
for q_chunk_idx in range(q_chunk_num):
Q_chunk = Q[:, :, q_start:q_end, :] # [B, H, q_chunk_size, D]
for kv_chunk_idx in range(kv_chunk_num):
K_chunk = K[:, :, kv_start:kv_end, :] # [B, H, kv_chunk_size, D]
# ... 三阶段处理
```
### 测试验证结果
| Config | seq_len | Q chunks | KV chunks | density | 对齐 |
|--------|---------|----------|-----------|---------|------|
| Q=16K, KV=16K | 64891 | 4 | 4 | 0.1117 | ✓ 100% |
| Q=8K, KV=16K | 64891 | 8 | 4 | 0.1112 | ✓ 100% |
| Q=16K, KV=8K | 64891 | 4 | 8 | 0.1117 | ✓ 100% |
| Q=8K, KV=8K | 64891 | 8 | 8 | 0.1112 | ✓ 100% |
| Q=4K, KV=16K | 64891 | 16 | 4 | 0.1109 | ✓ 100% |
## API 参考
### `softmax_compute_partial_stats`
```python
def softmax_compute_partial_stats(
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
reshaped_block_size: int,
segment_size: int,
scale: float,
chunk_start: int = 0, # Q chunk 起始位置 (reshaped space)
kv_offset: int = 0, # KV chunk 偏移 (reshaped space)
is_causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""返回 (m, l) partial stats"""
```
### `merge_softmax_stats`
```python
def merge_softmax_stats(
m_chunks: list, # List of [batch, heads, q_len] tensors
l_chunks: list, # List of [batch, heads, q_len] tensors
) -> Tuple[torch.Tensor, torch.Tensor]:
"""返回 (m_global, l_global)"""
```
### `softmax_normalize_and_block_sum`
```python
def softmax_normalize_and_block_sum(
attn_weights_slice: torch.Tensor, # [batch, heads, q_len, k_chunk_len]
m_global: torch.Tensor, # [batch, heads, q_len]
l_global: torch.Tensor, # [batch, heads, q_len]
reshaped_block_size: int,
segment_size: int,
chunk_start: int,
real_q_len: int,
scale: float,
kv_offset: int = 0,
is_causal: bool = False,
) -> torch.Tensor:
"""返回 block sums [batch, heads, q_blocks, k_chunk_blocks]"""
```
## 使用示例
```python
from nanovllm.ops.xattn import (
flat_group_gemm_fuse_reshape,
softmax_compute_partial_stats,
softmax_normalize_and_block_sum,
merge_softmax_stats,
find_blocks_chunked,
)
# 对每个 Q chunk
for q_chunk_idx in range(q_chunk_num):
Q_chunk = Q_padded[:, :, q_start:q_end, :]
# 阶段 1: 每个 KV chunk 计算 partial stats
m_chunks, l_chunks = [], []
attn_weights_chunks = []
for kv_chunk_idx in range(kv_chunk_num):
K_chunk = K_padded[:, :, kv_start:kv_end, :]
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
# 计算 raw scores
attn_weights = flat_group_gemm_fuse_reshape(
Q_chunk, K_chunk, STRIDE,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False, # K 不完整
)
attn_weights_chunks.append(attn_weights)
# 计算 partial stats
m, l = softmax_compute_partial_stats(
attn_weights, block_size, segment_size, scale,
chunk_start=chunk_start,
kv_offset=kv_offset,
is_causal=True,
)
m_chunks.append(m)
l_chunks.append(l)
# 阶段 2: 合并 stats
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
# 阶段 3: 归一化并计算 block sums
block_sums_list = []
for kv_chunk_idx, attn_weights in enumerate(attn_weights_chunks):
kv_offset = kv_chunk_idx * kv_chunk_size // STRIDE
block_sums = softmax_normalize_and_block_sum(
attn_weights, m_global, l_global,
block_size, segment_size, chunk_start, real_q_len, scale,
kv_offset=kv_offset,
is_causal=True,
)
block_sums_list.append(block_sums)
# 拼接并选择 blocks
attn_sum = torch.cat(block_sums_list, dim=-1)
mask = find_blocks_chunked(attn_sum, ...)
```
## 性能对比
| 方面 | 原始实现 | KV Chunking 实现 |
|------|---------|-----------------|
| Kernel 数量 | 1 | 2 (stats + normalize) |
| Raw scores 读取次数 | 2 | 2 |
| 额外内存 | 0 | O(B × H × S/T × 2) for (m, l) |
| Host 计算 | 无 | merge stats (轻量) |
| **峰值显存** | O(S²) | **O(S × C)** |
## 验证测试
### 批量测试结果
测试脚本 `tests/test_xattn_kv_chunking_batch.py` 验证了不同 seqlen 下的一致性:
```
| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |
|---------|--------|-----------|-----------|-------------|------------|----------|-----------|--------|
| 3688 | 4 | 0.90 | 1 | 0.383405 | 0.383405 | 0.000000 | 0.0000% | PASS |
| 7888 | 4 | 0.90 | 1 | 0.290611 | 0.290611 | 0.000000 | 0.0000% | PASS |
| 15685 | 4 | 0.90 | 1 | 0.197724 | 0.197724 | 0.000000 | 0.0000% | PASS |
| 32485 | 4 | 0.90 | 2 | 0.159023 | 0.159023 | 0.000000 | 0.0000% | PASS |
| 64891 | 4 | 0.90 | 4 | 0.111656 | 0.111656 | 0.000000 | 0.0000% | PASS |
```
### 关键结论
1. **数学等价性**: density_diff = 0.000000 对于所有测试
2. **Mask 完全对齐**: mask_diff = 0.0000% 对于所有测试
3. **支持任意 Q/KV chunk size 组合**
## 相关文件
- `nanovllm/ops/xattn.py`: Kernel 实现
- `tests/test_xattn_estimate_alignment.py`: 单文件验证测试
- `tests/test_xattn_kv_chunking_batch.py`: 批量验证测试
- `docs/xattn_kernels_guide.md`: 原始 kernel 文档

View File

@@ -0,0 +1,154 @@
# XAttention Memory Benchmark
GPU-only 模式下 XAttention 的内存使用分析。
## 测试配置
### 硬件
- **GPU**: NVIDIA A100 80GB (用于基准测试)
- **目标**: 验证在 RTX 3090/4090 (24GB) 上的可行性
### 模型
- **Model**: Qwen3-0.6B (28 layers, 16 heads, 8 KV heads, head_dim=128)
- **Context Length**: 32K (max_model_len=40960)
### XAttention 配置
- **Sparse Policy**: XATTN_BSA
- **Threshold**: 0.9
- **Block Size**: 128 tokens (BSA block)
- **Stride**: 8
---
## 内存使用分析
### 基准测试 (gpu-utilization=0.9)
| 指标 | 数值 |
|------|------|
| KV Cache | 157 blocks × 448 MB = 70.3 GB |
| **峰值内存** | **73,949 MiB (72.2 GB)** |
| GPU 利用率 | 90.2% |
### 24GB 显存可行性测试
| gpu-utilization | KV Cache Blocks | KV Cache Size | 峰值内存 | 测试结果 |
|-----------------|-----------------|---------------|----------|----------|
| 0.25 | 39 blocks | 17.5 GB | **20.6 GB** | ✅ 5/5 PASSED |
| 0.28 | 44 blocks | 19.7 GB | **22.8 GB** | ✅ 5/5 PASSED |
---
## 24GB 显存推荐配置
适用于 **RTX 3090 / RTX 4090 (24GB)**
```bash
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler.py \
--model ~/models/Qwen3-0.6B \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 5 \
--max-model-len 40960 \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9 \
--gpu-utilization 0.28
```
### 配置说明
| 参数 | 值 | 说明 |
|------|-----|------|
| `--gpu-utilization` | 0.28 | 限制 GPU 内存使用到 ~23GB |
| `--max-model-len` | 40960 | 支持 32K+ context |
| `--sparse-policy` | XATTN_BSA | 启用 XAttention 稀疏注意力 |
| `--sparse-threshold` | 0.9 | 选择覆盖 90% attention 的 blocks |
---
## 内存分解
### Qwen3-0.6B @ 32K Context
| 组件 | 计算公式 | 大小 |
|------|----------|------|
| 模型权重 | 0.6B × 2 bytes | ~1.2 GB |
| KV Cache (per-token) | 2 × 28 layers × 8 kv_heads × 128 head_dim × 2 bytes | 112 KB |
| KV Cache (per-block) | 112 KB × 4096 tokens | 448 MB |
| KV Cache (44 blocks) | 448 MB × 44 | 19.7 GB |
| XAttention Buffers | GQA + mask + KV chunking | ~0.3 GB |
| 中间激活 | 运行时分配 | ~1.5 GB |
| **总计** | | **~22.8 GB** |
---
## 性能指标
### RULER niah_single_1 (5 samples)
| 指标 | gpu-util=0.25 | gpu-util=0.28 | gpu-util=0.9 |
|------|---------------|---------------|--------------|
| 准确率 | 100% (5/5) | 100% (5/5) | 100% (5/5) |
| 耗时 | 11.4s | 11.5s | 11.6s |
| Compute Density | 24.77% | 24.77% | 24.77% |
| Min Layer Density | 4.29% (Layer 5) | 4.29% (Layer 5) | 4.29% (Layer 5) |
**结论**: 降低 gpu-utilization 不影响准确率和性能,只影响可支持的最大序列长度。
---
## 不同模型的估算
### KV Cache 公式
```
KV Cache per-token = 2 × num_layers × num_kv_heads × head_dim × dtype_size
KV Cache per-block = per-token × block_size
```
### 常见模型估算 (32K context, block_size=4096)
| 模型 | Layers | KV Heads | Head Dim | Per-Token | 32K Tokens | 24GB 可行? |
|------|--------|----------|----------|-----------|------------|------------|
| Qwen3-0.6B | 28 | 8 | 128 | 112 KB | 3.5 GB | ✅ 是 |
| Qwen3-4B | 36 | 8 | 128 | 144 KB | 4.5 GB | ✅ 是 |
| Llama-3.1-8B | 32 | 8 | 128 | 128 KB | 4.0 GB | ⚠️ 需要 offload |
| Qwen2.5-7B | 28 | 4 | 128 | 56 KB | 1.8 GB | ✅ 是 |
注: 8B 模型权重约 16GB加上 KV cache 超过 24GB需要 CPU offload。
---
## 使用建议
### RTX 3090/4090 (24GB)
1. **小模型 (≤4B)**:可直接使用 GPU-only + XAttention
```bash
--gpu-utilization 0.28 --sparse-policy XATTN_BSA
```
2. **大模型 (7B-8B)**:需要 CPU offload
```bash
--enable-offload --num-gpu-blocks 4 --num-cpu-blocks 32
```
### A100 (40GB/80GB)
1. **所有模型**:可使用 GPU-only 模式
```bash
--gpu-utilization 0.9 --sparse-policy XATTN_BSA
```
---
## 相关文件
- `tests/test_ruler.py`: RULER 测试脚本
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy 实现
- `docs/gpuonly_density_alignment_test.md`: Density 对齐验证
---
**Date**: 2026-02-02
**Author**: Zijie Tian

View File

@@ -0,0 +1,184 @@
# XAttention Offload Profiling - 32K Context
Nsys profiling 分析 XAttention vs Full Attention 在 Offload 模式下的性能。
**测试日期**: 2026-02-05
**测试模型**: Llama-3.1-8B-Instruct
**Context**: 32K tokens
**GPU**: A100-80GB (GPU 0)
---
## 测试配置
| 参数 | Full | XAttention |
|------|------|------------|
| Policy | FULL | XATTN_BSA |
| Block size | 4096 | 4096 |
| GPU blocks | 4 | 4 |
| Threshold | - | 0.95 |
| Density | 100% | ~50% |
---
## XAttention 各阶段时间统计
### NVTX Markers Summary
| 阶段 | 总时间(ms) | 调用次数 | 平均时间(ms) | 说明 |
|------|------------|----------|--------------|------|
| xattn_find_blocks | 1155.1 | 256 | 4.51 | 块选择 (threshold-based) |
| xattn_estimate_pass1 | 588.3 | 256 | 2.30 | 第一轮: partial stats |
| xattn_compute_historical | 512.0 | 224 | 2.29 | 历史 KV attention |
| xattn_estimate_pass2 | 501.6 | 256 | 1.96 | 第二轮: block sums |
| xattn_estimate_merge | 197.9 | 256 | 0.77 | 合并 softmax stats |
| xattn_compute_merge | 93.8 | 256 | 0.37 | 计算结果合并 |
| xattn_compute_current | 59.2 | 256 | 0.23 | 当前 chunk attention |
### 时间分配
```
Total XAttention overhead: 3108 ms
Estimate 阶段: 1288 ms (41.4%)
- pass1: 588 ms
- pass2: 502 ms
- merge: 198 ms
Find blocks: 1155 ms (37.2%)
Compute 阶段: 665 ms (21.4%)
- historical: 512 ms
- merge: 94 ms
- current: 59 ms
```
---
## Chunk7 (最后一个 chunk) 对比
### Per-Layer 时间
| Policy | Layer 0 | Layer 1 | ... | Layer 31 | Avg |
|--------|---------|---------|-----|----------|-----|
| Full | 36.5 ms | 33.6 ms | ... | 32.7 ms | ~35 ms |
| XAttn | 39.7 ms | 39.3 ms | ... | 38.5 ms | ~38 ms |
### 分析
Chunk7 是序列的最后 ~4K tokens (3813 tokens),此时:
- K 长度: 32485 tokens
- Density: 42.08%
**结论**: XAttention 在 Chunk7 比 Full 慢约 8%,原因:
1. Estimate 开销无法被稀疏计算收益抵消
2. 42% density 仍然较高,稀疏收益有限
---
## Full Attention Chunk7 详细数据
```
Layer Time(ms)
L0 36.5
L1 44.3
L2 43.7
L3 38.7
L4 34.2
L5 45.2
...
L31 32.7
Avg ~35
```
---
## XAttention Chunk7 详细数据
```
Layer Time(ms)
L0 39.7
L1 39.3
L2 37.1
L3 39.1
L4 38.7
L5 39.4
...
L31 38.5
Avg ~38
```
---
## 性能瓶颈分析
### 1. xattn_find_blocks 开销过高
- 平均 4.51 ms per call
- 占总时间 37.2%
- 原因: threshold-based 块选择涉及排序和累积求和
### 2. 两轮 estimate 开销
- Pass1 + Pass2 共 1090 ms
- 需要遍历所有 KV chunks 两次
- 可优化方向: 单轮 estimate
### 3. Compute 阶段相对高效
- 只占 21.4%
- 说明 BSA 稀疏计算本身效率不错
---
## 优化建议
### 短期
1. **减少 find_blocks 开销**
- 使用 top-k 而不是 threshold
- 预分配 mask buffer 避免动态分配
2. **合并 estimate 两轮**
- 在单轮中同时计算 stats 和 block sums
### 中期
1. **estimate 阶段使用更小的 block_size**
- 当前 block_size=4096 对 estimate 不友好
- 参考 `docs/estimate_block_size_performance.md`
2. **Pipeline estimate 和 H2D**
- 将 estimate 与下一个 chunk 的 H2D 重叠
### 长期
1. **预测式块选择**
- 基于历史 pattern 预测下一个 chunk 的重要 blocks
- 减少 estimate 开销
---
## 相关文件
- `results/nsys/full_offload_32k_blk4096_20260205_023257.nsys-rep`
- `results/nsys/xattn_offload_32k_blk4096_20260205_023435.nsys-rep`
---
## 命令
### Profile Full
```bash
bash scripts/profile_offload.sh --policy full --ctx-len 32k --gpu 0 --model ~/models/Llama-3.1-8B-Instruct
```
### Profile XAttention
```bash
bash scripts/profile_offload.sh --policy xattn --ctx-len 32k --gpu 0 --model ~/models/Llama-3.1-8B-Instruct
```
### 分析 NVTX
```bash
nsys stats --report nvtx_pushpop_sum <file>.nsys-rep
```

View File

@@ -0,0 +1,307 @@
# XAttention Offload Stream Synchronization Fix
修复 XAttention BSA Policy 在 Offload 模式下的 CUDA stream 同步 bug。
**修复日期**: 2026-02-05
**Commit**: `829b311`
**影响文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`, `nanovllm/kvcache/offload_engine.py`
---
## 问题描述
### 症状
在 Offload 模式下运行 RULER benchmark 时XAttention BSA 的 `select_blocks` 方法中 Pass 1 和 Pass 2 从**同一个 CPU block** 加载的 K 数据不一致:
```
Pass 1: K_chunk sum = 745472.00 (正确)
Pass 2: K_chunk sum = 0.00 (错误,数据未加载完成)
```
这导致 attention 计算结果错误RULER 准确率下降。
### 复现条件
- 模式: Offload (`--enable-offload`)
- Context: ≥ 32K tokens
- 稀疏策略: `--sparse-policy XATTN_BSA`
---
## 根因分析
### Stream 配置回顾
nano-vllm 的 CPU offload 使用多个 CUDA streams 实现 pipeline
| Stream | 用途 |
|--------|------|
| `slot_transfer_streams[i]` | H2D 传输 (CPU → GPU slot) |
| `compute_stream` | Attention 计算 |
| `prefill_offload_streams[i]` | D2H 传输 (GPU → CPU cache) |
### 同步机制
`wait_slot_layer(slot)` 使用 event 机制同步:
```python
def wait_slot_layer(self, slot_idx: int):
"""Make compute_stream wait for H2D transfer completion."""
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
```
### Bug 根因
`select_blocks` 方法中:
1. H2D 传输在 `slot_transfer_streams` 上执行
2. `wait_slot_layer``compute_stream` 等待传输完成
3. **但是** 后续的 compute kernels 在**默认 stream** 上执行,而不是 `compute_stream`
```python
# Bug 代码
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot) # compute_stream 等待
# 这些 kernel 在默认 stream 上运行,没有等待 H2D 完成!
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
# ... 后续计算 ...
```
### 时序图
```
slot_transfer_stream: [====H2D====]
compute_stream: |wait|
default_stream: [kernel1][kernel2] ← 没有等待!
数据未就绪
```
---
## 修复方案
### 核心修改
将所有 estimate 阶段的 compute kernels 包装在 `with torch.cuda.stream(compute_stream):` 中:
```python
# 修复后代码
compute_stream = offload_engine.compute_stream
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot) # compute_stream 等待
# 所有计算在 compute_stream 上执行
with torch.cuda.stream(compute_stream):
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
# ... 后续计算 ...
```
### 修复位置
`select_blocks` 方法中共 6 处需要修复:
| 位置 | 阶段 | 修复内容 |
|------|------|----------|
| Pass 1 历史 blocks | `xattn_estimate_pass1` | 历史 KV chunk 处理 |
| Pass 1 当前 chunk | `xattn_estimate_pass1` | 当前 GPU 上的 K 处理 |
| Step 2 合并 | `merge_softmax_stats` | softmax stats 合并 |
| Pass 2 历史 blocks | `xattn_estimate_pass2` | 带全局 stats 的 block_sum |
| Pass 2 当前 chunk | `xattn_estimate_pass2` | 当前 chunk 的 block_sum |
| Step 4 block 选择 | `find_blocks_chunked` | 最终 block 选择 |
### 时序图(修复后)
```
slot_transfer_stream: [====H2D====]
compute_stream: |wait|[kernel1][kernel2]
数据已就绪
```
---
## 代码变更详情
### 1. Pass 1 历史 blocks 处理
```python
# Before (bug)
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
k_block = offload_engine.get_k_for_slot(slot) # 默认 stream
K_chunk = k_block.transpose(1, 2)
# ... compute ...
# After (fixed)
compute_stream = offload_engine.compute_stream
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream): # 显式指定 stream
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
# ... compute ...
```
### 2. 移除 STRONG SYNC
`offload_engine.py` 中移除了不必要的强同步:
```python
# Removed from load_to_slot_layer() and load_k_only_to_slot_layer()
# STRONG SYNC: Synchronize all prefill offload streams before H2D
# for offload_stream in self.prefill_offload_streams:
# offload_stream.synchronize()
```
这些同步现在由 event 机制正确处理,不再需要阻塞式同步。
### 3. 其他清理
- 移除 DEBUG print 语句
- 移除 `torch.save()` debug 代码
- 合并多个 fallback 条件
-`chunk_size` 默认值从 16384 改为 4096匹配 offload Q chunk size
---
## 测试验证
### 测试命令
**GPU 0 - Offload 模式测试**:
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 10 \
--max-model-len 40960 \
--enable-offload \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
**GPU 1 - GPU-only 模式测试**:
```bash
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Qwen3-0.6B \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--num-samples 10 \
--max-model-len 40960 \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9
```
### 测试结果
| 模式 | 模型 | Context | Samples | Pass Rate | Density |
|------|------|---------|---------|-----------|---------|
| Offload | Llama-3.1-8B | 32K | 10/10 | **100%** | 9.53% |
| GPU-only | Qwen3-0.6B | 32K | 10/10 | **100%** | 9.84% |
### Density 对齐验证
| 模式 | Layer 0 Density | 差异 |
|------|-----------------|------|
| GPU-only | 9.84% | - |
| Offload | 9.53% | ~3% |
~3% 的差异是预期的,因为两种模式的 KV 累积模式不同:
- GPU-only: 一次性处理所有 KV
- Offload: 分 chunk 处理,每个 chunk 独立计算 softmax stats 后合并
---
## 技术细节
### 三阶段 KV Chunking 流程
```
┌─────────────────────────────────────────────────────────────┐
│ Stage 1: softmax_compute_partial_stats │
│ └── 每个 KV chunk 独立计算 partial stats (m_i, l_i) │
│ │
│ Stage 2: merge_softmax_stats │
│ └── Host 端合并所有 chunks: (m_global, l_global) │
│ │
│ Stage 3: softmax_normalize_and_block_sum │
│ └── 使用全局 stats 归一化并计算 block sums │
└─────────────────────────────────────────────────────────────┘
```
### Stream 配置要求
| 操作类型 | Stream | 原因 |
|----------|--------|------|
| H2D 传输 | `slot_transfer_streams` | 异步传输,不阻塞计算 |
| D2H 传输 | `prefill_offload_streams` | 异步 offload不阻塞计算 |
| Estimate kernels | `compute_stream` | 与 attention 计算共享,确保同步 |
| Attention kernels | `compute_stream` | 主计算流 |
### Event 同步机制
```python
# H2D 传输完成后记录 event
self.ring_slot_ready[slot_idx].record(slot_transfer_stream)
# 计算前等待 H2D 完成
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
# 计算完成后记录 event用于下一轮 H2D
self.ring_slot_compute_done[slot_idx].record(compute_stream)
```
---
## 相关文档
- [`docs/architecture_guide.md`](architecture_guide.md): Stream 配置和 ring buffer 架构
- [`docs/xattn_kv_chunking_kernels.md`](xattn_kv_chunking_kernels.md): 三阶段 softmax kernels
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md): Density 对齐测试
- [`docs/xattn_bsa_policy_design.md`](xattn_bsa_policy_design.md): XAttention BSA Policy 设计
---
## 经验总结
### 1. Stream 同步的隐蔽性
CUDA stream 同步 bug 很难发现:
- 数据可能"大部分时间"正确(取决于时序)
- 错误表现为随机/间歇性的结果偏差
- 需要精确的 debug logging 才能定位
### 2. Event vs Synchronize
| 方法 | 优点 | 缺点 |
|------|------|------|
| `stream.wait_event()` | 非阻塞,保持 pipeline | 只同步指定 stream |
| `stream.synchronize()` | 保证完成 | 阻塞整个 stream破坏 pipeline |
**最佳实践**: 使用 event 进行精确同步,避免 synchronize 阻塞。
### 3. 调试方法
```python
# 打印 tensor sum 验证数据一致性
print(f"K_chunk sum = {K_chunk.sum().item()}")
# 保存中间结果进行离线比较
torch.save({'K': K_chunk, 'layer': layer_id}, f'/tmp/debug_{pass}_{chunk}.pt')
```

View File

@@ -0,0 +1,170 @@
# XAttention Performance Analysis
本文档记录 XAttention 在不同配置下的性能分析结果,包括 NVTX 标记位置、block size 影响和性能瓶颈。
## NVTX 标记
XAttention 代码中添加了 NVTX 标记用于 nsys profiling便于分析 estimate 和 compute 阶段的性能。
### 标记位置
| 模式 | 标记名称 | 文件位置 | 说明 |
|------|---------|---------|------|
| GPU-only | `xattn_estimate` | `xattn_bsa.py:compute_prefill` | xattn_estimate 调用 |
| GPU-only | `xattn_bsa_compute` | `xattn_bsa.py:compute_prefill` | BSA kernel 调用 |
| Offload | `xattn_estimate_gemm` | `xattn_bsa.py:select_blocks` | flat_group_gemm 循环 |
| Offload | `xattn_estimate_softmax` | `xattn_bsa.py:select_blocks` | softmax_fuse_block_sum |
| Offload | `xattn_estimate_find_blocks` | `xattn_bsa.py:select_blocks` | find_blocks_chunked |
| Offload | `xattn_compute_historical` | `xattn_bsa.py:compute_chunked_prefill` | 历史 chunks attention |
| Offload | `xattn_compute_current` | `xattn_bsa.py:compute_chunked_prefill` | 当前 chunk attention |
| Offload | `xattn_compute_merge` | `xattn_bsa.py:compute_chunked_prefill` | merge 操作 |
### 查看 NVTX 统计
```bash
# 生成 profile
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 4096 --gpu 0
# 查看 NVTX 统计
nsys stats --report nvtx_pushpop_sum results/nsys/<filename>.nsys-rep
```
## Block Size 对 Offload 模式的影响
### 测试配置
- Model: Llama-3.1-8B-Instruct
- Context: 64K tokens
- Mode: xattn + offload
- GPU: A100 40GB
### 性能对比
| 指标 | block_size=4096 | block_size=1024 | 变化 |
|------|----------------|-----------------|------|
| **总时间** | 27.7s | 55.5s | **2x 慢** |
| **Chunks 数量** | 16 | 64 | 4x |
| **CPU blocks** | 18 | 71 | ~4x |
### 各阶段耗时分布
#### block_size=4096
| 阶段 | 占比 | 总时间 | 平均时间 | 调用次数 |
|-----|------|--------|---------|---------|
| **xattn_estimate_find_blocks** | **39.7%** | 18.0s | 37.6ms | 480 |
| xattn_compute_historical | 4.4% | 2.0s | 4.2ms | 480 |
| xattn_estimate_gemm | 3.4% | 1.5s | 3.2ms | 480 |
| xattn_compute_current | 0.2% | 113ms | 0.22ms | 512 |
| xattn_compute_merge | 0.2% | 96ms | 0.19ms | 512 |
| xattn_estimate_softmax | 0.2% | 88ms | 0.18ms | 480 |
#### block_size=1024
| 阶段 | 占比 | 总时间 | 平均时间 | 调用次数 |
|-----|------|--------|---------|---------|
| **xattn_estimate_gemm** | **23.6%** | 22.6s | 11.4ms | 1984 |
| **xattn_compute_historical** | **16.9%** | 16.2s | 8.0ms | 2016 |
| xattn_estimate_find_blocks | 1.4% | 1.3s | 0.66ms | 1984 |
| xattn_compute_current | 0.5% | 433ms | 0.21ms | 2048 |
| xattn_compute_merge | 0.4% | 373ms | 0.18ms | 2048 |
| xattn_estimate_softmax | 0.2% | 222ms | 0.11ms | 1984 |
### 关键发现
1. **Block size 对性能影响显著**
- block_size=1024 比 4096 慢约 2x
- 更小的 block size 导致更多的 chunks增加调用次数
2. **性能瓶颈随 block size 变化**
- **block_size=4096**: 瓶颈是 `find_blocks_chunked` (39.7%)
- **block_size=1024**: 瓶颈转移到 `estimate_gemm` (23.6%) 和 `compute_historical` (16.9%)
3. **Amortization 效应**
- 大 block size 虽然单次 `find_blocks` 更慢 (37.6ms vs 0.66ms)
- 但调用次数少 (480 vs 1984),总时间反而更少
4. **find_blocks_chunked 的特殊性**
- 该函数主要在 CPU 上执行 block 选择逻辑
- 处理更大的数据量时开销显著增加
- block_size=4096 时占用 40% 时间,是主要优化目标
## softmax_fuse_block_sum_kernel 性能分析
`softmax_fuse_block_sum_kernel_non_causal` 是 XAttention 估计阶段的核心 Triton kernel。
### Kernel 结构
```python
# 每个 thread block 处理的数据形状
工作负载: [block_size, segment_size] # 单个 Q block 对所有 K 的注意力
# Pass 1: 计算全局 softmax 参数 (m_i, l_i)
for iter in range(num_iters): # num_iters = k_len / segment_size
X = load [block_size, segment_size]
compute max, sum for softmax normalization
# Pass 2: Normalize + Block Sum
for iter in range(num_iters):
X = load [block_size, segment_size]
X = softmax(X)
X = reshape(X, [block_size, segment_size/block_size, block_size])
X = sum(X, axis=2) # → [block_size, segment_size/block_size]
X = sum(X, axis=0) # → [segment_size/block_size]
store output
```
### 性能随 block_size 变化的因素
| 因素 | 小 block_size (64) | 大 block_size (256) |
|------|-------------------|---------------------|
| Grid 并行度 | 高 (更多 blocks) | 低 (更少 blocks) |
| 寄存器使用 | 低 | 高 (可能 spill) |
| L2 Cache 复用 | 差 | 好 |
| 输出大小 | 大 | 小 |
### 典型性能曲线
```
Performance
│ ┌─────┐
│ / \
│ / \
│ / \
│ / \
└────/───────────────\────────→ block_size
64 128 256 512
最优点通常在 128-256 之间
```
## 优化建议
1. **优先使用 block_size=4096**
- 减少 chunk 数量,降低调度开销
- 更好的 amortization 效果
2. **优化 find_blocks_chunked**
- 当前是 block_size=4096 的主要瓶颈
- 考虑 GPU 加速或批量处理
3. **Pipeline 优化**
- 利用多 slot 的 ring buffer 实现计算和传输 overlap
- 当前已实现,但 find_blocks 是 CPU 操作,无法 overlap
## 测试命令
```bash
# GPU-only 模式 (需要 40GB+ VRAM)
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --no-offload --gpu 0
# Offload 模式block_size=4096
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 4096 --gpu 0
# Offload 模式block_size=1024
bash scripts/profile_offload.sh --policy xattn --ctx-len 64k --block-size 1024 --gpu 0
# 128K context
bash scripts/profile_offload.sh --policy xattn --ctx-len 128k --block-size 4096 --gpu 0
```

View File

@@ -1,109 +0,0 @@
# Findings: CUDA Graph for Offload Mode
## Discovery 1: 为什么 Offload Mode 不使用 CUDA Graph
**位置**: `nanovllm/engine/model_runner.py:421`
```python
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
```
**原因**: `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`,强制使用 eager mode。
---
## Discovery 2: 当前 CUDA Graph 架构
**文件**: `model_runner.py:682-717`
```python
def capture_cudagraph(self):
# 为不同 batch size 捕获完整 model forward
for bs in [1, 2, 4, 8, 16, ...]:
with torch.cuda.graph(graph):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
```
**特点**:
- 捕获完整的 `model()` 调用(包含所有层)
- 使用 graph pool 共享内存
- 只用于 decodeprefill 始终 eager
---
## Discovery 3: Offload Decode 的 Attention 流程
**文件**: `nanovllm/kvcache/sparse/full_policy.py:304-379`
**Ring Buffer Pipeline**:
```
1. 预加载前 N 个 blocks 到 GPU slots
2. 对每个 block:
a. wait_slot_layer() # 等待 H2D
b. get_kv_for_slot() # 获取 KV
c. flash_attn_with_lse() # ⭐ 可 graph
d. record_slot_compute_done()
e. load_next_block() # 启动下一个 H2D
f. merge_attention_outputs() # ⭐ 可 graph但动态
```
**关键**: H2D 传输不能 graph但 attention 计算可以。
---
## Discovery 4: 验证 Graph 复用可行性
**测试**: `tests/test_chunk_attention_graph_reuse.py`
**结论**:
- 只需 2 个 graphcausal + non-causal
- 通过 `copy_()` 更新 static tensors
- 可复用于所有层和所有 chunk pairs
**测试结果**:
```
Layer 0: max_diff=3.91e-03 ✅
Layer 1: max_diff=7.81e-03 ✅
Layer 2: max_diff=3.91e-03 ✅
✅ PASSED
```
---
## Discovery 5: Chunk Size 和 Block Size 关系
**观察**:
- Prefilled blocks 的 KV size = `block_size`
- Decode buffer 的 KV size = `1``block_size`(动态)
**Graph 策略**:
- Prefilled blocks: 固定 size = block_size适合 graph
- Decode buffer: 动态 size建议保持 eager
---
## Discovery 6: 使用的 Triton 算子
**文件**: `nanovllm/ops/chunked_attention.py`
| 算子 | 功能 | 可 Graph |
|------|------|----------|
| `flash_attn_with_lse()` | Attention + LSE | ✅ |
| `merge_attention_outputs()` | 合并两个 attention 输出 | ✅ |
这两个算子是纯 GPU 计算,可以被 CUDA Graph 捕获。
---
## Discovery 7: 数据依赖分析
**Attention 输入**:
- `q`: 来自当前层的 QKV projectionshape 固定
- `k, v`: 来自 GPU slotH2D 传输后shape = [1, block_size, heads, dim]
**依赖链**:
```
H2D(block) → wait() → get_kv() → copy_to_static() → graph.replay() → clone_output()
```
**关键**: Graph 只封装 attention 计算,不包含数据传输。

View File

@@ -22,7 +22,7 @@ class Config:
tensor_parallel_size: int = 1
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
eos: int | list[int] = -1 # Single EOS token or list of EOS tokens (e.g., GLM-4)
kvcache_block_size: int = 1024
num_kvcache_blocks: int = -1
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
@@ -48,16 +48,20 @@ class Config:
# XAttention BSA specific parameters
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1)
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
sparse_use_triton: bool = True # Use Triton kernels for estimation
sparse_stride: int = 8 # Stride for Q/K downsampling
sparse_chunk_size: int = 16384 # Triton kernel chunk size for estimation
def __post_init__(self):
assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=True)
# Get max position embeddings (GLM-4 uses seq_length instead of max_position_embeddings)
max_pos = getattr(self.hf_config, 'max_position_embeddings',
getattr(self.hf_config, 'seq_length', 4096))
self.max_model_len = min(self.max_model_len, max_pos)
assert self.max_num_batched_tokens >= self.max_model_len
# Override torch_dtype if user specified

View File

@@ -10,7 +10,8 @@ from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner
from nanovllm.utils.observer import Observer
from nanovllm.utils.observer import InferenceObserver
from nanovllm.utils.memory_observer import MemoryObserver
class LLMEngine:
@@ -29,7 +30,13 @@ class LLMEngine:
self.ps.append(process)
self.events.append(event)
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True)
# Get EOS token(s) from config (may be int or list, e.g., GLM-4 uses list)
# Prefer hf_config.eos_token_id which contains full list, fallback to tokenizer
eos_from_config = getattr(config.hf_config, 'eos_token_id', None)
if eos_from_config is not None:
config.eos = eos_from_config
else:
config.eos = self.tokenizer.eos_token_id
# Set Sequence.block_size to match the KV cache block size
Sequence.block_size = config.kvcache_block_size
@@ -58,15 +65,18 @@ class LLMEngine:
print(f"[DEBUG LLMEngine.step] Mode={mode}, active_sequences={len(seqs)}")
if not is_prefill:
# The end of the prefill mode. Get TTFT.
if Observer.ttft_start != 0:
Observer.ttft = perf_counter_ns() - Observer.ttft_start
Observer.reset_ttft()
# The start of the decode mode. Get TPOT.
if Observer.tpot_start != 0:
Observer.tpot = perf_counter_ns() - Observer.tpot_start
Observer.tpot_start = perf_counter_ns()
# Decode mode: calculate TPOT from previous decode step
if InferenceObserver.tpot_start != 0:
InferenceObserver.tpot = perf_counter_ns() - InferenceObserver.tpot_start
InferenceObserver.tpot_start = perf_counter_ns()
token_ids = self.model_runner.call("run", seqs, is_prefill)
if is_prefill:
# Calculate TTFT after prefill completes (including chunked prefill)
if InferenceObserver.ttft_start != 0:
InferenceObserver.ttft = perf_counter_ns() - InferenceObserver.ttft_start
InferenceObserver.reset_ttft()
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
@@ -91,7 +101,8 @@ class LLMEngine:
log_level = os.environ.get('NANOVLLM_LOG_LEVEL', 'INFO')
debug_enabled = log_level.upper() == 'DEBUG'
Observer.complete_reset()
InferenceObserver.complete_reset()
MemoryObserver.complete_reset()
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
if not isinstance(sampling_params, list):
@@ -128,8 +139,8 @@ class LLMEngine:
pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s",
"ttft": f"{float(Observer.ttft) / 1e6}ms",
"tpot": f"{float(Observer.tpot) / 1e6}ms",
"ttft": f"{float(InferenceObserver.ttft) / 1e6}ms",
"tpot": f"{float(InferenceObserver.tpot) / 1e6}ms",
})
for seq_id, token_ids in output:
outputs[seq_id] = token_ids

View File

@@ -10,6 +10,7 @@ from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence
from nanovllm.models import get_model_class
from nanovllm.layers.sampler import GreedySampler
from nanovllm.layers.graphed_layers import OffloadGraphManager
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model
from nanovllm.utils.logger import get_logger
@@ -29,6 +30,18 @@ def _find_free_port() -> int:
return s.getsockname()[1]
def get_num_kv_heads(hf_config) -> int:
"""Get number of KV heads from config (handles GLM-4's multi_query_group_num)."""
return getattr(hf_config, 'num_key_value_heads',
getattr(hf_config, 'multi_query_group_num', hf_config.num_attention_heads))
def get_head_dim(hf_config) -> int:
"""Get head dimension from config (handles GLM-4's kv_channels)."""
return getattr(hf_config, "head_dim",
getattr(hf_config, "kv_channels", hf_config.hidden_size // hf_config.num_attention_heads))
class ModelRunner:
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
@@ -63,6 +76,12 @@ class ModelRunner:
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
# Initialize offload graph manager if CPU offload is enabled
self.offload_graph_manager = None
if config.enable_cpu_offload and not self.enforce_eager:
self.init_offload_graph_manager()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
@@ -137,8 +156,8 @@ class ModelRunner:
used = total - free
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
head_dim = get_head_dim(hf_config)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
# Calculate max GPU blocks based on available memory
@@ -195,19 +214,37 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Initialize sparse policy if manager has one (CPU offload mode)
# Initialize sparse policy if manager has one (works for both CPU offload and GPU-only modes)
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
# Use CPU blocks for offload mode, GPU blocks for GPU-only mode
num_blocks_for_init = config.num_cpu_kvcache_blocks if config.enable_cpu_offload else config.num_kvcache_blocks
self.kvcache_manager.sparse_policy.initialize(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks,
num_cpu_blocks=num_blocks_for_init,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
# Pre-allocate policy metadata buffers
# - Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
# - GPU-only mode: additionally allocate GQA expansion buffers
num_heads = hf_config.num_attention_heads // self.world_size
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_seq_len=config.max_model_len,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
enable_cpu_offload=config.enable_cpu_offload,
)
# Log policy info (handle both enum and None cases)
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
logger.info(
f"Sparse policy initialized: {config.sparse_policy.name} "
f"Sparse policy initialized: {policy_name} "
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
)
@@ -368,7 +405,16 @@ class ModelRunner:
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
slot_mapping=slot_mapping,
block_tables=block_tables,
kvcache_manager=getattr(self, 'kvcache_manager', None),
)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
@@ -397,7 +443,13 @@ class ModelRunner:
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Use GPU physical block tables for attention
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
kvcache_manager=self.kvcache_manager,
)
return input_ids, positions
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
@@ -536,6 +588,13 @@ class ModelRunner:
break
#> Run model forward
# Use graph-optimized forward if available (chunk_size == block_size), otherwise eager mode
if (hasattr(self, 'prefill_graph_manager') and
self.prefill_graph_manager is not None and
self.prefill_graph_manager.captured and
input_ids.shape[0] == self.block_size):
logits = self.run_prefill_with_offload_graph(input_ids, positions)
else:
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
@@ -657,6 +716,7 @@ class ModelRunner:
)
# Run model forward pass
# TODO: Phase 5 decode graph needs shape fix, use eager mode for now
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
@@ -698,7 +758,13 @@ class ModelRunner:
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
set_context(
is_prefill=False,
slot_mapping=slot_mapping[:bs],
context_lens=context_lens[:bs],
block_tables=block_tables[:bs],
kvcache_manager=self.kvcache_manager,
)
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
@@ -716,3 +782,151 @@ class ModelRunner:
block_tables=block_tables,
outputs=outputs,
)
@torch.inference_mode()
def init_offload_graph_manager(self):
"""
Initialize and capture CUDA Graphs for offload path (Prefill + Decode).
Phase 5 Design:
- Creates N+2 graphs for both Prefill and Decode
- Decode graphs: seq_len=1
- Prefill graphs: seq_len=chunk_size (block_size)
Graph structure per mode:
- EmbedGraph: embed_tokens
- FirstGraph: input_norm → qkv_proj → rotary
- InterGraph[i]: o_proj → post_norm → mlp → input_norm → qkv_proj → rotary (N-1 graphs)
- LastGraph: o_proj → post_norm → mlp → final_norm
"""
hf_config = self.config.hf_config
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
head_dim = get_head_dim(hf_config)
# Create Decode Graph Manager (seq_len=1)
self.decode_graph_manager = OffloadGraphManager(
model=self.model,
seq_len=1,
hidden_size=hf_config.hidden_size,
num_heads=hf_config.num_attention_heads // self.world_size,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=hf_config.torch_dtype,
)
self.decode_graph_manager.capture_all()
# Create Prefill Graph Manager (seq_len=chunk_size)
chunk_size = self.block_size # chunk_size = block_size = 1024
self.prefill_graph_manager = OffloadGraphManager(
model=self.model,
seq_len=chunk_size,
hidden_size=hf_config.hidden_size,
num_heads=hf_config.num_attention_heads // self.world_size,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=hf_config.torch_dtype,
)
self.prefill_graph_manager.capture_all()
# Legacy compatibility (for backward compatibility)
self.offload_graph_manager = self.decode_graph_manager
logger.info(
f"Offload CUDA Graphs captured: {self.decode_graph_manager.num_graphs} decode graphs + "
f"{self.prefill_graph_manager.num_graphs} prefill graphs "
f"({self.decode_graph_manager.num_layers} layers)"
)
@torch.inference_mode()
def run_model_with_offload_graph(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
"""
Run decode with Phase 5 CUDA Graph optimization.
Graph coverage (~70-80% of computation):
- GRAPH_EMBED: embed_tokens
- GRAPH_FIRST: input_norm_0 → qkv_proj_0 → rotary_0
- GRAPH_INTER_i: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
- GRAPH_LAST: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
EAGER (only attention core with offload):
- attn.forward(q, k, v) for each layer
"""
gm = self.decode_graph_manager
layers = self.model.model.layers
num_layers = len(layers)
use_graph = input_ids.shape[0] == 1 # Only use graph for batch=1
# GRAPH_EMBED: embed_tokens
hidden_states = gm.embed_graph(input_ids, use_graph=use_graph)
# GRAPH_FIRST: input_norm_0 → qkv_proj_0 → rotary_0
q, k, v, residual = gm.first_graph(hidden_states, positions, use_graph=use_graph)
for i in range(num_layers):
# EAGER: Attention core only (with offload)
# Note: attn.forward already handles store_kvcache internally
attn_output = layers[i].self_attn.attn(q, k, v)
# attn.forward returns [batch, 1, num_heads, head_dim] for decode
# graph expects [seq_len, num_heads, head_dim], so squeeze to [1, heads, dim]
if attn_output.dim() == 4:
attn_output = attn_output.squeeze(0).squeeze(0).unsqueeze(0)
if i < num_layers - 1:
# GRAPH_INTER_i: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
q, k, v, residual = gm.inter_graphs[i](
attn_output, residual, positions, use_graph=use_graph
)
else:
# GRAPH_LAST: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
hidden_states = gm.last_graph(attn_output, residual, use_graph=use_graph)
return self.model.compute_logits(hidden_states)
@torch.inference_mode()
def run_prefill_with_offload_graph(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
"""
Run chunked prefill with Phase 5 CUDA Graph optimization.
Graph coverage (~70-80% of computation):
- GRAPH_EMBED: embed_tokens
- GRAPH_FIRST: input_norm_0 → qkv_proj_0 → rotary_0
- GRAPH_INTER_i: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
- GRAPH_LAST: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
EAGER (only attention core with offload):
- attn.forward(q, k, v) for each layer
"""
gm = self.prefill_graph_manager
layers = self.model.model.layers
num_layers = len(layers)
use_graph = input_ids.shape[0] == self.block_size # Only use graph for chunk_size
# GRAPH_EMBED: embed_tokens
hidden_states = gm.embed_graph(input_ids, use_graph=use_graph)
# GRAPH_FIRST: input_norm_0 → qkv_proj_0 → rotary_0
q, k, v, residual = gm.first_graph(hidden_states, positions, use_graph=use_graph)
for i in range(num_layers):
# EAGER: Attention core only (with offload)
# Note: attn.forward already handles store_kvcache internally
attn_output = layers[i].self_attn.attn(q, k, v)
if i < num_layers - 1:
# GRAPH_INTER_i: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
q, k, v, residual = gm.inter_graphs[i](
attn_output, residual, positions, use_graph=use_graph
)
else:
# GRAPH_LAST: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
hidden_states = gm.last_graph(attn_output, residual, use_graph=use_graph)
return self.model.compute_logits(hidden_states)

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.utils.observer import Observer
from nanovllm.utils.observer import InferenceObserver
if TYPE_CHECKING:
from nanovllm.kvcache import KVCacheManager
@@ -15,7 +15,9 @@ class Scheduler:
def __init__(self, config: Config, kvcache_manager: "KVCacheManager"):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos
# Convert EOS to set for efficient lookup (supports single int or list)
eos = config.eos
self.eos_set = set(eos) if isinstance(eos, list) else {eos}
self.kvcache_manager = kvcache_manager
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
@@ -32,8 +34,8 @@ class Scheduler:
num_seqs = 0
num_batched_tokens = 0
while self.waiting and num_seqs < self.max_num_seqs:
if Observer.ttft_start == 0:
Observer.ttft_start = perf_counter_ns()
if InferenceObserver.ttft_start == 0:
InferenceObserver.ttft_start = perf_counter_ns()
seq = self.waiting[0]
# Check if sequence is too large
@@ -94,7 +96,7 @@ class Scheduler:
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
if (not seq.ignore_eos and token_id in self.eos_set) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.kvcache_manager.deallocate(seq)
self.running.remove(seq)

View File

@@ -25,7 +25,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
Factory function to create the appropriate KV cache manager.
Decision logic:
1. If enable_cpu_offload=False: use GPUOnlyManager
1. If enable_cpu_offload=False: use GPUOnlyManager (optionally with sparse policy)
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
@@ -37,9 +37,44 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
"""
if not getattr(config, 'enable_cpu_offload', False):
# Default: pure GPU mode
# Check if sparse policy is requested for GPU-only mode
from nanovllm.config import SparsePolicyType
sparse_policy_type = getattr(config, 'sparse_policy', None)
# Handle None case - use FULL as default
if sparse_policy_type is None:
sparse_policy_type = SparsePolicyType.FULL
sparse_policy = None
if sparse_policy_type != SparsePolicyType.FULL:
# Create sparse policy for GPU-only mode
from nanovllm.kvcache.sparse import create_sparse_policy
policy_kwargs = {}
if sparse_policy_type == SparsePolicyType.QUEST:
policy_kwargs = {
'topk_blocks': getattr(config, 'sparse_topk_blocks', 8),
'threshold_blocks': getattr(config, 'sparse_threshold_blocks', 4),
}
elif sparse_policy_type == SparsePolicyType.XATTN_BSA:
policy_kwargs = {
'block_size': getattr(config, 'sparse_block_size', 128),
'samples_per_chunk': getattr(config, 'sparse_samples_per_chunk', 128),
'threshold': getattr(config, 'sparse_threshold', 0.9),
'use_triton': getattr(config, 'sparse_use_triton', True),
'stride': getattr(config, 'sparse_stride', 8),
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
}
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
else:
# FULL policy for GPU-only mode - always create for consistent API
from nanovllm.kvcache.sparse import FullAttentionPolicy
sparse_policy = FullAttentionPolicy()
return GPUOnlyManager(
num_blocks=config.num_kvcache_blocks,
block_size=config.kvcache_block_size,
sparse_policy=sparse_policy,
)
# CPU offload is enabled
@@ -79,6 +114,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
'threshold': getattr(config, 'sparse_threshold', 0.9),
'use_triton': getattr(config, 'sparse_use_triton', True),
'stride': getattr(config, 'sparse_stride', 8),
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
}
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)

View File

@@ -7,13 +7,16 @@ the KVCacheManager interface.
"""
from collections import deque
from typing import List, Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional, TYPE_CHECKING
import torch
from torch import Tensor
from nanovllm.engine.sequence import Sequence
from nanovllm.kvcache.base_manager import KVCacheManager
if TYPE_CHECKING:
from nanovllm.kvcache.sparse.policy import SparsePolicy
class Block:
"""Physical block in GPU memory."""
@@ -50,17 +53,28 @@ class GPUOnlyManager(KVCacheManager):
all data stays on GPU at fixed addresses.
"""
def __init__(self, num_blocks: int, block_size: int):
def __init__(
self,
num_blocks: int,
block_size: int,
sparse_policy: Optional["SparsePolicy"] = None,
):
"""
Initialize GPU-only manager.
Args:
num_blocks: Total number of blocks to manage
block_size: Tokens per block (default 256)
sparse_policy: Optional sparse attention policy for GPU-only mode
"""
self._block_size = block_size
self._num_blocks = num_blocks
# Sparse policy for GPU-only mode (optional)
self.sparse_policy = sparse_policy
# No offload engine in GPU-only mode
self.offload_engine = None
# Block metadata
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]

View File

@@ -9,6 +9,7 @@ Key design principles for CUDA Graph compatibility:
import torch
import torch.cuda.nvtx
import nvtx
from torch import Tensor
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
@@ -16,6 +17,7 @@ from dataclasses import dataclass
from nanovllm.kvcache.kernels import gathered_copy_kv
from nanovllm.comm import memcpy_2d_async
from nanovllm.utils.logger import get_logger
from nanovllm.utils.memory_observer import MemoryObserver
# Import for type hints only (avoid circular import)
from typing import TYPE_CHECKING
@@ -374,7 +376,10 @@ class OffloadEngine:
"""
self.ring_slot_compute_done[slot_idx].record()
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
def load_to_slot_layer(
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1,
is_prefill: bool = True,
) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
@@ -389,13 +394,21 @@ class OffloadEngine:
slot_idx: Target GPU slot index
layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
is_prefill: True if in prefill phase, False if in decode phase (for MemoryObserver)
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
# Use per-slot stream for parallel transfers across different slots
stream = self.slot_transfer_streams[slot_idx]
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
# Build NVTX label with optional chunk info
if chunk_idx >= 0:
nvtx_label = f"H2D: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
else:
nvtx_label = f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
nvtx.push_range(message=nvtx_label, color="blue")
with torch.cuda.stream(stream):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
@@ -413,7 +426,66 @@ class OffloadEngine:
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx].record(stream)
torch.cuda.nvtx.range_pop()
nvtx.pop_range()
# Record H2D transfer: K + V = 2 * block_bytes
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=is_prefill)
def load_k_only_to_slot_layer(
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1,
is_prefill: bool = True,
) -> None:
"""
Async load only K (not V) from CPU block to GPU slot.
Used by XAttention estimate phase which only needs K for attention score
computation. Saves 50% communication compared to loading K+V.
Args:
slot_idx: Target GPU slot index
layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
is_prefill: True if in prefill phase, False if in decode phase
"""
logger.debug(f"Ring load K-only: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
stream = self.slot_transfer_streams[slot_idx]
if chunk_idx >= 0:
nvtx_label = f"H2D K-only: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
else:
nvtx_label = f"H2D K-only: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
nvtx.push_range(message=nvtx_label, color="cyan")
with torch.cuda.stream(stream):
stream.wait_event(self.ring_slot_compute_done[slot_idx])
stream.wait_event(self.ring_slot_offload_done[slot_idx])
# Only copy K, not V
self.k_cache_gpu[slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx].record(stream)
nvtx.pop_range()
# Record H2D transfer: K only = 1 * block_bytes
MemoryObserver.record_h2d(self.gpu_block_bytes, is_prefill=is_prefill)
def get_k_for_slot(self, slot_idx: int) -> Tensor:
"""
Get only K for a ring buffer slot (no V).
Used by XAttention estimate phase which only needs K for attention
score computation.
Args:
slot_idx: GPU slot index
Returns:
k_cache, shape: [1, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu[slot_idx].unsqueeze(0)
def wait_slot_layer(self, slot_idx: int) -> None:
"""
@@ -470,7 +542,8 @@ class OffloadEngine:
else:
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
nvtx_label = f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]"
nvtx.push_range(message=nvtx_label, color="green")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations
@@ -486,7 +559,10 @@ class OffloadEngine:
self.v_cache_gpu[slot_idx], non_blocking=True
)
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
nvtx.pop_range()
# Record D2H transfer: K + V = 2 * block_bytes
MemoryObserver.record_d2h(2 * self.gpu_block_bytes, is_prefill=is_prefill)
# ----- KV access methods for ring buffer -----
@@ -702,6 +778,69 @@ class OffloadEngine:
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
return k, v
def write_to_prefill_buffer(
self,
layer_id: int,
k: Tensor,
v: Tensor,
chunk_idx: int = -1,
) -> None:
"""
Write KV tensors to prefill buffer (D2D copy within GPU).
This is called during chunked prefill to store current chunk's KV
before computing attention.
Args:
layer_id: Layer index
k: Key tensor [num_tokens, kv_heads, head_dim]
v: Value tensor [num_tokens, kv_heads, head_dim]
chunk_idx: Current chunk index for NVTX labeling (-1 = not specified)
"""
num_tokens = k.shape[0]
# Build NVTX label
if chunk_idx >= 0:
nvtx_label = f"D2D: L{layer_id} Chunk{chunk_idx} WritePrefillBuffer"
else:
nvtx_label = f"D2D: L{layer_id} WritePrefillBuffer"
torch.cuda.nvtx.range_push(nvtx_label)
self.prefill_k_buffer[layer_id, :num_tokens].copy_(k)
self.prefill_v_buffer[layer_id, :num_tokens].copy_(v)
torch.cuda.nvtx.range_pop()
# Record D2D transfer: K + V
transfer_bytes = 2 * k.numel() * k.element_size()
MemoryObserver.record_d2d(transfer_bytes)
def write_to_decode_buffer(
self,
layer_id: int,
pos_in_block: int,
k: Tensor,
v: Tensor,
) -> None:
"""
Write KV tensors to decode buffer (D2D copy within GPU).
This is called during chunked decode to store current decode token's KV.
Args:
layer_id: Layer index
pos_in_block: Position within the current block
k: Key tensor [kv_heads, head_dim] (single token, squeezed)
v: Value tensor [kv_heads, head_dim] (single token, squeezed)
"""
torch.cuda.nvtx.range_push(f"D2D: L{layer_id} Pos{pos_in_block} WriteDecodeBuffer")
self.decode_k_buffer[layer_id, pos_in_block].copy_(k)
self.decode_v_buffer[layer_id, pos_in_block].copy_(v)
torch.cuda.nvtx.range_pop()
# Record D2D transfer: K + V (single token)
transfer_bytes = 2 * k.numel() * k.element_size()
MemoryObserver.record_d2d(transfer_bytes)
def offload_prefill_buffer_async(
self,
layer_id: int,
@@ -729,7 +868,8 @@ class OffloadEngine:
# Use per-layer stream for parallel offloads
stream = self.prefill_offload_streams[layer_id]
torch.cuda.nvtx.range_push(f"AsyncPrefillOffload: L{layer_id}->CPU[{cpu_block_id}]")
nvtx_label = f"D2H: PrefillBuffer L{layer_id}->CPU[{cpu_block_id}]"
nvtx.push_range(message=nvtx_label, color="orange")
with torch.cuda.stream(stream):
# Wait for compute to finish writing to prefill buffer
stream.wait_stream(self.compute_stream)
@@ -744,7 +884,10 @@ class OffloadEngine:
# Record completion event
self.prefill_offload_events[layer_id].record(stream)
torch.cuda.nvtx.range_pop()
nvtx.pop_range()
# Record D2H transfer: K + V = 2 * block_bytes
MemoryObserver.record_d2h(2 * self.gpu_block_bytes, is_prefill=True)
def wait_all_prefill_offloads(self) -> None:
"""Wait for all prefill buffer offloads to complete."""
@@ -784,6 +927,11 @@ class OffloadEngine:
v_sample = self.v_cache_cpu[
layer_id, cpu_block_id, :num_samples
].clone().cuda()
# Record H2D transfer: K + V samples
transfer_bytes = 2 * k_sample.numel() * k_sample.element_size()
MemoryObserver.record_h2d(transfer_bytes, is_prefill=True)
return k_sample, v_sample
def load_block_full_from_cpu(
@@ -810,4 +958,8 @@ class OffloadEngine:
v_full = self.v_cache_cpu[
layer_id, cpu_block_id
].clone().cuda()
# Record H2D transfer: K + V full block
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=True)
return k_full, v_full

View File

@@ -61,6 +61,9 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
block_size=kwargs.get("block_size", 128),
samples_per_chunk=kwargs.get("samples_per_chunk", 128),
threshold=kwargs.get("threshold", 0.9),
stride=kwargs.get("stride", 8),
chunk_size=kwargs.get("chunk_size", 16384),
use_triton=kwargs.get("use_triton", True),
)
else:

View File

@@ -37,15 +37,116 @@ class FullAttentionPolicy(SparsePolicy):
supports_prefill = True
supports_decode = True
def __init__(self):
"""Initialize with statistics tracking."""
self._stats_total_blocks = 0
self._stats_num_chunks = 0
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""Return all blocks - no sparsity."""
# Update statistics (only for layer 0 to avoid overcounting)
if ctx.layer_id == 0 and available_blocks:
self._stats_total_blocks += len(available_blocks)
self._stats_num_chunks += 1
logger.debug(f"[Full] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, density=100.0%")
return available_blocks
def reset_stats(self) -> None:
"""Reset density statistics."""
self._stats_total_blocks = 0
self._stats_num_chunks = 0
def get_density_stats(self) -> dict:
"""Get density statistics."""
return {
"total_available_blocks": self._stats_total_blocks,
"total_selected_blocks": self._stats_total_blocks, # Full = all selected
"num_chunks": self._stats_num_chunks,
"overall_density": 1.0, # Always 100%
}
def print_density_stats(self) -> None:
"""Print density statistics summary."""
stats = self.get_density_stats()
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
f"blocks={stats['total_available_blocks']}, density=100.0%")
# =========================================================================
# GPU-only methods (non-chunked)
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
GPU-only prefill attention using flash_attn_varlen_func.
This is the simplest implementation - just call flash attention directly.
For sparse policies, this method would implement block selection.
"""
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
block_table=block_tables,
)
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
GPU-only decode attention using flash_attn_with_kvcache.
This is the simplest implementation - just call flash attention directly.
For sparse policies, this method would implement block selection.
"""
from flash_attn import flash_attn_with_kvcache
# q is [batch, num_heads, head_dim], need to add seq dim
return flash_attn_with_kvcache(
q.unsqueeze(1), # [batch, 1, heads, dim]
k_cache,
v_cache,
cache_seqlens=cache_seqlens,
block_table=block_tables,
softmax_scale=softmax_scale,
causal=True,
)
# =========================================================================
# Chunked offload methods
# =========================================================================
def compute_chunked_prefill(
self,
q: torch.Tensor,
@@ -58,16 +159,17 @@ class FullAttentionPolicy(SparsePolicy):
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute full attention for chunked prefill.
This method handles the complete chunked prefill flow:
1. Get historical blocks
2. Select blocks via select_blocks
3. Load and compute attention to historical chunks
4. Compute attention to current chunk
5. Merge all results
This method handles the chunked prefill computation:
1. Load and compute attention to historical chunks (using selected_blocks)
2. Compute attention to current chunk
3. Merge all results
Note: Block selection is done by the caller before invoking this method.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
@@ -80,37 +182,28 @@ class FullAttentionPolicy(SparsePolicy):
current_chunk_idx: Current chunk index
seq: Sequence object
num_tokens: Number of tokens in current chunk
selected_blocks: List of CPU block IDs to process (already filtered)
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# Use FlashInfer-based implementations (more optimized)
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
merge_attention_outputs_flashinfer as merge_attention_outputs,
)
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
f"selected_blocks={len(selected_blocks)}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Step 1: Get historical blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots))
@@ -121,7 +214,8 @@ class FullAttentionPolicy(SparsePolicy):
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
# cpu_block_id is the chunk index (block N = chunk N)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
@@ -141,7 +235,8 @@ class FullAttentionPolicy(SparsePolicy):
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
@@ -168,7 +263,7 @@ class FullAttentionPolicy(SparsePolicy):
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Step 4: Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):
@@ -200,16 +295,17 @@ class FullAttentionPolicy(SparsePolicy):
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute full attention for chunked decode.
This method handles the complete chunked decode flow:
1. Get prefilled CPU blocks
2. Apply select_blocks for block filtering
3. Load blocks via pipeline (ring buffer or cross-layer)
4. Read accumulated decode tokens from decode buffer
5. Merge all results
This method handles the chunked decode computation:
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
2. Read accumulated decode tokens from decode buffer
3. Merge all results
Note: Block selection is done by the caller before invoking this method.
Args:
q: Query tensor [batch_size, num_heads, head_dim]
@@ -218,49 +314,49 @@ class FullAttentionPolicy(SparsePolicy):
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
seq: Sequence object
selected_blocks: List of CPU block IDs to process (already filtered)
Returns:
Attention output [batch_size, 1, num_heads, head_dim]
"""
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# Use FlashInfer-based implementations (more optimized)
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
merge_attention_outputs_flashinfer as merge_attention_outputs,
)
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Get only PREFILLED CPU blocks (exclude the current decode block)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
if layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
logger.debug(f"Decode attention: selected_blocks={len(selected_blocks)}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last CPU block
# CRITICAL: Use original prefill length, not current seq length!
# CPU blocks are fixed after prefill, their content doesn't change during decode.
# Note: We need to get all prefilled blocks to determine last_block_valid_tokens
block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table)
all_prefilled_blocks = kvcache_manager.get_prefilled_cpu_blocks(seq)
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy (self) for block filtering
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=layer_id,
query=q_batched,
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
# Determine if selected_blocks contains the last prefilled block
# If not, all selected blocks are full blocks (use block_size as valid tokens)
last_prefilled_block = all_prefilled_blocks[-1] if all_prefilled_blocks else None
selected_contains_last = (cpu_block_table and cpu_block_table[-1] == last_prefilled_block)
effective_last_block_tokens = last_block_valid_tokens if selected_contains_last else block_size
# Use ring buffer pipeline for loading prefilled blocks
load_slots = offload_engine.decode_load_slots
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens, layer_id, softmax_scale
block_size, effective_last_block_tokens, layer_id, softmax_scale
)
# Now attend to accumulated decode tokens from per-layer decode buffer
@@ -319,7 +415,11 @@ class FullAttentionPolicy(SparsePolicy):
Loads one block at a time, computes attention, and merges results.
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
"""
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# Use FlashInfer-based implementations (more optimized)
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
merge_attention_outputs_flashinfer as merge_attention_outputs,
)
num_blocks = len(cpu_block_table)
if num_blocks == 0:
@@ -335,7 +435,8 @@ class FullAttentionPolicy(SparsePolicy):
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id, is_prefill=False)
# Phase 2: Process blocks with pipeline
for block_idx in range(num_blocks):
@@ -368,7 +469,8 @@ class FullAttentionPolicy(SparsePolicy):
# Start loading next block (pipeline)
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(current_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id, is_prefill=False)
# Merge with accumulated
with torch.cuda.stream(compute_stream):

View File

@@ -108,12 +108,45 @@ class SparsePolicy(ABC):
"""
pass
def alloc_policy_metadata(
self,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
enable_cpu_offload: bool = False,
) -> None:
"""
Pre-allocate GPU buffers for policy computation.
Called by the framework after KV cache allocation. Implementations should
use enable_cpu_offload to decide which buffers to allocate:
- Offload mode: allocate chunked prefill buffers (mask, KV chunking stats)
- GPU-only mode: additionally allocate GQA expansion buffers
This is separate from initialize() which is used for CPU offload metadata.
Args:
num_heads: Number of query heads
num_kv_heads: Number of KV heads (for GQA)
head_dim: Dimension per head
max_seq_len: Maximum sequence length (for buffer sizing)
dtype: Data type (typically float16/bfloat16)
device: Target device (cuda)
enable_cpu_offload: Whether CPU offload is enabled
"""
pass
@abstractmethod
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""
Select which KV blocks to load for the current query chunk.
@@ -130,6 +163,8 @@ class SparsePolicy(ABC):
to load KV to make selection decisions).
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
Returns:
List of block IDs to load (must be a subset of available_blocks).
@@ -191,6 +226,87 @@ class SparsePolicy(ABC):
"""
pass
# =========================================================================
# GPU-only methods (non-chunked)
# These methods are used when all KV cache is on GPU, no CPU offload needed.
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute GPU-only prefill attention (non-chunked).
This method is used when all KV cache resides on GPU (no CPU offload).
Override this to implement sparse prefill attention for GPU-only mode.
Default implementation raises NotImplementedError.
Args:
q: [total_q, num_heads, head_dim] query tensor (packed variable length)
k: [total_kv, num_kv_heads, head_dim] key tensor
v: [total_kv, num_kv_heads, head_dim] value tensor
cu_seqlens_q: [batch+1] cumulative sequence lengths for queries
cu_seqlens_k: [batch+1] cumulative sequence lengths for keys
max_seqlen_q: maximum query sequence length
max_seqlen_k: maximum key sequence length
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
layer_id: transformer layer index
block_tables: [batch, max_blocks] paged attention block tables (optional)
Returns:
[total_q, num_heads, head_dim] attention output
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement compute_prefill for GPU-only mode"
)
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute GPU-only decode attention (non-chunked).
This method is used when all KV cache resides on GPU (no CPU offload).
Override this to implement sparse decode attention for GPU-only mode.
Default implementation raises NotImplementedError.
Args:
q: [batch, num_heads, head_dim] query tensor (single token per sequence)
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged key cache
v_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged value cache
cache_seqlens: [batch] sequence lengths in cache
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
layer_id: transformer layer index
block_tables: [batch, max_blocks] paged attention block tables (optional)
Returns:
[batch, 1, num_heads, head_dim] attention output
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement compute_decode for GPU-only mode"
)
# =========================================================================
# Chunked offload methods (for CPU offload mode)
# =========================================================================
@abstractmethod
def compute_chunked_prefill(
self,
@@ -204,17 +320,20 @@ class SparsePolicy(ABC):
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute chunked prefill attention (complete flow).
This is the main entry point for prefill attention computation.
It defines the complete prefill flow:
1. Get historical blocks
2. Select blocks (call select_blocks)
3. Load and compute historical blocks via offload_engine
4. Get current chunk KV from offload_engine, compute attention
5. Merge all results
1. Load and compute historical blocks via offload_engine (using selected_blocks)
2. Get current chunk KV from offload_engine, compute attention
3. Merge all results
Note: Block selection (select_blocks) is called by the caller (attention.py)
before invoking this method. The selected_blocks parameter contains the
filtered block IDs to process.
Args:
q: [seq_len, num_heads, head_dim] query for current chunk
@@ -227,6 +346,7 @@ class SparsePolicy(ABC):
current_chunk_idx: current chunk index
seq: Sequence object
num_tokens: number of tokens in current chunk
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
Returns:
[seq_len, num_heads, head_dim] final attention output
@@ -242,17 +362,20 @@ class SparsePolicy(ABC):
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute chunked decode attention (complete flow).
This is the main entry point for decode attention computation.
It defines the complete decode flow:
1. Get prefilled blocks from CPU
2. Select blocks (call select_blocks)
3. Load blocks via pipeline (ring buffer or cross-layer)
4. Read accumulated decode tokens from decode buffer
5. Merge all results
1. Load blocks via pipeline using selected_blocks (ring buffer or cross-layer)
2. Read accumulated decode tokens from decode buffer
3. Merge all results
Note: Block selection (select_blocks) is called by the caller (attention.py)
before invoking this method. The selected_blocks parameter contains the
filtered block IDs to process.
The decode position information can be computed internally:
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
@@ -265,6 +388,7 @@ class SparsePolicy(ABC):
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
seq: Sequence object
selected_blocks: list of CPU block IDs to process (already filtered by select_blocks)
Returns:
[batch_size, 1, num_heads, head_dim] final attention output

View File

@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""
Select Top-K blocks based on query-key similarity bounds.
If query is not available (some prefill scenarios), falls back
to loading all blocks.
Args:
available_blocks: List of CPU block IDs
offload_engine: OffloadEngine for loading KV (unused in Quest)
ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
Returns:
Selected block IDs
"""
if self.metadata is None:
raise RuntimeError(
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
if n <= self.config.threshold_blocks:
return available_blocks
if ctx.query is None:
if q is None:
# No query available - cannot compute scores
return available_blocks
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
)
# Metadata is already on GPU, same device as query
device = ctx.query.device
device = q.device
# Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
q = ctx.query
# query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
if q.dim() == 4:
# Prefill: use mean over sequence length
q = q.mean(dim=1) # [1, num_heads, head_dim]

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@ from torch import nn
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__)
@@ -103,50 +104,67 @@ class Attention(nn.Module):
# This enables fully async offloads since each layer has its own buffer.
offload_engine = context.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
chunk_idx = context.current_chunk_idx if hasattr(context, 'current_chunk_idx') else -1
# Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
# Write KV to per-layer prefill buffer via offload_engine
# k, v shape: [num_tokens, kv_heads, head_dim]
num_tokens = k.shape[0]
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
#! GPU 2 GPU
offload_engine.write_to_prefill_buffer(self.layer_id, k, v, chunk_idx=chunk_idx)
elif is_chunked_offload:
# Chunked decode mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
# slot_mapping is created with non_blocking=True on default stream, but we use it
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Chunked decode mode: write KV to per-layer decode buffer via offload_engine
# KV will be written to decode buffer in the decode branch below
# No store_kvcache needed - all KV management goes through offload_engine
pass
else:
# Normal mode: store on default stream
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Get sparse_policy from kvcache_manager (required, never None after warmup)
# During warmup, kvcache_manager is not yet allocated
if context.kvcache_manager is None:
# Warmup phase: use flash_attn directly
if context.is_prefill:
return flash_attn_varlen_func(
q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True,
)
else:
return flash_attn_with_kvcache(
q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True,
)
sparse_policy = context.kvcache_manager.sparse_policy
assert sparse_policy is not None, "sparse_policy must not be None"
if context.is_prefill:
if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV
# Chunked prefill: merge attention from previous KV (CPU offload mode)
o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else:
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
# GPU-only mode: use policy for attention
# Use paged attention if block_tables provided, else use k, v directly
if context.block_tables is not None:
k_for_attn, v_for_attn = k_cache, v_cache
else:
k_for_attn, v_for_attn = k, v
o = sparse_policy.compute_prefill(
q, k_for_attn, v_for_attn,
context.cu_seqlens_q, context.cu_seqlens_k,
context.max_seqlen_q, context.max_seqlen_k,
self.scale, self.layer_id,
context.block_tables,
)
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
# Store current decode token to per-layer decode buffer
# This is needed because GPU cache has no layer dimension,
# so all layers would overwrite each other in decode_slot.
@@ -154,13 +172,15 @@ class Attention(nn.Module):
offload_engine = kvcache_manager.offload_engine
pos_in_block = context.decode_pos_in_block
# k, v shape: [1, kv_heads, head_dim]
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
# GPU-only mode: use policy for attention
o = sparse_policy.compute_decode(
q, k_cache, v_cache,
context.context_lens, self.scale, self.layer_id,
context.block_tables,
)
return o
def _chunked_prefill_attention(
@@ -197,11 +217,29 @@ class Attention(nn.Module):
if sparse_policy is None:
raise RuntimeError("sparse_policy is required for chunked prefill")
# Step 1: Get historical CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
# Always call select_blocks even for first chunk (cpu_block_table may be empty)
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size if cpu_block_table else 0,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
# Delegate all computation to policy (no flash_attn or merge calls here!)
# Delegate computation to policy with pre-selected blocks
final_o = sparse_policy.compute_chunked_prefill(
q, k, v,
self.layer_id,
@@ -211,6 +249,7 @@ class Attention(nn.Module):
current_chunk_idx,
seq,
num_tokens,
selected_blocks,
)
torch.cuda.nvtx.range_pop() # ChunkedPrefill
@@ -258,14 +297,36 @@ class Attention(nn.Module):
raise RuntimeError("sparse_policy is required for chunked decode")
# Check if policy supports decode phase
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
if not sparse_policy.supports_decode:
raise RuntimeError(f"{sparse_policy} does not support decode phase")
from nanovllm.kvcache.sparse import FullAttentionPolicy
sparse_policy = FullAttentionPolicy()
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
f"falling back to FullAttentionPolicy")
# Step 1: Get prefilled CPU blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
selected_blocks = []
if cpu_block_table:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q, # Pass query for sparse policies that need it
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
f"policy={sparse_policy}, layer={self.layer_id}")
# Delegate all computation to policy (no flash_attn or merge calls here!)
# Delegate computation to policy with pre-selected blocks
return sparse_policy.compute_chunked_decode(
q,
self.layer_id,
@@ -273,4 +334,5 @@ class Attention(nn.Module):
offload_engine,
kvcache_manager,
seq,
selected_blocks,
)

View File

@@ -0,0 +1,572 @@
"""
CUDA Graph wrapped layers for offload optimization.
This module provides Graph-wrapped versions of non-attention layers
to reduce kernel launch overhead in CPU offload path.
Phase 5 Design:
- Supports both Prefill (seq_len=chunk_size) and Decode (seq_len=1)
- Extended coverage: embed, input_norm, qkv_proj, rotary, o_proj, post_norm, mlp, final_norm
- Only attention core (attn.forward) remains in eager mode
Graph Structure (N layers):
- EmbedGraph: embed_tokens
- FirstGraph: input_norm → qkv_proj → rotary
- InterGraph[i]: o_proj → post_norm → mlp → input_norm → qkv_proj → rotary (N-1 graphs)
- LastGraph: o_proj → post_norm → mlp → final_norm
Total: N+2 graphs
"""
import torch
from torch import nn
from typing import Optional, Tuple
class EmbedGraph(nn.Module):
"""
Graph wrapper for embedding layer.
Input: input_ids [seq_len]
Output: hidden_states [seq_len, hidden_size]
"""
def __init__(
self,
embed_tokens: nn.Module,
seq_len: int,
hidden_size: int,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.embed_tokens = embed_tokens
self.seq_len = seq_len
self.hidden_size = hidden_size
self.dtype = dtype
# Graph state
self.graph: Optional[torch.cuda.CUDAGraph] = None
self.ids_in: Optional[torch.Tensor] = None
self.h_out: Optional[torch.Tensor] = None
def _compute(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def capture_graph(self, graph_pool=None):
"""Capture CUDA Graph."""
# Allocate placeholders outside inference_mode
self.ids_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
self.h_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
with torch.inference_mode():
# Warmup
for _ in range(3):
h = self._compute(self.ids_in)
self.h_out.copy_(h)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=graph_pool):
h = self._compute(self.ids_in)
self.h_out.copy_(h)
return self.graph.pool() if graph_pool is None else graph_pool
def forward(self, input_ids: torch.Tensor, use_graph: bool = False) -> torch.Tensor:
if use_graph and self.graph is not None and input_ids.shape[0] == self.seq_len:
self.ids_in.copy_(input_ids)
self.graph.replay()
return self.h_out.clone()
else:
return self._compute(input_ids)
class FirstGraph(nn.Module):
"""
Graph wrapper for first layer pre-attention:
input_norm → qkv_proj → split → reshape → rotary
Input: hidden_states [seq_len, hidden_size], positions [seq_len]
Output: q [seq_len, num_heads, head_dim], k [seq_len, num_kv_heads, head_dim],
v [seq_len, num_kv_heads, head_dim], residual [seq_len, hidden_size]
"""
def __init__(
self,
input_norm: nn.Module,
qkv_proj: nn.Module,
rotary_emb: nn.Module,
# Shape parameters
seq_len: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.input_norm = input_norm
self.qkv_proj = qkv_proj
self.rotary_emb = rotary_emb
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
# Split sizes
self.q_size = num_heads * head_dim
self.kv_size = num_kv_heads * head_dim
# Graph state
self.graph: Optional[torch.cuda.CUDAGraph] = None
def _compute(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
First layer computation:
1. input_layernorm (residual = hidden_states for first layer)
2. QKV projection
3. Split and reshape
4. Rotary embedding
"""
# For first layer, residual = hidden_states (before norm)
residual = hidden_states.clone()
hidden_states = self.input_norm(hidden_states)
# QKV projection
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Reshape
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
# Rotary embedding
q, k = self.rotary_emb(positions, q, k)
return q, k, v, residual
def capture_graph(self, graph_pool=None):
"""Capture CUDA Graph."""
# Allocate placeholders
self.h_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
self.pos_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
self.q_out = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.k_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.v_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.r_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
with torch.inference_mode():
# Warmup
for _ in range(3):
q, k, v, r = self._compute(self.h_in, self.pos_in)
self.q_out.copy_(q)
self.k_out.copy_(k)
self.v_out.copy_(v)
self.r_out.copy_(r)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=graph_pool):
q, k, v, r = self._compute(self.h_in, self.pos_in)
self.q_out.copy_(q)
self.k_out.copy_(k)
self.v_out.copy_(v)
self.r_out.copy_(r)
return self.graph.pool() if graph_pool is None else graph_pool
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
use_graph: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if use_graph and self.graph is not None and hidden_states.shape[0] == self.seq_len:
self.h_in.copy_(hidden_states)
self.pos_in.copy_(positions)
self.graph.replay()
return self.q_out.clone(), self.k_out.clone(), self.v_out.clone(), self.r_out.clone()
else:
return self._compute(hidden_states, positions)
class InterGraph(nn.Module):
"""
Graph wrapper for inter-layer computation:
o_proj → post_norm → mlp → input_norm → qkv_proj → rotary
Merges current layer's post-attention with next layer's pre-attention.
Input: attn_output [seq_len, num_heads, head_dim], residual [seq_len, hidden_size], positions [seq_len]
Output: q [seq_len, num_heads, head_dim], k [seq_len, num_kv_heads, head_dim],
v [seq_len, num_kv_heads, head_dim], residual [seq_len, hidden_size]
"""
def __init__(
self,
# Current layer components
o_proj: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
# Next layer components
next_input_norm: nn.Module,
next_qkv_proj: nn.Module,
next_rotary_emb: nn.Module,
# Shape parameters
seq_len: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
# Current layer
self.o_proj = o_proj
self.post_norm = post_norm
self.mlp = mlp
# Next layer
self.next_input_norm = next_input_norm
self.next_qkv_proj = next_qkv_proj
self.next_rotary_emb = next_rotary_emb
# Shape params
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
# Split sizes
self.q_size = num_heads * head_dim
self.kv_size = num_kv_heads * head_dim
# Graph state
self.graph: Optional[torch.cuda.CUDAGraph] = None
def _compute(
self,
attn_output: torch.Tensor, # [seq_len, num_heads, head_dim]
residual: torch.Tensor, # [seq_len, hidden_size]
positions: torch.Tensor, # [seq_len]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Inter-layer computation:
1. O projection (flatten first)
2. Post-attention layernorm + residual
3. MLP
4. Next layer's input layernorm + residual
5. QKV projection
6. Split and reshape
7. Rotary embedding
"""
# O projection
hidden_states = self.o_proj(attn_output.flatten(1, -1))
# Post-attention of current layer
hidden_states, residual = self.post_norm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
# Pre-attention of next layer
hidden_states, residual = self.next_input_norm(hidden_states, residual)
# QKV projection
qkv = self.next_qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Reshape
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
# Rotary embedding
q, k = self.next_rotary_emb(positions, q, k)
return q, k, v, residual
def capture_graph(self, graph_pool=None):
"""Capture CUDA Graph."""
# Allocate placeholders
self.attn_in = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.r_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
self.pos_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
self.q_out = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.k_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.v_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.r_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
with torch.inference_mode():
# Warmup
for _ in range(3):
q, k, v, r = self._compute(self.attn_in, self.r_in, self.pos_in)
self.q_out.copy_(q)
self.k_out.copy_(k)
self.v_out.copy_(v)
self.r_out.copy_(r)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=graph_pool):
q, k, v, r = self._compute(self.attn_in, self.r_in, self.pos_in)
self.q_out.copy_(q)
self.k_out.copy_(k)
self.v_out.copy_(v)
self.r_out.copy_(r)
return self.graph.pool() if graph_pool is None else graph_pool
def forward(
self,
attn_output: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
use_graph: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if use_graph and self.graph is not None and attn_output.shape[0] == self.seq_len:
self.attn_in.copy_(attn_output)
self.r_in.copy_(residual)
self.pos_in.copy_(positions)
self.graph.replay()
return self.q_out.clone(), self.k_out.clone(), self.v_out.clone(), self.r_out.clone()
else:
return self._compute(attn_output, residual, positions)
class LastGraph(nn.Module):
"""
Graph wrapper for last layer:
o_proj → post_norm → mlp → final_norm
Input: attn_output [seq_len, num_heads, head_dim], residual [seq_len, hidden_size]
Output: hidden_states [seq_len, hidden_size]
"""
def __init__(
self,
o_proj: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
final_norm: nn.Module,
# Shape parameters
seq_len: int,
hidden_size: int,
num_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.o_proj = o_proj
self.post_norm = post_norm
self.mlp = mlp
self.final_norm = final_norm
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.dtype = dtype
# Graph state
self.graph: Optional[torch.cuda.CUDAGraph] = None
def _compute(
self,
attn_output: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
"""
Last layer computation:
1. O projection
2. Post-attention layernorm + residual
3. MLP
4. Final model norm + residual
"""
hidden_states = self.o_proj(attn_output.flatten(1, -1))
hidden_states, residual = self.post_norm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states, _ = self.final_norm(hidden_states, residual)
return hidden_states
def capture_graph(self, graph_pool=None):
"""Capture CUDA Graph."""
# Allocate placeholders
self.attn_in = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.r_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
self.h_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
with torch.inference_mode():
# Warmup
for _ in range(3):
h = self._compute(self.attn_in, self.r_in)
self.h_out.copy_(h)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=graph_pool):
h = self._compute(self.attn_in, self.r_in)
self.h_out.copy_(h)
return self.graph.pool() if graph_pool is None else graph_pool
def forward(
self,
attn_output: torch.Tensor,
residual: torch.Tensor,
use_graph: bool = False,
) -> torch.Tensor:
if use_graph and self.graph is not None and attn_output.shape[0] == self.seq_len:
self.attn_in.copy_(attn_output)
self.r_in.copy_(residual)
self.graph.replay()
return self.h_out.clone()
else:
return self._compute(attn_output, residual)
class OffloadGraphManager:
"""
Manager for all CUDA Graphs in offload path.
Creates and manages N+2 graphs for N-layer model:
- 1 EmbedGraph
- 1 FirstGraph
- N-1 InterGraphs
- 1 LastGraph
Supports both Prefill and Decode modes via seq_len parameter.
"""
def __init__(
self,
model: nn.Module,
seq_len: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
):
"""
Initialize graph manager from model.
Args:
model: The CausalLM model (e.g., LlamaForCausalLM)
seq_len: Sequence length (1 for decode, chunk_size for prefill)
hidden_size: Model hidden dimension
num_heads: Number of attention heads
num_kv_heads: Number of KV heads
head_dim: Head dimension
dtype: Data type for tensors
"""
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
# Access model layers
layers = model.model.layers
num_layers = len(layers)
self.num_layers = num_layers
# Create EmbedGraph
self.embed_graph = EmbedGraph(
embed_tokens=model.model.embed_tokens,
seq_len=seq_len,
hidden_size=hidden_size,
dtype=dtype,
)
# Create FirstGraph: input_norm_0 → qkv_proj_0 → rotary_0
self.first_graph = FirstGraph(
input_norm=layers[0].input_layernorm,
qkv_proj=layers[0].self_attn.qkv_proj,
rotary_emb=layers[0].self_attn.rotary_emb,
seq_len=seq_len,
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
# Create InterGraphs: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
self.inter_graphs = nn.ModuleList()
for i in range(num_layers - 1):
self.inter_graphs.append(InterGraph(
o_proj=layers[i].self_attn.o_proj,
post_norm=layers[i].post_attention_layernorm,
mlp=layers[i].mlp,
next_input_norm=layers[i + 1].input_layernorm,
next_qkv_proj=layers[i + 1].self_attn.qkv_proj,
next_rotary_emb=layers[i + 1].self_attn.rotary_emb,
seq_len=seq_len,
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
))
# Create LastGraph: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
self.last_graph = LastGraph(
o_proj=layers[-1].self_attn.o_proj,
post_norm=layers[-1].post_attention_layernorm,
mlp=layers[-1].mlp,
final_norm=model.model.norm,
seq_len=seq_len,
hidden_size=hidden_size,
num_heads=num_heads,
head_dim=head_dim,
dtype=dtype,
)
self.captured = False
self.graph_pool = None
def capture_all(self):
"""Capture all graphs, sharing memory pool."""
graph_pool = None
# Capture embed graph
graph_pool = self.embed_graph.capture_graph(graph_pool)
# Capture first graph
graph_pool = self.first_graph.capture_graph(graph_pool)
# Capture inter-layer graphs
for inter_graph in self.inter_graphs:
graph_pool = inter_graph.capture_graph(graph_pool)
# Capture last graph
graph_pool = self.last_graph.capture_graph(graph_pool)
self.graph_pool = graph_pool
self.captured = True
@property
def num_graphs(self) -> int:
"""Total number of graphs: 1 + 1 + (N-1) + 1 = N+2"""
return 1 + 1 + len(self.inter_graphs) + 1
# Legacy compatibility aliases (for gradual migration)
FirstLayerGraph = FirstGraph
InterLayerGraph = InterGraph
LastLayerGraph = LastGraph

View File

@@ -8,12 +8,43 @@ def apply_rotary_emb(
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""Non-interleaved RoPE (used by Llama, Qwen, etc.)"""
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)
def apply_rotary_emb_interleaved(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""Interleaved RoPE (used by GLM-4, etc.)
Args:
x: [seq_len, num_heads, head_dim]
cos: [seq_len, 1, head_dim // 2]
sin: [seq_len, 1, head_dim // 2]
x is reshaped to [seq_len, num_heads, head_dim // 2, 2] where:
- x[..., 0] are even positions
- x[..., 1] are odd positions
"""
rot_dim = x.shape[-1]
# x_shaped: [seq_len, num_heads, rot_dim // 2, 2]
x_shaped = x.float().reshape(*x.shape[:-1], rot_dim // 2, 2)
# x_0, x_1: [seq_len, num_heads, rot_dim // 2]
x_0 = x_shaped[..., 0]
x_1 = x_shaped[..., 1]
# cos/sin: [seq_len, 1, rot_dim // 2] - broadcasts to num_heads
x_out = torch.stack([
x_0 * cos - x_1 * sin,
x_1 * cos + x_0 * sin,
], dim=-1)
return x_out.flatten(-2).to(x.dtype)
class RotaryEmbedding(nn.Module):
def __init__(
@@ -140,6 +171,76 @@ class Llama3RotaryEmbedding(nn.Module):
return query, key
class GLM4RotaryEmbedding(nn.Module):
"""
GLM-4 RoPE with interleaved rotation and partial rotation.
GLM-4 uses:
- Interleaved rotation (pairs adjacent elements, not first/second half)
- rope_ratio to scale base: base = 10000 * rope_ratio
- Partial rotation: only rotates first rotary_dim elements, rest pass through
- rotary_dim = head_dim // 2 (only half of head_dim is rotated)
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim # GLM-4: rotary_dim = head_dim // 2
# inv_freq shape: [rotary_dim // 2]
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rotary_dim // 2]
cos = freqs.cos()
sin = freqs.sin()
# cache shape [max_pos, 1, rotary_dim // 2, 2]
cache = torch.stack((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
@torch.compile
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply RoPE to query and key.
Args:
positions: [seq_len]
query: [seq_len, num_heads, head_dim]
key: [seq_len, num_kv_heads, head_dim]
Returns:
Rotated query and key with same shapes as input.
"""
cache = self.cos_sin_cache[positions] # [seq_len, 1, rotary_dim // 2, 2]
cos = cache[..., 0] # [seq_len, 1, rotary_dim // 2]
sin = cache[..., 1] # [seq_len, 1, rotary_dim // 2]
# Split into rotated and pass-through parts
q_rot = query[..., :self.rotary_dim]
q_pass = query[..., self.rotary_dim:]
k_rot = key[..., :self.rotary_dim]
k_pass = key[..., self.rotary_dim:]
# Apply interleaved RoPE to rotated part
q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin)
k_rot = apply_rotary_emb_interleaved(k_rot, cos, sin)
# Concatenate rotated and pass-through parts
query = torch.cat([q_rot, q_pass], dim=-1)
key = torch.cat([k_rot, k_pass], dim=-1)
return query, key
# Cache for RoPE instances (keyed by hashable parameters)
_rope_cache: dict[tuple, nn.Module] = {}
@@ -150,10 +251,11 @@ def get_rope(
max_position: int,
base: float,
rope_scaling: dict | None = None,
is_interleaved: bool = False,
):
# Create hashable cache key
if rope_scaling is None:
cache_key = (head_size, rotary_dim, max_position, base, None)
cache_key = (head_size, rotary_dim, max_position, base, None, is_interleaved)
else:
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
if rope_type == "llama3":
@@ -163,14 +265,18 @@ def get_rope(
rope_scaling["low_freq_factor"],
rope_scaling["high_freq_factor"],
rope_scaling["original_max_position_embeddings"],
is_interleaved,
)
else:
cache_key = (head_size, rotary_dim, max_position, base, rope_type)
cache_key = (head_size, rotary_dim, max_position, base, rope_type, is_interleaved)
if cache_key in _rope_cache:
return _rope_cache[cache_key]
if rope_scaling is None:
if is_interleaved:
rope = GLM4RotaryEmbedding(head_size, rotary_dim, max_position, base)
else:
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
else:
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))

View File

@@ -3,7 +3,9 @@
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
# Import models to trigger registration
from nanovllm.models import qwen2
from nanovllm.models import qwen3
from nanovllm.models import llama
from nanovllm.models import glm4
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]

235
nanovllm/models/glm4.py Normal file
View File

@@ -0,0 +1,235 @@
"""GLM-4 model implementation for nano-vllm."""
import torch
from torch import nn
import torch.distributed as dist
from nanovllm.layers.activation import SiluAndMul
from nanovllm.layers.attention import Attention
from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
from nanovllm.models.registry import register_model
class GLM4Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 1048576,
head_dim: int = 128,
rope_theta: float = 10000,
rope_scaling: dict | None = None,
) -> None:
super().__init__()
tp_size = dist.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True, # GLM-4 has QKV bias
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False, # GLM-4 has no output bias
)
# GLM-4 only rotates half of head_dim
rotary_dim = self.head_dim // 2
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
is_interleaved=True, # GLM-4 uses interleaved RoPE
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
output = self.o_proj(o.flatten(1, -1))
return output
class GLM4MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False, # GLM-4 has no MLP bias
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class GLM4DecoderLayer(nn.Module):
def __init__(self, config) -> None:
super().__init__()
# GLM-4 config field mapping
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
num_kv_heads = getattr(config, 'multi_query_group_num', num_heads)
head_dim = getattr(config, 'kv_channels', hidden_size // num_heads)
max_position = getattr(config, 'seq_length', 1048576)
rope_ratio = getattr(config, 'rope_ratio', 1)
rope_theta = 10000 * rope_ratio # GLM-4 uses rope_ratio to scale base
intermediate_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None))
rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5)
self.self_attn = GLM4Attention(
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
head_dim=head_dim,
rope_theta=rope_theta,
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = GLM4MLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class GLM4Model(nn.Module):
def __init__(self, config) -> None:
super().__init__()
vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size)
num_layers = getattr(config, 'num_layers', config.num_hidden_layers)
rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5)
self.embed_tokens = VocabParallelEmbedding(vocab_size, config.hidden_size)
self.layers = nn.ModuleList([GLM4DecoderLayer(config) for _ in range(num_layers)])
self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@register_model("ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMForCausalLM(nn.Module):
"""
GLM-4 model for causal language modeling.
Weight mapping from HuggingFace to nanovllm:
- transformer.embedding.word_embeddings → model.embed_tokens
- transformer.encoder.layers.X.input_layernorm → model.layers.X.input_layernorm
- transformer.encoder.layers.X.self_attention.query_key_value → model.layers.X.self_attn.qkv_proj (split q/k/v)
- transformer.encoder.layers.X.self_attention.dense → model.layers.X.self_attn.o_proj
- transformer.encoder.layers.X.post_attention_layernorm → model.layers.X.post_attention_layernorm
- transformer.encoder.layers.X.mlp.dense_h_to_4h → model.layers.X.mlp.gate_up_proj (split gate/up)
- transformer.encoder.layers.X.mlp.dense_4h_to_h → model.layers.X.mlp.down_proj
- transformer.encoder.final_layernorm → model.norm
- transformer.output_layer → lm_head
"""
packed_modules_mapping = {
# QKV is merged in GLM-4 as query_key_value
"query_key_value": ("qkv_proj", None), # Special handling needed
# MLP gate and up are merged as dense_h_to_4h
"dense_h_to_4h": ("gate_up_proj", None), # Special handling needed
}
# Weight name mapping for loader
hf_to_nanovllm_mapping = {
"transformer.embedding.word_embeddings": "model.embed_tokens",
"transformer.encoder.final_layernorm": "model.norm",
"transformer.output_layer": "lm_head",
}
def __init__(self, config) -> None:
super().__init__()
vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size)
self.config = config
self.model = GLM4Model(config)
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
# GLM-4 does not tie embeddings
# if getattr(config, 'tie_word_embeddings', False):
# self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)

207
nanovllm/models/qwen2.py Normal file
View File

@@ -0,0 +1,207 @@
import torch
from torch import nn
import torch.distributed as dist
from transformers import Qwen2Config
from nanovllm.layers.activation import SiluAndMul
from nanovllm.layers.attention import Attention
from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
from nanovllm.models.registry import register_model
class Qwen2Attention(nn.Module):
"""Qwen2/2.5 Attention without QK norm (unlike Qwen3)."""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
) -> None:
super().__init__()
tp_size = dist.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True, # Qwen2/2.5 always uses bias
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
output = self.o_proj(o.flatten(1, -1))
return output
class Qwen2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class Qwen2DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2Config,
) -> None:
super().__init__()
self.self_attn = Qwen2Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
head_dim=getattr(config, 'head_dim', None),
rope_theta=getattr(config, "rope_theta", 1000000),
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = Qwen2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@register_model("Qwen2ForCausalLM")
class Qwen2ForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: Qwen2Config
) -> None:
super().__init__()
self.model = Qwen2Model(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)

View File

@@ -187,7 +187,7 @@ class Qwen3Model(nn.Module):
return hidden_states
@register_model("Qwen3ForCausalLM", "Qwen2ForCausalLM")
@register_model("Qwen3ForCausalLM")
class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),

View File

@@ -414,6 +414,90 @@ def merge_attention_outputs(
return o_merged, lse_merged
# ============================================================
# FlashInfer-based implementations (recommended for merge only)
# ============================================================
# LSE conversion constants: FlashInfer uses log2, flash_attn uses ln
_LOG2_E = 1.4426950408889634 # math.log2(math.e) - ln -> log2
_LN_2 = 0.6931471805599453 # math.log(2) - log2 -> ln
# Check FlashInfer availability (only for merge_state, not attention kernel)
try:
from flashinfer.cascade import merge_state, merge_state_in_place
FLASHINFER_MERGE_AVAILABLE = True
except ImportError:
FLASHINFER_MERGE_AVAILABLE = False
def flash_attn_with_lse_flashinfer(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Flash attention that returns output and LSE.
Uses flash_attn library (FlashInfer attention has JIT compatibility issues).
Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
causal: Whether to apply causal masking
Returns:
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] (ln format)
"""
# Use flash_attn directly (FlashInfer attention JIT has CUDA version issues)
return flash_attn_with_lse(q, k, v, softmax_scale, causal)
def merge_attention_outputs_flashinfer(
o1: torch.Tensor,
lse1: torch.Tensor,
o2: torch.Tensor,
lse2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge two attention outputs using FlashInfer's optimized kernel.
Args:
o1: First output [batch, seqlen_q, nheads, headdim]
lse1: First LSE [batch, nheads, seqlen_q] (ln format)
o2: Second output [batch, seqlen_q, nheads, headdim]
lse2: Second LSE [batch, nheads, seqlen_q] (ln format)
Returns:
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
lse_merged: Merged LSE [batch, nheads, seqlen_q] (ln format)
"""
if not FLASHINFER_MERGE_AVAILABLE:
# Fallback to Triton implementation
return merge_attention_outputs(o1, lse1, o2, lse2)
# Convert to FlashInfer format
# o: [batch, seq, heads, dim] -> [seq, heads, dim]
# lse: [batch, heads, seq] -> [seq, heads] (convert ln -> log2)
v_a = o1.squeeze(0).contiguous()
s_a = (lse1.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E)
v_b = o2.squeeze(0).contiguous()
s_b = (lse2.squeeze(0).transpose(0, 1).contiguous().float() * _LOG2_E)
# FlashInfer merge
v_merged, s_merged = merge_state(v_a, s_a, v_b, s_b)
# Convert back to flash_attn format
o_merged = v_merged.unsqueeze(0) # [1, seq, heads, dim]
lse_merged = (s_merged * _LN_2).transpose(0, 1).unsqueeze(0) # [1, heads, seq]
return o_merged, lse_merged
def chunked_attention_varlen(
q: torch.Tensor,
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],

View File

@@ -218,6 +218,209 @@ def softmax_fuse_block_sum_kernel_non_causal(
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
# ============================================================
# KV Chunking Support Kernels
# ============================================================
@triton.jit
def softmax_partial_stats_kernel(
In,
M_out, # max per row
L_out, # sum per row (normalized by M_out)
scale,
input_stride_0,
input_stride_1,
input_stride_2,
stats_stride_0,
stats_stride_1,
k_len,
chunk_start, # Q start position (for causal)
kv_offset, # KV chunk offset (for causal)
segment_size: tl.constexpr,
block_size: tl.constexpr,
is_causal: tl.constexpr,
):
"""
Compute partial softmax statistics for a KV chunk.
For each query row, computes:
- m: max value in this chunk
- l: sum of exp(x - m) in this chunk
These can be merged across chunks using online softmax formula.
Input shape: [batch, heads, q_len, k_chunk_len]
Output shapes: M[batch, heads, q_len], L[batch, heads, q_len]
"""
block_id = tl.program_id(0)
head_id = tl.program_id(1)
batch_id = tl.program_id(2)
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
offs_k = tl.arange(0, segment_size)
num_iters = k_len // segment_size
# For causal: compute boundary
if is_causal:
# causal boundary: Q position where this KV chunk starts to be valid
# Q[i] can attend K[j] if i >= j
# For KV chunk at kv_offset, Q[i] can attend if i >= kv_offset
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
else:
num_iters_before_causal = num_iters
# Online softmax state
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
l_i = tl.zeros([block_size], dtype=tl.float32)
# Input pointer
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
# Compute max and sum (before causal boundary)
for iter in range(0, num_iters_before_causal):
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
m_local = tl.max(X, 1)
m_new = tl.maximum(m_i, m_local)
alpha = tl.math.exp2(m_i - m_new)
X = X - m_new[:, None]
l_local = tl.sum(tl.math.exp2(X), 1)
l_i = l_i * alpha + l_local
m_i = m_new
# Handle causal boundary
if is_causal:
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
if iter < num_iters:
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
# causal mask: Q[i] >= K[j] + kv_offset
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
X = tl.where(mask, X, -1.0e6)
m_local = tl.max(X, 1)
m_new = tl.maximum(m_i, m_local)
alpha = tl.math.exp2(m_i - m_new)
X = X - m_new[:, None]
l_local = tl.sum(tl.math.exp2(X), 1)
l_i = l_i * alpha + l_local
m_i = m_new
# Output pointers
m_ptr = M_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
l_ptr = L_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
offs = tl.arange(0, block_size)
tl.store(m_ptr + offs, m_i.to(M_out.type.element_ty))
tl.store(l_ptr + offs, l_i.to(L_out.type.element_ty))
@triton.jit
def softmax_normalize_block_sum_kernel(
In,
Out,
M_global, # global max per row
L_global, # global sum per row
scale,
input_stride_0,
input_stride_1,
input_stride_2,
output_stride_0,
output_stride_1,
output_stride_2,
stats_stride_0,
stats_stride_1,
real_q_len,
k_len,
chunk_start,
kv_offset, # KV chunk offset (for causal)
segment_size: tl.constexpr,
block_size: tl.constexpr,
is_causal: tl.constexpr,
):
"""
Normalize with global stats and compute block sums for a KV chunk.
Uses pre-computed global m and l to correctly normalize softmax
across all KV chunks.
Input shape: [batch, heads, q_len, k_chunk_len]
Output shape: [batch, heads, q_blocks, k_chunk_blocks]
"""
block_id = tl.program_id(0)
head_id = tl.program_id(1)
batch_id = tl.program_id(2)
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
offs_k = tl.arange(0, segment_size)
num_iters = k_len // segment_size
# For causal: compute boundary
if is_causal:
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
else:
num_iters_before_causal = num_iters
# Load global stats
m_ptr = M_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
l_ptr = L_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
offs = tl.arange(0, block_size)
m_global = tl.load(m_ptr + offs).to(tl.float32)
l_global = tl.load(l_ptr + offs).to(tl.float32)
# Handle l_global = 0 (when all positions are masked)
l_global_safe = tl.where(l_global > 0, l_global, 1.0)
l_global_inv = 1.0 / l_global_safe
# Input pointer
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
# Output pointer
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
sum_mask = offs_q[:, None] < real_q_len
# Normalize and compute block sums (before causal boundary)
for iter in range(0, num_iters_before_causal):
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
X = tl.where(sum_mask, X, 0)
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
X = tl.sum(X, 2)
X = tl.sum(X, 0)
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
# Handle causal boundary
if is_causal:
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
if iter < num_iters:
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
# causal mask: Q[i] >= K[j] + kv_offset
causal_mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
X = tl.where(causal_mask, X, -1.0e6)
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
X = tl.where(sum_mask, X, 0)
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
X = tl.sum(X, 2)
X = tl.sum(X, 0)
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
# Zero out future blocks
for iter in range(num_iters_before_causal + 1, num_iters):
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
@triton.jit
def flat_group_gemm_fuse_reshape_kernel(
Q, K, Out,
@@ -380,6 +583,194 @@ def softmax_fuse_block_sum(
return output
def softmax_compute_partial_stats(
attn_weights_slice: torch.Tensor,
reshaped_block_size: int,
segment_size: int,
scale: float,
chunk_start: int = 0,
kv_offset: int = 0,
is_causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute partial softmax statistics for a KV chunk.
This is the first step for KV-chunked softmax computation.
For each query row, computes:
- m: max value in this chunk
- l: sum of exp(x - m) in this chunk
These partial stats can be merged across KV chunks using
`merge_softmax_stats()`, then used with `softmax_normalize_and_block_sum()`.
Args:
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
reshaped_block_size: Block size in reshaped space
segment_size: Processing segment size
scale: Softmax scale factor
chunk_start: Q chunk start position (in reshaped space)
kv_offset: KV chunk offset (in reshaped space, for causal masking)
is_causal: Whether to apply causal masking
Returns:
Tuple of (m, l) where:
- m: [batch, heads, q_len] max values per row
- l: [batch, heads, q_len] partial sums per row
"""
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
assert q_len % reshaped_block_size == 0
assert k_len % segment_size == 0
assert attn_weights_slice.stride(-1) == 1
m_out = torch.empty(
(batch_size, num_heads, q_len),
dtype=torch.float32,
device=attn_weights_slice.device
)
l_out = torch.empty(
(batch_size, num_heads, q_len),
dtype=torch.float32,
device=attn_weights_slice.device
)
grid = (q_len // reshaped_block_size, num_heads, batch_size)
softmax_partial_stats_kernel[grid](
attn_weights_slice,
m_out,
l_out,
scale,
attn_weights_slice.stride(0),
attn_weights_slice.stride(1),
attn_weights_slice.stride(2),
m_out.stride(0),
m_out.stride(1),
k_len,
chunk_start,
kv_offset,
segment_size,
reshaped_block_size,
is_causal,
)
return m_out, l_out
def merge_softmax_stats(
m_chunks: list,
l_chunks: list,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge partial softmax statistics from multiple KV chunks.
Uses the online softmax merging formula:
m_new = max(m1, m2)
l_new = l1 * exp(m1 - m_new) + l2 * exp(m2 - m_new)
Args:
m_chunks: List of max tensors [batch, heads, q_len] from each chunk
l_chunks: List of sum tensors [batch, heads, q_len] from each chunk
Returns:
Tuple of (m_global, l_global) with same shape as inputs
"""
assert len(m_chunks) == len(l_chunks)
assert len(m_chunks) > 0
# Use log2 scale to match kernel (exp2)
LOG2E = 1.4426950408889634
m_global = m_chunks[0].clone()
l_global = l_chunks[0].clone()
for i in range(1, len(m_chunks)):
m_chunk = m_chunks[i]
l_chunk = l_chunks[i]
m_new = torch.maximum(m_global, m_chunk)
# exp2(m - m_new) = 2^(m - m_new)
l_global = l_global * torch.pow(2.0, m_global - m_new) + l_chunk * torch.pow(2.0, m_chunk - m_new)
m_global = m_new
return m_global, l_global
def softmax_normalize_and_block_sum(
attn_weights_slice: torch.Tensor,
m_global: torch.Tensor,
l_global: torch.Tensor,
reshaped_block_size: int,
segment_size: int,
chunk_start: int,
real_q_len: int,
scale: float,
kv_offset: int = 0,
is_causal: bool = False,
) -> torch.Tensor:
"""
Normalize with global stats and compute block sums for a KV chunk.
This is the second step for KV-chunked softmax computation.
Uses pre-computed global m and l (from `merge_softmax_stats()`)
to correctly normalize softmax values and compute block sums.
Args:
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
m_global: Global max values [batch, heads, q_len]
l_global: Global sum values [batch, heads, q_len]
reshaped_block_size: Block size in reshaped space
segment_size: Processing segment size
chunk_start: Start position for this chunk (for masking)
real_q_len: Actual Q length (before padding)
scale: Softmax scale factor
kv_offset: KV chunk offset (in reshaped space, for causal masking)
is_causal: Whether to apply causal masking
Returns:
Block-level attention sums [batch, heads, q_blocks, k_chunk_blocks]
"""
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
assert q_len % reshaped_block_size == 0
assert k_len % segment_size == 0
assert segment_size % reshaped_block_size == 0
assert attn_weights_slice.stride(-1) == 1
output = torch.empty(
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
dtype=attn_weights_slice.dtype,
device=attn_weights_slice.device
)
grid = (q_len // reshaped_block_size, num_heads, batch_size)
softmax_normalize_block_sum_kernel[grid](
attn_weights_slice,
output,
m_global,
l_global,
scale,
attn_weights_slice.stride(0),
attn_weights_slice.stride(1),
attn_weights_slice.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
m_global.stride(0),
m_global.stride(1),
real_q_len,
k_len,
chunk_start,
kv_offset,
segment_size,
reshaped_block_size,
is_causal,
)
return output
def flat_group_gemm_fuse_reshape(
query_states: torch.Tensor,
key_states: torch.Tensor,
@@ -419,7 +810,9 @@ def flat_group_gemm_fuse_reshape(
assert key_states.shape[1] == num_heads
assert key_states.shape[3] == head_dim
output = torch.empty(
# Use zeros instead of empty to handle causal early-exit in kernel
# (some blocks may not be written due to causal mask optimization)
output = torch.zeros(
(batch_size, num_heads, q_len // stride, kv_len // stride),
dtype=query_states.dtype,
device=query_states.device
@@ -1067,6 +1460,7 @@ def xattn_estimate_chunked(
)
# Softmax + block sum
# segment_size should match the standard xattn_estimate for consistency
attn_sum = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
@@ -1082,6 +1476,14 @@ def xattn_estimate_chunked(
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
else:
# PyTorch fallback implementation
# Match Triton kernel exactly for consistency
#
# Triton uses:
# 1. exp2 (base-2 exponential) for softmax
# 2. scale factor includes log2(e) = 1.4426950408889634
# 3. causal mask: q_pos >= k_pos (not q_pos + 1 > k_pos)
# 4. chunk_start for global Q position tracking
# Reshape K: interleave positions and concatenate head dims
reshaped_key = torch.cat(
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
@@ -1093,49 +1495,58 @@ def xattn_estimate_chunked(
dim=-1,
)
# Use same scale as Triton: includes log2(e) for exp2 compatibility
# Triton: scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Convert to float32 for numerical stability (matching Triton)
reshaped_query_f32 = reshaped_query.to(torch.float32)
reshaped_key_f32 = reshaped_key.to(torch.float32)
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
attn_weights = torch.matmul(
reshaped_query, reshaped_key.transpose(2, 3)
) / math.sqrt(head_dim) / stride / norm
reshaped_query_f32, reshaped_key_f32.transpose(2, 3)
) * scale
# Apply causal mask
# Apply causal mask (matching Triton's logic exactly)
if causal:
reshaped_q_positions = reshaped_q_len
causal_mask = torch.zeros(
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len),
device=key_states.device,
dtype=attn_weights.dtype,
# Triton uses: offs_q = chunk_start + block_id * block_size + arange(0, block_size)
# chunk_start = q_start_block * reshaped_block_size
chunk_start = q_start_block * reshaped_block_size
# Create position indices in reshaped space
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
k_positions = torch.arange(reshaped_k_len, device=attn_weights.device)
# Triton causal mask: q_pos >= k_pos
causal_mask = q_positions[:, None] >= k_positions[None, :] # (reshaped_q_len, reshaped_k_len)
# Apply causal mask: set future positions to -1e6 (matching Triton)
attn_weights = attn_weights.masked_fill(
~causal_mask.unsqueeze(0).unsqueeze(0), -1e6
)
# Mask out padding in K
if k_pad > 0:
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf")
# Softmax using exp2 (matching Triton exactly)
# Triton: X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
# All computation in float32
attn_max = attn_weights.max(dim=-1, keepdim=True).values
attn_weights_shifted = attn_weights - attn_max
attn_exp2 = torch.exp2(attn_weights_shifted)
attn_sum_exp2 = attn_exp2.sum(dim=-1, keepdim=True)
attn_weights = attn_exp2 / attn_sum_exp2
# Mask out future positions
q_start_reshaped = q_start_pos // stride
for q_idx in range(reshaped_q_positions):
q_pos_reshaped = q_start_reshaped + q_idx
if q_pos_reshaped + 1 < reshaped_k_len:
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf")
# Mask for valid Q positions (matching Triton's sum_mask)
# Triton: sum_mask = offs_q[:, None] < real_q_len
# real_q_len = chunk_start + valid_q_reshaped
chunk_start = q_start_block * reshaped_block_size
real_q_len = chunk_start + valid_q_reshaped
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
valid_q_mask = q_positions < real_q_len # (reshaped_q_len,)
# Handle padding in Q
if q_pad > 0:
q_pad_reshaped = q_pad // stride
if q_pad_reshaped > 0:
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
# Zero out invalid Q positions
attn_weights = attn_weights * valid_q_mask.view(1, 1, -1, 1).float()
attn_weights = attn_weights + causal_mask
# Apply softmax
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Zero out padded Q positions
if q_pad > 0:
q_pad_reshaped = q_pad // stride
if q_pad_reshaped > 0:
attn_weights[:, :, -q_pad_reshaped:, :] = 0
# Aggregate to block level
# Aggregate to block level (keep in float32)
attn_sum = attn_weights.view(
batch_size,
num_heads,
@@ -1145,6 +1556,9 @@ def xattn_estimate_chunked(
reshaped_block_size,
).sum(dim=-1).sum(dim=-2)
# Convert back to input dtype for consistency
attn_sum = attn_sum.to(query_states.dtype)
# Find blocks that exceed threshold
simple_mask = find_blocks_chunked(
attn_sum,

View File

@@ -0,0 +1,327 @@
"""
DensityObserver - Sparse Attention Density 统计 Observer。
统计两种 density:
1. Compute Density (计算密度): 基于 BSA block size (128)
- density = selected_bsa_blocks / total_causal_bsa_blocks
- GPU-only 和 Offload 模式应该一致
2. Communication Density (通信密度): 基于 CPU block size (如 4096)
- comm_density = selected_cpu_blocks / total_cpu_blocks
- 仅用于 Offload 模式,由于粒度更粗,必然 >= compute density
统计位置:
- GPU-only: xattn_bsa.py compute_prefill() - 只记录 compute density
- Offload: xattn_bsa.py select_blocks() - 记录两种 density
对于 Offload 模式的 Density 计算:
- 不是简单的 avg 或 min
- 而是 sum(selected) / sum(total),正确处理不同 chunk 大小的权重
"""
from typing import List, Dict, Optional, Tuple
import torch
from nanovllm.utils.observer import Observer
class DensityObserver(Observer):
"""
Sparse Attention Density Observer。
记录每层的 density用于验证 GPU-only 和 Offload 模式的一致性。
使用方式:
DensityObserver.enable()
DensityObserver.complete_reset()
# ... run inference ...
DensityObserver.record(layer_id, mask, causal=True)
# 或者使用累积模式 (offload):
DensityObserver.record_counts(layer_id, selected, total)
# ...
DensityObserver.print_summary()
"""
_enabled: bool = False # 默认禁用
# 每层的 compute density 记录 (BSA block 粒度)
# key: layer_id, value: list of density values (每次 prefill chunk 一个)
_layer_densities: Dict[int, List[float]] = {}
# 每层的 communication density 记录 (CPU block 粒度,仅 offload 模式)
_layer_comm_densities: Dict[int, List[float]] = {}
# 累积模式: 记录 selected/total counts (用于 offload 模式)
# 这样可以在所有 chunks 完成后正确计算 density = sum(selected) / sum(total)
_layer_selected_counts: Dict[int, List[int]] = {}
_layer_total_counts: Dict[int, List[int]] = {}
# Mask shape 记录 (用于调试)
_last_q_blocks: int = 0
_last_k_blocks: int = 0
# 模式标记
_mode: str = "unknown" # "gpu_only" or "offload"
@classmethod
def set_mode(cls, mode: str) -> None:
"""设置当前模式 (gpu_only / offload)"""
cls._mode = mode
@classmethod
def record(
cls,
layer_id: int,
mask: torch.Tensor,
causal: bool = True,
) -> float:
"""
记录一层的 density (适用于 GPU-only 模式)。
Args:
layer_id: 层 ID
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
causal: 是否考虑 causal mask (只计算下三角)
Returns:
density 值
"""
if not cls._enabled:
return 0.0
density = cls._compute_density(mask, causal)
# 记录
if layer_id not in cls._layer_densities:
cls._layer_densities[layer_id] = []
cls._layer_densities[layer_id].append(density)
# 记录 mask shape
cls._last_q_blocks = mask.shape[2]
cls._last_k_blocks = mask.shape[3]
return density
@classmethod
def record_counts(
cls,
layer_id: int,
selected_blocks: int,
total_blocks: int,
) -> None:
"""
记录一层的 selected/total block counts (适用于 offload 累积模式)。
使用累积计数而不是直接计算 density这样在所有 chunks 处理完后可以正确计算:
overall_density = sum(selected) / sum(total)
这比 avg(density) 更准确,因为不同 chunk 的 Q 和 K 长度不同。
Args:
layer_id: 层 ID
selected_blocks: 这个 chunk 选中的 blocks 数量
total_blocks: 这个 chunk 的 total possible blocks 数量
"""
if not cls._enabled:
return
# 初始化列表
if layer_id not in cls._layer_selected_counts:
cls._layer_selected_counts[layer_id] = []
if layer_id not in cls._layer_total_counts:
cls._layer_total_counts[layer_id] = []
# 累积记录
cls._layer_selected_counts[layer_id].append(selected_blocks)
cls._layer_total_counts[layer_id].append(total_blocks)
@classmethod
def record_comm_density(
cls,
layer_id: int,
selected_cpu_blocks: int,
total_cpu_blocks: int,
) -> float:
"""
记录一层的 communication density (CPU block 粒度)。
Args:
layer_id: 层 ID
selected_cpu_blocks: 选中的 CPU blocks 数量
total_cpu_blocks: 总 CPU blocks 数量
Returns:
communication density 值
"""
if not cls._enabled:
return 0.0
if total_cpu_blocks == 0:
return 1.0
comm_density = selected_cpu_blocks / total_cpu_blocks
# 记录
if layer_id not in cls._layer_comm_densities:
cls._layer_comm_densities[layer_id] = []
cls._layer_comm_densities[layer_id].append(comm_density)
return comm_density
@classmethod
def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float:
"""计算 mask 的 density"""
batch, heads, q_blocks, k_blocks = mask.shape
if causal:
# 只计算下三角区域
causal_mask = torch.tril(
torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool)
)
total_blocks = causal_mask.sum().item() * batch * heads
selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
else:
total_blocks = mask.numel()
selected_blocks = mask.sum().item()
if total_blocks == 0:
return 1.0
return selected_blocks / total_blocks
@classmethod
def complete_reset(cls) -> None:
"""重置所有统计"""
cls._layer_densities = {}
cls._layer_comm_densities = {}
cls._layer_selected_counts = {}
cls._layer_total_counts = {}
cls._last_q_blocks = 0
cls._last_k_blocks = 0
cls._mode = "unknown"
@classmethod
def get_per_layer_density(cls) -> Dict[int, float]:
"""
获取每层的 density。
对于累积模式 (offload): density = sum(selected) / sum(total)
对于直接记录模式 (gpu_only): density = avg(density_values)
"""
result = {}
# 优先使用累积模式 (offload)
if cls._layer_selected_counts:
for layer_id in cls._layer_selected_counts:
selected_list = cls._layer_selected_counts.get(layer_id, [])
total_list = cls._layer_total_counts.get(layer_id, [])
total_selected = sum(selected_list)
total_total = sum(total_list)
if total_total > 0:
result[layer_id] = total_selected / total_total
else:
# 直接记录模式 (gpu_only)
for layer_id, densities in cls._layer_densities.items():
if densities:
result[layer_id] = sum(densities) / len(densities)
return result
@classmethod
def get_overall_density(cls) -> float:
"""
获取所有层的总体 compute density。
对于累积模式 (offload): density = sum(all_selected) / sum(all_total)
对于直接记录模式 (gpu_only): density = avg(all_density_values)
注意: 总体 density 不是简单的 avg(per_layer_density)
而是 sum(all_selected) / sum(all_total),这样可以正确处理权重。
"""
# 优先使用累积模式 (offload)
if cls._layer_selected_counts:
total_selected = 0
total_total = 0
for layer_id in cls._layer_selected_counts:
total_selected += sum(cls._layer_selected_counts[layer_id])
total_total += sum(cls._layer_total_counts.get(layer_id, []))
if total_total > 0:
return total_selected / total_total
return 0.0
# 直接记录模式 (gpu_only)
all_densities = []
for densities in cls._layer_densities.values():
all_densities.extend(densities)
if not all_densities:
return 0.0
return sum(all_densities) / len(all_densities)
@classmethod
def get_overall_comm_density(cls) -> float:
"""获取所有层的平均 communication density"""
all_densities = []
for densities in cls._layer_comm_densities.values():
all_densities.extend(densities)
if not all_densities:
return 0.0
return sum(all_densities) / len(all_densities)
@classmethod
def get_per_layer_comm_density(cls) -> Dict[int, float]:
"""
获取每层的 communication density (CPU block 粒度)。
Returns:
Dict[layer_id, avg_comm_density]
"""
result = {}
for layer_id, densities in cls._layer_comm_densities.items():
if densities:
result[layer_id] = sum(densities) / len(densities)
return result
@classmethod
def get_summary(cls) -> dict:
"""返回统计摘要"""
per_layer = cls.get_per_layer_density()
per_layer_comm = cls.get_per_layer_comm_density()
return {
"mode": cls._mode,
"overall_compute_density": cls.get_overall_density(),
"overall_comm_density": cls.get_overall_comm_density(),
"per_layer_compute_density": per_layer,
"per_layer_comm_density": per_layer_comm,
"num_layers": len(per_layer),
"last_mask_shape": {
"q_blocks": cls._last_q_blocks,
"k_blocks": cls._last_k_blocks,
},
}
@classmethod
def get_min_density(cls) -> Tuple[int, float]:
"""获取最低 density 的层和值"""
per_layer = cls.get_per_layer_density()
if not per_layer:
return -1, 0.0
min_layer = min(per_layer, key=per_layer.get)
return min_layer, per_layer[min_layer]
@classmethod
def print_summary(cls) -> None:
"""打印人类可读的摘要"""
per_layer = cls.get_per_layer_density()
overall = cls.get_overall_density()
min_layer, min_density = cls.get_min_density()
overall_comm = cls.get_overall_comm_density()
print(f"[DensityObserver] Mode: {cls._mode}")
print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
if overall_comm > 0:
# Offload mode: show both densities with explanation
print(f" Comm density: {overall_comm:.4f} (CPU block granularity)")
print(f" Savings ratio: {1 - overall_comm:.1%} H2D transfer reduction")
print(f" Num layers: {len(per_layer)}")
# 输出 layer 0 的 density 用于对比
if 0 in per_layer:
print(f" Layer 0 density: {per_layer[0]:.6f}")

View File

@@ -1,4 +1,5 @@
import os
import re
from glob import glob
import torch
from torch import nn
@@ -9,20 +10,146 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
# GLM-4 weight name mappings
GLM4_NAME_MAPPING = {
"transformer.embedding.word_embeddings": "model.embed_tokens",
"transformer.encoder.final_layernorm": "model.norm",
"transformer.output_layer": "lm_head",
}
GLM4_LAYER_MAPPING = {
"self_attention.query_key_value": "self_attn.qkv_proj",
"self_attention.dense": "self_attn.o_proj",
"mlp.dense_h_to_4h": "mlp.gate_up_proj",
"mlp.dense_4h_to_h": "mlp.down_proj",
}
def convert_glm4_weight_name(weight_name: str) -> tuple[str, str | None]:
"""
Convert GLM-4 weight name to nanovllm format.
Returns:
tuple: (converted_name, shard_id) where shard_id is used for packed modules
Returns (None, None) for weights that should be skipped
"""
# Skip rotary embedding weights (we use our own RoPE implementation)
if "rotary_pos_emb" in weight_name:
return None, None
# Check direct mappings first
for glm_name, nano_name in GLM4_NAME_MAPPING.items():
if weight_name.startswith(glm_name):
return weight_name.replace(glm_name, nano_name), None
# Handle layer weights: transformer.encoder.layers.X.xxx
layer_match = re.match(r"transformer\.encoder\.layers\.(\d+)\.(.+)", weight_name)
if layer_match:
layer_idx = layer_match.group(1)
remainder = layer_match.group(2)
# Handle packed modules (QKV and gate_up)
for glm_subname, nano_subname in GLM4_LAYER_MAPPING.items():
if remainder.startswith(glm_subname):
suffix = remainder[len(glm_subname):] # .weight or .bias
new_name = f"model.layers.{layer_idx}.{nano_subname}{suffix}"
# Determine shard_id for packed modules
if "qkv_proj" in nano_subname:
return new_name, "qkv" # Special marker for GLM4 QKV
elif "gate_up_proj" in nano_subname:
return new_name, "gate_up" # Special marker for GLM4 gate_up
else:
return new_name, None
# Handle non-packed layer weights (layernorms)
new_name = f"model.layers.{layer_idx}.{remainder}"
return new_name, None
# No mapping found, return original
return weight_name, None
def load_glm4_qkv(param: nn.Parameter, loaded_weight: torch.Tensor, config):
"""Load GLM-4 merged QKV weights by splitting into q, k, v."""
num_heads = config.num_attention_heads
num_kv_heads = getattr(config, 'multi_query_group_num', num_heads)
head_dim = getattr(config, 'kv_channels', config.hidden_size // num_heads)
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
# Split QKV: [q_size + kv_size + kv_size, hidden_size]
q, k, v = loaded_weight.split([q_size, kv_size, kv_size], dim=0)
# Load each part using the weight_loader
weight_loader = getattr(param, "weight_loader")
weight_loader(param, q, "q")
weight_loader(param, k, "k")
weight_loader(param, v, "v")
def load_glm4_gate_up(param: nn.Parameter, loaded_weight: torch.Tensor, config):
"""Load GLM-4 merged gate_up weights by splitting into gate, up."""
ffn_hidden_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None))
# Split gate_up: [ffn_hidden_size * 2, hidden_size]
gate, up = loaded_weight.split([ffn_hidden_size, ffn_hidden_size], dim=0)
# Load each part using the weight_loader
weight_loader = getattr(param, "weight_loader")
weight_loader(param, gate, 0) # gate_proj is shard 0
weight_loader(param, up, 1) # up_proj is shard 1
def is_glm4_model(model: nn.Module) -> bool:
"""Check if the model is a GLM-4 model."""
return model.__class__.__name__ in ("ChatGLMForCausalLM",)
def load_model(model: nn.Module, path: str):
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
is_glm4 = is_glm4_model(model)
config = getattr(model, "config", None)
for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
loaded_weight = f.get_tensor(weight_name)
# GLM-4 specific handling
if is_glm4:
param_name, shard_id = convert_glm4_weight_name(weight_name)
# Skip weights that don't need to be loaded
if param_name is None:
continue
if shard_id == "qkv":
param = model.get_parameter(param_name)
load_glm4_qkv(param, loaded_weight, config)
continue
elif shard_id == "gate_up":
param = model.get_parameter(param_name)
load_glm4_gate_up(param, loaded_weight, config)
continue
else:
# Regular weight, use converted name
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
continue
# Original loading logic for other models
for k in packed_modules_mapping:
if k in weight_name:
v, shard_id = packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, f.get_tensor(weight_name), shard_id)
weight_loader(param, loaded_weight, shard_id)
break
else:
param = model.get_parameter(weight_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, f.get_tensor(weight_name))
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,133 @@
"""
MemoryObserver - 内存传输统计 Observer。
统计 GPU-CPU 间的数据传输量:
- H2D (Host to Device): CPU → GPU
- D2H (Device to Host): GPU → CPU
- D2D (Device to Device): GPU → GPU (buffer copy)
"""
from nanovllm.utils.observer import Observer
class MemoryObserver(Observer):
"""
内存传输 Observer统计 GPU-CPU 间的数据传输量。
统计类型:
- H2D (Host to Device): CPU → GPU
- D2H (Device to Host): GPU → CPU
- D2D (Device to Device): GPU → GPU (buffer copy)
统计位置(均在 offload_engine.py
- H2D: load_to_slot_layer(), load_block_sample_from_cpu(), load_block_full_from_cpu()
- D2H: offload_slot_layer_to_cpu(), offload_prefill_buffer_async()
- D2D: write_to_prefill_buffer(), write_to_decode_buffer()
- 重置: llm_engine.py:generate() - 与 InferenceObserver 一起重置
"""
_enabled: bool = False # 默认禁用,需要显式启用
# H2D 统计
h2d_bytes: int = 0
h2d_count: int = 0
# D2H 统计
d2h_bytes: int = 0
d2h_count: int = 0
# D2D 统计
d2d_bytes: int = 0
d2d_count: int = 0
# 按阶段统计
prefill_h2d_bytes: int = 0
prefill_d2h_bytes: int = 0
decode_h2d_bytes: int = 0
decode_d2h_bytes: int = 0
@classmethod
def record_h2d(cls, num_bytes: int, is_prefill: bool = True) -> None:
"""记录 H2D 传输"""
if not cls._enabled:
return
cls.h2d_bytes += num_bytes
cls.h2d_count += 1
if is_prefill:
cls.prefill_h2d_bytes += num_bytes
else:
cls.decode_h2d_bytes += num_bytes
@classmethod
def record_d2h(cls, num_bytes: int, is_prefill: bool = True) -> None:
"""记录 D2H 传输"""
if not cls._enabled:
return
cls.d2h_bytes += num_bytes
cls.d2h_count += 1
if is_prefill:
cls.prefill_d2h_bytes += num_bytes
else:
cls.decode_d2h_bytes += num_bytes
@classmethod
def record_d2d(cls, num_bytes: int) -> None:
"""记录 D2D 传输"""
if not cls._enabled:
return
cls.d2d_bytes += num_bytes
cls.d2d_count += 1
@classmethod
def complete_reset(cls) -> None:
"""重置所有统计"""
cls.h2d_bytes = cls.h2d_count = 0
cls.d2h_bytes = cls.d2h_count = 0
cls.d2d_bytes = cls.d2d_count = 0
cls.prefill_h2d_bytes = cls.prefill_d2h_bytes = 0
cls.decode_h2d_bytes = cls.decode_d2h_bytes = 0
@classmethod
def get_summary(cls) -> dict:
"""返回统计摘要"""
return {
"total": {
"h2d_bytes": cls.h2d_bytes,
"h2d_count": cls.h2d_count,
"d2h_bytes": cls.d2h_bytes,
"d2h_count": cls.d2h_count,
"d2d_bytes": cls.d2d_bytes,
"d2d_count": cls.d2d_count,
},
"prefill": {
"h2d_bytes": cls.prefill_h2d_bytes,
"d2h_bytes": cls.prefill_d2h_bytes,
},
"decode": {
"h2d_bytes": cls.decode_h2d_bytes,
"d2h_bytes": cls.decode_d2h_bytes,
},
}
@classmethod
def _fmt_bytes(cls, b: int) -> str:
"""格式化字节数"""
if b >= 1e9:
return f"{b/1e9:.2f} GB"
if b >= 1e6:
return f"{b/1e6:.2f} MB"
if b >= 1e3:
return f"{b/1e3:.2f} KB"
return f"{b} B"
@classmethod
def print_summary(cls) -> None:
"""打印人类可读的摘要"""
fmt = cls._fmt_bytes
total = cls.h2d_bytes + cls.d2h_bytes + cls.d2d_bytes
print(f"[MemoryObserver] Total: {fmt(total)}")
print(f" H2D: {fmt(cls.h2d_bytes)} ({cls.h2d_count} ops)")
print(f" D2H: {fmt(cls.d2h_bytes)} ({cls.d2h_count} ops)")
print(f" D2D: {fmt(cls.d2d_bytes)} ({cls.d2d_count} ops)")
print(f" Prefill - H2D: {fmt(cls.prefill_h2d_bytes)}, D2H: {fmt(cls.prefill_d2h_bytes)}")
print(f" Decode - H2D: {fmt(cls.decode_h2d_bytes)}, D2H: {fmt(cls.decode_d2h_bytes)}")

View File

@@ -1,17 +1,106 @@
class Observer():
ttft_start = 0
tpot_start = 0
"""
Observer 基类和 InferenceObserver 实现。
ttft = 0
tpot = 0
Observer 架构:
- Observer: 基类,定义通用接口
- InferenceObserver: 推理性能观测TTFT/TPOT
- MemoryObserver: 内存传输观测(在 memory_observer.py 中定义)
"""
class Observer:
"""
Observer 基类,提供通用的启用/禁用、重置、输出接口。
所有 Observer 子类应继承此类并实现:
- complete_reset(): 重置所有统计数据
- get_summary(): 返回统计摘要 dict
- print_summary(): 打印人类可读的摘要
"""
_enabled: bool = True # 默认启用
@classmethod
def reset_ttft(cls):
def enable(cls) -> None:
"""启用 observer"""
cls._enabled = True
@classmethod
def disable(cls) -> None:
"""禁用 observer"""
cls._enabled = False
@classmethod
def is_enabled(cls) -> bool:
"""检查是否启用"""
return cls._enabled
@classmethod
def complete_reset(cls) -> None:
"""重置所有统计数据(子类实现)"""
raise NotImplementedError
@classmethod
def get_summary(cls) -> dict:
"""返回统计摘要(子类实现)"""
raise NotImplementedError
@classmethod
def print_summary(cls) -> None:
"""打印人类可读的摘要(子类可选覆盖)"""
import json
print(json.dumps(cls.get_summary(), indent=2))
class InferenceObserver(Observer):
"""
推理性能 Observer统计 TTFT 和 TPOT。
- TTFT (Time To First Token): 首个 token 生成延迟
- TPOT (Time Per Output Token): 每个输出 token 的平均延迟
统计位置:
- TTFT 开始: scheduler.py:35-36 - 第一个 sequence 从 waiting 队列取出时
- TTFT 结束: llm_engine.py:69-72 - prefill 完成后(包括 chunked prefill 所有 chunks
- TPOT 开始: llm_engine.py:65 - 每次 decode step 结束时
- TPOT 结束: llm_engine.py:62-63 - 下一次 decode step 开始时计算(测量上一次 decode 时间)
- 重置: llm_engine.py:97 - generate() 开始时
注意TPOT 需要至少 2 个输出 token 才能计算(测量 decode step 间隔)。
"""
# 时间戳 (nanoseconds)
ttft_start: int = 0
tpot_start: int = 0
# 统计结果 (nanoseconds)
ttft: int = 0
tpot: int = 0
@classmethod
def reset_ttft(cls) -> None:
"""重置 TTFT 计时器"""
cls.ttft_start = 0
@classmethod
def complete_reset(cls):
def complete_reset(cls) -> None:
"""重置所有统计数据"""
cls.ttft_start = 0
cls.tpot_start = 0
cls.ttft = 0
cls.tpot = 0
@classmethod
def get_summary(cls) -> dict:
"""返回统计摘要"""
return {
"ttft_ns": cls.ttft,
"ttft_ms": cls.ttft / 1e6,
"tpot_ns": cls.tpot,
"tpot_ms": cls.tpot / 1e6,
}
@classmethod
def print_summary(cls) -> None:
"""打印摘要"""
print(f"[InferenceObserver] TTFT: {cls.ttft / 1e6:.2f}ms, TPOT: {cls.tpot / 1e6:.2f}ms")

View File

@@ -1,55 +0,0 @@
# Progress: CUDA Graph for Offload Mode
## Session: 2026-01-22
### 调研阶段 ✅ 完成
**完成的调研**:
1. ✅ 分析 `model_runner.py` 中的 CUDA Graph 实现
- `capture_cudagraph()`: 为不同 batch size 捕获完整 model forward
- `run_model()`: 通过 `is_chunked_prefill` 决定 eager/graph
2. ✅ 分析 offload decode 流程
- `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`
- 导致永远使用 eager mode
3. ✅ 分析 ring buffer pipeline
- `_decode_ring_buffer_pipeline()` 包含 H2D 传输 + attention 计算
- H2D 不能 graphattention 可以 graph
4. ✅ 验证 graph 复用策略
- 创建 `test_chunk_attention_graph_reuse.py`
- 确认 2 个 graph 可复用于所有层
### 计划编写 ✅ 完成
- ✅ 创建 `task_plan.md`
- ✅ 创建 `findings.md`
- ✅ 创建 `progress.md`
### 下一步: 实现
**Phase 1**: 添加 graph 捕获到 OffloadEngine
- [ ]`offload_engine.py` 添加 `capture_attention_graphs()`
- [ ] 添加 `attention_graph_causal``attention_graph_non_causal` 属性
**Phase 2**: 修改 ring buffer pipeline
- [ ]`_decode_ring_buffer_pipeline()` 使用 graph replay
- [ ] 保持 H2D 和 merge 为 eager
**Phase 3**: 测试
- [ ] 运行 needle test 验证正确性
- [ ] 对比性能
---
## 文件清单
| 文件 | 状态 | 说明 |
|------|------|------|
| `tests/test_chunk_attention_graph.py` | ✅ 已提交 | 预分配 chunk pair graphs 测试 |
| `tests/test_chunk_attention_graph_reuse.py` | 待提交 | Graph 复用验证 |
| `task_plan.md` | ✅ 创建 | 实现计划 |
| `findings.md` | ✅ 创建 | 调研发现 |
| `progress.md` | ✅ 创建 | 进度日志 |

158
scripts/profile.sh Executable file
View File

@@ -0,0 +1,158 @@
#!/bin/bash
# Profile bench.py using NVIDIA Nsight Systems (GPU-only mode)
#
# Usage:
# bash scripts/profile.sh [options]
#
# Options:
# --max-len LENGTH Max sequence length (default: 32768)
# --policy POLICY Sparse policy: full, xattn (default: xattn)
# --gpu GPU_ID GPU to use (default: 0)
# --gpu-util UTIL GPU memory utilization (default: 0.9)
# --input-len LENGTH Input length (default: max-len - 1)
# --bench-decode Run decode benchmark instead of prefill
#
# Output:
# results/nsys/bench_<policy>_<max_len>_<timestamp>.nsys-rep
#
# Examples:
# bash scripts/profile.sh
# bash scripts/profile.sh --max-len 65536 --gpu-util 0.7
# bash scripts/profile.sh --policy full --max-len 32768
# bash scripts/profile.sh --bench-decode
set -e
# Default configuration
MAX_LEN="32768"
POLICY="xattn"
GPU_ID="0"
GPU_UTIL="0.9"
INPUT_LEN=""
BENCH_MODE="prefill"
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
--max-len)
MAX_LEN="$2"
shift 2
;;
--policy)
POLICY="$2"
shift 2
;;
--gpu)
GPU_ID="$2"
shift 2
;;
--gpu-util)
GPU_UTIL="$2"
shift 2
;;
--input-len)
INPUT_LEN="$2"
shift 2
;;
--bench-decode)
BENCH_MODE="decode"
shift
;;
-h|--help)
echo "Usage: $0 [options]"
echo ""
echo "Options:"
echo " --max-len LENGTH Max sequence length (default: 32768)"
echo " --policy POLICY Sparse policy: full, xattn (default: xattn)"
echo " --gpu GPU_ID GPU to use (default: 0)"
echo " --gpu-util UTIL GPU memory utilization (default: 0.9)"
echo " --input-len LENGTH Input length (default: max-len - 1)"
echo " --bench-decode Run decode benchmark instead of prefill"
exit 0
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done
# Path configuration
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
OUTPUT_DIR="$PROJECT_ROOT/results/nsys"
BENCH_SCRIPT="$PROJECT_ROOT/bench.py"
# Create output directory if needed
mkdir -p "$OUTPUT_DIR"
# Generate timestamp for unique filename
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
# Convert max_len to human-readable format (e.g., 32768 -> 32k)
if [ "$MAX_LEN" -ge 1024 ]; then
MAX_LEN_SUFFIX="$((MAX_LEN / 1024))k"
else
MAX_LEN_SUFFIX="${MAX_LEN}"
fi
OUTPUT_FILE="$OUTPUT_DIR/bench_${POLICY}_${MAX_LEN_SUFFIX}_${BENCH_MODE}_${TIMESTAMP}"
# Build bench.py arguments
BENCH_ARGS="--max-len $MAX_LEN --gpu-util $GPU_UTIL"
if [ -n "$POLICY" ]; then
BENCH_ARGS="$BENCH_ARGS --policy $POLICY"
fi
if [ -n "$INPUT_LEN" ]; then
BENCH_ARGS="$BENCH_ARGS --input-len $INPUT_LEN"
fi
if [ "$BENCH_MODE" = "decode" ]; then
BENCH_ARGS="$BENCH_ARGS --bench-decode"
fi
echo "============================================================"
echo "NVIDIA Nsight Systems Profiling (GPU-only)"
echo "============================================================"
echo "Bench script: $BENCH_SCRIPT"
echo "Policy: $POLICY"
echo "Max length: $MAX_LEN"
echo "GPU: $GPU_ID"
echo "GPU util: $GPU_UTIL"
echo "Bench mode: $BENCH_MODE"
echo "Output file: $OUTPUT_FILE.nsys-rep"
echo ""
# nsys profile options:
# --trace=cuda,nvtx : Trace CUDA API and NVTX markers
# --force-overwrite=true : Overwrite existing output file
# --output=<path> : Output file path (without .nsys-rep extension)
echo "Running nsys profile..."
echo "Command: python bench.py $BENCH_ARGS"
echo ""
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" \
nsys profile \
--trace=cuda,nvtx \
--force-overwrite=true \
--output="$OUTPUT_FILE" \
python "$BENCH_SCRIPT" $BENCH_ARGS
echo ""
echo "============================================================"
echo "Profiling completed successfully!"
echo "============================================================"
echo "Output file: $OUTPUT_FILE.nsys-rep"
echo ""
echo "To view results in GUI:"
echo " nsight-sys $OUTPUT_FILE.nsys-rep"
echo ""
echo "To export statistics:"
echo " nsys stats --report cuda_api_sum $OUTPUT_FILE.nsys-rep"
echo " nsys stats --report cuda_gpu_kern_sum $OUTPUT_FILE.nsys-rep"
echo " nsys stats --report cuda_gpu_mem_size_sum $OUTPUT_FILE.nsys-rep"
echo "============================================================"

View File

@@ -1,35 +1,171 @@
#!/bin/bash
# Profile test_attention_offload.py using NVIDIA Nsight Systems
# Profile test_ruler.py using NVIDIA Nsight Systems
#
# Usage:
# bash scripts/profile_offload.sh
# bash scripts/profile_offload.sh [options]
#
# Options:
# --policy POLICY Sparse policy name (default: full)
# --ctx-len LENGTH Context length: 32k, 64k, 128k (default: 64k)
# --dataset DATASET Task name (default: niah_single_1)
# --sample INDEX Sample index (default: 0)
# --gpu GPU_ID GPU to use (default: 0)
# --num-gpu-blocks N Number of GPU blocks/slots (default: 4)
# --block-size SIZE KV cache block size (default: 4096)
# --no-offload Disable CPU offload
#
# Output:
# results/nsys/attention_offload_<timestamp>.nsys-rep
# results/nsys/<policy>_<gpuonly|offload>_<ctx-len>_blk<size>_<timestamp>.nsys-rep
#
# View results:
# nsight-sys results/nsys/attention_offload_<timestamp>.nsys-rep
# Examples:
# bash scripts/profile_offload.sh
# bash scripts/profile_offload.sh --policy xattn --ctx-len 128k --no-offload
# bash scripts/profile_offload.sh --policy full --ctx-len 32k --num-gpu-blocks 8
set -e
# Default configuration
POLICY="full"
CTX_LEN="64k"
DATASET="niah_single_1"
SAMPLE_INDEX="0"
GPU_ID="0"
NUM_GPU_BLOCKS="4"
BLOCK_SIZE="4096"
GPU_UTIL="0.9"
ENABLE_OFFLOAD="--enable-offload"
MODEL=""
DATA_DIR_OVERRIDE=""
# Configuration
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
--policy)
POLICY="$2"
shift 2
;;
--ctx-len)
CTX_LEN="$2"
shift 2
;;
--dataset)
DATASET="$2"
shift 2
;;
--sample)
SAMPLE_INDEX="$2"
shift 2
;;
--gpu)
GPU_ID="$2"
shift 2
;;
--no-offload)
ENABLE_OFFLOAD=""
shift
;;
--num-gpu-blocks)
NUM_GPU_BLOCKS="$2"
shift 2
;;
--gpu-util)
GPU_UTIL="$2"
shift 2
;;
--block-size)
BLOCK_SIZE="$2"
shift 2
;;
--model)
MODEL="$2"
shift 2
;;
--data-dir)
DATA_DIR_OVERRIDE="$2"
shift 2
;;
-h|--help)
echo "Usage: $0 [options]"
echo ""
echo "Options:"
echo " --policy POLICY Sparse policy name (default: full)"
echo " --ctx-len LENGTH Context length: 32k, 64k, 128k (default: 64k)"
echo " --block-size SIZE KV cache block size (default: 4096)"
echo " --dataset DATASET Task name (default: niah_single_1)"
echo " --sample INDEX Sample index (default: 0)"
echo " --gpu GPU_ID GPU to use (default: 0)"
echo " --gpu-util UTIL GPU memory utilization (default: 0.9)"
echo " --no-offload Disable CPU offload"
echo " --num-gpu-blocks N Number of GPU blocks/slots (default: 4)"
exit 0
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done
# Path configuration
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
OUTPUT_DIR="$PROJECT_ROOT/results/nsys"
TEST_SCRIPT="$PROJECT_ROOT/tests/test_attention_offload.py"
TEST_SCRIPT="$PROJECT_ROOT/tests/test_ruler.py"
DATA_DIR="$PROJECT_ROOT/tests/data/ruler_${CTX_LEN}"
# Set max-model-len based on context length
case "$CTX_LEN" in
32k)
MAX_MODEL_LEN=36000
;;
64k)
MAX_MODEL_LEN=72000
;;
128k)
MAX_MODEL_LEN=144000
;;
256k)
MAX_MODEL_LEN=288000
;;
512k)
MAX_MODEL_LEN=576000
;;
1m)
MAX_MODEL_LEN=1100000
;;
*)
MAX_MODEL_LEN=72000
;;
esac
# Override DATA_DIR if specified
if [ -n "$DATA_DIR_OVERRIDE" ]; then
DATA_DIR="$DATA_DIR_OVERRIDE"
fi
# Create output directory if needed
mkdir -p "$OUTPUT_DIR"
# Generate timestamp for unique filename
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
OUTPUT_FILE="$OUTPUT_DIR/attention_offload_$TIMESTAMP"
if [ -n "$ENABLE_OFFLOAD" ]; then
OFFLOAD_TAG="offload"
else
OFFLOAD_TAG="gpuonly"
fi
OUTPUT_FILE="$OUTPUT_DIR/${POLICY}_${OFFLOAD_TAG}_${CTX_LEN}_blk${BLOCK_SIZE}_${TIMESTAMP}"
echo "============================================================"
echo "NVIDIA Nsight Systems Profiling"
echo "============================================================"
echo "Test script: $TEST_SCRIPT"
echo "Policy: $POLICY"
echo "Offload: $OFFLOAD_TAG"
echo "Context: $CTX_LEN"
echo "Block Size: $BLOCK_SIZE"
echo "Dataset: $DATASET"
echo "Sample: $SAMPLE_INDEX"
echo "GPU: $GPU_ID"
echo "GPU Blocks: $NUM_GPU_BLOCKS"
echo "Data Dir: $DATA_DIR"
echo "Output file: $OUTPUT_FILE.nsys-rep"
echo ""
@@ -43,13 +179,59 @@ echo ""
echo "Running nsys profile..."
echo ""
# Map policy name to internal enum name
# User-friendly name -> SparsePolicyType enum name
case "$POLICY" in
xattn)
POLICY_ENUM="XATTN_BSA"
;;
*)
POLICY_ENUM="$POLICY"
;;
esac
# Build sparse policy argument
SPARSE_POLICY_ARG=""
if [ -n "$POLICY_ENUM" ] && [ "$POLICY_ENUM" != "full" ]; then
SPARSE_POLICY_ARG="--sparse-policy $POLICY_ENUM"
fi
# Build model argument
MODEL_ARG=""
if [ -n "$MODEL" ]; then
MODEL_ARG="--model $MODEL"
fi
# Run nsys profile and capture exit code
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" \
nsys profile \
--trace=cuda,nvtx,osrt,cudnn,cublas \
--cuda-memory-usage=true \
--stats=true \
--trace=cuda,nvtx \
--force-overwrite=true \
--output="$OUTPUT_FILE" \
python "$TEST_SCRIPT"
python "$TEST_SCRIPT" \
--data-dir "$DATA_DIR" \
--datasets "$DATASET" \
--sample-indices "$SAMPLE_INDEX" \
--num-gpu-blocks "$NUM_GPU_BLOCKS" \
--block-size "$BLOCK_SIZE" \
--max-model-len "$MAX_MODEL_LEN" \
--gpu-utilization "$GPU_UTIL" \
$ENABLE_OFFLOAD \
$SPARSE_POLICY_ARG \
$MODEL_ARG \
--quiet
EXIT_CODE=$?
# If test failed, delete the output file
if [ $EXIT_CODE -ne 0 ]; then
echo ""
echo "============================================================"
echo "Test FAILED! Cleaning up..."
echo "============================================================"
rm -f "$OUTPUT_FILE.nsys-rep"
echo "Deleted: $OUTPUT_FILE.nsys-rep"
exit $EXIT_CODE
fi
echo ""
echo "============================================================"

View File

@@ -1,357 +0,0 @@
# Task Plan: CUDA Graph 优化 Offload Mode Decode
## 目标
为 nanovllm 的 CPU offload 模式添加 CUDA Graph 支持,加速 decode 阶段的计算。
## 问题分析
### Transformer 层的完整结构
```
Qwen3DecoderLayer.forward:
├── input_layernorm (RMSNorm) # ✅ 纯 GPU
├── self_attn:
│ ├── qkv_proj (Linear) # ✅ 纯 GPU
│ ├── q_norm, k_norm (RMSNorm) # ✅ 纯 GPU
│ ├── rotary_emb # ✅ 纯 GPU
│ ├── attn._chunked_decode_attention: # ⚠️ 包含 CPU→GPU
│ │ ├── H2D transfer # ❌ 不能 graph
│ │ ├── flash_attn_with_lse # ✅ 可以 graph
│ │ └── merge # ✅ 纯 GPU
│ └── o_proj (Linear) # ✅ 纯 GPU
├── post_attention_layernorm # ✅ 纯 GPU
└── mlp (FFN: gate, up, down) # ✅ 纯 GPU
```
**核心问题**H2D 传输被嵌在 attention 中间,打断了整层的 graph 捕获。
### 可能的方案
| 方案 | 描述 | 优点 | 缺点 |
|------|------|------|------|
| A. 分段 Graph | 将层拆分为 pre/post attention 两段 | 覆盖面广 | 改动大,需拆分层执行 |
| B. 只 Graph Attention | 只优化 flash_attn_with_lse | 改动小 | 优化效果有限 |
| C. 重构执行流程 | 完全重写 model forward | 最优效果 | 工作量巨大 |
### 推荐:方案 A分段 Graph
将每层拆分为两个 graph
1. **pre_attention_graph**: `norm → qkv_proj → q/k_norm → rotary`
2. **post_attention_graph**: `o_proj → norm → FFN`
中间的 `_chunked_decode_attention` 保持 eager包含 H2D但内部的 `flash_attn_with_lse` 使用 graph。
---
## 当前状态分析
### 现有 CUDA Graph 实现
**文件**: `nanovllm/engine/model_runner.py`
| 方法 | 行号 | 功能 |
|------|------|------|
| `capture_cudagraph()` | 682-717 | 为不同 batch size 捕获完整 model forward |
| `run_model()` | 415-436 | 决定使用 eager 还是 graph replay |
**关键逻辑** (`run_model`):
```python
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
```
**问题**: `run_chunked_offload_decode` 设置 `is_chunked_prefill=True`,导致**永远使用 eager mode**。
### Offload Decode 流程
**文件**: `nanovllm/kvcache/sparse/full_policy.py`
`_decode_ring_buffer_pipeline()` (L304-379):
```
for block in cpu_blocks:
1. wait_slot_layer(slot) # 等待 H2D 完成
2. k, v = get_kv_for_slot(slot) # 获取 KV
3. o, lse = flash_attn_with_lse() # ⭐ 纯 GPU 计算
4. record_slot_compute_done(slot) # 标记计算完成
5. load_next_block() # 启动下一个 H2D
6. merge_attention_outputs() # ⭐ 纯 GPU 计算
```
**可 Graph 化的部分**:
- `flash_attn_with_lse()` - 纯 GPU 计算
- 不可 Graph 化: H2D 传输、动态 merge
## 验证结果
**测试文件**: `tests/test_chunk_attention_graph_reuse.py`
| 测试 | 结果 |
|------|------|
| 2 个 Graph 复用于所有层和所有 chunk | ✅ PASSED |
| copy_() 更新 static tensors | ✅ 有效 |
| Eager merge | ✅ 用户已接受 |
**结论**: 只需 2 个 graphcausal + non-causal通过 copy_() 复用。
---
## 修改计划(方案 A分段 Graph
### 架构设计
```
每层执行流程Offload Decode:
┌─────────────────────────────────────────────────────────────┐
│ PRE-ATTENTION GRAPH (可复用于所有层) │
│ input_layernorm → qkv_proj → q/k_norm → rotary → split Q │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ CHUNKED ATTENTION (Eager + 部分 Graph) │
│ for block in cpu_blocks: │
│ H2D transfer (eager) │
│ flash_attn_with_lse (GRAPH - 2个可复用) │
│ merge (eager) │
│ decode_buffer attention (eager) │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ POST-ATTENTION GRAPH (可复用于所有层) │
│ o_proj → post_layernorm → gate_proj → up_proj → SiLU │
│ → down_proj → residual │
└─────────────────────────────────────────────────────────────┘
```
**总共需要的 Graph 数量**:
- 1 个 pre_attention_graph所有层复用
- 2 个 attention_graphcausal + non-causal所有层复用
- 1 个 post_attention_graph所有层复用
- **总计: 4 个 graph**
---
### Phase 1: 拆分 DecoderLayer 执行
**目标**: 将 `Qwen3DecoderLayer.forward` 拆分为可独立调用的三段
**修改文件**: `nanovllm/models/qwen3.py`
**新增方法**:
```python
class Qwen3DecoderLayer:
def forward_pre_attention(self, positions, hidden_states, residual):
"""Pre-attention: norm → qkv → rotary → 返回 q, k, v"""
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
qkv = self.self_attn.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q = self.self_attn.q_norm(q)
k = self.self_attn.k_norm(k)
q, k = self.self_attn.rotary_emb(positions, q, k)
return q, k, v, hidden_states, residual
def forward_post_attention(self, attn_output, hidden_states, residual):
"""Post-attention: o_proj → norm → FFN"""
output = self.self_attn.o_proj(attn_output.flatten(1, -1))
hidden_states, residual = self.post_attention_layernorm(output, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
```
**状态**: `pending`
---
### Phase 2: 捕获 Pre/Post Attention Graph
**目标**: 捕获 pre_attention 和 post_attention 的 graph
**修改文件**: `nanovllm/engine/model_runner.py`
**新增方法**: `capture_offload_layer_graphs()`
```python
def capture_offload_layer_graphs(self):
"""捕获 offload mode 的 layer graphs"""
# 获取任意一层作为模板(所有层结构相同)
layer = self.model.model.layers[0]
# Static tensors
static_hidden = torch.zeros(1, self.hidden_size, ...)
static_residual = torch.zeros(1, self.hidden_size, ...)
static_positions = torch.zeros(1, ...)
# Pre-attention graph
self.pre_attn_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.pre_attn_graph):
static_q, static_k, static_v, _, _ = layer.forward_pre_attention(
static_positions, static_hidden, static_residual
)
# Post-attention graph
self.post_attn_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.post_attn_graph):
_, _ = layer.forward_post_attention(
static_attn_output, static_hidden, static_residual
)
```
**状态**: `pending`
---
### Phase 3: 捕获 Attention Graph
**目标**: 捕获 2 个 attention graphcausal + non-causal
**修改文件**: `nanovllm/kvcache/offload_engine.py`
```python
class OffloadEngine:
def capture_attention_graphs(self):
"""捕获 attention graphs复用于所有层"""
self.attn_graph_causal = self._capture_attn_graph(causal=True)
self.attn_graph_non_causal = self._capture_attn_graph(causal=False)
def _capture_attn_graph(self, causal: bool):
static_q = torch.zeros(1, 1, num_heads, head_dim, ...)
static_k = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
static_v = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output, lse = flash_attn_with_lse(static_q, static_k, static_v,
self.scale, causal)
return AttentionGraph(graph, static_q, static_k, static_v, output, lse)
```
**状态**: `pending`
---
### Phase 4: 修改 Offload Decode 执行流程
**目标**: 使用 graph replay 执行 offload decode
**修改文件**: `nanovllm/engine/model_runner.py`
**修改方法**: `run_chunked_offload_decode()`
```python
def run_chunked_offload_decode_with_graph(self, seqs):
"""使用 graph 加速的 offload decode"""
seq = seqs[0]
# 准备输入
input_ids = torch.tensor([seq.last_token], ...)
positions = torch.tensor([len(seq) - 1], ...)
# Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
for layer_id, layer in enumerate(self.model.model.layers):
# Phase 1: Pre-attention (GRAPH)
self.pre_attn_vars["hidden"].copy_(hidden_states)
self.pre_attn_vars["residual"].copy_(residual) if residual else None
self.pre_attn_vars["positions"].copy_(positions)
self.pre_attn_graph.replay()
q = self.pre_attn_vars["q"].clone()
k = self.pre_attn_vars["k"].clone()
v = self.pre_attn_vars["v"].clone()
# Phase 2: Chunked Attention (Eager + Graph)
attn_output = self._chunked_attention_with_graph(q, k, v, layer_id, ...)
# Phase 3: Post-attention (GRAPH)
self.post_attn_vars["attn_output"].copy_(attn_output)
self.post_attn_graph.replay()
hidden_states = self.post_attn_vars["hidden"].clone()
residual = self.post_attn_vars["residual"].clone()
# LM head
logits = self.model.compute_logits(hidden_states)
return logits
```
**状态**: `pending`
---
### Phase 5: 修改 Ring Buffer Pipeline
**目标**: 在 attention 内部使用 graph
**修改文件**: `nanovllm/kvcache/sparse/full_policy.py`
**修改**: `_decode_ring_buffer_pipeline()` 中的 `flash_attn_with_lse` 调用
```python
# 当前eager
prev_o, prev_lse = flash_attn_with_lse(q, k, v, scale, causal=False)
# 修改为graph replay
graph = offload_engine.attn_graph_non_causal
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
prev_o = graph.static_output.clone()
prev_lse = graph.static_lse.clone()
```
**状态**: `pending`
---
### Phase 6: 添加配置开关
**修改文件**: `nanovllm/config.py`
```python
enable_offload_graph: bool = True # 默认启用
```
**状态**: `pending`
---
## 文件修改清单
| 文件 | 修改类型 | 说明 |
|------|----------|------|
| `nanovllm/engine/model_runner.py` | 新增方法 | `capture_offload_attention_graph()` |
| `nanovllm/kvcache/offload_engine.py` | 新增属性+方法 | Graph 存储和访问 |
| `nanovllm/kvcache/sparse/full_policy.py` | 修改方法 | 使用 graph replay |
| `nanovllm/config.py` | 新增配置 | `enable_offload_graph` |
---
## 风险和注意事项
1. **Graph 捕获时机**: 需要在 KV cache 分配后、第一次 decode 前捕获
2. **Chunk size 匹配**: Graph 的 chunk_size 必须和 block_size 一致
3. **多 GPU**: Graph 需要在每个 GPU 上分别捕获
4. **内存**: 2 个 graph 的额外内存开销很小
---
## 测试计划
1. **单元测试**: 验证 graph replay 结果正确
2. **集成测试**: 运行 `test_needle.py --enable-offload --input-len 32768`
3. **性能测试**: 对比 eager vs graph 的 decode 延迟
---
## 预期收益
- Decode 阶段 attention 计算加速(减少 kernel launch overhead
- 与现有 ring buffer pipeline 兼容
- 内存开销极小(只有 2 个额外 graph

View File

@@ -1,757 +0,0 @@
"""
Custom Qwen3 implementation using only torch and transformers.
This file provides a clean reference implementation for understanding the model computation graph.
Computation Graph:
==================
Input: token_ids [batch, seq_len]
┌─────────────┐
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
└─────────────┘
hidden_states [batch, seq_len, hidden_size]
┌─────────────────────────────────────────────────────────┐
│ Decoder Layer (x N) │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Self Attention Block │ │
│ │ │ │
│ │ input_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3Attention │ │ │
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
│ │ │ V = v_proj(x) → reshape │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ attn_output = attention(Q, K, V) │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ output = o_proj(attn_output) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + attn_output │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ MLP Block │ │
│ │ │ │
│ │ post_attention_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3MLP │ │ │
│ │ │ gate = gate_proj(x) │ │ │
│ │ │ up = up_proj(x) │ │ │
│ │ │ output = down_proj(silu(gate) * up) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + mlp_output │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
┌─────────────┐
│ norm │ final RMSNorm
└─────────────┘
┌─────────────┐
│ lm_head │ [hidden_size, vocab_size]
└─────────────┘
logits [batch, seq_len, vocab_size]
"""
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Qwen3RMSNorm(nn.Module):
"""RMSNorm implementation."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
class Qwen3RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
position_ids: Position indices [batch, seq_len]
Returns:
cos, sin: [batch, seq_len, head_dim]
"""
# inv_freq: [dim/2]
# position_ids: [batch, seq_len]
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
# freqs: [batch, dim/2, seq_len]
freqs = inv_freq_expanded @ position_ids_expanded
# freqs: [batch, seq_len, dim/2]
freqs = freqs.transpose(1, 2)
# Duplicate for full head_dim: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(x.dtype)
sin = emb.sin().to(x.dtype)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to Q and K.
Args:
q: [batch, num_heads, seq_len, head_dim]
k: [batch, num_kv_heads, seq_len, head_dim]
cos: [batch, seq_len, head_dim]
sin: [batch, seq_len, head_dim]
Returns:
q_embed, k_embed with same shapes as inputs
"""
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3Attention(nn.Module):
"""
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
apply_rotary_pos_emb(Q, K)
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
attention_bias: bool = False,
rms_norm_eps: float = 1e-6,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = head_dim
self.num_kv_groups = num_attention_heads // num_key_value_heads
self.layer_idx = layer_idx
# Scaling factor
self.scaling = head_dim ** -0.5
# QKV projections
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
# QK normalization (Qwen3 specific)
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
# Rotary embeddings
self.rotary_emb = Qwen3RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
past_key_value: (k_cache, v_cache) from previous steps
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V tensors for debugging
Returns:
output: [batch, seq_len, hidden_size]
past_key_value: Updated cache (if use_cache=True)
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
"""
batch_size, seq_len, _ = hidden_states.shape
# === QKV Projections ===
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
# Reshape to [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# === QK Normalization (Qwen3 specific) ===
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to [batch, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# === Rotary Position Embeddings ===
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# === KV Cache Update ===
if past_key_value is not None:
k_cache, v_cache = past_key_value
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_past_key_value = (k, v) if use_cache else None
# === Grouped Query Attention (expand KV heads if needed) ===
if self.num_kv_groups > 1:
# Repeat KV for each query group
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
# === Attention Computation (using SDPA for memory efficiency) ===
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
# is_causal only works when q_len == kv_len (prefill), not during decode
q_len, kv_len = q.shape[2], k.shape[2]
is_causal = (q_len == kv_len) and (q_len > 1)
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=is_causal,
scale=self.scaling,
) # [batch, num_heads, seq_len, head_dim]
# === Output Projection ===
# Transpose back and reshape
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
output = self.o_proj(attn_output)
# Optional QKV output for debugging
qkv_dict = None
if output_qkv:
qkv_dict = {
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
}
return output, new_past_key_value, qkv_dict
class Qwen3MLP(nn.Module):
"""
Qwen3 MLP with SwiGLU activation.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
└──► up_proj ──► up [batch, seq_len, intermediate_size]
silu(gate) * up
down_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(F.silu(gate) * up)
class Qwen3DecoderLayer(nn.Module):
"""Single Qwen3 Decoder Layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
layer_idx: int = 0,
):
super().__init__()
self.layer_idx = layer_idx
# Pre-attention LayerNorm
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Self-attention
self.self_attn = Qwen3Attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
attention_bias=attention_bias,
rms_norm_eps=rms_norm_eps,
layer_idx=layer_idx,
)
# Post-attention LayerNorm
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: Causal attention mask
past_key_value: KV cache for this layer
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V for debugging
Returns:
hidden_states: [batch, seq_len, hidden_size]
past_key_value: Updated cache
qkv_dict: QKV tensors (if output_qkv=True)
"""
# === Self Attention Block ===
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, new_past_key_value, qkv_dict = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_qkv=output_qkv,
)
hidden_states = residual + attn_output
# === MLP Block ===
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, new_past_key_value, qkv_dict
class Qwen3Model(nn.Module):
"""Qwen3 Transformer Model (without LM head)."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers
# Token embeddings
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# Decoder layers
self.layers = nn.ModuleList([
Qwen3DecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
layer_idx=i,
)
for i in range(num_hidden_layers)
])
# Final LayerNorm
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
position_ids: [batch, seq_len]
attention_mask: [batch, seq_len] or pre-computed 4D mask
past_key_values: List of (k, v) tuples for each layer
use_cache: Whether to return new cache
output_qkv_layers: List of layer indices to output QKV for
Returns:
hidden_states: [batch, seq_len, hidden_size]
new_past_key_values: Updated cache
qkv_outputs: {layer_idx: qkv_dict}
"""
batch_size, seq_len = input_ids.shape
# Embedding
hidden_states = self.embed_tokens(input_ids)
# Position IDs
if position_ids is None:
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Attention mask (create causal mask if not provided)
if attention_mask is None or attention_mask.dim() == 2:
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
causal_mask = torch.triu(
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
diagonal=kv_seq_len - seq_len + 1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
# Initialize cache list
new_past_key_values = [] if use_cache else None
qkv_outputs = {} if output_qkv_layers else None
# Decoder layers
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values else None
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
hidden_states, new_kv, qkv_dict = layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_kv,
use_cache=use_cache,
output_qkv=output_qkv,
)
if use_cache:
new_past_key_values.append(new_kv)
if qkv_dict is not None:
qkv_outputs[i] = qkv_dict
# Final norm
hidden_states = self.norm(hidden_states)
return hidden_states, new_past_key_values, qkv_outputs
class Qwen3ForCausalLM(nn.Module):
"""Qwen3 Model with Language Modeling head."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
tie_word_embeddings: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
# Transformer model
self.model = Qwen3Model(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
)
# LM head
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
... (same as Qwen3Model)
Returns:
logits: [batch, seq_len, vocab_size]
past_key_values: Updated KV cache
qkv_outputs: QKV tensors for specified layers
"""
hidden_states, new_past_key_values, qkv_outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_qkv_layers=output_qkv_layers,
)
logits = self.lm_head(hidden_states)
return logits, new_past_key_values, qkv_outputs
@classmethod
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
"""
Load weights from a pretrained Qwen3 model.
Args:
model_path: Path to model directory containing config.json and model weights
dtype: Data type for model weights
Returns:
Initialized Qwen3ForCausalLM model
"""
import json
import os
from safetensors.torch import load_file
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
# Create model
model = cls(
vocab_size=config["vocab_size"],
hidden_size=config["hidden_size"],
intermediate_size=config["intermediate_size"],
num_hidden_layers=config["num_hidden_layers"],
num_attention_heads=config["num_attention_heads"],
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
max_position_embeddings=config.get("max_position_embeddings", 32768),
rope_theta=config.get("rope_theta", 10000.0),
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
attention_bias=config.get("attention_bias", False),
mlp_bias=config.get("mlp_bias", False),
tie_word_embeddings=config.get("tie_word_embeddings", True),
)
# Load weights
weight_files = sorted([
f for f in os.listdir(model_path)
if f.endswith(".safetensors")
])
state_dict = {}
for wf in weight_files:
state_dict.update(load_file(os.path.join(model_path, wf)))
# Load into model
model.load_state_dict(state_dict, strict=False)
# Tie lm_head weights to embed_tokens if configured
if model.tie_word_embeddings:
model.lm_head.weight = model.model.embed_tokens.weight
model = model.to(dtype)
return model
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 32,
temperature: float = 1.0,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""Simple autoregressive generation."""
device = input_ids.device
batch_size, seq_len = input_ids.shape
past_key_values = None
generated = input_ids.clone()
for _ in range(max_new_tokens):
if past_key_values is None:
current_input = generated
else:
current_input = generated[:, -1:]
logits, past_key_values, _ = self(
input_ids=current_input,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = logits[:, -1, :]
if temperature > 0 and do_sample:
next_token_logits = next_token_logits / temperature
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def print_computation_graph():
"""Print the computation graph for reference."""
print(__doc__)
if __name__ == "__main__":
print_computation_graph()

View File

@@ -1,151 +0,0 @@
#!/usr/bin/env python3
"""
Test: Pre-allocated chunk pair graphs for block sparse attention.
Each (Q_chunk, K_chunk) pair has its own captured CUDA graph.
Zero copy_() during replay - all data pre-filled.
Usage:
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph.py
"""
from dataclasses import dataclass
from typing import List, Optional
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@dataclass
class ChunkAttentionGraph:
"""Container for a captured chunk attention graph."""
graph: torch.cuda.CUDAGraph
static_q: torch.Tensor
static_k: torch.Tensor
static_v: torch.Tensor
static_output: torch.Tensor
static_lse: torch.Tensor
causal: bool
def capture_chunk_attention_graph(
chunk_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
scale: float,
device: torch.device,
dtype: torch.dtype,
causal: bool = False,
) -> ChunkAttentionGraph:
"""Capture a CUDA graph for single chunk attention."""
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_q.normal_()
static_k.normal_()
static_v.normal_()
# Warmup
with torch.inference_mode():
for _ in range(3):
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
return ChunkAttentionGraph(
graph=graph,
static_q=static_q,
static_k=static_k,
static_v=static_v,
static_output=static_output,
static_lse=static_lse,
causal=causal,
)
def main():
device = torch.device("cuda")
dtype = torch.bfloat16
chunk_size = 64
num_chunks = 4
num_heads = 8
num_kv_heads = 8
head_dim = 64
scale = 1.0 / (head_dim ** 0.5)
seq_len = chunk_size * num_chunks
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}")
print(f"Total graphs: {num_chunks * (num_chunks + 1) // 2}")
# Test data
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
# Reference
with torch.inference_mode():
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
# Capture all graphs
graphs: List[List[Optional[ChunkAttentionGraph]]] = [[None] * num_chunks for _ in range(num_chunks)]
for q_idx in range(num_chunks):
for k_idx in range(q_idx + 1):
graphs[q_idx][k_idx] = capture_chunk_attention_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype,
causal=(k_idx == q_idx)
)
print("All graphs captured")
# Pre-fill static tensors
for q_idx in range(num_chunks):
for k_idx in range(q_idx + 1):
g = graphs[q_idx][k_idx]
g.static_q.copy_(full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size])
g.static_k.copy_(full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
g.static_v.copy_(full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
print("Static tensors pre-filled")
# Replay and merge
chunked_output = torch.zeros_like(full_output)
for q_idx in range(num_chunks):
acc_out, acc_lse = None, None
for k_idx in range(q_idx + 1):
g = graphs[q_idx][k_idx]
g.graph.replay()
out, lse = g.static_output.clone(), g.static_lse.clone()
if acc_out is None:
acc_out, acc_lse = out, lse
else:
with torch.inference_mode():
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
torch.cuda.synchronize()
# Compare
all_pass = True
for q_idx in range(num_chunks):
s, e = q_idx * chunk_size, (q_idx + 1) * chunk_size
diff = (full_output[:, s:e] - chunked_output[:, s:e]).abs().max().item()
status = "" if diff < 1e-2 else ""
print(f"Q[{q_idx}]: max_diff={diff:.2e} {status}")
if diff >= 1e-2:
all_pass = False
print("✅ PASSED" if all_pass else "❌ FAILED")
if __name__ == "__main__":
main()

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env python3
"""
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
Key insight: LLM layers have identical computation structure.
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
Usage:
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
"""
from dataclasses import dataclass
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@dataclass
class ReusableChunkGraph:
"""A single graph that can be reused with copy_() updates."""
graph: torch.cuda.CUDAGraph
static_q: torch.Tensor
static_k: torch.Tensor
static_v: torch.Tensor
static_output: torch.Tensor
static_lse: torch.Tensor
def capture_reusable_graph(
chunk_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
scale: float,
device: torch.device,
dtype: torch.dtype,
causal: bool,
) -> ReusableChunkGraph:
"""Capture ONE graph to be reused for all chunk pairs."""
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_q.normal_()
static_k.normal_()
static_v.normal_()
# Warmup
with torch.inference_mode():
for _ in range(3):
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
return ReusableChunkGraph(
graph=graph,
static_q=static_q,
static_k=static_k,
static_v=static_v,
static_output=static_output,
static_lse=static_lse,
)
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Replay graph after updating static tensors with copy_()."""
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
return graph.static_output.clone(), graph.static_lse.clone()
def main():
device = torch.device("cuda")
dtype = torch.bfloat16
chunk_size = 64
num_chunks = 4
num_layers = 3 # Simulate multiple layers
num_heads = 8
num_kv_heads = 8
head_dim = 64
scale = 1.0 / (head_dim ** 0.5)
seq_len = chunk_size * num_chunks
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
# Capture only 2 graphs
graph_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
)
graph_non_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
)
print("2 graphs captured (causal + non-causal)")
all_pass = True
for layer_id in range(num_layers):
# Different Q/K/V for each layer (simulating different layer outputs)
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
# Reference: full causal attention
with torch.inference_mode():
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
# Chunked with graph reuse
chunked_output = torch.zeros_like(full_output)
for q_idx in range(num_chunks):
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
acc_out, acc_lse = None, None
for k_idx in range(q_idx + 1):
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
# Reuse graph with copy_()
graph = graph_causal if k_idx == q_idx else graph_non_causal
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
if acc_out is None:
acc_out, acc_lse = out, lse
else:
with torch.inference_mode():
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
torch.cuda.synchronize()
# Compare
max_diff = (full_output - chunked_output).abs().max().item()
status = "" if max_diff < 1e-2 else ""
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
if max_diff >= 1e-2:
all_pass = False
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
if __name__ == "__main__":
main()

View File

@@ -1,357 +0,0 @@
#!/usr/bin/env python3
"""
CUDA Graph Memory Analysis Test
This script analyzes the memory overhead of CUDA Graph at each stage:
1. Model loading
2. StaticCache allocation
3. Warmup runs
4. Graph capture
5. Graph replay
Usage:
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048
"""
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache
def get_memory_mb():
"""Get current allocated memory in MB."""
return torch.cuda.memory_allocated() / 1024**2
def get_memory_gb():
"""Get current allocated memory in GB."""
return torch.cuda.memory_allocated() / 1024**3
def get_peak_memory_gb():
"""Get peak allocated memory in GB."""
return torch.cuda.max_memory_allocated() / 1024**3
def print_separator(title=None):
"""Print a separator line."""
if title:
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}")
else:
print("-" * 70)
def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1):
"""
Test memory usage at each stage of CUDA Graph setup.
Args:
model_path: Path to the model
max_cache_len: Maximum cache length for StaticCache
batch_size: Batch size for inference
"""
print_separator("CUDA Graph Memory Analysis")
print(f"Model: {model_path}")
print(f"Max cache length: {max_cache_len}")
print(f"Batch size: {batch_size}")
results = {}
# Stage 0: Initial
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
results["initial"] = get_memory_mb()
# Stage 1: Load model
print_separator("Stage 1: Model Loading")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
results["after_model"] = get_memory_mb()
model_size = results["after_model"] - results["initial"]
print(f" Memory: {results['after_model']:.0f} MB")
print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)")
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Stage 2: Allocate StaticCache
print_separator("Stage 2: StaticCache Allocation")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
static_cache = StaticCache(
config=config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
)
results["after_cache"] = get_memory_mb()
cache_size = results["after_cache"] - before
print(f" Memory: {results['after_cache']:.0f} MB")
print(f" StaticCache size: {cache_size:.0f} MB")
# Calculate theoretical cache size
num_layers = config.num_hidden_layers
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
head_dim = config.hidden_size // config.num_attention_heads
dtype_size = 2 # bfloat16
theoretical_cache = (
num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size
) / (1024**2)
print(f" Theoretical: {theoretical_cache:.0f} MB")
print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)")
# Stage 3: Prepare static tensors
print_separator("Stage 3: Static Tensor Allocation")
before = get_memory_mb()
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
results["after_tensors"] = get_memory_mb()
tensor_size = results["after_tensors"] - before
print(f" Memory: {results['after_tensors']:.0f} MB")
print(f" Static tensors: {tensor_size:.2f} MB (negligible)")
# Stage 4: Warmup runs
print_separator("Stage 4: Warmup Runs (3 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for i in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
results["after_warmup"] = get_memory_mb()
results["warmup_peak"] = get_peak_memory_gb() * 1024
warmup_size = results["after_warmup"] - before
print(f" Memory: {results['after_warmup']:.0f} MB")
print(f" Peak: {results['warmup_peak']:.0f} MB")
print(f" Warmup overhead: {warmup_size:.0f} MB")
# Stage 5: CUDA Graph capture
print_separator("Stage 5: CUDA Graph Capture")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
static_logits = outputs.logits
torch.cuda.synchronize()
results["after_capture"] = get_memory_mb()
results["capture_peak"] = get_peak_memory_gb() * 1024
capture_size = results["after_capture"] - before
print(f" Memory: {results['after_capture']:.0f} MB")
print(f" Peak: {results['capture_peak']:.0f} MB")
print(f" Graph capture overhead: {capture_size:.0f} MB")
# Stage 6: Graph replay
print_separator("Stage 6: Graph Replay (10 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for _ in range(10):
static_input_ids.fill_(1)
static_cache_position.fill_(0)
graph.replay()
torch.cuda.synchronize()
results["after_replay"] = get_memory_mb()
results["replay_peak"] = get_peak_memory_gb() * 1024
replay_change = results["after_replay"] - before
print(f" Memory: {results['after_replay']:.0f} MB")
print(f" Peak: {results['replay_peak']:.0f} MB")
print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)")
# Summary
print_separator("SUMMARY")
total_overhead = results["after_capture"] - results["after_model"]
print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}")
print("-" * 50)
print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}")
print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}")
print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}")
print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}")
print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}")
print("-" * 50)
print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}")
print_separator("KEY FINDINGS")
print(f" 1. Model size: {model_size/1024:.2f} GB")
print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)")
print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)")
print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)")
print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB")
return results
def test_cache_length_scaling(model_path: str, cache_lengths: list):
"""
Test how memory scales with different cache lengths.
Args:
model_path: Path to the model
cache_lengths: List of cache lengths to test
"""
print_separator("Cache Length Scaling Test")
print(f"Model: {model_path}")
print(f"Cache lengths: {cache_lengths}")
# Load model once
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
model_mem = get_memory_mb()
results = []
for cache_len in cache_lengths:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Create cache and capture graph
static_cache = StaticCache(
config=config,
max_batch_size=1,
max_cache_len=cache_len,
device=device,
dtype=dtype,
)
static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
with torch.inference_mode():
# Warmup
for _ in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
total_mem = get_memory_mb()
overhead = total_mem - model_mem
results.append((cache_len, total_mem, overhead))
del static_cache, graph
torch.cuda.empty_cache()
# Print results
print()
print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}")
print("-" * 60)
for cache_len, total, overhead in results:
per_1k = overhead / (cache_len / 1000)
print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}")
return results
def main():
parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis")
parser.add_argument(
"--model",
type=str,
default="~/models/Qwen3-4B-Instruct-2507",
help="Model path",
)
parser.add_argument(
"--max-cache-len",
type=int,
default=1024,
help="Maximum cache length",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size",
)
parser.add_argument(
"--test-scaling",
action="store_true",
help="Test cache length scaling",
)
args = parser.parse_args()
model_path = os.path.expanduser(args.model)
if not torch.cuda.is_available():
print("CUDA is not available!")
return
print(f"Device: cuda:{torch.cuda.current_device()}")
print(f"GPU: {torch.cuda.get_device_name()}")
if args.test_scaling:
cache_lengths = [256, 512, 1024, 2048, 4096]
test_cache_length_scaling(model_path, cache_lengths)
else:
test_memory_stages(model_path, args.max_cache_len, args.batch_size)
print("\ntest_cudagraph_memory: PASSED")
if __name__ == "__main__":
main()

View File

@@ -1,254 +0,0 @@
"""
Needle-in-a-haystack test for LLM.
Tests: Long context retrieval capability with configurable sequence length.
NOTE: CPU offload mode has a known bug that causes incorrect outputs for
sequences longer than ~200 tokens. Use --no-offload for correctness testing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
enable_cpu_offload: bool = False,
enable_quest: bool = False,
enable_xattn_bsa: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
sparse_samples: int = 128,
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test.
Args:
model_path: Path to model
max_model_len: Maximum model context length
input_len: Target input sequence length
num_gpu_blocks: Number of GPU blocks for offload
block_size: KV cache block size
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_xattn_bsa: Enable XAttention BSA sparse attention (prefill-only)
sparse_topk: Top-K blocks for Quest
sparse_threshold: Threshold for sparse selection (Quest/XAttention BSA)
sparse_samples: Samples per chunk for XAttention BSA estimation
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
# Determine sparse policy
if enable_xattn_bsa:
sparse_policy = SparsePolicyType.XATTN_BSA
elif enable_quest:
sparse_policy = SparsePolicyType.QUEST
else:
sparse_policy = SparsePolicyType.FULL
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload:
print(f"Sparse policy: {sparse_policy.name}")
if sparse_policy == SparsePolicyType.QUEST:
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
elif sparse_policy == SparsePolicyType.XATTN_BSA:
print(f" XAttention BSA: threshold={sparse_threshold}, samples={sparse_samples}")
print(f"{'='*60}\n")
# 1. Initialize LLM
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy
if sparse_policy == SparsePolicyType.QUEST:
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
elif sparse_policy == SparsePolicyType.XATTN_BSA:
llm_kwargs["sparse_threshold"] = float(sparse_threshold) / 10.0 # Convert to 0.0-1.0 range
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt
prompt, expected = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Generate output
sampling_params = SamplingParams(
temperature=0.6, # Moderate temperature
max_tokens=max_new_tokens,
)
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
# 4. Check result
output_text = outputs[0]["text"]
output_token_ids = outputs[0]["token_ids"]
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Needle-in-haystack test for long context LLM")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=128 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload (has known bug for long sequences)"
)
parser.add_argument(
"--enable-quest",
action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)"
)
parser.add_argument(
"--enable-xattn-bsa",
action="store_true",
help="Enable XAttention BSA sparse attention (prefill-only)"
)
parser.add_argument(
"--sparse-topk",
type=int,
default=8,
help="Top-K blocks for Quest sparse attention"
)
parser.add_argument(
"--sparse-threshold",
type=int,
default=4,
help="Apply sparse only when blocks > threshold (Quest) or attention threshold 0-9 (XAttention BSA)"
)
parser.add_argument(
"--sparse-samples",
type=int,
default=128,
help="Samples per chunk for XAttention BSA estimation"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
enable_xattn_bsa=args.enable_xattn_bsa,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
sparse_samples=args.sparse_samples,
verbose=True,
)
if passed:
print("test_needle: PASSED")
else:
print("test_needle: FAILED")
exit(1)

View File

@@ -1,176 +0,0 @@
"""
Needle-in-a-haystack reference test using pure torch + transformers.
This is a reference implementation for comparison with nanovllm.
Uses standard HuggingFace inference (no custom KV cache, no offload).
"""
import os
import argparse
import torch
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Main Test
# ============================================================
def run_needle_test(
model_path: str,
input_len: int,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
dtype: str = "auto",
verbose: bool = True,
) -> bool:
"""
Run a needle-in-haystack test using standard transformers inference.
Args:
model_path: Path to model
input_len: Target input sequence length
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
dtype: Model dtype ("auto", "float16", "bfloat16")
verbose: Print detailed output
Returns:
True if test passed, False otherwise
"""
if verbose:
print(f"\n{'='*60}")
print(f"Needle-in-Haystack Reference Test (torch + transformers)")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Input length: {input_len}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"Dtype: {dtype}")
print(f"{'='*60}\n")
# 1. Load tokenizer
print("[1/4] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# 2. Generate needle prompt
print("[2/4] Generating needle prompt...")
prompt, expected = generate_needle_prompt(
tokenizer=tokenizer,
target_length=input_len,
needle_position=needle_position,
needle_value=needle_value,
)
# 3. Load model
print("[3/4] Loading model...")
torch_dtype = {
"auto": torch.float16, # default to float16 for custom model
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}.get(dtype, torch.float16)
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# 4. Generate output
print("[4/4] Running inference...")
device = next(model.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
print(f" Input shape: {input_ids.shape}")
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=0.6,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
# Decode only the new tokens
new_token_ids = output_ids[0, input_ids.shape[1]:]
output_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
# 5. Check result
passed = check_needle_answer(output_text, expected)
if verbose:
print(f"\n{'='*60}")
print(f"Result")
print(f"{'='*60}")
print(f"Expected: {expected}")
print(f"Output tokens ({len(new_token_ids)}): {new_token_ids[:20].tolist()}")
print(f"Output: {output_text[:200]}...")
print(f"Status: {'PASSED' if passed else 'FAILED'}")
print(f"{'='*60}\n")
return passed
# ============================================================
# CLI Entry Point
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Needle-in-haystack reference test (torch + transformers)"
)
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--needle-position",
type=float,
default=0.5,
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
)
parser.add_argument(
"--needle-value",
type=str,
default="7492",
help="The secret value to hide"
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=32,
help="Maximum tokens to generate"
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "float16", "bfloat16"],
help="Model dtype"
)
args = parser.parse_args()
passed = run_needle_test(
model_path=args.model,
input_len=args.input_len,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,
dtype=args.dtype,
verbose=True,
)
if passed:
print("test_needle_ref: PASSED")
else:
print("test_needle_ref: FAILED")
exit(1)

View File

@@ -1,136 +0,0 @@
"""
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
Demonstrates the key limitation: scores are AVERAGED across heads,
so blocks strongly needed by one head but not others may be dropped.
This is the expected Quest behavior - not a bug.
"""
import torch
from nanovllm.kvcache.sparse import (
create_sparse_policy,
SparsePolicyType,
PolicyContext,
)
# ============================================================
# Test: Per-Head Score Averaging in GQA
# ============================================================
# Determine device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running test on device: {device}")
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
# topk=2 to make selection competitive
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
quest.initialize(
num_layers=1,
num_kv_heads=2,
head_dim=4,
num_cpu_blocks=6,
dtype=torch.float32,
device=device, # Metadata stored on GPU
)
metadata = quest.metadata
def set_key(block_id, head_id, values):
"""Set both key_min and key_max to same values for deterministic scoring."""
# Values need to be on the same device as metadata
tensor = torch.tensor(values, device=device)
metadata.key_min[block_id, 0, head_id, :] = tensor
metadata.key_max[block_id, 0, head_id, :] = tensor
# ============================================================
# Design: Different heads want different blocks
# ============================================================
#
# Query = [1,1,1,1] for all heads, so score = sum(key values)
#
# Block | Head 0 | Head 1 | Average | Result
# ------|--------|--------|---------|--------
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 2: Both heads want equally (highest average)
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# Block 3: Both heads want moderately
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
# Block 4: Head 0 strongly wants, Head 1 neutral
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
# Block 5: Head 1 strongly wants, Head 0 neutral
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
# ============================================================
# Run selection
# ============================================================
# Query on same device as metadata
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=0,
query=query,
is_prefill=False,
block_size=1024,
total_kv_len=6144,
)
available = list(range(6))
selected = quest.select_blocks(available, ctx)
# ============================================================
# Verify: Averaging behavior
# ============================================================
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
# Key insight: blocks 0 and 1 have score +4 for ONE head,
# but they cancel out due to averaging with the other head's -4
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
# But +2 < +3 (block 3), so they don't make the top-2
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
# Verify metadata is on correct device
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
print(f"✓ Metadata stored on {device.type.upper()}")
print("\ntest_quest_policy: PASSED")

View File

@@ -41,6 +41,7 @@ from pathlib import Path
from typing import List, Dict, Tuple, Optional
from nanovllm import LLM, SamplingParams
from nanovllm.utils.density_observer import DensityObserver
# ============================================================
@@ -48,11 +49,67 @@ from nanovllm import LLM, SamplingParams
# ============================================================
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k"
# ============================================================
# Chat Template Conversion
# ============================================================
def convert_llama_to_glm4_format(prompt: str) -> str:
"""
Convert Llama 3 chat template format to GLM-4 format.
Llama 3 format:
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{user_content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{assistant_prefix}
GLM-4 format:
[gMASK]<sop><|user|>
{user_content}<|assistant|>
{assistant_prefix}
"""
# Split into user content and assistant prefix
parts = prompt.split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
# Extract user content (remove Llama header tokens)
user_content = parts[0]
user_content = user_content.replace("<|begin_of_text|>", "")
user_content = user_content.replace("<|start_header_id|>user<|end_header_id|>", "")
user_content = user_content.strip()
# Extract assistant prefix (if exists)
assistant_prefix = ""
if len(parts) > 1:
assistant_prefix = parts[1].replace("<|eot_id|>", "").strip()
# Apply GLM-4 format
glm_prompt = f"[gMASK]<sop><|user|>\n{user_content}<|assistant|>"
if assistant_prefix:
glm_prompt += f"\n{assistant_prefix}"
return glm_prompt
def is_glm_model(model_path: str) -> bool:
"""Check if the model is a GLM model based on config."""
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
return getattr(config, 'model_type', '') == 'chatglm'
def convert_prompt_for_model(prompt: str, model_path: str) -> str:
"""Convert prompt format based on model type."""
if is_glm_model(model_path):
return convert_llama_to_glm4_format(prompt)
return prompt # Keep original format for Llama and other models
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
# Note: max_model_len must be > max_input_len to leave room for output tokens
# 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664
DEFAULT_MAX_MODEL_LEN = 65664
DEFAULT_MAX_NEW_TOKENS = 128 # Larger for multi-value tasks
DEFAULT_MAX_NEW_TOKENS = 16 # Sufficient for NIAH single-value answers
# Task categories for evaluation
NIAH_TASKS = ["niah_single_1", "niah_single_2", "niah_single_3",
@@ -161,6 +218,7 @@ def run_task_test(
verbose: bool = True,
llm_factory: Optional[callable] = None,
fresh_llm: bool = False,
model_path: Optional[str] = None,
) -> Dict:
"""
Run test for a single RULER task.
@@ -198,6 +256,9 @@ def run_task_test(
for sample in samples:
idx = sample.get("index", sample["_local_idx"])
prompt = sample["input"]
# Convert prompt format for GLM models
if model_path:
prompt = convert_prompt_for_model(prompt, model_path)
expected = sample["outputs"]
# Fresh LLM mode: reinitialize for each sample
@@ -263,7 +324,7 @@ def run_ruler_benchmark(
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
enable_cpu_offload: bool = False,
num_gpu_blocks: int = 4,
block_size: int = 1024,
block_size: int = 4096,
num_kv_buffers: int = 4,
gpu_utilization: float = 0.9,
enforce_eager: bool = True,
@@ -274,6 +335,8 @@ def run_ruler_benchmark(
sparse_threshold: float = 0.9,
sparse_samples: int = 128,
sparse_block_size: int = 128,
sparse_stride: int = 8,
dtype: Optional[str] = None,
) -> Dict:
"""
Run RULER benchmark on multiple tasks.
@@ -319,6 +382,16 @@ def run_ruler_benchmark(
print(f"Fresh LLM mode: {fresh_llm}")
print(f"{'='*60}")
# Enable DensityObserver for XAttention BSA
if sparse_policy and sparse_policy.upper() == "XATTN_BSA":
DensityObserver.enable()
DensityObserver.complete_reset()
# Set mode for correct density interpretation
DensityObserver.set_mode("offload" if enable_cpu_offload else "gpu_only")
if not json_output:
mode_str = "offload" if enable_cpu_offload else "gpu_only"
print(f"[DensityObserver] Enabled for XAttention BSA (mode: {mode_str})")
# LLM initialization kwargs
llm_kwargs = {
"max_model_len": max_model_len,
@@ -328,6 +401,8 @@ def run_ruler_benchmark(
"kvcache_block_size": block_size,
"enable_cpu_offload": enable_cpu_offload,
}
if dtype:
llm_kwargs["dtype"] = dtype
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["num_kv_buffers"] = num_kv_buffers
@@ -339,6 +414,7 @@ def run_ruler_benchmark(
if sparse_policy_type == SparsePolicyType.XATTN_BSA:
llm_kwargs["sparse_threshold"] = sparse_threshold
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
llm_kwargs["sparse_stride"] = sparse_stride
# Factory function for fresh_llm mode
def create_llm():
@@ -365,6 +441,7 @@ def run_ruler_benchmark(
verbose=verbose and not json_output,
llm_factory=create_llm,
fresh_llm=fresh_llm,
model_path=model_path,
)
task_results.append(result)
@@ -405,6 +482,14 @@ def run_ruler_benchmark(
print(f"{'-'*54}")
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
print(f"\nTime: {total_time:.1f}s")
# Print DensityObserver summary if enabled
if sparse_policy and sparse_policy.upper() == "XATTN_BSA" and DensityObserver.is_enabled():
print(f"\n{'='*60}")
print("Density Statistics (XAttention BSA)")
print(f"{'='*60}")
DensityObserver.print_summary()
print(f"{'='*60}\n")
results = {
@@ -462,8 +547,8 @@ if __name__ == "__main__":
help="Enable CPU offload mode")
parser.add_argument("--num-gpu-blocks", type=int, default=4,
help="Number of GPU blocks for CPU offload (default: 4)")
parser.add_argument("--block-size", type=int, default=1024,
help="KV cache block size (default: 1024)")
parser.add_argument("--block-size", type=int, default=4096,
help="KV cache block size (default: 4096)")
parser.add_argument("--num-kv-buffers", type=int, default=4,
help="Number of KV buffers for ring buffer (default: 4)")
parser.add_argument("--gpu-utilization", type=float, default=0.9,
@@ -485,6 +570,10 @@ if __name__ == "__main__":
help="XAttention BSA: samples per chunk for estimation")
parser.add_argument("--sparse-block-size", type=int, default=128,
help="XAttention BSA: block size for estimation")
parser.add_argument("--sparse-stride", type=int, default=8,
help="XAttention BSA: stride for Q/K downsampling")
parser.add_argument("--dtype", type=str, default=None,
help="Model dtype (bfloat16, float16). Required for models with float32 default.")
args = parser.parse_args()
@@ -521,6 +610,8 @@ if __name__ == "__main__":
sparse_threshold=args.sparse_threshold,
sparse_samples=args.sparse_samples,
sparse_block_size=args.sparse_block_size,
sparse_stride=args.sparse_stride,
dtype=args.dtype,
)
# Exit code (skip for json output mode)

View File

@@ -1,199 +0,0 @@
"""
Sequential inference test for LLM.
Tests: After completing one prompt, the system can correctly handle
a second prompt with a clean state (first prompt's KV cache deallocated).
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from utils import generate_needle_prompt, check_needle_answer
def run_sequential_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
enable_cpu_offload: bool = False,
verbose: bool = True,
) -> bool:
"""
Run sequential inference test with two different prompts.
Each prompt has a different needle value. Both must be retrieved correctly.
"""
if verbose:
print(f"\n{'='*60}")
print(f"Sequential Inference Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"CPU offload: {enable_cpu_offload}")
print(f"{'='*60}\n")
# Initialize LLM once
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm = LLM(model_path, **llm_kwargs)
sampling_params = SamplingParams(
temperature=0.6,
max_tokens=32,
)
# ============================================================
# Test 1: First prompt with needle value "1234"
# ============================================================
needle_value_1 = "1234"
if verbose:
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
prompt_1, expected_1 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_1,
)
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
output_text_1 = outputs_1[0]["text"]
passed_1 = check_needle_answer(output_text_1, expected_1)
if verbose:
print(f" Expected: {expected_1}")
print(f" Output: {output_text_1[:100]}...")
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
# ============================================================
# Test 2: Second prompt with needle value "5678"
# ============================================================
needle_value_2 = "5678"
if verbose:
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
prompt_2, expected_2 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_2,
)
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
output_text_2 = outputs_2[0]["text"]
passed_2 = check_needle_answer(output_text_2, expected_2)
if verbose:
print(f" Expected: {expected_2}")
print(f" Output: {output_text_2[:100]}...")
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
# ============================================================
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
# ============================================================
needle_value_3 = "9999"
if verbose:
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
prompt_3, expected_3 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_3,
)
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
output_text_3 = outputs_3[0]["text"]
passed_3 = check_needle_answer(output_text_3, expected_3)
if verbose:
print(f" Expected: {expected_3}")
print(f" Output: {output_text_3[:100]}...")
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
# ============================================================
# Summary
# ============================================================
all_passed = passed_1 and passed_2 and passed_3
if verbose:
print(f"\n{'='*60}")
print(f"Summary")
print(f"{'='*60}")
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
print(f"{'='*60}\n")
return all_passed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential inference test")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=36 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload"
)
args = parser.parse_args()
passed = run_sequential_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
enable_cpu_offload=args.enable_offload,
verbose=True,
)
if passed:
print("test_sequential: PASSED")
else:
print("test_sequential: FAILED")
exit(1)

View File

@@ -0,0 +1,365 @@
"""
Test: 验证 xattn_estimate 与 KV chunking kernels 的一致性
使用真实 KV cache 数据,对比:
1. xattn_estimate (高层 API)
2. 三阶段 KV chunking (softmax_compute_partial_stats + merge + normalize)
三阶段 KV chunking 流程:
1. softmax_compute_partial_stats: 计算每个 KV chunk 的 (m, l)
2. merge_softmax_stats: Host 端合并所有 chunks 的 stats
3. softmax_normalize_and_block_sum: 使用全局 stats 归一化
支持两种数据格式:
1. offload 模式保存: {"query", "key", "stride", "threshold", "density", "layer_id"}
2. GPU-only 模式保存: {"Q", "K", "chunk_size", "block_size", "stride", "threshold", "mask", "attn_sums", ...}
Usage:
# 使用 offload 模式数据
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_alignment.py
# 使用 GPU-only 模式数据
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_alignment.py --gpuonly
"""
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
import argparse
import torch
import math
from nanovllm.ops.xattn import (
xattn_estimate,
flat_group_gemm_fuse_reshape,
softmax_compute_partial_stats,
softmax_normalize_and_block_sum,
merge_softmax_stats,
find_blocks_chunked,
)
# ============================================================
# 命令行参数
# ============================================================
parser = argparse.ArgumentParser()
parser.add_argument("--gpuonly", action="store_true", help="使用 GPU-only 模式保存的数据")
parser.add_argument("--data-file", type=str, default=None, help="数据文件路径")
parser.add_argument("--chunk-size", type=int, default=None, help="覆盖 CHUNK_SIZE (用于测试不同分块大小)")
args = parser.parse_args()
# ============================================================
# 参数配置
# ============================================================
if args.gpuonly:
DATA_FILE = args.data_file or "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt"
else:
DATA_FILE = args.data_file or "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt"
device = "cuda"
# ============================================================
# Step 1: 加载真实数据
# ============================================================
print("=" * 60)
print("Step 1: 加载真实 KV cache 数据")
print("=" * 60)
data = torch.load(DATA_FILE, map_location="cpu")
# 检测数据格式并加载
if "Q" in data:
# GPU-only 模式保存的格式
print(f"[INFO] 检测到 GPU-only 模式数据格式")
Q = data["Q"].to(device)
K = data["K"].to(device)
BSA_BLOCK_SIZE = data.get("block_size", 128)
CHUNK_SIZE = data.get("chunk_size", 4096)
STRIDE = data.get("stride", 8)
THRESHOLD = data.get("threshold", 0.9)
if isinstance(THRESHOLD, torch.Tensor):
THRESHOLD = THRESHOLD.item()
# GPU-only 模式保存了 mask 和 attn_sums可以用于验证
saved_mask = data.get("mask", None)
saved_attn_sums = data.get("attn_sums", None)
saved_density = None # GPU-only 模式没有保存 density
layer_id = 0 # GPU-only 只保存 layer 0
else:
# offload 模式保存的格式
print(f"[INFO] 检测到 offload 模式数据格式")
Q = data["query"].to(device)
K = data["key"].to(device)
BSA_BLOCK_SIZE = 128
CHUNK_SIZE = 4096
STRIDE = data["stride"]
THRESHOLD = data["threshold"]
if isinstance(THRESHOLD, torch.Tensor):
THRESHOLD = THRESHOLD[0].item()
saved_mask = None
saved_attn_sums = None
saved_density = data.get("density", None)
layer_id = data.get("layer_id", 0)
batch_size, num_heads, seq_len, head_dim = Q.shape
# 命令行覆盖 CHUNK_SIZE
if args.chunk_size is not None:
CHUNK_SIZE = args.chunk_size
print(f"[INFO] 使用命令行指定的 CHUNK_SIZE={CHUNK_SIZE}")
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
if saved_density is not None:
print(f"Data layer_id: {layer_id}, saved density: {saved_density:.4f}")
else:
print(f"Data layer_id: {layer_id}")
print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}, CHUNK_SIZE={CHUNK_SIZE}, BSA_BLOCK_SIZE={BSA_BLOCK_SIZE}")
print()
# ============================================================
# Step 2: 使用 xattn_estimate 高层 API
# ============================================================
print("=" * 60)
print("Step 2: 调用 xattn_estimate (高层 API)")
print("=" * 60)
attn_sums_api, mask_api = xattn_estimate(
Q, K,
block_size=BSA_BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
causal=True,
)
# 裁剪到有效区域
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
# 计算 density (causal)
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
total_api = causal_mask.sum().item() * batch_size * num_heads
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_api = selected_api / total_api
print(f"mask_api shape (padded): {mask_api.shape}")
print(f"mask_api_valid shape: {mask_api_valid.shape}")
print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, total={total_api})")
print()
# ============================================================
# Step 3: 三阶段 KV Chunking
# ============================================================
print("=" * 60)
print("Step 3: 三阶段 KV Chunking")
print("=" * 60)
print(" 1) 每个 KV chunk 计算 partial stats")
print(" 2) Host 端合并 stats")
print(" 3) 使用全局 stats 归一化并计算 block sums")
print()
# 计算 padding 参数
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
reshaped_chunk_size = CHUNK_SIZE // STRIDE
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
print(f"seq_len: {seq_len}, q_chunk_num: {q_chunk_num}, kv_chunk_num: {kv_chunk_num}")
print()
# Padding
if k_num_to_pad > 0:
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
else:
K_padded = K
if q_num_to_pad > 0:
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
else:
Q_padded = Q
# Softmax scale
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
simple_mask_list = []
for q_chunk_idx in range(q_chunk_num):
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
q_end = q_start + reshaped_chunk_size * STRIDE
Q_chunk = Q_padded[:, :, q_start:q_end, :]
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
chunk_end = chunk_start + reshaped_chunk_size
# 阶段 1: 每个 KV chunk 计算 partial stats 和 raw scores
m_chunks = []
l_chunks = []
attn_weights_chunks = []
for kv_chunk_idx in range(kv_chunk_num):
kv_start = kv_chunk_idx * CHUNK_SIZE
kv_end = kv_start + CHUNK_SIZE
K_chunk = K_padded[:, :, kv_start:kv_end, :]
# KV offset in reshaped space
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
# 计算 raw attention scores
attn_weights_kv = flat_group_gemm_fuse_reshape(
Q_chunk, K_chunk, STRIDE,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False, # K 不完整,不能在这里用 causal
)
attn_weights_chunks.append(attn_weights_kv)
# 计算 partial stats (带 causal mask)
m_partial, l_partial = softmax_compute_partial_stats(
attn_weights_kv,
reshaped_block_size,
min(4096, reshaped_block_size),
scale,
chunk_start=chunk_start,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
m_chunks.append(m_partial)
l_chunks.append(l_partial)
# 阶段 2: Host 端合并 stats
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
# 阶段 3: 使用全局 stats 归一化并计算 block sums
attn_sum_per_kv = []
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
attn_sum_kv = softmax_normalize_and_block_sum(
attn_weights_kv,
m_global,
l_global,
reshaped_block_size,
min(4096, reshaped_block_size),
chunk_start=chunk_start,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
attn_sum_per_kv.append(attn_sum_kv)
# 拼接各 KV chunk 的 block sums
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
# 选择 blocks
simple_mask = find_blocks_chunked(
attn_sum_concat,
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
threshold=THRESHOLD,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True,
)
simple_mask_list.append(simple_mask)
print(f" Q chunk {q_chunk_idx}: merged {kv_chunk_num} KV chunks, attn_sum shape={attn_sum_concat.shape}")
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
False,
)
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_kv = selected_kv / total_api
print()
print(f"[KV chunking] density: {density_kv:.6f} (selected={selected_kv}, total={total_api})")
print()
# ============================================================
# Step 4: 对比结果
# ============================================================
print("=" * 60)
print("Step 4: 对比结果")
print("=" * 60)
print()
mask_total = mask_api_valid.numel()
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
print("| 方法 | density | 与 API 差异 | Mask 差异 |")
print("|------|---------|-------------|-----------|")
print(f"| xattn_estimate API | {density_api:.6f} | - | - |")
print(f"| KV chunking | {density_kv:.6f} | {abs(density_api - density_kv):.6f} | {100*mask_diff/mask_total:.4f}% |")
print()
passed = abs(density_api - density_kv) < 1e-6 and mask_diff / mask_total < 0.001
# ============================================================
# Step 5: 与 GPU-only 保存的数据对比 (如果有)
# ============================================================
if saved_mask is not None or saved_attn_sums is not None:
print("=" * 60)
print("Step 5: 与 GPU-only 保存的数据对比")
print("=" * 60)
print()
if saved_mask is not None:
saved_mask_gpu = saved_mask.to(device)
# 比较 mask
mask_saved_diff = (mask_api_valid != saved_mask_gpu).sum().item()
mask_saved_total = saved_mask_gpu.numel()
print(f"| xattn_estimate vs GPU-only saved mask | 差异 blocks: {mask_saved_diff} / {mask_saved_total} ({100*mask_saved_diff/mask_saved_total:.4f}%) |")
if mask_saved_diff == 0:
print("✅ mask 与 GPU-only 保存完全一致")
else:
print("❌ mask 与 GPU-only 保存存在差异")
passed = False
if saved_attn_sums is not None:
saved_attn_sums_gpu = saved_attn_sums.to(device)
# 需要从 xattn_estimate 获取 attn_sums
# 重新调用一次获取 attn_sums
attn_sums_check, _ = xattn_estimate(
Q, K,
block_size=BSA_BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
causal=True,
)
attn_sums_check_valid = attn_sums_check[:, :, :q_blocks, :k_blocks]
max_diff = (attn_sums_check_valid - saved_attn_sums_gpu).abs().max().item()
mean_diff = (attn_sums_check_valid - saved_attn_sums_gpu).abs().mean().item()
print(f"| xattn_estimate vs GPU-only saved attn_sums | max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e} |")
if max_diff < 1e-5:
print("✅ attn_sums 与 GPU-only 保存一致")
else:
print("❌ attn_sums 与 GPU-only 保存存在差异")
passed = False
print()
if passed:
print("test_xattn_estimate_alignment: PASSED")
else:
print("test_xattn_estimate_alignment: FAILED")

View File

@@ -1,244 +0,0 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask
as standard estimation. This ensures the chunked version can be used in
chunked prefill scenarios without accuracy loss.
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
"""
import sys
import traceback
import torch
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
# Configuration for xattn_estimate_chunked consistency test.
# Key requirements for 100% match:
# 1. Use matching chunk_size for both standard and chunked versions
# 2. Use same random seed for reproducibility
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
# floating point precision in cumulative sum calculations.
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096 # External chunking size
# Test sequence lengths
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
# ============================================================
# Utility Functions
# ============================================================
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
return xattn_estimate_chunked(
query, key,
q_start_pos=0,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
# q_start_pos=0 means Q starts at position 0 in the full sequence
# K is [0, q_chunk_end) for causal attention
k_end = q_chunk_end
k_chunk = key[:, :, :k_end, :]
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
"""Test a single sequence length."""
print(f"\nTesting seq_len={seq_len}")
print("=" * 60)
# Generate random Q/K
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
causal=True,
)
density_std = mask_std.float().mean().item()
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
density_chunked = mask_chunked.float().mean().item()
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
print("XAttention Chunked vs Standard Test")
print("=" * 60)
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
print(f"External chunk_size={CHUNK_SIZE}")
print()
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available!")
sys.exit(1)
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print("✓ xattn_estimate imported")
print("✓ xattn_estimate_chunked imported")
# Run tests
all_passed = True
results = []
for seq_len in TEST_SEQ_LENS:
passed = test_single_seq_len(seq_len)
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
results.append((seq_len, chunks, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, chunks, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
print("=" * 60)
if all_passed:
print("ALL TESTS PASSED!")
sys.exit(0)
else:
print("SOME TESTS FAILED!")
sys.exit(1)