🔀 merge: integrate tzj/minference-exp (GPU-only sparse attention)
Merge GPU-only sparse attention support from tzj/minference-exp branch: **GPU-only mode additions:** - Add compute_prefill/compute_decode methods to SparsePolicy base class - Add GPU-only attention routing in attention.py - Add alloc_policy_metadata() for pre-allocating GQA buffers - Add XAttention + BSA sparse attention for GPU-only prefill - Add kvcache_manager to set_context() for policy access **bench.py enhancements:** - Add --model argument for configurable model path - Add --policy argument (full, xattn) for sparse policy selection - Add --enable-policy flag for FullAttentionPolicy routing - Add --enforce-eager option to disable CUDA graphs - Add --gpu-util option for GPU memory utilization **Documentation:** - Add gpu_only_xattn_guide.md with performance analysis - Add gpu_only_sparse_integration.md baseline document - Add gpu-vram-requirement.md rule for GPU-only mode Both CPU offload and GPU-only paths are preserved and functional. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -202,19 +202,36 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Initialize sparse policy if manager has one (CPU offload mode)
|
||||
# Initialize sparse policy if manager has one (works for both CPU offload and GPU-only modes)
|
||||
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||
# Use CPU blocks for offload mode, GPU blocks for GPU-only mode
|
||||
num_blocks_for_init = config.num_cpu_kvcache_blocks if config.enable_cpu_offload else config.num_kvcache_blocks
|
||||
self.kvcache_manager.sparse_policy.initialize(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
||||
num_cpu_blocks=num_blocks_for_init,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# GPU-only mode: pre-allocate policy metadata buffers
|
||||
# This avoids dynamic GPU memory allocation during forward pass
|
||||
if not config.enable_cpu_offload:
|
||||
num_heads = hf_config.num_attention_heads // self.world_size
|
||||
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
max_seq_len=config.max_model_len,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# Log policy info (handle both enum and None cases)
|
||||
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
||||
logger.info(
|
||||
f"Sparse policy initialized: {config.sparse_policy.name} "
|
||||
f"Sparse policy initialized: {policy_name} "
|
||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||
)
|
||||
|
||||
@@ -375,7 +392,16 @@ class ModelRunner:
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
kvcache_manager=getattr(self, 'kvcache_manager', None),
|
||||
)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
@@ -404,7 +430,13 @@ class ModelRunner:
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
# Use GPU physical block tables for attention
|
||||
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
||||
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
)
|
||||
return input_ids, positions
|
||||
|
||||
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
|
||||
@@ -713,7 +745,13 @@ class ModelRunner:
|
||||
|
||||
for bs in reversed(self.graph_bs):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping[:bs],
|
||||
context_lens=context_lens[:bs],
|
||||
block_tables=block_tables[:bs],
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
)
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||
with torch.cuda.graph(graph, self.graph_pool):
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||
|
||||
Reference in New Issue
Block a user