🔀 merge: integrate tzj/minference-exp (GPU-only sparse attention)
Merge GPU-only sparse attention support from tzj/minference-exp branch: **GPU-only mode additions:** - Add compute_prefill/compute_decode methods to SparsePolicy base class - Add GPU-only attention routing in attention.py - Add alloc_policy_metadata() for pre-allocating GQA buffers - Add XAttention + BSA sparse attention for GPU-only prefill - Add kvcache_manager to set_context() for policy access **bench.py enhancements:** - Add --model argument for configurable model path - Add --policy argument (full, xattn) for sparse policy selection - Add --enable-policy flag for FullAttentionPolicy routing - Add --enforce-eager option to disable CUDA graphs - Add --gpu-util option for GPU memory utilization **Documentation:** - Add gpu_only_xattn_guide.md with performance analysis - Add gpu_only_sparse_integration.md baseline document - Add gpu-vram-requirement.md rule for GPU-only mode Both CPU offload and GPU-only paths are preserved and functional. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -202,19 +202,36 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Initialize sparse policy if manager has one (CPU offload mode)
|
||||
# Initialize sparse policy if manager has one (works for both CPU offload and GPU-only modes)
|
||||
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||
# Use CPU blocks for offload mode, GPU blocks for GPU-only mode
|
||||
num_blocks_for_init = config.num_cpu_kvcache_blocks if config.enable_cpu_offload else config.num_kvcache_blocks
|
||||
self.kvcache_manager.sparse_policy.initialize(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
||||
num_cpu_blocks=num_blocks_for_init,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# GPU-only mode: pre-allocate policy metadata buffers
|
||||
# This avoids dynamic GPU memory allocation during forward pass
|
||||
if not config.enable_cpu_offload:
|
||||
num_heads = hf_config.num_attention_heads // self.world_size
|
||||
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
max_seq_len=config.max_model_len,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# Log policy info (handle both enum and None cases)
|
||||
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
||||
logger.info(
|
||||
f"Sparse policy initialized: {config.sparse_policy.name} "
|
||||
f"Sparse policy initialized: {policy_name} "
|
||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||
)
|
||||
|
||||
@@ -375,7 +392,16 @@ class ModelRunner:
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
kvcache_manager=getattr(self, 'kvcache_manager', None),
|
||||
)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
@@ -404,7 +430,13 @@ class ModelRunner:
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
# Use GPU physical block tables for attention
|
||||
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
||||
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
)
|
||||
return input_ids, positions
|
||||
|
||||
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
|
||||
@@ -713,7 +745,13 @@ class ModelRunner:
|
||||
|
||||
for bs in reversed(self.graph_bs):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping[:bs],
|
||||
context_lens=context_lens[:bs],
|
||||
block_tables=block_tables[:bs],
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
)
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||
with torch.cuda.graph(graph, self.graph_pool):
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||
|
||||
@@ -25,7 +25,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
Factory function to create the appropriate KV cache manager.
|
||||
|
||||
Decision logic:
|
||||
1. If enable_cpu_offload=False: use GPUOnlyManager
|
||||
1. If enable_cpu_offload=False: use GPUOnlyManager (optionally with sparse policy)
|
||||
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
|
||||
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
|
||||
|
||||
@@ -37,9 +37,44 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
"""
|
||||
if not getattr(config, 'enable_cpu_offload', False):
|
||||
# Default: pure GPU mode
|
||||
# Check if sparse policy is requested for GPU-only mode
|
||||
from nanovllm.config import SparsePolicyType
|
||||
sparse_policy_type = getattr(config, 'sparse_policy', None)
|
||||
# Handle None case - use FULL as default
|
||||
if sparse_policy_type is None:
|
||||
sparse_policy_type = SparsePolicyType.FULL
|
||||
|
||||
sparse_policy = None
|
||||
if sparse_policy_type != SparsePolicyType.FULL:
|
||||
# Create sparse policy for GPU-only mode
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||
|
||||
policy_kwargs = {}
|
||||
if sparse_policy_type == SparsePolicyType.QUEST:
|
||||
policy_kwargs = {
|
||||
'topk_blocks': getattr(config, 'sparse_topk_blocks', 8),
|
||||
'threshold_blocks': getattr(config, 'sparse_threshold_blocks', 4),
|
||||
}
|
||||
elif sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
||||
policy_kwargs = {
|
||||
'block_size': getattr(config, 'sparse_block_size', 128),
|
||||
'samples_per_chunk': getattr(config, 'sparse_samples_per_chunk', 128),
|
||||
'threshold': getattr(config, 'sparse_threshold', 0.9),
|
||||
'use_triton': getattr(config, 'sparse_use_triton', True),
|
||||
'stride': getattr(config, 'sparse_stride', 8),
|
||||
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
|
||||
}
|
||||
|
||||
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
|
||||
else:
|
||||
# FULL policy for GPU-only mode - always create for consistent API
|
||||
from nanovllm.kvcache.sparse import FullAttentionPolicy
|
||||
sparse_policy = FullAttentionPolicy()
|
||||
|
||||
return GPUOnlyManager(
|
||||
num_blocks=config.num_kvcache_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
sparse_policy=sparse_policy,
|
||||
)
|
||||
|
||||
# CPU offload is enabled
|
||||
|
||||
@@ -7,13 +7,16 @@ the KVCacheManager interface.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from typing import List, Tuple, Dict, Optional, TYPE_CHECKING
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy
|
||||
|
||||
|
||||
class Block:
|
||||
"""Physical block in GPU memory."""
|
||||
@@ -50,17 +53,28 @@ class GPUOnlyManager(KVCacheManager):
|
||||
all data stays on GPU at fixed addresses.
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
sparse_policy: Optional["SparsePolicy"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize GPU-only manager.
|
||||
|
||||
Args:
|
||||
num_blocks: Total number of blocks to manage
|
||||
block_size: Tokens per block (default 256)
|
||||
sparse_policy: Optional sparse attention policy for GPU-only mode
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self._num_blocks = num_blocks
|
||||
|
||||
# Sparse policy for GPU-only mode (optional)
|
||||
self.sparse_policy = sparse_policy
|
||||
# No offload engine in GPU-only mode
|
||||
self.offload_engine = None
|
||||
|
||||
# Block metadata
|
||||
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
||||
|
||||
|
||||
@@ -76,6 +76,75 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
||||
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
# =========================================================================
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPU-only prefill attention using flash_attn_varlen_func.
|
||||
|
||||
This is the simplest implementation - just call flash attention directly.
|
||||
For sparse policies, this method would implement block selection.
|
||||
"""
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
)
|
||||
|
||||
def compute_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPU-only decode attention using flash_attn_with_kvcache.
|
||||
|
||||
This is the simplest implementation - just call flash attention directly.
|
||||
For sparse policies, this method would implement block selection.
|
||||
"""
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
|
||||
# q is [batch, num_heads, head_dim], need to add seq dim
|
||||
return flash_attn_with_kvcache(
|
||||
q.unsqueeze(1), # [batch, 1, heads, dim]
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_seqlens=cache_seqlens,
|
||||
block_table=block_tables,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Chunked offload methods
|
||||
# =========================================================================
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -108,6 +108,34 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def alloc_policy_metadata(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-allocate GPU buffers for policy computation.
|
||||
|
||||
Called by the framework after KV cache allocation, but ONLY for GPU-only
|
||||
mode (not CPU offload mode). Override this to pre-allocate buffers that
|
||||
would otherwise be dynamically allocated during forward pass.
|
||||
|
||||
This is separate from initialize() which is used for CPU offload metadata.
|
||||
|
||||
Args:
|
||||
num_heads: Number of query heads
|
||||
num_kv_heads: Number of KV heads (for GQA)
|
||||
head_dim: Dimension per head
|
||||
max_seq_len: Maximum sequence length (for buffer sizing)
|
||||
dtype: Data type (typically float16/bfloat16)
|
||||
device: Target device (cuda)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
@@ -191,6 +219,87 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
# These methods are used when all KV cache is on GPU, no CPU offload needed.
|
||||
# =========================================================================
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute GPU-only prefill attention (non-chunked).
|
||||
|
||||
This method is used when all KV cache resides on GPU (no CPU offload).
|
||||
Override this to implement sparse prefill attention for GPU-only mode.
|
||||
Default implementation raises NotImplementedError.
|
||||
|
||||
Args:
|
||||
q: [total_q, num_heads, head_dim] query tensor (packed variable length)
|
||||
k: [total_kv, num_kv_heads, head_dim] key tensor
|
||||
v: [total_kv, num_kv_heads, head_dim] value tensor
|
||||
cu_seqlens_q: [batch+1] cumulative sequence lengths for queries
|
||||
cu_seqlens_k: [batch+1] cumulative sequence lengths for keys
|
||||
max_seqlen_q: maximum query sequence length
|
||||
max_seqlen_k: maximum key sequence length
|
||||
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
|
||||
layer_id: transformer layer index
|
||||
block_tables: [batch, max_blocks] paged attention block tables (optional)
|
||||
|
||||
Returns:
|
||||
[total_q, num_heads, head_dim] attention output
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement compute_prefill for GPU-only mode"
|
||||
)
|
||||
|
||||
def compute_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute GPU-only decode attention (non-chunked).
|
||||
|
||||
This method is used when all KV cache resides on GPU (no CPU offload).
|
||||
Override this to implement sparse decode attention for GPU-only mode.
|
||||
Default implementation raises NotImplementedError.
|
||||
|
||||
Args:
|
||||
q: [batch, num_heads, head_dim] query tensor (single token per sequence)
|
||||
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged key cache
|
||||
v_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged value cache
|
||||
cache_seqlens: [batch] sequence lengths in cache
|
||||
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
|
||||
layer_id: transformer layer index
|
||||
block_tables: [batch, max_blocks] paged attention block tables (optional)
|
||||
|
||||
Returns:
|
||||
[batch, 1, num_heads, head_dim] attention output
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement compute_decode for GPU-only mode"
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Chunked offload methods (for CPU offload mode)
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
|
||||
@@ -122,6 +122,271 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
self._stats_total_selected_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
# Pre-allocated GQA expansion buffers (GPU-only mode)
|
||||
# Set by alloc_policy_metadata(), None if not pre-allocated
|
||||
self._k_expanded: torch.Tensor | None = None
|
||||
self._v_expanded: torch.Tensor | None = None
|
||||
self._max_seq_len: int = 0
|
||||
|
||||
def alloc_policy_metadata(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-allocate GQA expansion buffers for GPU-only mode.
|
||||
|
||||
These buffers are used by compute_prefill() to avoid dynamic allocation
|
||||
during forward pass. The buffers are sized for max_seq_len and sliced
|
||||
to actual seq_len during use.
|
||||
|
||||
Memory usage: 2 * num_heads * max_seq_len * head_dim * dtype_size
|
||||
For 64K seq, 32 heads, 128 dim, fp16: 2 * 32 * 65536 * 128 * 2 = 1 GB
|
||||
|
||||
Args:
|
||||
num_heads: Number of query heads
|
||||
num_kv_heads: Number of KV heads (for GQA)
|
||||
head_dim: Dimension per head
|
||||
max_seq_len: Maximum sequence length
|
||||
dtype: Data type
|
||||
device: Target device
|
||||
"""
|
||||
# Only allocate if GQA (num_heads != num_kv_heads)
|
||||
if num_heads == num_kv_heads:
|
||||
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
||||
return
|
||||
|
||||
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
|
||||
# Also used for BSA which expects [seq_len, num_heads, head_dim]
|
||||
shape = (1, num_heads, max_seq_len, head_dim)
|
||||
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||
self._max_seq_len = max_seq_len
|
||||
|
||||
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
# =========================================================================
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPU-only prefill attention using XAttention + BSA.
|
||||
|
||||
This method implements sparse attention for GPU-only mode:
|
||||
1. Estimate block importance using xattn_estimate
|
||||
2. Compute sparse attention using block_sparse_attn_func
|
||||
|
||||
Args:
|
||||
q: Query tensor [total_q, num_heads, head_dim] (varlen packed)
|
||||
k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
|
||||
v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
|
||||
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
|
||||
max_seqlen_q: Maximum Q sequence length
|
||||
max_seqlen_k: Maximum K sequence length
|
||||
softmax_scale: Softmax scaling factor
|
||||
layer_id: Transformer layer index
|
||||
block_tables: Paged attention block tables (not used for XAttention)
|
||||
|
||||
Returns:
|
||||
Attention output [total_q, num_heads, head_dim]
|
||||
"""
|
||||
# When block_tables is provided (paged KV cache / prefix cache),
|
||||
# fallback to flash_attn as XAttention expects contiguous K, V
|
||||
if block_tables is not None:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
)
|
||||
|
||||
if not BSA_AVAILABLE:
|
||||
# Fallback to flash attention if BSA not available
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if not XATTN_AVAILABLE:
|
||||
# Fallback to flash attention if xattn not available
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
# Get dimensions
|
||||
total_q, num_heads, head_dim = q.shape
|
||||
total_kv, num_kv_heads, _ = k.shape
|
||||
|
||||
# For now, assume batch_size = 1 (single sequence)
|
||||
# TODO: Support batched varlen format
|
||||
batch_size = cu_seqlens_q.shape[0] - 1
|
||||
if batch_size != 1:
|
||||
# Fallback to flash attention for batched input
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention")
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
q_len = max_seqlen_q
|
||||
k_len = max_seqlen_k
|
||||
|
||||
# Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim]
|
||||
# q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim]
|
||||
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
|
||||
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||
|
||||
# Expand KV for GQA - use pre-allocated buffers if available
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
if self._k_expanded is not None and k_len <= self._max_seq_len:
|
||||
# Use pre-allocated buffers with in-place expansion
|
||||
K_exp = self._k_expanded[:, :, :k_len, :]
|
||||
V_exp = self._v_expanded[:, :, :k_len, :]
|
||||
# In-place GQA expansion: [1, num_kv_heads, k_len, head_dim] -> [1, num_heads, k_len, head_dim]
|
||||
# Reshape K to [1, num_kv_heads, 1, k_len, head_dim] and broadcast to [1, num_kv_heads, num_groups, k_len, head_dim]
|
||||
K_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
|
||||
K.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
|
||||
)
|
||||
V_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
|
||||
V.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
|
||||
)
|
||||
else:
|
||||
# Fallback: dynamic allocation (when buffers not pre-allocated or seq too long)
|
||||
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
|
||||
else:
|
||||
K_exp, V_exp = K, V
|
||||
|
||||
# Estimate block importance and get sparse mask
|
||||
_, mask = xattn_estimate(
|
||||
Q, K_exp,
|
||||
chunk_size=self.chunk_size,
|
||||
block_size=self.BSA_BLOCK_SIZE,
|
||||
threshold=self.threshold,
|
||||
use_triton=self.use_triton,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Compute block counts
|
||||
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
|
||||
# Prepare tensors for BSA
|
||||
# q, k, v need to be [seq_len, num_heads, head_dim]
|
||||
q_bsa = q # Already [q_len, num_heads, head_dim]
|
||||
|
||||
# For GQA with BSA, reuse the expanded K_exp, V_exp (convert to BSA format)
|
||||
# K_exp: [1, num_heads, k_len, head_dim] -> [k_len, num_heads, head_dim]
|
||||
if num_heads != num_kv_heads:
|
||||
k_bsa = K_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
v_bsa = V_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
else:
|
||||
k_bsa = k
|
||||
v_bsa = v
|
||||
|
||||
# Prepare BSA inputs
|
||||
cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device)
|
||||
head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device)
|
||||
|
||||
# Trim mask to actual block counts
|
||||
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
|
||||
|
||||
# Compute sparse attention using BSA
|
||||
output = block_sparse_attn_func(
|
||||
q_bsa, k_bsa, v_bsa,
|
||||
cu_seqlens_q_bsa,
|
||||
cu_seqlens_k_bsa,
|
||||
head_groups,
|
||||
None, # key_padding_mask
|
||||
mask_trimmed,
|
||||
q_len, k_len,
|
||||
p_dropout=0.0,
|
||||
deterministic=True,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Update statistics (layer 0 only to avoid overcounting)
|
||||
if layer_id == 0:
|
||||
selected_blocks = mask_trimmed.sum().item()
|
||||
total_blocks = q_block_num * k_block_num * num_heads
|
||||
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
|
||||
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
|
||||
f"k_blocks={k_block_num}, density={density:.1%}")
|
||||
|
||||
return output
|
||||
|
||||
def compute_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPU-only decode attention - delegates to FullAttentionPolicy.
|
||||
|
||||
XAttention is designed for long prefill sequences. For decode (single token),
|
||||
we use FullAttentionPolicy which calls flash_attn_with_kvcache.
|
||||
"""
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
return FullAttentionPolicy().compute_decode(
|
||||
q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Chunked offload methods
|
||||
# =========================================================================
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
|
||||
@@ -124,24 +124,47 @@ class Attention(nn.Module):
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
|
||||
# Get sparse_policy from kvcache_manager (required, never None after warmup)
|
||||
# During warmup, kvcache_manager is not yet allocated
|
||||
if context.kvcache_manager is None:
|
||||
# Warmup phase: use flash_attn directly
|
||||
if context.is_prefill:
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True,
|
||||
)
|
||||
else:
|
||||
return flash_attn_with_kvcache(
|
||||
q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True,
|
||||
)
|
||||
sparse_policy = context.kvcache_manager.sparse_policy
|
||||
assert sparse_policy is not None, "sparse_policy must not be None"
|
||||
|
||||
if context.is_prefill:
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked prefill: merge attention from previous KV
|
||||
# Chunked prefill: merge attention from previous KV (CPU offload mode)
|
||||
o = self._chunked_prefill_attention(q, k, v, context)
|
||||
elif context.block_tables is not None: # prefix cache
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else:
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
# GPU-only mode: use policy for attention
|
||||
# Use paged attention if block_tables provided, else use k, v directly
|
||||
if context.block_tables is not None:
|
||||
k_for_attn, v_for_attn = k_cache, v_cache
|
||||
else:
|
||||
k_for_attn, v_for_attn = k, v
|
||||
o = sparse_policy.compute_prefill(
|
||||
q, k_for_attn, v_for_attn,
|
||||
context.cu_seqlens_q, context.cu_seqlens_k,
|
||||
context.max_seqlen_q, context.max_seqlen_k,
|
||||
self.scale, self.layer_id,
|
||||
context.block_tables,
|
||||
)
|
||||
else: # decode
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked decode: need to load all KV from CPU+GPU
|
||||
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
|
||||
# Store current decode token to per-layer decode buffer
|
||||
# This is needed because GPU cache has no layer dimension,
|
||||
# so all layers would overwrite each other in decode_slot.
|
||||
@@ -152,9 +175,12 @@ class Attention(nn.Module):
|
||||
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
||||
o = self._chunked_decode_attention(q, k, v, context)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
# GPU-only mode: use policy for attention
|
||||
o = sparse_policy.compute_decode(
|
||||
q, k_cache, v_cache,
|
||||
context.context_lens, self.scale, self.layer_id,
|
||||
context.block_tables,
|
||||
)
|
||||
return o
|
||||
|
||||
def _chunked_prefill_attention(
|
||||
|
||||
Reference in New Issue
Block a user