✨ feat: integrate sparse policy architecture into GPU-only mode
- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class - Implement GPU-only methods in FullAttentionPolicy using flash_attn - Add sparse_policy parameter to GPUOnlyManager - Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode - Route GPU-only attention through sparse_policy in attention.py - Pass kvcache_manager to context for policy access - Add --enable-policy flag to bench.py for testing - Handle warmup phase when kvcache_manager is not yet allocated This allows GPU-only mode to use the same policy architecture as CPU offload mode, enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode. Performance verified: ~4890 tok/s (unchanged from baseline) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -191,6 +191,87 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
# These methods are used when all KV cache is on GPU, no CPU offload needed.
|
||||
# =========================================================================
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute GPU-only prefill attention (non-chunked).
|
||||
|
||||
This method is used when all KV cache resides on GPU (no CPU offload).
|
||||
Override this to implement sparse prefill attention for GPU-only mode.
|
||||
Default implementation raises NotImplementedError.
|
||||
|
||||
Args:
|
||||
q: [total_q, num_heads, head_dim] query tensor (packed variable length)
|
||||
k: [total_kv, num_kv_heads, head_dim] key tensor
|
||||
v: [total_kv, num_kv_heads, head_dim] value tensor
|
||||
cu_seqlens_q: [batch+1] cumulative sequence lengths for queries
|
||||
cu_seqlens_k: [batch+1] cumulative sequence lengths for keys
|
||||
max_seqlen_q: maximum query sequence length
|
||||
max_seqlen_k: maximum key sequence length
|
||||
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
|
||||
layer_id: transformer layer index
|
||||
block_tables: [batch, max_blocks] paged attention block tables (optional)
|
||||
|
||||
Returns:
|
||||
[total_q, num_heads, head_dim] attention output
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement compute_prefill for GPU-only mode"
|
||||
)
|
||||
|
||||
def compute_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute GPU-only decode attention (non-chunked).
|
||||
|
||||
This method is used when all KV cache resides on GPU (no CPU offload).
|
||||
Override this to implement sparse decode attention for GPU-only mode.
|
||||
Default implementation raises NotImplementedError.
|
||||
|
||||
Args:
|
||||
q: [batch, num_heads, head_dim] query tensor (single token per sequence)
|
||||
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged key cache
|
||||
v_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged value cache
|
||||
cache_seqlens: [batch] sequence lengths in cache
|
||||
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
|
||||
layer_id: transformer layer index
|
||||
block_tables: [batch, max_blocks] paged attention block tables (optional)
|
||||
|
||||
Returns:
|
||||
[batch, 1, num_heads, head_dim] attention output
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement compute_decode for GPU-only mode"
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Chunked offload methods (for CPU offload mode)
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user