✨ feat: integrate sparse policy architecture into GPU-only mode
- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class - Implement GPU-only methods in FullAttentionPolicy using flash_attn - Add sparse_policy parameter to GPUOnlyManager - Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode - Route GPU-only attention through sparse_policy in attention.py - Pass kvcache_manager to context for policy access - Add --enable-policy flag to bench.py for testing - Handle warmup phase when kvcache_manager is not yet allocated This allows GPU-only mode to use the same policy architecture as CPU offload mode, enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode. Performance verified: ~4890 tok/s (unchanged from baseline) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -195,19 +195,23 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Initialize sparse policy if manager has one (CPU offload mode)
|
||||
# Initialize sparse policy if manager has one (works for both CPU offload and GPU-only modes)
|
||||
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||
# Use CPU blocks for offload mode, GPU blocks for GPU-only mode
|
||||
num_blocks_for_init = config.num_cpu_kvcache_blocks if config.enable_cpu_offload else config.num_kvcache_blocks
|
||||
self.kvcache_manager.sparse_policy.initialize(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
||||
num_cpu_blocks=num_blocks_for_init,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# Log policy info (handle both enum and None cases)
|
||||
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
||||
logger.info(
|
||||
f"Sparse policy initialized: {config.sparse_policy.name} "
|
||||
f"Sparse policy initialized: {policy_name} "
|
||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||
)
|
||||
|
||||
@@ -368,7 +372,16 @@ class ModelRunner:
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
kvcache_manager=getattr(self, 'kvcache_manager', None),
|
||||
)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
@@ -397,7 +410,13 @@ class ModelRunner:
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
# Use GPU physical block tables for attention
|
||||
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
||||
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
)
|
||||
return input_ids, positions
|
||||
|
||||
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
|
||||
@@ -698,7 +717,13 @@ class ModelRunner:
|
||||
|
||||
for bs in reversed(self.graph_bs):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping[:bs],
|
||||
context_lens=context_lens[:bs],
|
||||
block_tables=block_tables[:bs],
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
)
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||
with torch.cuda.graph(graph, self.graph_pool):
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||
|
||||
Reference in New Issue
Block a user