""" Full attention policy - standard FlashAttention without sparsity. This serves as a baseline and default policy when sparse attention is not needed. """ from typing import Optional import torch from .policy import AttentionPolicy class FullAttentionPolicy(AttentionPolicy): """ Full attention policy using FlashAttention (no sparsity). This is the default behavior with standard causal attention. All tokens attend to all previous tokens. Use this as: - A baseline for comparing sparse policies - When you need full attention accuracy - For short sequences where sparsity isn't beneficial """ # Full attention supports both prefill and decode supports_prefill = True supports_decode = True def estimate( self, q: torch.Tensor, k: torch.Tensor, layer_id: int, ) -> Optional[torch.Tensor]: """ Full attention - no sparse mask needed. Returns None to indicate full attention should be used. """ return None def compute_prefill( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, ) -> torch.Tensor: """ Compute full causal attention using FlashAttention. Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim] layer_id: Transformer layer index softmax_scale: Softmax scaling factor (1/sqrt(head_dim)) Returns: Attention output [seq_len, num_heads, head_dim] """ from flash_attn.flash_attn_interface import flash_attn_varlen_func seq_len = q.shape[0] cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, softmax_scale=softmax_scale, causal=True, ) def __repr__(self) -> str: return "FullAttentionPolicy()"