[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -1,20 +1,21 @@
"""
Full attention policy - loads all blocks (no sparsity).
Full attention policy - standard FlashAttention without sparsity.
This serves as a baseline and default policy when sparse
attention is not needed.
"""
from typing import List
from .policy import SparsePolicy, PolicyContext
from typing import Optional
import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(SparsePolicy):
class FullAttentionPolicy(AttentionPolicy):
"""
Full attention policy that loads all available blocks.
Full attention policy using FlashAttention (no sparsity).
This is the default behavior with no sparsity - all previous
KV cache blocks are loaded for each query chunk.
This is the default behavior with standard causal attention.
All tokens attend to all previous tokens.
Use this as:
- A baseline for comparing sparse policies
@@ -25,15 +26,55 @@ class FullAttentionPolicy(SparsePolicy):
# Full attention supports both prefill and decode
supports_prefill = True
supports_decode = True
requires_block_selection = False # Load all blocks, no selective loading
def select_blocks(
def estimate(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""Return all blocks - no sparsity."""
return available_blocks
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Full attention - no sparse mask needed.
Returns None to indicate full attention should be used.
"""
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self) -> str:
return "FullAttentionPolicy()"