feat: integrate sparse policy architecture into GPU-only mode

- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class
- Implement GPU-only methods in FullAttentionPolicy using flash_attn
- Add sparse_policy parameter to GPUOnlyManager
- Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode
- Route GPU-only attention through sparse_policy in attention.py
- Pass kvcache_manager to context for policy access
- Add --enable-policy flag to bench.py for testing
- Handle warmup phase when kvcache_manager is not yet allocated

This allows GPU-only mode to use the same policy architecture as CPU offload mode,
enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode.

Performance verified: ~4890 tok/s (unchanged from baseline)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 05:08:02 +08:00
parent 05ce57ee8e
commit 09b2136e9f
7 changed files with 287 additions and 25 deletions

View File

@@ -40,6 +40,8 @@ def bench_prefill(llm, num_seqs, input_len):
def main(): def main():
import argparse import argparse
from nanovllm.config import SparsePolicyType
parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance") parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance")
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct", parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)") help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
@@ -48,11 +50,20 @@ def main():
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)") parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)") parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks") parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
# Sparse policy option (GPU-only mode now supports policy routing)
parser.add_argument("--enable-policy", action="store_true",
help="Enable sparse policy routing (FullAttentionPolicy by default)")
args = parser.parse_args() args = parser.parse_args()
path = os.path.expanduser(args.model) path = os.path.expanduser(args.model)
max_len = args.max_len max_len = args.max_len
# Configure sparse policy
if args.enable_policy:
sparse_policy = SparsePolicyType.FULL
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
else:
sparse_policy = None
print(f"\n[nanovllm GPU] max_len={max_len}") print(f"\n[nanovllm GPU] max_len={max_len}")
llm = LLM( llm = LLM(
@@ -60,6 +71,7 @@ def main():
enforce_eager=False, enforce_eager=False,
max_model_len=max_len, max_model_len=max_len,
max_num_batched_tokens=max_len, max_num_batched_tokens=max_len,
sparse_policy=sparse_policy,
) )
# Warmup # Warmup

View File

@@ -195,19 +195,23 @@ class ModelRunner:
dtype=hf_config.torch_dtype, dtype=hf_config.torch_dtype,
) )
# Initialize sparse policy if manager has one (CPU offload mode) # Initialize sparse policy if manager has one (works for both CPU offload and GPU-only modes)
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None: if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
# Use CPU blocks for offload mode, GPU blocks for GPU-only mode
num_blocks_for_init = config.num_cpu_kvcache_blocks if config.enable_cpu_offload else config.num_kvcache_blocks
self.kvcache_manager.sparse_policy.initialize( self.kvcache_manager.sparse_policy.initialize(
num_layers=hf_config.num_hidden_layers, num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_dim, head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks, num_cpu_blocks=num_blocks_for_init,
dtype=hf_config.torch_dtype, dtype=hf_config.torch_dtype,
device=torch.device("cuda"), device=torch.device("cuda"),
) )
# Log policy info (handle both enum and None cases)
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
logger.info( logger.info(
f"Sparse policy initialized: {config.sparse_policy.name} " f"Sparse policy initialized: {policy_name} "
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})" f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
) )
@@ -368,7 +372,16 @@ class ModelRunner:
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables) set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
slot_mapping=slot_mapping,
block_tables=block_tables,
kvcache_manager=getattr(self, 'kvcache_manager', None),
)
return input_ids, positions return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]): def prepare_decode(self, seqs: list[Sequence]):
@@ -397,7 +410,13 @@ class ModelRunner:
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Use GPU physical block tables for attention # Use GPU physical block tables for attention
block_tables = self._prepare_gpu_block_tables(gpu_block_tables) block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
kvcache_manager=self.kvcache_manager,
)
return input_ids, positions return input_ids, positions
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]): def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
@@ -698,7 +717,13 @@ class ModelRunner:
for bs in reversed(self.graph_bs): for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) set_context(
is_prefill=False,
slot_mapping=slot_mapping[:bs],
context_lens=context_lens[:bs],
block_tables=block_tables[:bs],
kvcache_manager=self.kvcache_manager,
)
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, self.graph_pool): with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture

View File

