Merge branch 'tzj/minference' of ssh://git.zijie-tian.site:2222/zijie-tian/nano-vllm into tzj/minference

This commit is contained in:
Zijie Tian
2026-01-20 02:16:39 +08:00
21 changed files with 1743 additions and 698 deletions

View File

@@ -64,11 +64,24 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
# Create sparse policy from config enum
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
sparse_policy = create_sparse_policy(
sparse_policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
)
# Build policy kwargs based on policy type
policy_kwargs = {}
if sparse_policy_type == SparsePolicyType.QUEST:
policy_kwargs = {
'topk_blocks': getattr(config, 'sparse_topk_blocks', 8),
'threshold_blocks': getattr(config, 'sparse_threshold_blocks', 4),
}
elif sparse_policy_type == SparsePolicyType.XATTN_BSA:
policy_kwargs = {
'block_size': getattr(config, 'sparse_block_size', 128),
'samples_per_chunk': getattr(config, 'sparse_samples_per_chunk', 128),
'threshold': getattr(config, 'sparse_threshold', 0.9),
'use_triton': getattr(config, 'sparse_use_triton', True),
'stride': getattr(config, 'sparse_stride', 8),
}
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,

View File

@@ -905,3 +905,60 @@ class OffloadEngine:
def wait_prefill_offload(self, layer_id: int) -> None:
"""Wait for a specific layer's prefill offload to complete."""
self.prefill_offload_events[layer_id].synchronize()
# ========== XAttention BSA Helper Methods ==========
def load_block_sample_from_cpu(
self,
cpu_block_id: int,
layer_id: int,
num_samples: int,
) -> Tuple[Tensor, Tensor]:
"""
Load sample tokens from a CPU block for XAttention BSA estimation.
This is used in the estimate phase of XAttention BSA to load a small
sample of tokens from each historical chunk for importance estimation.
Args:
cpu_block_id: Source CPU block ID
layer_id: Layer index
num_samples: Number of tokens to sample
Returns:
(k_sample, v_sample) tensors, shape: [num_samples, kv_heads, head_dim]
"""
# Sample from the beginning of the block
k_sample = self.k_cache_cpu[
layer_id, cpu_block_id, :num_samples
].clone().cuda()
v_sample = self.v_cache_cpu[
layer_id, cpu_block_id, :num_samples
].clone().cuda()
return k_sample, v_sample
def load_block_full_from_cpu(
self,
cpu_block_id: int,
layer_id: int,
) -> Tuple[Tensor, Tensor]:
"""
Load full tokens from a CPU block for XAttention BSA computation.
This is used in the compute phase of XAttention BSA to load the full
data for selected important chunks.
Args:
cpu_block_id: Source CPU block ID
layer_id: Layer index
Returns:
(k_full, v_full) tensors, shape: [block_size, kv_heads, head_dim]
"""
k_full = self.k_cache_cpu[
layer_id, cpu_block_id
].clone().cuda()
v_full = self.v_cache_cpu[
layer_id, cpu_block_id
].clone().cuda()
return k_full, v_full

View File

@@ -23,6 +23,7 @@ from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.xattn_bsa import XAttentionBSAPolicy
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
@@ -55,6 +56,13 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
)
return QuestPolicy(config)
elif policy_type == SparsePolicyType.XATTN_BSA:
return XAttentionBSAPolicy(
block_size=kwargs.get("block_size", 128),
samples_per_chunk=kwargs.get("samples_per_chunk", 128),
threshold=kwargs.get("threshold", 0.9),
)
else:
raise ValueError(f"Unknown policy type: {policy_type}")
@@ -67,5 +75,6 @@ __all__ = [
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"XAttentionBSAPolicy",
"create_sparse_policy",
]

View File

@@ -5,8 +5,19 @@ This serves as a baseline and default policy when sparse
attention is not needed.
"""
from typing import List
import logging
import torch
from typing import List, Optional, TYPE_CHECKING
from .policy import SparsePolicy, PolicyContext
from nanovllm.utils.context import get_context
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.manager import KVCacheManager
from nanovllm.engine.sequence import Sequence
logger = logging.getLogger(__name__)
class FullAttentionPolicy(SparsePolicy):
@@ -29,10 +40,157 @@ class FullAttentionPolicy(SparsePolicy):
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
) -> List[int]:
"""Return all blocks - no sparsity."""
return available_blocks
def compute_chunked_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
) -> torch.Tensor:
"""
Compute full attention for chunked prefill.
This method handles the complete chunked prefill flow:
1. Get historical blocks
2. Select blocks via select_blocks
3. Load and compute attention to historical chunks
4. Compute attention to current chunk
5. Merge all results
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
layer_id: Current layer index
softmax_scale: Softmax scaling factor
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
current_chunk_idx: Current chunk index
seq: Sequence object
num_tokens: Number of tokens in current chunk
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, "
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Step 1: Get historical blocks
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks
if cpu_block_table:
num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
logger.debug(f"[DEBUG] select_blocks: output={len(cpu_block_table)} blocks")
if cpu_block_table:
load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table)
if len(load_slots) == 1:
# Only 1 slot - use synchronous mode
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
offload_engine.record_slot_compute_done(slot)
else:
# Multiple slots - use pipeline
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
offload_engine.record_slot_compute_done(current_slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Issue next transfer
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
# Step 4: Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched, k_curr, v_curr,
softmax_scale=softmax_scale,
causal=True,
)
# Step 5: Merge historical and current attention
with torch.cuda.stream(compute_stream):
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Sync default stream with compute_stream before returning
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
return final_o.squeeze(0)
def __repr__(self) -> str:
return "FullAttentionPolicy()"

