🔀 merge: integrate tzj/minference-exp (GPU-only sparse attention)
Merge GPU-only sparse attention support from tzj/minference-exp branch: **GPU-only mode additions:** - Add compute_prefill/compute_decode methods to SparsePolicy base class - Add GPU-only attention routing in attention.py - Add alloc_policy_metadata() for pre-allocating GQA buffers - Add XAttention + BSA sparse attention for GPU-only prefill - Add kvcache_manager to set_context() for policy access **bench.py enhancements:** - Add --model argument for configurable model path - Add --policy argument (full, xattn) for sparse policy selection - Add --enable-policy flag for FullAttentionPolicy routing - Add --enforce-eager option to disable CUDA graphs - Add --gpu-util option for GPU memory utilization **Documentation:** - Add gpu_only_xattn_guide.md with performance analysis - Add gpu_only_sparse_integration.md baseline document - Add gpu-vram-requirement.md rule for GPU-only mode Both CPU offload and GPU-only paths are preserved and functional. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -124,24 +124,47 @@ class Attention(nn.Module):
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
|
||||
# Get sparse_policy from kvcache_manager (required, never None after warmup)
|
||||
# During warmup, kvcache_manager is not yet allocated
|
||||
if context.kvcache_manager is None:
|
||||
# Warmup phase: use flash_attn directly
|
||||
if context.is_prefill:
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True,
|
||||
)
|
||||
else:
|
||||
return flash_attn_with_kvcache(
|
||||
q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True,
|
||||
)
|
||||
sparse_policy = context.kvcache_manager.sparse_policy
|
||||
assert sparse_policy is not None, "sparse_policy must not be None"
|
||||
|
||||
if context.is_prefill:
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked prefill: merge attention from previous KV
|
||||
# Chunked prefill: merge attention from previous KV (CPU offload mode)
|
||||
o = self._chunked_prefill_attention(q, k, v, context)
|
||||
elif context.block_tables is not None: # prefix cache
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else:
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
# GPU-only mode: use policy for attention
|
||||
# Use paged attention if block_tables provided, else use k, v directly
|
||||
if context.block_tables is not None:
|
||||
k_for_attn, v_for_attn = k_cache, v_cache
|
||||
else:
|
||||
k_for_attn, v_for_attn = k, v
|
||||
o = sparse_policy.compute_prefill(
|
||||
q, k_for_attn, v_for_attn,
|
||||
context.cu_seqlens_q, context.cu_seqlens_k,
|
||||
context.max_seqlen_q, context.max_seqlen_k,
|
||||
self.scale, self.layer_id,
|
||||
context.block_tables,
|
||||
)
|
||||
else: # decode
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked decode: need to load all KV from CPU+GPU
|
||||
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
|
||||
# Store current decode token to per-layer decode buffer
|
||||
# This is needed because GPU cache has no layer dimension,
|
||||
# so all layers would overwrite each other in decode_slot.
|
||||
@@ -152,9 +175,12 @@ class Attention(nn.Module):
|
||||
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
||||
o = self._chunked_decode_attention(q, k, v, context)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
# GPU-only mode: use policy for attention
|
||||
o = sparse_policy.compute_decode(
|
||||
q, k_cache, v_cache,
|
||||
context.context_lens, self.scale, self.layer_id,
|
||||
context.block_tables,
|
||||
)
|
||||
return o
|
||||
|
||||
def _chunked_prefill_attention(
|
||||
|
||||
Reference in New Issue
Block a user