@@ -25,7 +25,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
Factory function to create the appropriate KV cache manager. Factory function to create the appropriate KV cache manager.
Decision logic: Decision logic:
1. If enable_cpu_offload=False: use GPUOnlyManager 1. If enable_cpu_offload=False: use GPUOnlyManager (optionally with sparse policy)
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager 2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager 3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
@@ -37,9 +37,44 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
""" """
if not getattr(config, 'enable_cpu_offload', False): if not getattr(config, 'enable_cpu_offload', False):
# Default: pure GPU mode # Default: pure GPU mode
# Check if sparse policy is requested for GPU-only mode
from nanovllm.config import SparsePolicyType
sparse_policy_type = getattr(config, 'sparse_policy', None)
# Handle None case - use FULL as default
if sparse_policy_type is None:
sparse_policy_type = SparsePolicyType.FULL
sparse_policy = None
if sparse_policy_type != SparsePolicyType.FULL:
# Create sparse policy for GPU-only mode
from nanovllm.kvcache.sparse import create_sparse_policy
policy_kwargs = {}
if sparse_policy_type == SparsePolicyType.QUEST:
policy_kwargs = {
'topk_blocks': getattr(config, 'sparse_topk_blocks', 8),
'threshold_blocks': getattr(config, 'sparse_threshold_blocks', 4),
}
elif sparse_policy_type == SparsePolicyType.XATTN_BSA:
policy_kwargs = {
'block_size': getattr(config, 'sparse_block_size', 128),
'samples_per_chunk': getattr(config, 'sparse_samples_per_chunk', 128),
'threshold': getattr(config, 'sparse_threshold', 0.9),
'use_triton': getattr(config, 'sparse_use_triton', True),
'stride': getattr(config, 'sparse_stride', 8),
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
}
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
else:
# FULL policy for GPU-only mode - always create for consistent API
from nanovllm.kvcache.sparse import FullAttentionPolicy
sparse_policy = FullAttentionPolicy()
return GPUOnlyManager( return GPUOnlyManager(
num_blocks=config.num_kvcache_blocks, num_blocks=config.num_kvcache_blocks,
block_size=config.kvcache_block_size, block_size=config.kvcache_block_size,
sparse_policy=sparse_policy,
) )
# CPU offload is enabled # CPU offload is enabled

View File

@@ -7,13 +7,16 @@ the KVCacheManager interface.
""" """
from collections import deque from collections import deque
from typing import List, Tuple, Dict, Optional from typing import List, Tuple, Dict, Optional, TYPE_CHECKING
import torch import torch
from torch import Tensor from torch import Tensor
from nanovllm.engine.sequence import Sequence from nanovllm.engine.sequence import Sequence
from nanovllm.kvcache.base_manager import KVCacheManager from nanovllm.kvcache.base_manager import KVCacheManager
if TYPE_CHECKING:
from nanovllm.kvcache.sparse.policy import SparsePolicy
class Block: class Block:
"""Physical block in GPU memory.""" """Physical block in GPU memory."""
@@ -50,17 +53,28 @@ class GPUOnlyManager(KVCacheManager):
all data stays on GPU at fixed addresses. all data stays on GPU at fixed addresses.
""" """
def __init__(self, num_blocks: int, block_size: int): def __init__(
self,
num_blocks: int,
block_size: int,
sparse_policy: Optional["SparsePolicy"] = None,
):
""" """
Initialize GPU-only manager. Initialize GPU-only manager.
Args: Args:
num_blocks: Total number of blocks to manage num_blocks: Total number of blocks to manage
block_size: Tokens per block (default 256) block_size: Tokens per block (default 256)
sparse_policy: Optional sparse attention policy for GPU-only mode
""" """
self._block_size = block_size self._block_size = block_size
self._num_blocks = num_blocks self._num_blocks = num_blocks
# Sparse policy for GPU-only mode (optional)
self.sparse_policy = sparse_policy
# No offload engine in GPU-only mode
self.offload_engine = None
# Block metadata # Block metadata
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)] self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]

View File

