⚡️ perf: replace Triton merge with FlashInfer merge_state
Use FlashInfer's optimized merge_state kernel for attention output merging in chunked prefill. End-to-end improvement: +0.8% (32K) to +2.4% (64K). Key changes: - Add merge_attention_outputs_flashinfer() with LSE format conversion - FlashInfer uses log2, flash_attn uses ln: convert via LOG2_E/LN_2 - Keep original Triton kernel for fallback Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -185,7 +185,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
# Use FlashInfer-based implementations (more optimized)
|
||||
from nanovllm.ops.chunked_attention import (
|
||||
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||
)
|
||||
|
||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}, "
|
||||
@@ -313,7 +317,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
Returns:
|
||||
Attention output [batch_size, 1, num_heads, head_dim]
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
# Use FlashInfer-based implementations (more optimized)
|
||||
from nanovllm.ops.chunked_attention import (
|
||||
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||
)
|
||||
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
@@ -405,7 +413,11 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
Loads one block at a time, computes attention, and merges results.
|
||||
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
# Use FlashInfer-based implementations (more optimized)
|
||||
from nanovllm.ops.chunked_attention import (
|
||||
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
|
||||
merge_attention_outputs_flashinfer as merge_attention_outputs,
|
||||
)
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
|
||||
Reference in New Issue
Block a user