🔀 merge: integrate tzj/minference-exp (GPU-only sparse attention)

Merge GPU-only sparse attention support from tzj/minference-exp branch:

**GPU-only mode additions:**
- Add compute_prefill/compute_decode methods to SparsePolicy base class
- Add GPU-only attention routing in attention.py
- Add alloc_policy_metadata() for pre-allocating GQA buffers
- Add XAttention + BSA sparse attention for GPU-only prefill
- Add kvcache_manager to set_context() for policy access

**bench.py enhancements:**
- Add --model argument for configurable model path
- Add --policy argument (full, xattn) for sparse policy selection
- Add --enable-policy flag for FullAttentionPolicy routing
- Add --enforce-eager option to disable CUDA graphs
- Add --gpu-util option for GPU memory utilization

**Documentation:**
- Add gpu_only_xattn_guide.md with performance analysis
- Add gpu_only_sparse_integration.md baseline document
- Add gpu-vram-requirement.md rule for GPU-only mode

Both CPU offload and GPU-only paths are preserved and functional.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-27 09:25:36 +08:00
14 changed files with 1228 additions and 27 deletions

View File

@@ -124,24 +124,47 @@ class Attention(nn.Module):
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Get sparse_policy from kvcache_manager (required, never None after warmup)
# During warmup, kvcache_manager is not yet allocated
if context.kvcache_manager is None:
# Warmup phase: use flash_attn directly
if context.is_prefill:
return flash_attn_varlen_func(
q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True,
)
else:
return flash_attn_with_kvcache(
q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True,
)
sparse_policy = context.kvcache_manager.sparse_policy
assert sparse_policy is not None, "sparse_policy must not be None"
if context.is_prefill:
if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV
# Chunked prefill: merge attention from previous KV (CPU offload mode)
o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else:
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
# GPU-only mode: use policy for attention
# Use paged attention if block_tables provided, else use k, v directly
if context.block_tables is not None:
k_for_attn, v_for_attn = k_cache, v_cache
else:
k_for_attn, v_for_attn = k, v
o = sparse_policy.compute_prefill(
q, k_for_attn, v_for_attn,
context.cu_seqlens_q, context.cu_seqlens_k,
context.max_seqlen_q, context.max_seqlen_k,
self.scale, self.layer_id,
context.block_tables,
)
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
# Store current decode token to per-layer decode buffer
# This is needed because GPU cache has no layer dimension,
# so all layers would overwrite each other in decode_slot.
@@ -152,9 +175,12 @@ class Attention(nn.Module):
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
# GPU-only mode: use policy for attention
o = sparse_policy.compute_decode(
q, k_cache, v_cache,
context.context_lens, self.scale, self.layer_id,
context.block_tables,
)
return o
def _chunked_prefill_attention(