From b97b0b96a025c8a55ae68dbaa2b8dd8d7250460a Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Mon, 19 Jan 2026 22:34:44 +0800 Subject: [PATCH] [WIP] Before refactor the nanovllm sparse policy. --- .claude/rules/planning-with-files.md | 2 +- findings.md | 160 --------- nanovllm/kvcache/sparse/__init__.py | 2 - nanovllm/kvcache/sparse/full_policy.py | 129 ++++++- nanovllm/kvcache/sparse/xattn_bsa.py | 475 +------------------------ nanovllm/layers/attention.py | 41 +-- progress.md | 76 ---- task_plan.md | 427 ++++++++++++++++------ 8 files changed, 475 insertions(+), 837 deletions(-) delete mode 100644 findings.md delete mode 100644 progress.md diff --git a/.claude/rules/planning-with-files.md b/.claude/rules/planning-with-files.md index 6ce318c..5c7f4c0 100644 --- a/.claude/rules/planning-with-files.md +++ b/.claude/rules/planning-with-files.md @@ -23,7 +23,7 @@ rm -f task_plan_*.md findings_*.md progress_*.md ```bash # Step 1: 清理旧计划文件 -rm -f task_plan.md findings.md progress.md task_plan_*.md findings_*.md progress_*.md +rm -f task_plan.md findings.md progress.md # Step 2: 启动 planning-with-files 技能 # 在 Claude 中调用 /planning-with-files 或 Skill tool diff --git a/findings.md b/findings.md deleted file mode 100644 index bb77faa..0000000 --- a/findings.md +++ /dev/null @@ -1,160 +0,0 @@ -# Findings: Multi-Model Support Analysis - -## Current Architecture Analysis - -### Model Loading Flow -``` -LLM(model_path) - → LLMEngine.__init__() - → Config.__post_init__() - → hf_config = AutoConfig.from_pretrained(model) - → ModelRunner.__init__() - → model = Qwen3ForCausalLM(hf_config) ← HARDCODED - → load_model(model, config.model) -``` - -### Key Files -| File | Purpose | -|------|---------| -| `nanovllm/engine/model_runner.py` | 模型加载和运行 | -| `nanovllm/models/qwen3.py` | Qwen3 模型定义 | -| `nanovllm/utils/loader.py` | safetensors 权重加载 | -| `nanovllm/layers/rotary_embedding.py` | RoPE 实现 | - ---- - -## Llama 3.1 Config Analysis - -```json -{ - "architectures": ["LlamaForCausalLM"], - "model_type": "llama", - "attention_bias": false, - "mlp_bias": false, - "head_dim": 128, - "hidden_size": 4096, - "intermediate_size": 14336, - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 8, - "hidden_act": "silu", - "rms_norm_eps": 1e-05, - "rope_theta": 500000.0, - "rope_scaling": { - "factor": 8.0, - "high_freq_factor": 4.0, - "low_freq_factor": 1.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "max_position_embeddings": 131072, - "tie_word_embeddings": false, - "vocab_size": 128256 -} -``` - -### Llama 3 RoPE Scaling -Llama 3 使用特殊的 RoPE scaling 策略 (`rope_type: "llama3"`): -- 低频分量保持不变(对应短距离依赖) -- 高频分量线性插值(对应长距离依赖) -- 参数: `factor`, `low_freq_factor`, `high_freq_factor`, `original_max_position_embeddings` - -参考实现 (transformers): -```python -def _compute_llama3_parameters(config, device, inv_freq): - factor = config.factor - low_freq_factor = config.low_freq_factor - high_freq_factor = config.high_freq_factor - old_context_len = config.original_max_position_embeddings - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - wavelen = 2 * math.pi / inv_freq - inv_freq_llama = torch.where( - wavelen > low_freq_wavelen, - inv_freq / factor, - inv_freq - ) - smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq - is_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen) - inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - return inv_freq_llama -``` - ---- - -## Weight Mapping Analysis - -### Qwen3 packed_modules_mapping -```python -packed_modules_mapping = { - "q_proj": ("qkv_proj", "q"), - "k_proj": ("qkv_proj", "k"), - "v_proj": ("qkv_proj", "v"), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), -} -``` - -### Llama Weight Names (from safetensors) -预期 Llama 权重命名与 Qwen3 类似: -- `model.layers.{i}.self_attn.q_proj.weight` -- `model.layers.{i}.self_attn.k_proj.weight` -- `model.layers.{i}.self_attn.v_proj.weight` -- `model.layers.{i}.self_attn.o_proj.weight` -- `model.layers.{i}.mlp.gate_proj.weight` -- `model.layers.{i}.mlp.up_proj.weight` -- `model.layers.{i}.mlp.down_proj.weight` -- `model.layers.{i}.input_layernorm.weight` -- `model.layers.{i}.post_attention_layernorm.weight` - -**结论**: Llama 的 `packed_modules_mapping` 与 Qwen3 相同,可以复用。 - ---- - -## Shared Components (Can Reuse) - -| Component | File | Notes | -|-----------|------|-------| -| `RMSNorm` | `layers/layernorm.py` | 通用 | -| `SiluAndMul` | `layers/activation.py` | 通用 | -| `Attention` | `layers/attention.py` | FlashAttention wrapper | -| `QKVParallelLinear` | `layers/linear.py` | 支持 bias=False | -| `RowParallelLinear` | `layers/linear.py` | 通用 | -| `MergedColumnParallelLinear` | `layers/linear.py` | 通用 | -| `VocabParallelEmbedding` | `layers/embed_head.py` | 通用 | -| `ParallelLMHead` | `layers/embed_head.py` | 通用 | -| `load_model` | `utils/loader.py` | 通用 | - ---- - -## Llama vs Qwen3 Implementation Diff - -### Attention -| Feature | Qwen3Attention | LlamaAttention | -|---------|----------------|----------------| -| QKV bias | 可配置 (attention_bias) | 始终 False | -| q_norm | 有 (when bias=False) | 无 | -| k_norm | 有 (when bias=False) | 无 | -| RoPE | Standard | Llama3 scaled | - -### MLP -| Feature | Qwen3MLP | LlamaMLP | -|---------|----------|----------| -| gate/up bias | False | False | -| down bias | False | False | -| hidden_act | silu | silu | - -**结论**: Llama MLP 与 Qwen3 MLP 几乎相同,可以直接复用或简化。 - ---- - -## Risk Assessment - -| Risk | Impact | Mitigation | -|------|--------|------------| -| RoPE 实现错误 | 高 - 导致错误输出 | 参考 transformers 实现,单元测试 | -| 权重映射错误 | 高 - 模型无法加载 | 检查 safetensors 键名 | -| 注册表循环导入 | 中 - 启动失败 | 延迟导入 | diff --git a/nanovllm/kvcache/sparse/__init__.py b/nanovllm/kvcache/sparse/__init__.py index 545fe71..6c947fe 100644 --- a/nanovllm/kvcache/sparse/__init__.py +++ b/nanovllm/kvcache/sparse/__init__.py @@ -61,8 +61,6 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic block_size=kwargs.get("block_size", 128), samples_per_chunk=kwargs.get("samples_per_chunk", 128), threshold=kwargs.get("threshold", 0.9), - use_triton=kwargs.get("use_triton", True), - stride=kwargs.get("stride", 8), ) else: diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index a6cff50..8dd8b42 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -5,8 +5,11 @@ This serves as a baseline and default policy when sparse attention is not needed. """ -from typing import List +import torch +from typing import List, Optional + from .policy import SparsePolicy, PolicyContext +from nanovllm.utils.context import get_context class FullAttentionPolicy(SparsePolicy): @@ -34,5 +37,129 @@ class FullAttentionPolicy(SparsePolicy): """Return all blocks - no sparsity.""" return available_blocks + def compute_prefill_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + offload_engine, + current_chunk_idx: int, + seq, + ) -> torch.Tensor: + """ + Compute full attention for chunked prefill. + + This method handles the complete chunked prefill flow: + 1. Load historical blocks from CPU + 2. Compute attention to historical chunks + 3. Compute attention to current chunk + 4. Merge all results + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer) + v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer) + layer_id: Current layer index + softmax_scale: Softmax scaling factor + offload_engine: OffloadEngine for loading blocks + current_chunk_idx: Current chunk index + seq: ChunkedSequence + + Returns: + Attention output [seq_len, num_heads, head_dim] + """ + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] + num_tokens = q.shape[0] + o_acc = None + lse_acc = None + compute_stream = offload_engine.compute_stream + + # Step 1: Get and load historical blocks + cpu_block_table = seq.kvcache_manager.get_prefilled_cpu_blocks(seq) + + if cpu_block_table: + load_slots = list(range(offload_engine.num_ring_slots)) + num_blocks = len(cpu_block_table) + + if len(load_slots) == 1: + # Only 1 slot - use synchronous mode + slot = load_slots[0] + for block_idx in range(num_blocks): + cpu_block_id = cpu_block_table[block_idx] + offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id) + offload_engine.wait_slot_layer(slot) + + with torch.cuda.stream(compute_stream): + prev_k, prev_v = offload_engine.get_kv_for_slot(slot) + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=softmax_scale, + causal=False, + ) + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + offload_engine.record_slot_compute_done(slot) + else: + # Multiple slots - use pipeline + num_slots = len(load_slots) + num_preload = min(num_slots, num_blocks) + for i in range(num_preload): + offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) + + for block_idx in range(num_blocks): + current_slot = load_slots[block_idx % num_slots] + cpu_block_id = cpu_block_table[block_idx] + + offload_engine.wait_slot_layer(current_slot) + + with torch.cuda.stream(compute_stream): + prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=softmax_scale, + causal=False, + ) + offload_engine.record_slot_compute_done(current_slot) + + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + + # Issue next transfer + next_block_idx = block_idx + num_slots + if next_block_idx < num_blocks: + next_slot = load_slots[next_block_idx % num_slots] + next_cpu_block_id = cpu_block_table[next_block_idx] + offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id) + + # Step 2: Compute attention to current chunk (causal mask) + with torch.cuda.stream(compute_stream): + k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) + current_o, current_lse = flash_attn_with_lse( + q_batched, k_curr, v_curr, + softmax_scale=softmax_scale, + causal=True, + ) + + # Step 3: Merge historical and current attention + with torch.cuda.stream(compute_stream): + if o_acc is None: + final_o = current_o + else: + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + + # Sync default stream with compute_stream before returning + torch.cuda.default_stream().wait_stream(compute_stream) + + # Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim] + return final_o.squeeze(0) + def __repr__(self) -> str: return "FullAttentionPolicy()" diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 81c1fc6..7a21a47 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -1,15 +1,13 @@ """ XAttention Block Sparse Attention (BSA) Policy for nano-vllm. -This module implements XAttention-inspired block sparse attention for chunked prefill, -using block-level estimation to select important KV blocks for computation. +This module implements XAttention-inspired block sparse attention for chunked prefill. +Current implementation loads all historical blocks (FULL strategy). -Reference: COMPASS/compass/src/Xattention.py +Sparse selection to be implemented in next phase. """ -import math import torch -import torch.nn.functional as F from typing import List, Optional, Tuple from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext @@ -23,18 +21,11 @@ class XAttentionBSAPolicy(SparsePolicy): This policy uses block-level estimation to determine which KV blocks are important for the current chunk's queries, enabling sparse computation. - Key features: - - Double-loading design: estimate phase loads samples, compute phase loads selected blocks - - Block-level granularity: 128-token blocks for estimation and computation - - Triton kernels for efficient estimation (optional, falls back to PyTorch) - - Architecture: - 1. Estimate Phase: Load samples from all historical chunks, compute importance scores - 2. Selection Phase: Select top chunks by cumulative attention threshold - 3. Compute Phase: Load selected chunks fully, apply block sparse attention + Note: Current implementation loads all historical chunks (FULL strategy). + Sparse selection to be implemented in next phase. """ - supports_prefill = True + supports_prefill = False # Uses standard select_blocks interface supports_decode = False # BSA is prefill-only requires_block_selection = False # Selection happens at chunk level, not block level @@ -43,8 +34,6 @@ class XAttentionBSAPolicy(SparsePolicy): block_size: int = 128, samples_per_chunk: int = 128, threshold: float = 0.9, - use_triton: bool = True, - stride: int = 8, ): """ Initialize XAttention BSA policy. @@ -53,457 +42,29 @@ class XAttentionBSAPolicy(SparsePolicy): block_size: Number of tokens per block (default: 128) samples_per_chunk: Number of tokens to sample from each historical chunk for estimation threshold: Cumulative attention threshold for chunk selection (0-1) - use_triton: Use Triton kernels for estimation (requires SM 80+) - stride: Stride for Q/K downsampling in estimation """ self.block_size = block_size self.samples_per_chunk = samples_per_chunk self.threshold = threshold - self.use_triton = use_triton - self.stride = stride - - # Check Triton availability - if self.use_triton: - try: - import triton - props = torch.cuda.get_device_properties(torch.cuda.current_device()) - if props.major < 8: - self.use_triton = False - print(f"[XAttentionBSA] Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.") - except ImportError: - self.use_triton = False - print("[XAttentionBSA] Triton not available. Using PyTorch implementation.") def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]: """ - Select blocks to load from CPU (for decode compatibility, not used in prefill). + Select blocks to load from CPU. - For prefill, BSA handles chunk-level selection internally. + Current implementation returns all blocks (FULL strategy). + Sparse selection to be implemented in next phase. + + Args: + available_blocks: List of all available CPU block IDs + ctx: Policy context with query info, chunk index, etc. + + Returns: + List of selected block IDs to load """ - # For prefill, we return all blocks - selection happens in sparse_prefill_attention + # Current: Return all blocks (FULL strategy) + # TODO: Implement sparse selection based on query attention estimation return available_blocks - def sparse_prefill_attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer_id: int, - softmax_scale: float, - ) -> torch.Tensor: - """ - Compute XAttention block sparse attention for current chunk. - - This implements a simplified version that loads all historical chunks - (sparse selection to be implemented in next phase). - - Args: - q: Query tensor [seq_len, num_heads, head_dim] - k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, we use prefill buffer) - v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, we use prefill buffer) - layer_id: Current transformer layer index - softmax_scale: Softmax scaling factor from attention layer - - Returns: - Attention output [seq_len, num_heads, head_dim] - """ - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs - - context = get_context() - kvcache_manager = context.kvcache_manager - offload_engine = kvcache_manager.offload_engine if kvcache_manager else None - - if offload_engine is None: - # No offload engine, use standard attention with provided k, v - return self._full_attention(q, k, v, causal=True) - - current_chunk_idx = getattr(context, 'current_chunk_idx', 0) - seq = getattr(context, 'chunked_seq', None) - num_tokens = q.shape[0] - - if seq is None: - # No chunked sequence, fallback to full attention on current chunk only - return self._full_attention(q, k, v, causal=True) - - # Get prefilled CPU blocks (historical chunks) - cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) - - q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] - o_acc = None - lse_acc = None - - # Get compute stream for all attention operations - compute_stream = offload_engine.compute_stream - - # Step 1: Load historical chunks from CPU using slot mechanism - if cpu_block_table: - load_slots = list(range(offload_engine.num_ring_slots)) - num_blocks = len(cpu_block_table) - - # Load ALL historical blocks (not just min(num_blocks, num_slots)) - # Use synchronous mode like standard flow when pipeline_depth=1 - if len(load_slots) == 1: - # Only 1 slot available, cannot pipeline - use synchronous mode - slot = load_slots[0] - for block_idx in range(num_blocks): - cpu_block_id = cpu_block_table[block_idx] - offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id) - offload_engine.wait_slot_layer(slot) - - with torch.cuda.stream(compute_stream): - # Get KV from slot - returns [1, block_size, kv_heads, head_dim] - prev_k, prev_v = offload_engine.get_kv_for_slot(slot) - - # Compute attention to historical chunk (non-causal, already processed) - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=softmax_scale, - causal=False, - ) - - # Merge results - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - - # Record compute done so slot can be reused - offload_engine.record_slot_compute_done(slot) - else: - # Multiple slots available - use pipeline - num_slots = len(load_slots) - - # Phase 1: Pre-load up to num_slots blocks to fill the pipeline - num_preload = min(num_slots, num_blocks) - for i in range(num_preload): - offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) - - # Phase 2: Main loop - compute and immediately reuse slot for next transfer - for block_idx in range(num_blocks): - # Cycle through slots: slot[block_idx % num_slots] - current_slot = load_slots[block_idx % num_slots] - cpu_block_id = cpu_block_table[block_idx] - - # Wait for current slot's transfer to complete - offload_engine.wait_slot_layer(current_slot) - - # Compute attention on current slot's data - with torch.cuda.stream(compute_stream): - # Get KV from slot - returns [1, block_size, kv_heads, head_dim] - prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) - - # Compute attention to historical chunk (non-causal, already processed) - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=softmax_scale, - causal=False, - ) - - # Merge results - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - - # Record compute done so slot can be reused - offload_engine.record_slot_compute_done(current_slot) - - # Issue next transfer if there are more blocks - next_block_idx = block_idx + num_slots - if next_block_idx < num_blocks: - next_slot = load_slots[next_block_idx % num_slots] - next_cpu_block_id = cpu_block_table[next_block_idx] - offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id) - - # Step 2: Compute attention to current chunk (causal mask) - use prefill buffer on compute_stream - with torch.cuda.stream(compute_stream): - k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) - - current_o, current_lse = flash_attn_with_lse( - q_batched, - k_curr, - v_curr, - softmax_scale=softmax_scale, - causal=True, - ) - - # Step 3: Merge historical and current attention - with torch.cuda.stream(compute_stream): - if o_acc is None: - # No historical chunks processed - final_o = current_o - else: - final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) - - # Sync default stream with compute_stream before returning - torch.cuda.default_stream().wait_stream(compute_stream) - - # Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim] - return final_o.squeeze(0) - - def _estimate_historical_chunks( - self, - q: torch.Tensor, - historical_blocks: List[int], - layer_id: int, - current_chunk_idx: int, - ) -> Tuple[List[float], bool]: - """ - Estimate importance of each historical chunk for current Q. - - First load: Load samples from each historical chunk for estimation. - - Args: - q: Current chunk queries [chunk_size, num_heads, head_dim] - historical_blocks: List of historical CPU block IDs - layer_id: Current layer index - current_chunk_idx: Current chunk index - - Returns: - (List of importance scores (one per historical chunk), has_valid_data flag) - has_valid_data is True if at least one block had non-zero data - """ - chunk_estimates = [] - has_valid_data = False - - for block_idx, cpu_block_id in enumerate(historical_blocks): - # First load: Load sample from this historical chunk - k_sample, v_sample = self._load_block_sample( - cpu_block_id, layer_id, self.samples_per_chunk - ) - - # Check if loaded data is valid (non-zero) - if k_sample.abs().max().item() > 0: - has_valid_data = True - - # Quick estimation: Compute Q attention to this chunk's sample - # q [chunk_size, H, D] @ k_sample [samples, H, D] - # Result: Aggregate to chunk-level score - estimate = self._compute_chunk_estimate(q, k_sample) - chunk_estimates.append(estimate) - - return chunk_estimates, has_valid_data - - def _select_important_chunks( - self, - chunk_estimates: List[float], - ) -> List[int]: - """ - Select important chunks based on cumulative attention threshold. - - Args: - chunk_estimates: Importance scores for each historical chunk - - Returns: - Indices of selected chunks - """ - if not chunk_estimates: - return [] - - scores = torch.tensor(chunk_estimates, device='cpu') - threshold_value = scores.max() * self.threshold - - # Select chunks that contribute to cumulative attention threshold - selected_indices = [] - cumulative = 0.0 - sorted_indices = torch.argsort(scores, descending=True) - - for idx in sorted_indices: - cumulative += scores[idx].item() - selected_indices.append(idx.item()) - if cumulative >= threshold_value: - break - - return selected_indices - - def _compute_with_selected_chunks( - self, - q: torch.Tensor, - historical_blocks: List[int], - selected_indices: List[int], - layer_id: int, - current_chunk_idx: int, - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Compute attention to selected historical chunks. - - Second load: Load full data for selected chunks. - - Args: - q: Current chunk queries - historical_blocks: All historical block IDs - selected_indices: Indices of selected blocks - layer_id: Current layer index - current_chunk_idx: Current chunk index - - Returns: - (accumulated_output, accumulated_lse) or (None, None) - """ - if not selected_indices: - return None, None - - o_acc = None - lse_acc = None - - for chunk_idx in selected_indices: - cpu_block_id = historical_blocks[chunk_idx] - - # Second load: Load full data for this selected chunk - k_full, v_full = self._load_block_full( - cpu_block_id, layer_id - ) - - # Compute attention (non-causal, already processed) - o, lse = self._full_attention( - q.unsqueeze(0), k_full.unsqueeze(0), - v_full.unsqueeze(0), causal=False, return_lse=True - ) - - # Merge results - if o_acc is None: - o_acc, lse_acc = o.squeeze(0), lse - else: - from nanovllm.kvcache.chunked_attention import merge_attention_outputs - o_acc, lse_acc = merge_attention_outputs( - o_acc.unsqueeze(0), lse_acc, - o.unsqueeze(0), lse - ) - o_acc = o_acc.squeeze(0) - - return o_acc, lse_acc - - def _load_block_sample( - self, - cpu_block_id: int, - layer_id: int, - num_samples: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Load sample tokens from a CPU block.""" - offload_engine = get_context().kvcache_manager.offload_engine - - k_sample, v_sample = offload_engine.load_block_sample_from_cpu( - cpu_block_id, layer_id, num_samples - ) - return k_sample, v_sample - - def _load_block_full( - self, - cpu_block_id: int, - layer_id: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Load full tokens from a CPU block.""" - offload_engine = get_context().kvcache_manager.offload_engine - return offload_engine.load_block_full_from_cpu( - cpu_block_id, layer_id - ) - - def _compute_chunk_estimate( - self, - q: torch.Tensor, - k_sample: torch.Tensor, - ) -> float: - """ - Compute chunk-level importance estimate. - - Args: - q: [chunk_size, num_heads, head_dim] - k_sample: [num_samples, num_kv_heads, head_dim] - - Returns: - Aggregate importance score for this chunk - """ - # Expand K to match Q's head count (GQA support) - num_heads = q.shape[1] - num_kv_heads = k_sample.shape[1] - head_dim = q.shape[2] # Last dimension is head_dim - if num_heads != num_kv_heads: - repeat_factor = num_heads // num_kv_heads - k_sample = k_sample.repeat_interleave(repeat_factor, dim=1) - - # Compute attention scores: Q @ K.T with proper scaling - # q [chunk_size, H, D], k [samples, H, D] -> need to compute per-head attention - # Use scaled dot-product attention: (Q @ K.T) / sqrt(D) - scale = 1.0 / (head_dim ** 0.5) - - # Reshape to 2D: [chunk_size * H, D] @ [D, samples * H] then aggregate - chunk_size = q.shape[0] - num_samples = k_sample.shape[0] - - # Reshape for batched matmul: merge heads and seq dims - q_2d = q.reshape(chunk_size * num_heads, head_dim) # [chunk_size*H, D] - k_2d = k_sample.reshape(num_samples * num_heads, head_dim) # [samples*H, D] - - # Compute scaled Q @ K.T: [chunk_size*H, D] @ [D, samples*H] = [chunk_size*H, samples*H] - attn_scores_2d = torch.matmul(q_2d, k_2d.T) * scale - - # Use max absolute value as importance (captures both positive and negative attention) - importance = attn_scores_2d.abs().max().item() - - return importance - - def _full_attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - causal: bool = False, - return_lse: bool = False, - ) -> torch.Tensor: - """ - Compute full FlashAttention (fallback when sparse not applicable). - - Args: - q: [batch_size, seq_len, num_heads, head_dim] or [seq_len, num_heads, head_dim] - k, v: Same shape as q - causal: Apply causal mask - return_lse: Whether to return log-sum-exp - - Returns: - attention output [batch_size, seq_len, num_heads, head_dim] or [seq_len, num_heads, head_dim] - """ - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse - - # Handle 3D input: add batch dimension - input_3d = q.dim() == 3 - if input_3d: - q = q.unsqueeze(0) # [seq_len, H, D] -> [1, seq_len, H, D] - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - if return_lse: - o, lse = flash_attn_with_lse(q, k, v, softmax_scale=self.scale, causal=causal) - result = (o, lse) - else: - o, _ = flash_attn_with_lse(q, k, v, softmax_scale=self.scale, causal=causal) - result = o - - # Remove batch dimension if input was 3D - if input_3d: - if return_lse: - result = (result[0].squeeze(0), result[1]) - else: - result = result.squeeze(0) - - return result - - @property - def scale(self) -> float: - """Get softmax scale factor from Attention layer.""" - context = get_context() - # Get scale from current Attention layer in the model - if hasattr(context, 'current_attention') and context.current_attention is not None: - return context.current_attention.scale - # Fallback: try to get from model runner - if hasattr(context, 'model_runner') and context.model_runner is not None: - model_runner = context.model_runner - if hasattr(model_runner, 'model') and hasattr(model_runner.model, 'layers'): - # Get scale from first attention layer - first_layer = model_runner.model.layers[0] - if hasattr(first_layer, 'self_attn'): - return first_layer.self_attn.scaling - # Default: 1 / sqrt(128) for Qwen models - return 1.0 / 128.0 ** 0.5 - def reset(self) -> None: """Reset policy state.""" pass diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 3150a86..d403c73 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -210,22 +210,7 @@ class Attention(nn.Module): # Apply sparse policy if enabled sparse_policy = kvcache_manager.sparse_policy - # === XAttention BSA: Policy handles entire sparse prefill === - # Check if policy has sparse_prefill_attention method (XAttention BSA) - if (sparse_policy is not None and - hasattr(sparse_policy, 'sparse_prefill_attention') and - getattr(sparse_policy, 'supports_prefill', False)): - # Use policy's sparse_prefill_attention method - # Pass softmax_scale from attention layer - # IMPORTANT: Don't return early - we still need to do KV offload below! - o = sparse_policy.sparse_prefill_attention(q, k, v, self.layer_id, self.scale) - # Convert back to batched format for consistency with standard flow - o_acc = o.unsqueeze(0) # [seq_len, heads, dim] -> [1, seq_len, heads, dim] - lse_acc = None # sparse_prefill_attention returns final output, not intermediate LSE - # Skip standard flow processing since we already computed attention - cpu_block_table = None # Signal to skip historical chunk processing - - # === Standard sparse policy (Quest, etc.) === + # === All sparse policies use select_blocks interface === if cpu_block_table and sparse_policy is not None: num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1) policy_ctx = PolicyContext( @@ -262,8 +247,7 @@ class Attention(nn.Module): compute_stream = offload_engine.compute_stream if offload_engine is not None else None # Compute attention against current chunk's KV from prefill buffer (with causal mask) - # Skip this if XAttention BSA already computed full attention (o_acc is set, lse_acc is None) - needs_current_chunk_attention = (lse_acc is not None or o_acc is None) + needs_current_chunk_attention = True if needs_current_chunk_attention: if compute_stream is not None: @@ -294,24 +278,19 @@ class Attention(nn.Module): # Merge with accumulated (all on compute_stream for consistency) if o_acc is None: - # No accumulated attention (standard flow or XAttention BSA with no historical chunks) - final_o = current_o if needs_current_chunk_attention else o_acc + # No accumulated attention (no historical chunks processed) + final_o = current_o else: - # Has accumulated attention (XAttention BSA with historical chunks) - if needs_current_chunk_attention: - # Need to merge historical (from XAttention BSA) with current chunk - if compute_stream is not None: - with torch.cuda.stream(compute_stream): - torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") - final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) - torch.cuda.nvtx.range_pop() - else: + # Has accumulated attention (historical chunks processed) + if compute_stream is not None: + with torch.cuda.stream(compute_stream): torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) torch.cuda.nvtx.range_pop() else: - # XAttention BSA already computed everything - final_o = o_acc + torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop() # ChunkedPrefill diff --git a/progress.md b/progress.md deleted file mode 100644 index 11a1daa..0000000 --- a/progress.md +++ /dev/null @@ -1,76 +0,0 @@ -# Progress Log: Multi-Model Support - -## Session: 2026-01-10 - -### Initial Analysis Complete - -**Time**: Session start - -**Actions:** -1. Read `nanovllm/engine/model_runner.py` - 确认硬编码位置 (line 35) -2. Read `nanovllm/models/qwen3.py` - 理解 Qwen3 模型结构 -3. Read `nanovllm/utils/loader.py` - 理解权重加载机制 -4. Read `nanovllm/layers/rotary_embedding.py` - 发现 RoPE scaling 限制 -5. Read `/home/zijie/models/Llama-3.1-8B-Instruct/config.json` - 理解 Llama 配置 - -**Key Findings:** -- 模型加载在 `model_runner.py:35` 硬编码为 Qwen3 -- RoPE 目前不支持 scaling (`assert rope_scaling is None`) -- Llama 3.1 需要 "llama3" 类型的 RoPE scaling -- Llama 无 q_norm/k_norm,无 attention bias - -**Created:** -- `task_plan.md` - 6 阶段实施计划 -- `findings.md` - 技术分析和发现 - ---- - -### Phase Status - -| Phase | Status | Notes | -|-------|--------|-------| -| 1. Model Registry | **COMPLETED** | `registry.py`, `__init__.py` | -| 2. Llama3 RoPE | **COMPLETED** | `rotary_embedding.py` | -| 3. Llama Model | **COMPLETED** | `llama.py` | -| 4. ModelRunner | **COMPLETED** | Dynamic loading | -| 5. Qwen3 Register | **COMPLETED** | `@register_model` decorator | -| 6. Testing | **COMPLETED** | Both Llama & Qwen3 pass | - ---- - -## Test Results - -### Llama 3.1-8B-Instruct (32K needle, GPU 0, offload) -``` -Input: 32768 tokens -Expected: 7492 -Output: 7492 -Status: PASSED -Prefill: 1644 tok/s -``` - -### Qwen3-4B (8K needle, GPU 1, offload) - Regression Test -``` -Input: 8192 tokens -Expected: 7492 -Output: 7492 -Status: PASSED -Prefill: 3295 tok/s -``` - ---- - -## Files Modified This Session - -| File | Action | Description | -|------|--------|-------------| -| `nanovllm/models/registry.py` | created | Model registry with `@register_model` decorator | -| `nanovllm/models/__init__.py` | created | Export registry functions, import models | -| `nanovllm/models/llama.py` | created | Llama model implementation | -| `nanovllm/models/qwen3.py` | modified | Added `@register_model` decorator | -| `nanovllm/layers/rotary_embedding.py` | modified | Added Llama3 RoPE scaling | -| `nanovllm/engine/model_runner.py` | modified | Dynamic model loading via registry | -| `.claude/rules/gpu-testing.md` | created | GPU testing rules | -| `task_plan.md` | created | Implementation plan | -| `findings.md` | created | Technical findings | -| `progress.md` | created | Progress tracking | diff --git a/task_plan.md b/task_plan.md index 87626ef..23f2406 100644 --- a/task_plan.md +++ b/task_plan.md @@ -1,144 +1,353 @@ -# Task Plan: Multi-Model Support for nanovllm +# Task Plan: Sparse Policy 架构重构 v3 ## Goal -扩展 nanovllm 框架以支持多种模型(当前只支持 Qwen3),特别是添加 Llama-3.1-8B-Instruct 支持,并建立可扩展的模型添加范式。 -## Current State Analysis +将 chunked prefill 的 attention 计算逻辑完全从 `attention.py` 移到 `SparsePolicy` 内部。attention.py 只负责调用 policy,不包含任何计算逻辑。 -### 硬编码问题位置 -- `nanovllm/engine/model_runner.py:35`: 直接实例化 `Qwen3ForCausalLM(hf_config)` -- `nanovllm/engine/model_runner.py:9`: 硬编码导入 `from nanovllm.models.qwen3 import Qwen3ForCausalLM` +## 核心设计原则(强制要求) -### Qwen3 vs Llama 3.1 架构差异 +1. **Policy 内部完成所有计算**:包括 attention 计算和结果合并 +2. **select_blocks 传入 offload_engine**:policy 通过 offload_engine 加载 blocks +3. **强制实现计算函数**:所有 policy 必须实现 `compute_block_attention` 和 `merge_attention_outputs` +4. **chunked_prefill 强制 policy 存在**:没有 policy 则报错 +5. **外部默认 FULL policy**:model_runner.py 默认创建 FullPolicy +6. **attention.py 零计算逻辑**:_chunked_prefill_attention 只调用 policy,不直接调用 flashattn 或 merge -| Feature | Qwen3 | Llama 3.1 | -|---------|-------|-----------| -| Config Class | Qwen3Config | LlamaConfig | -| attention_bias | True (可配置) | False | -| q_norm/k_norm | 有 (when bias=False) | 无 | -| mlp_bias | N/A | False | -| RoPE Scaling | None (目前) | llama3 类型 | -| RoPE theta | 1000000 | 500000 | -| hidden_act | silu | silu | -| tie_word_embeddings | True | False | +## 目标架构 -### 关键限制 -- `rotary_embedding.py:59`: `assert rope_scaling is None` - 不支持 RoPE scaling +``` +model_runner.py: + 默认创建 FullPolicy(如果没有指定 sparse policy) ---- +attention.py (_chunked_prefill_attention): + 检查 sparse_policy 是否存在 + ↓ + 调用 sparse_policy.compute_prefill_attention(q, k, v, ...) + ↓ + 返回最终输出(不包含任何计算逻辑) + +SparsePolicy.compute_prefill_attention(): + 1. select_blocks(blocks, offload_engine, ctx) → 筛选 blocks + 2. 加载 blocks(通过 offload_engine) + 3. 遍历 blocks: + - 调用 self.compute_block_attention(q, k, v, ...) + - 调用 self.merge_attention_outputs(...) + 4. 计算当前 chunk attention + 5. 合并最终结果 + 6. 返回 final_output +``` + +## 关键设计决策 + +| 决策 | 说明 | +|------|------| +| **决策 1** | `compute_block_attention` 是抽象方法,所有 policy 必须实现 | +| **决策 2** | `merge_attention_outputs` 是抽象方法,所有 policy 必须实现 | +| **决策 3** | `compute_prefill_attention` 是抽象方法,定义完整的 prefill 流程 | +| **决策 4** | `select_blocks` 接收 `offload_engine` 参数(为未来准备) | +| **决策 5** | chunked_prefill 检查 policy 是否存在,不存在则抛出错误 | +| **决策 6** | model_runner 默认创建 FullPolicy 作为兜底 | +| **决策 7** | attention.py 的 _chunked_prefill_attention 不包含任何 flashattn 或 merge 调用 | ## Phases -### Phase 1: Create Model Registry Pattern [pending] -**Files to modify:** -- `nanovllm/models/__init__.py` (new) -- `nanovllm/models/registry.py` (new) +- [ ] Phase 1: 分析当前架构,理解所有计算逻辑的位置 +- [ ] Phase 2: 在 SparsePolicy 基类中添加三个抽象方法 +- [ ] Phase 3: 修改 FullPolicy,实现三个抽象方法 +- [ ] Phase 4: 修改 QuestPolicy,实现三个抽象方法 +- [ ] Phase 5: 修改 XAttentionBSAPolicy,实现三个抽象方法 +- [ ] Phase 6: 修改 model_runner.py,默认创建 FullPolicy +- [ ] Phase 7: 修改 attention.py,移除所有计算逻辑,只调用 policy +- [ ] Phase 8: 测试验证 -**Tasks:** -1. 创建模型注册表机制 -2. 定义模型注册装饰器 `@register_model` -3. 实现 `get_model_class(hf_config)` 函数,根据 `architectures` 字段自动选择模型 +## Phase 1: 分析当前架构,理解所有计算逻辑的位置 + +### 当前 attention.py 中包含的计算逻辑 + +1. `_ring_buffer_pipeline_load` 方法: + - 调用 `offload_engine.load_to_slot_layer()` + - 调用 `offload_engine.wait_slot_layer()` + - 调用 `offload_engine.get_kv_for_slot()` + - 调用 `flash_attn_with_lse()` ← **直接调用** + - 调用 `merge_attention_outputs()` ← **直接调用** + +2. `_sync_load_previous_chunks` 方法: + - 同上,直接调用 flashattn 和 merge + +3. `_chunked_prefill_attention` 方法: + - 调用 `_ring_buffer_pipeline_load` 或 `_sync_load_previous_chunks` + - 调用 `flash_attn_with_lse()` 计算当前 chunk + - 调用 `merge_attention_outputs()` 合并结果 + +### 需要移动的计算逻辑 + +所有 `flash_attn_with_lse` 和 `merge_attention_outputs` 调用都应该在 SparsePolicy 内部。 + +## Phase 2: 在 SparsePolicy 基类中添加三个抽象方法 + +### 2.1 compute_block_attention -**Design:** ```python -MODEL_REGISTRY: dict[str, type] = {} +@abstractmethod +def compute_block_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + causal: bool, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + 计算单个 block 的 attention。 -def register_model(*architectures): - """Decorator to register a model class for given architecture names.""" - def decorator(cls): - for arch in architectures: - MODEL_REGISTRY[arch] = cls - return cls - return decorator + Args: + q: [1, seq_len, num_heads, head_dim] 或 [seq_len, num_heads, head_dim] + k, v: 同上 + layer_id: 层索引 + softmax_scale: softmax 缩放因子 + causal: 是否应用因果掩码 -def get_model_class(hf_config) -> type: - """Get model class based on HF config architectures.""" - for arch in hf_config.architectures: - if arch in MODEL_REGISTRY: - return MODEL_REGISTRY[arch] - raise ValueError(f"Unsupported architecture: {hf_config.architectures}") + Returns: + (o, lse) - attention 输出和 LSE + """ + pass ``` -### Phase 2: Add Llama3 RoPE Scaling Support [pending] -**Files to modify:** -- `nanovllm/layers/rotary_embedding.py` +### 2.2 merge_attention_outputs -**Tasks:** -1. 实现 `Llama3RotaryEmbedding` 类,支持 llama3 rope_type -2. 修改 `get_rope()` 函数,根据 rope_scaling 类型选择实现 -3. 保持向后兼容(rope_scaling=None 使用原实现) - -**Llama3 RoPE Scaling Formula:** ```python -# From transformers: -# low_freq_factor, high_freq_factor, original_max_position_embeddings -# Adjust frequencies based on wavelength thresholds +@abstractmethod +def merge_attention_outputs( + self, + o_acc: torch.Tensor, + lse_acc: Optional[torch.Tensor], + o_new: torch.Tensor, + lse_new: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + 合并两个 attention 输出。 + + Args: + o_acc: 累积的 attention 输出 [1, seq_len, num_heads, head_dim] + lse_acc: 累积的 LSE + o_new: 新的 attention 输出 + lse_new: 新的 LSE + + Returns: + (merged_o, merged_lse) + """ + pass ``` -### Phase 3: Implement Llama Model [pending] -**Files to create:** -- `nanovllm/models/llama.py` +### 2.3 compute_chunked_attention -**Tasks:** -1. 创建 `LlamaAttention` 类(无 q_norm/k_norm,无 QKV bias) -2. 创建 `LlamaMLP` 类(与 Qwen3MLP 类似,无 bias) -3. 创建 `LlamaDecoderLayer` 类 -4. 创建 `LlamaModel` 和 `LlamaForCausalLM` 类 -5. 添加 `packed_modules_mapping` 以支持权重加载 -6. 使用 `@register_model("LlamaForCausalLM")` 注册 +```python +@abstractmethod +def compute_chunked_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + offload_engine: OffloadEngine, + current_chunk_idx: int, + seq: ChunkedSequence, + num_tokens: int, +) -> torch.Tensor: + """ + 计算 chunked prefill attention(完整流程)。 -### Phase 4: Modify ModelRunner for Dynamic Loading [pending] -**Files to modify:** -- `nanovllm/engine/model_runner.py` + 这是 policy 的主入口,定义完整的 prefill 计算流程: + 1. 获取历史 blocks + 2. 筛选 blocks(调用 select_blocks) + 3. 加载和计算历史 blocks + 4. 计算当前 chunk attention + 5. 合并所有结果 -**Tasks:** -1. 移除硬编码 `from nanovllm.models.qwen3 import Qwen3ForCausalLM` -2. 导入 `from nanovllm.models import get_model_class` -3. 替换 `self.model = Qwen3ForCausalLM(hf_config)` 为: - ```python - model_class = get_model_class(hf_config) - self.model = model_class(hf_config) - ``` + Args: + q, k, v: 当前 chunk 的 QKV + layer_id: 层索引 + softmax_scale: softmax 缩放因子 + offload_engine: offload engine + current_chunk_idx: 当前 chunk 索引 + seq: chunked 序列 + num_tokens: 当前 chunk 的 token 数 -### Phase 5: Register Qwen3 Model [pending] -**Files to modify:** -- `nanovllm/models/qwen3.py` + Returns: + [seq_len, num_heads, head_dim] 最终 attention输出 + """ + pass +``` -**Tasks:** -1. 导入 `from nanovllm.models.registry import register_model` -2. 添加 `@register_model("Qwen3ForCausalLM", "Qwen2ForCausalLM")` 装饰器 +### 2.4 修改 select_blocks 接口 -### Phase 6: Test with Llama-3.1-8B-Instruct [pending] -**Files:** -- `tests/test_needle.py` (existing, use for validation) +```python +def select_blocks( + self, + available_blocks: List[int], + offload_engine: OffloadEngine, + ctx: PolicyContext, +) -> List[int]: + """ + 选择要加载的 blocks。 -**Tasks:** -1. 运行 needle 测试: `python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct` -2. 验证模型加载正确 -3. 验证推理输出正确 + Args: + available_blocks: 所有可用的 block IDs + offload_engine: offload engine(为未来准备,当前可能不使用) + ctx: policy context ---- + Returns: + 选择的 block IDs + """ + pass +``` + +## Phase 3: 修改 FullPolicy,实现三个抽象方法 + +### 3.1 FullPolicy.compute_block_attention + +直接调用 `flash_attn_with_lse`,处理 3D 输入。 + +### 3.2 FullPolicy.merge_attention_outputs + +调用 `chunked_attention.merge_attention_outputs`。 + +### 3.3 FullPolicy.compute_prefill_attention + +实现完整的 prefill 流程: +1. 获取 `cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)` +2. 调用 `select_blocks(cpu_block_table, offload_engine, ctx)` +3. 遍历 blocks: + - `offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)` + - `offload_engine.wait_slot_layer(slot)` + - `k, v = offload_engine.get_kv_for_slot(slot)` + - 调用 `self.compute_block_attention(q, k, v, layer_id, scale, causal=False)` + - 调用 `self.merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)` +4. 计算当前 chunk attention +5. 合并最终结果 + +### 需要移动的代码 + +从 `attention.py` 的 `_ring_buffer_pipeline_load` 和 `_sync_load_previous_chunks` 移动逻辑: +- slot 遍历逻辑 +- offload_engine 调用 +- 计算和合并逻辑 + +从 `attention.py` 的 `_chunked_prefill_attention` 移动逻辑: +- 当前 chunk 的 attention 计算 +- 最终合并逻辑 + +## Phase 4: 修改 QuestPolicy + +QuestPolicy 实现与 FullPolicy 类似,区别在于: +- `select_blocks` 返回 Top-K blocks +- 其他计算逻辑相同 + +## Phase 5: 修改 XAttentionBSAPolicy + +当前 XAttentionBSAPolicy 只返回所有 blocks,修改后: +- `select_blocks` 当前返回所有 blocks +- `compute_block_attention` 与 FullPolicy 相同 +- `merge_attention_outputs` 与 FullPolicy 相同 +- `compute_prefill_attention` 与 FullPolicy 相同 + +未来可以实现稀疏计算。 + +## Phase 6: 修改 model_runner.py,默认创建 FullPolicy + +### 6.1 当前创建 sparse policy 的逻辑 + +```python +# 当前:只有指定 sparse_policy_type 时才创建 +if sparse_policy_type is not None: + sparse_policy = create_sparse_policy(sparse_policy_type, **kwargs) +``` + +### 6.2 修改后 + +```python +# 默认创建 FullPolicy +if sparse_policy_type is None: + sparse_policy_type = SparsePolicyType.FULL + +sparse_policy = create_sparse_policy(sparse_policy_type, **kwargs) +``` + +### 6.3 位置 + +`model_runner.py` 中的 `allocate_kv_cache` 方法。 + +## Phase 7: 修改 attention.py,移除所有计算逻辑 + +### 7.1 _chunked_prefill_attention 简化 + +**当前(伪代码)**: +```python +# 获取 cpu_block_table +# 调用 select_blocks +# 调用 _ring_buffer_pipeline_load(包含计算逻辑) +# 计算当前 chunk(flash_attn) +# 合并结果(merge) +``` + +**修改后**: +```python +sparse_policy = kvcache_manager.sparse_policy +if sparse_policy is None: + raise RuntimeError("sparse_policy is required for chunked prefill") + +o = sparse_policy.compute_prefill_attention( + q, k, v, self.layer_id, self.scale, + offload_engine, current_chunk_idx, seq, num_tokens +) + +# 直接返回,不需要合并(policy 内部已完成所有计算) +return o +``` + +### 7.2 删除的方法 + +删除以下方法(逻辑移到 policy 中): +- `_ring_buffer_pipeline_load` - 逻辑移到 FullPolicy.compute_prefill_attention +- `_sync_load_previous_chunks` - 逻辑移到 FullPolicy.compute_prefill_attention + +### 7.3 保留的方法 + +- `_decode_with_layer_pipeline` - decode 逻辑保持不变 +- `_decode_ring_buffer_pipeline` - decode 逻辑保持不变 + +## Phase 8: 测试验证 + +- [ ] 运行 `test_needle.py --enable-offload` (FULL policy) +- [ ] 验证输出正确 (needle value: 7492) +- [ ] 验证性能无明显下降 + +## 关键文件清单 + +| 文件 | 修改内容 | +|------|----------| +| `nanovllm/kvcache/sparse/policy.py` | 添加三个抽象方法,修改 select_blocks 签名 | +| `nanovllm/kvcache/sparse/full_policy.py` | 实现三个抽象方法,移动计算逻辑 | +| `nanovllm/kvcache/sparse/quest.py` | 实现三个抽象方法 | +| `nanovllm/kvcache/sparse/xattn_bsa.py` | 实现三个抽象方法 | +| `nanovllm/engine/model_runner.py` | 默认创建 FullPolicy | +| `nanovllm/layers/attention.py` | 简化 _chunked_prefill_attention,删除计算方法 | + +## Decisions Made + +- **决策 1**: 三个方法都是抽象方法,强制所有 policy 实现 +- **决策 2**: compute_prefill_attention 定义完整的 prefill 流程,是 policy 的主入口 +- **决策 3**: attention.py 只调用 policy.compute_prefill_attention,零计算逻辑 +- **决策 4**: chunked_prefill 检查 policy 是否存在,不存在则抛出错误 +- **决策 5**: model_runner 默认创建 FullPolicy 作为兜底 +- **决策 6**: _ring_buffer_pipeline_load 和 _sync_load_previous_chunks 删除,逻辑移到 policy ## Errors Encountered -| Error | Attempt | Resolution | -|-------|---------|------------| -| (none yet) | | | ---- +(待记录) -## Success Criteria -- [x] 分析完成:理解当前架构和需要的改动 -- [ ] Phase 1: 模型注册表实现 -- [ ] Phase 2: Llama3 RoPE scaling 支持 -- [ ] Phase 3: Llama 模型实现 -- [ ] Phase 4: ModelRunner 动态加载 -- [ ] Phase 5: Qwen3 模型注册 -- [ ] Phase 6: Llama needle 测试通过 +## Status ---- - -## Notes -- 保持现有 Qwen3 功能不变 -- 遵循现有代码风格 -- 复用现有 layers 组件(Linear, RMSNorm, Embedding 等) -- 只添加必要的代码,不过度工程化 +**Currently in Phase 1** - 分析当前架构,理解所有计算逻辑的位置