♻️ refactor: create ops module and move chunked_attention
- Create nanovllm/ops/ module for low-level attention operators - Move chunked_attention.py from kvcache/ to ops/ - Update imports in full_policy.py (3 locations) - Fix: remove dead code in OffloadEngine.reset() referencing non-existent layer_k/v_buffer_a/b attributes Verified with needle test (32K offload): PASSED Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -255,7 +255,6 @@ class OffloadEngine:
|
|||||||
Clears:
|
Clears:
|
||||||
- GPU ring buffer slots (k_cache_gpu, v_cache_gpu)
|
- GPU ring buffer slots (k_cache_gpu, v_cache_gpu)
|
||||||
- Per-layer decode buffers (decode_k_buffer, decode_v_buffer)
|
- Per-layer decode buffers (decode_k_buffer, decode_v_buffer)
|
||||||
- Cross-layer pipeline buffers (layer_k/v_buffer_a/b)
|
|
||||||
- Per-layer prefill buffers (prefill_k/v_buffer)
|
- Per-layer prefill buffers (prefill_k/v_buffer)
|
||||||
- All pending async transfer events
|
- All pending async transfer events
|
||||||
"""
|
"""
|
||||||
@@ -267,12 +266,6 @@ class OffloadEngine:
|
|||||||
self.decode_k_buffer.zero_()
|
self.decode_k_buffer.zero_()
|
||||||
self.decode_v_buffer.zero_()
|
self.decode_v_buffer.zero_()
|
||||||
|
|
||||||
# Clear cross-layer pipeline buffers
|
|
||||||
self.layer_k_buffer_a.zero_()
|
|
||||||
self.layer_v_buffer_a.zero_()
|
|
||||||
self.layer_k_buffer_b.zero_()
|
|
||||||
self.layer_v_buffer_b.zero_()
|
|
||||||
|
|
||||||
# Clear per-layer prefill buffers
|
# Clear per-layer prefill buffers
|
||||||
self.prefill_k_buffer.zero_()
|
self.prefill_k_buffer.zero_()
|
||||||
self.prefill_v_buffer.zero_()
|
self.prefill_v_buffer.zero_()
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
||||||
@@ -222,7 +222,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
Attention output [batch_size, 1, num_heads, head_dim]
|
Attention output [batch_size, 1, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||||
@@ -319,7 +319,7 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
Loads one block at a time, computes attention, and merges results.
|
Loads one block at a time, computes attention, and merges results.
|
||||||
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
num_blocks = len(cpu_block_table)
|
num_blocks = len(cpu_block_table)
|
||||||
if num_blocks == 0:
|
if num_blocks == 0:
|
||||||
|
|||||||
19
nanovllm/ops/__init__.py
Normal file
19
nanovllm/ops/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Operators module for nano-vLLM.
|
||||||
|
|
||||||
|
This module contains low-level attention operators and kernels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nanovllm.ops.chunked_attention import (
|
||||||
|
flash_attn_with_lse,
|
||||||
|
merge_attention_outputs,
|
||||||
|
chunked_attention_varlen,
|
||||||
|
ChunkedPrefillState,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"flash_attn_with_lse",
|
||||||
|
"merge_attention_outputs",
|
||||||
|
"chunked_attention_varlen",
|
||||||
|
"ChunkedPrefillState",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user