feat: integrate sparse policy architecture into GPU-only mode

- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class
- Implement GPU-only methods in FullAttentionPolicy using flash_attn
- Add sparse_policy parameter to GPUOnlyManager
- Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode
- Route GPU-only attention through sparse_policy in attention.py
- Pass kvcache_manager to context for policy access
- Add --enable-policy flag to bench.py for testing
- Handle warmup phase when kvcache_manager is not yet allocated

This allows GPU-only mode to use the same policy architecture as CPU offload mode,
enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode.

Performance verified: ~4890 tok/s (unchanged from baseline)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 05:08:02 +08:00
parent 05ce57ee8e
commit 09b2136e9f
7 changed files with 287 additions and 25 deletions

View File

@@ -195,19 +195,23 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Initialize sparse policy if manager has one (CPU offload mode)
# Initialize sparse policy if manager has one (works for both CPU offload and GPU-only modes)
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
# Use CPU blocks for offload mode, GPU blocks for GPU-only mode
num_blocks_for_init = config.num_cpu_kvcache_blocks if config.enable_cpu_offload else config.num_kvcache_blocks
self.kvcache_manager.sparse_policy.initialize(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_cpu_blocks=config.num_cpu_kvcache_blocks,
num_cpu_blocks=num_blocks_for_init,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
# Log policy info (handle both enum and None cases)
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
logger.info(
f"Sparse policy initialized: {config.sparse_policy.name} "
f"Sparse policy initialized: {policy_name} "
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
)
@@ -368,7 +372,16 @@ class ModelRunner:
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
slot_mapping=slot_mapping,
block_tables=block_tables,
kvcache_manager=getattr(self, 'kvcache_manager', None),
)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
@@ -397,7 +410,13 @@ class ModelRunner:
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Use GPU physical block tables for attention
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
kvcache_manager=self.kvcache_manager,
)
return input_ids, positions
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
@@ -698,7 +717,13 @@ class ModelRunner:
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
set_context(
is_prefill=False,
slot_mapping=slot_mapping[:bs],
context_lens=context_lens[:bs],
block_tables=block_tables[:bs],
kvcache_manager=self.kvcache_manager,
)
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture