Files
nano-vllm/nanovllm/kvcache/sparse/full_policy.py
2026-01-22 22:20:34 +08:00

81 lines
2.2 KiB
Python

"""
Full attention policy - standard FlashAttention without sparsity.
This serves as a baseline and default policy when sparse
attention is not needed.
"""
from typing import Optional
import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(AttentionPolicy):
"""
Full attention policy using FlashAttention (no sparsity).
This is the default behavior with standard causal attention.
All tokens attend to all previous tokens.
Use this as:
- A baseline for comparing sparse policies
- When you need full attention accuracy
- For short sequences where sparsity isn't beneficial
"""
# Full attention supports both prefill and decode
supports_prefill = True
supports_decode = True
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Full attention - no sparse mask needed.
Returns None to indicate full attention should be used.
"""
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self) -> str:
return "FullAttentionPolicy()"