feat: integrate sparse policy architecture into GPU-only mode

- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class
- Implement GPU-only methods in FullAttentionPolicy using flash_attn
- Add sparse_policy parameter to GPUOnlyManager
- Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode
- Route GPU-only attention through sparse_policy in attention.py
- Pass kvcache_manager to context for policy access
- Add --enable-policy flag to bench.py for testing
- Handle warmup phase when kvcache_manager is not yet allocated

This allows GPU-only mode to use the same policy architecture as CPU offload mode,
enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode.

Performance verified: ~4890 tok/s (unchanged from baseline)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 05:08:02 +08:00
parent 05ce57ee8e
commit 09b2136e9f
7 changed files with 287 additions and 25 deletions

View File

@@ -76,6 +76,75 @@ class FullAttentionPolicy(SparsePolicy):
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
f"blocks={stats['total_available_blocks']}, density=100.0%")
# =========================================================================
# GPU-only methods (non-chunked)
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
GPU-only prefill attention using flash_attn_varlen_func.
This is the simplest implementation - just call flash attention directly.
For sparse policies, this method would implement block selection.
"""
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
block_table=block_tables,
)
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
GPU-only decode attention using flash_attn_with_kvcache.
This is the simplest implementation - just call flash attention directly.
For sparse policies, this method would implement block selection.
"""
from flash_attn import flash_attn_with_kvcache
# q is [batch, num_heads, head_dim], need to add seq dim
return flash_attn_with_kvcache(
q.unsqueeze(1), # [batch, 1, heads, dim]
k_cache,
v_cache,
cache_seqlens=cache_seqlens,
block_table=block_tables,
softmax_scale=softmax_scale,
causal=True,
)
# =========================================================================
# Chunked offload methods
# =========================================================================
def compute_chunked_prefill(
self,
q: torch.Tensor,

View File

@@ -191,6 +191,87 @@ class SparsePolicy(ABC):
"""
pass
# =========================================================================
# GPU-only methods (non-chunked)
# These methods are used when all KV cache is on GPU, no CPU offload needed.
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute GPU-only prefill attention (non-chunked).
This method is used when all KV cache resides on GPU (no CPU offload).
Override this to implement sparse prefill attention for GPU-only mode.
Default implementation raises NotImplementedError.
Args:
q: [total_q, num_heads, head_dim] query tensor (packed variable length)
k: [total_kv, num_kv_heads, head_dim] key tensor
v: [total_kv, num_kv_heads, head_dim] value tensor
cu_seqlens_q: [batch+1] cumulative sequence lengths for queries
cu_seqlens_k: [batch+1] cumulative sequence lengths for keys
max_seqlen_q: maximum query sequence length
max_seqlen_k: maximum key sequence length
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
layer_id: transformer layer index
block_tables: [batch, max_blocks] paged attention block tables (optional)
Returns:
[total_q, num_heads, head_dim] attention output
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement compute_prefill for GPU-only mode"
)
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute GPU-only decode attention (non-chunked).
This method is used when all KV cache resides on GPU (no CPU offload).
Override this to implement sparse decode attention for GPU-only mode.
Default implementation raises NotImplementedError.
Args:
q: [batch, num_heads, head_dim] query tensor (single token per sequence)
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged key cache
v_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged value cache
cache_seqlens: [batch] sequence lengths in cache
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
layer_id: transformer layer index
block_tables: [batch, max_blocks] paged attention block tables (optional)
Returns:
[batch, 1, num_heads, head_dim] attention output
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement compute_decode for GPU-only mode"
)
# =========================================================================
# Chunked offload methods (for CPU offload mode)
# =========================================================================
@abstractmethod
def compute_chunked_prefill(
self,