[claudesquad] update from 'layer-prefill-1' on 08 Jan 26 03:36 CST

This commit is contained in:
Zijie Tian
2026-01-08 03:36:39 +08:00
parent 6575099a06
commit d8a87da1c3
10 changed files with 822 additions and 32 deletions

View File

@@ -183,5 +183,32 @@ class SparsePolicy(ABC):
"""
pass
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute sparse attention for prefill phase.
This method is called when supports_prefill=True and the policy
is used for GPU-only sparse prefill (no CPU offload).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
"Set supports_prefill=False or implement this method."
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"