81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
"""
|
|
Full attention policy - standard FlashAttention without sparsity.
|
|
|
|
This serves as a baseline and default policy when sparse
|
|
attention is not needed.
|
|
"""
|
|
|
|
from typing import Optional
|
|
import torch
|
|
from .policy import AttentionPolicy
|
|
|
|
|
|
class FullAttentionPolicy(AttentionPolicy):
|
|
"""
|
|
Full attention policy using FlashAttention (no sparsity).
|
|
|
|
This is the default behavior with standard causal attention.
|
|
All tokens attend to all previous tokens.
|
|
|
|
Use this as:
|
|
- A baseline for comparing sparse policies
|
|
- When you need full attention accuracy
|
|
- For short sequences where sparsity isn't beneficial
|
|
"""
|
|
|
|
# Full attention supports both prefill and decode
|
|
supports_prefill = True
|
|
supports_decode = True
|
|
|
|
def estimate(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
layer_id: int,
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
Full attention - no sparse mask needed.
|
|
|
|
Returns None to indicate full attention should be used.
|
|
"""
|
|
return None
|
|
|
|
def compute_prefill(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer_id: int,
|
|
softmax_scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute full causal attention using FlashAttention.
|
|
|
|
Args:
|
|
q: Query tensor [seq_len, num_heads, head_dim]
|
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
layer_id: Transformer layer index
|
|
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
|
|
|
Returns:
|
|
Attention output [seq_len, num_heads, head_dim]
|
|
"""
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
|
|
seq_len = q.shape[0]
|
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
|
|
return flash_attn_varlen_func(
|
|
q, k, v,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=seq_len,
|
|
max_seqlen_k=seq_len,
|
|
softmax_scale=softmax_scale,
|
|
causal=True,
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return "FullAttentionPolicy()"
|