feat: integrate sparse policy architecture into GPU-only mode

- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class
- Implement GPU-only methods in FullAttentionPolicy using flash_attn
- Add sparse_policy parameter to GPUOnlyManager
- Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode
- Route GPU-only attention through sparse_policy in attention.py
- Pass kvcache_manager to context for policy access
- Add --enable-policy flag to bench.py for testing
- Handle warmup phase when kvcache_manager is not yet allocated

This allows GPU-only mode to use the same policy architecture as CPU offload mode,
enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode.

Performance verified: ~4890 tok/s (unchanged from baseline)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 05:08:02 +08:00
parent 05ce57ee8e
commit 09b2136e9f
7 changed files with 287 additions and 25 deletions

View File

@@ -124,24 +124,47 @@ class Attention(nn.Module):
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Get sparse_policy from kvcache_manager (required, never None after warmup)
# During warmup, kvcache_manager is not yet allocated
if context.kvcache_manager is None:
# Warmup phase: use flash_attn directly
if context.is_prefill:
return flash_attn_varlen_func(
q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True,
)
else:
return flash_attn_with_kvcache(
q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True,
)
sparse_policy = context.kvcache_manager.sparse_policy
assert sparse_policy is not None, "sparse_policy must not be None"
if context.is_prefill:
if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV
# Chunked prefill: merge attention from previous KV (CPU offload mode)
o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else:
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
# GPU-only mode: use policy for attention
# Use paged attention if block_tables provided, else use k, v directly
if context.block_tables is not None:
k_for_attn, v_for_attn = k_cache, v_cache
else:
k_for_attn, v_for_attn = k, v
o = sparse_policy.compute_prefill(
q, k_for_attn, v_for_attn,
context.cu_seqlens_q, context.cu_seqlens_k,
context.max_seqlen_q, context.max_seqlen_k,
self.scale, self.layer_id,
context.block_tables,
)
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
# Store current decode token to per-layer decode buffer
# This is needed because GPU cache has no layer dimension,
# so all layers would overwrite each other in decode_slot.
@@ -152,9 +175,12 @@ class Attention(nn.Module):
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
# GPU-only mode: use policy for attention
o = sparse_policy.compute_decode(
q, k_cache, v_cache,
context.context_lens, self.scale, self.layer_id,
context.block_tables,
)
return o
def _chunked_prefill_attention(