View File

@@ -7,12 +7,17 @@ from CPU for each query chunk during chunked attention computation.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Any
from typing import List, Optional, Any, TYPE_CHECKING
import torch
# Import SparsePolicyType from config to avoid circular imports
from nanovllm.config import SparsePolicyType
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.manager import KVCacheManager
from nanovllm.engine.sequence import Sequence
@dataclass
class PolicyContext:
@@ -35,8 +40,8 @@ class PolicyContext:
query: Optional[torch.Tensor]
"""
Query tensor for current chunk.
Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill.
May be None if not available (e.g., some prefill scenarios).
Shape: [1, num_heads, head_dim] for decode, [seq_len, num_heads, head_dim] for prefill.
Available for both prefill and decode phases.
"""
is_prefill: bool
@@ -107,6 +112,7 @@ class SparsePolicy(ABC):
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
) -> List[int]:
"""
@@ -120,6 +126,8 @@ class SparsePolicy(ABC):
available_blocks: List of CPU block IDs that contain KV cache
from previous chunks. These are ordered by
their position in the sequence.
offload_engine: OffloadEngine for loading KV (some policies need
to load KV to make selection decisions).
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
@@ -183,5 +191,47 @@ class SparsePolicy(ABC):
"""
pass
@abstractmethod
def compute_chunked_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
) -> torch.Tensor:
"""
Compute chunked prefill attention (complete flow).
This is the main entry point for prefill attention computation.
It defines the complete prefill flow:
1. Get historical blocks
2. Select blocks (call select_blocks)
3. Load and compute historical blocks via offload_engine
4. Get current chunk KV from offload_engine, compute attention
5. Merge all results
Args:
q: [seq_len, num_heads, head_dim] query for current chunk
k: [seq_len, num_kv_heads, head_dim] key for current chunk (in prefill buffer)
v: [seq_len, num_kv_heads, head_dim] value for current chunk (in prefill buffer)
layer_id: transformer layer index
softmax_scale: softmax scaling factor
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
current_chunk_idx: current chunk index
seq: Sequence object
num_tokens: number of tokens in current chunk
Returns:
[seq_len, num_heads, head_dim] final attention output
"""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

View File

@@ -0,0 +1,70 @@
"""
XAttention Block Sparse Attention (BSA) Policy for nano-vllm.
This module implements XAttention-inspired block sparse attention for chunked prefill.
Current implementation loads all historical blocks (FULL strategy).
Sparse selection to be implemented in next phase.
"""
import torch
from typing import List, Optional, Tuple
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.utils.context import get_context
class XAttentionBSAPolicy(SparsePolicy):
"""
XAttention Block Sparse Attention policy for chunked prefill.
This policy uses block-level estimation to determine which KV blocks
are important for the current chunk's queries, enabling sparse computation.
Note: Current implementation loads all historical chunks (FULL strategy).
Sparse selection to be implemented in next phase.
"""
supports_prefill = False # Uses standard select_blocks interface
supports_decode = False # BSA is prefill-only
requires_block_selection = False # Selection happens at chunk level, not block level
def __init__(
self,
block_size: int = 128,
samples_per_chunk: int = 128,
threshold: float = 0.9,
):
"""
Initialize XAttention BSA policy.
Args:
block_size: Number of tokens per block (default: 128)
samples_per_chunk: Number of tokens to sample from each historical chunk for estimation
threshold: Cumulative attention threshold for chunk selection (0-1)
"""
self.block_size = block_size
self.samples_per_chunk = samples_per_chunk
self.threshold = threshold
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
"""
Select blocks to load from CPU.
Current implementation returns all blocks (FULL strategy).
Sparse selection to be implemented in next phase.
Args:
available_blocks: List of all available CPU block IDs
ctx: Policy context with query info, chunk index, etc.
Returns:
List of selected block IDs to load
"""
# Current: Return all blocks (FULL strategy)
# TODO: Implement sparse selection based on query attention estimation
return available_blocks
def reset(self) -> None:
"""Reset policy state."""
pass