✨ feat: integrate sparse policy architecture into GPU-only mode
- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class - Implement GPU-only methods in FullAttentionPolicy using flash_attn - Add sparse_policy parameter to GPUOnlyManager - Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode - Route GPU-only attention through sparse_policy in attention.py - Pass kvcache_manager to context for policy access - Add --enable-policy flag to bench.py for testing - Handle warmup phase when kvcache_manager is not yet allocated This allows GPU-only mode to use the same policy architecture as CPU offload mode, enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode. Performance verified: ~4890 tok/s (unchanged from baseline) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -76,6 +76,75 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
||||
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
# =========================================================================
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPU-only prefill attention using flash_attn_varlen_func.
|
||||
|
||||
This is the simplest implementation - just call flash attention directly.
|
||||
For sparse policies, this method would implement block selection.
|
||||
"""
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
)
|
||||
|
||||
def compute_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPU-only decode attention using flash_attn_with_kvcache.
|
||||
|
||||
This is the simplest implementation - just call flash attention directly.
|
||||
For sparse policies, this method would implement block selection.
|
||||
"""
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
|
||||
# q is [batch, num_heads, head_dim], need to add seq dim
|
||||
return flash_attn_with_kvcache(
|
||||
q.unsqueeze(1), # [batch, 1, heads, dim]
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_seqlens=cache_seqlens,
|
||||
block_table=block_tables,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Chunked offload methods
|
||||
# =========================================================================
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user