✨ feat: integrate sparse policy architecture into GPU-only mode
- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class - Implement GPU-only methods in FullAttentionPolicy using flash_attn - Add sparse_policy parameter to GPUOnlyManager - Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode - Route GPU-only attention through sparse_policy in attention.py - Pass kvcache_manager to context for policy access - Add --enable-policy flag to bench.py for testing - Handle warmup phase when kvcache_manager is not yet allocated This allows GPU-only mode to use the same policy architecture as CPU offload mode, enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode. Performance verified: ~4890 tok/s (unchanged from baseline) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -124,24 +124,47 @@ class Attention(nn.Module):
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
|
||||
# Get sparse_policy from kvcache_manager (required, never None after warmup)
|
||||
# During warmup, kvcache_manager is not yet allocated
|
||||
if context.kvcache_manager is None:
|
||||
# Warmup phase: use flash_attn directly
|
||||
if context.is_prefill:
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True,
|
||||
)
|
||||
else:
|
||||
return flash_attn_with_kvcache(
|
||||
q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True,
|
||||
)
|
||||
sparse_policy = context.kvcache_manager.sparse_policy
|
||||
assert sparse_policy is not None, "sparse_policy must not be None"
|
||||
|
||||
if context.is_prefill:
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked prefill: merge attention from previous KV
|
||||
# Chunked prefill: merge attention from previous KV (CPU offload mode)
|
||||
o = self._chunked_prefill_attention(q, k, v, context)
|
||||
elif context.block_tables is not None: # prefix cache
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else:
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
# GPU-only mode: use policy for attention
|
||||
# Use paged attention if block_tables provided, else use k, v directly
|
||||
if context.block_tables is not None:
|
||||
k_for_attn, v_for_attn = k_cache, v_cache
|
||||
else:
|
||||
k_for_attn, v_for_attn = k, v
|
||||
o = sparse_policy.compute_prefill(
|
||||
q, k_for_attn, v_for_attn,
|
||||
context.cu_seqlens_q, context.cu_seqlens_k,
|
||||
context.max_seqlen_q, context.max_seqlen_k,
|
||||
self.scale, self.layer_id,
|
||||
context.block_tables,
|
||||
)
|
||||
else: # decode
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked decode: need to load all KV from CPU+GPU
|
||||
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
|
||||
# Store current decode token to per-layer decode buffer
|
||||
# This is needed because GPU cache has no layer dimension,
|
||||
# so all layers would overwrite each other in decode_slot.
|
||||
@@ -152,9 +175,12 @@ class Attention(nn.Module):
|
||||
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
||||
o = self._chunked_decode_attention(q, k, v, context)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
# GPU-only mode: use policy for attention
|
||||
o = sparse_policy.compute_decode(
|
||||
q, k_cache, v_cache,
|
||||
context.context_lens, self.scale, self.layer_id,
|
||||
context.block_tables,
|
||||
)
|
||||
return o
|
||||
|
||||
def _chunked_prefill_attention(
|
||||
|
||||
Reference in New Issue
Block a user