@@ -76,6 +76,75 @@ class FullAttentionPolicy(SparsePolicy):
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, " logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
f"blocks={stats['total_available_blocks']}, density=100.0%") f"blocks={stats['total_available_blocks']}, density=100.0%")
# =========================================================================
# GPU-only methods (non-chunked)
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
GPU-only prefill attention using flash_attn_varlen_func.
This is the simplest implementation - just call flash attention directly.
For sparse policies, this method would implement block selection.
"""
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
block_table=block_tables,
)
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
GPU-only decode attention using flash_attn_with_kvcache.
This is the simplest implementation - just call flash attention directly.
For sparse policies, this method would implement block selection.
"""
from flash_attn import flash_attn_with_kvcache
# q is [batch, num_heads, head_dim], need to add seq dim
return flash_attn_with_kvcache(
q.unsqueeze(1), # [batch, 1, heads, dim]
k_cache,
v_cache,
cache_seqlens=cache_seqlens,
block_table=block_tables,
softmax_scale=softmax_scale,
causal=True,
)
# =========================================================================
# Chunked offload methods
# =========================================================================
def compute_chunked_prefill( def compute_chunked_prefill(
self, self,
q: torch.Tensor, q: torch.Tensor,

View File

@@ -191,6 +191,87 @@ class SparsePolicy(ABC):
""" """
pass pass
# =========================================================================
# GPU-only methods (non-chunked)
# These methods are used when all KV cache is on GPU, no CPU offload needed.
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute GPU-only prefill attention (non-chunked).
This method is used when all KV cache resides on GPU (no CPU offload).
Override this to implement sparse prefill attention for GPU-only mode.
Default implementation raises NotImplementedError.
Args:
q: [total_q, num_heads, head_dim] query tensor (packed variable length)
k: [total_kv, num_kv_heads, head_dim] key tensor
v: [total_kv, num_kv_heads, head_dim] value tensor
cu_seqlens_q: [batch+1] cumulative sequence lengths for queries
cu_seqlens_k: [batch+1] cumulative sequence lengths for keys
max_seqlen_q: maximum query sequence length
max_seqlen_k: maximum key sequence length
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
layer_id: transformer layer index
block_tables: [batch, max_blocks] paged attention block tables (optional)
Returns:
[total_q, num_heads, head_dim] attention output
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement compute_prefill for GPU-only mode"
)
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute GPU-only decode attention (non-chunked).
This method is used when all KV cache resides on GPU (no CPU offload).
Override this to implement sparse decode attention for GPU-only mode.
Default implementation raises NotImplementedError.
Args:
q: [batch, num_heads, head_dim] query tensor (single token per sequence)
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged key cache
v_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged value cache
cache_seqlens: [batch] sequence lengths in cache
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
layer_id: transformer layer index
block_tables: [batch, max_blocks] paged attention block tables (optional)
Returns:
[batch, 1, num_heads, head_dim] attention output
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement compute_decode for GPU-only mode"
)
# =========================================================================
# Chunked offload methods (for CPU offload mode)
# =========================================================================
@abstractmethod @abstractmethod
def compute_chunked_prefill( def compute_chunked_prefill(
self, self,

View File

@@ -124,24 +124,47 @@ class Attention(nn.Module):
if k_cache.numel() and v_cache.numel(): if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Get sparse_policy from kvcache_manager (required, never None after warmup)
# During warmup, kvcache_manager is not yet allocated
if context.kvcache_manager is None:
# Warmup phase: use flash_attn directly
if context.is_prefill:
return flash_attn_varlen_func(
q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True,
)
else:
return flash_attn_with_kvcache(
q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True,
)
sparse_policy = context.kvcache_manager.sparse_policy
assert sparse_policy is not None, "sparse_policy must not be None"
if context.is_prefill: if context.is_prefill:
if context.is_chunked_prefill: if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV # Chunked prefill: merge attention from previous KV (CPU offload mode)
o = self._chunked_prefill_attention(q, k, v, context) o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: else:
o = flash_attn_varlen_func(q, k, v, # GPU-only mode: use policy for attention
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, # Use paged attention if block_tables provided, else use k, v directly
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, if context.block_tables is not None:
softmax_scale=self.scale, causal=True, block_table=context.block_tables) k_for_attn, v_for_attn = k_cache, v_cache
else:
k_for_attn, v_for_attn = k, v
o = sparse_policy.compute_prefill(
q, k_for_attn, v_for_attn,
context.cu_seqlens_q, context.cu_seqlens_k,
context.max_seqlen_q, context.max_seqlen_k,
self.scale, self.layer_id,
context.block_tables,
)
else: # decode else: # decode
if context.is_chunked_prefill: if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU # Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
# Store current decode token to per-layer decode buffer # Store current decode token to per-layer decode buffer
# This is needed because GPU cache has no layer dimension, # This is needed because GPU cache has no layer dimension,
# so all layers would overwrite each other in decode_slot. # so all layers would overwrite each other in decode_slot.
@@ -152,9 +175,12 @@ class Attention(nn.Module):
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0)) offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context) o = self._chunked_decode_attention(q, k, v, context)
else: else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, # GPU-only mode: use policy for attention
cache_seqlens=context.context_lens, block_table=context.block_tables, o = sparse_policy.compute_decode(
softmax_scale=self.scale, causal=True) q, k_cache, v_cache,
context.context_lens, self.scale, self.layer_id,
context.block_tables,
)
return o return o
def _chunked_prefill_attention( def _chunked_prefill_attention(