""" Utility functions for sparse attention policies. Copied from COMPASS/compass/src/utils.py for XAttention integration. """ import torch def find_blocks_chunked( input_tensor, current_index, threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True ): """ Finds and selects relevant blocks of attention for transformer-based models based on a threshold or a predefined number of blocks. Parameters: - input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num). - current_index (int): The current index in the sequence processing. - threshold (float or None): A threshold value used to determine the minimum attention weight sum. - num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval. - decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode. - mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'. - causal (bool): If True, applies causal masking to prevent future information leakage. Returns: - torch.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num), indicating which blocks should be attended to. """ assert threshold is None or num_to_choose is None batch_size, head_num, chunk_num, block_num = input_tensor.shape if mode == "prefill" and decoding: return torch.ones_like(input_tensor, dtype=torch.bool) if mode == "decode" and not decoding: mask = torch.ones_like(input_tensor, dtype=torch.bool) if causal: mask[:, :, :, current_index : current_index + chunk_num] = torch.tril( torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device) ) mask[:, :, current_index + chunk_num :, :] = 0 return torch.cat( [ torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1], torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :], ], dim=-1, ) else: return mask input_tensor = input_tensor.to(float) if threshold is not None: total_sum = input_tensor.sum(dim=-1, keepdim=True) if isinstance(threshold, torch.Tensor): threshold = threshold.to(float) required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze( -1 ).expand((batch_size, head_num, chunk_num, 1)).to(input_tensor.device) else: required_sum = total_sum * threshold if causal: mask = torch.zeros_like(input_tensor, dtype=torch.bool) mask[:, :, :, 0] = 1 mask[:, :, :, current_index : current_index + chunk_num] = ( torch.eye(chunk_num, device=mask.device) .unsqueeze(0) .unsqueeze(0) .expand(1, head_num, chunk_num, chunk_num) ) other_values = input_tensor.masked_fill(mask, 0) sorted_values, _ = torch.sort( other_values, dim=-1, descending=True ) sorted_values = sorted_values.to(input_tensor.device) sorted_values = torch.cat( [ torch.zeros( (batch_size, head_num, chunk_num, 1), device=input_tensor.device ), torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True), sorted_values[:, :, :, :-2], ], dim=-1, ) _, index = torch.sort( torch.where(mask, 100000 * (1 + input_tensor), input_tensor), dim=-1, descending=True ) cumulative_sum_without_self = torch.cat( [ torch.zeros( (batch_size, head_num, chunk_num, 1), device=input_tensor.device ), sorted_values[:, :, :, 0:-1], ], dim=-1, ).cumsum(dim=-1) index_mask = cumulative_sum_without_self < required_sum index = torch.where(index_mask, index, 0) mask = mask.view(batch_size, head_num * chunk_num, block_num) index = index.view(batch_size, head_num * chunk_num, block_num) mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True mask = mask.view(batch_size, head_num, chunk_num, block_num) else: mask = torch.zeros_like(input_tensor, dtype=torch.bool) sorted_values, index = torch.sort( input_tensor, dim=-1, descending=True ) sorted_values = sorted_values.to(input_tensor.device) cumulative_sum_without_self = torch.cat( [ torch.zeros( (batch_size, head_num, chunk_num, 1), device=input_tensor.device ), sorted_values[:, :, :, 0:-1], ], dim=-1, ).cumsum(dim=-1) index_mask = cumulative_sum_without_self < required_sum index = torch.where(index_mask, index, 0) mask = mask.view(batch_size, head_num * chunk_num, block_num) index = index.view(batch_size, head_num * chunk_num, block_num) mask[ :, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index, ] = True mask = mask.view(batch_size, head_num, chunk_num, block_num) else: raise NotImplementedError("block num chunk prefill not implemented") try: if causal: assert (~mask[:, :, :, current_index + chunk_num :]).all() except: mask[:, :, :, current_index + chunk_num :] = False if causal: if decoding: assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all() else: lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device) lambda_mask[:, :, :, 0] = 1 lambda_mask[:, :, :, current_index:current_index+chunk_num] = torch.eye( chunk_num, device=lambda_mask.device ).unsqueeze(0).unsqueeze(0).expand(1, head_num, chunk_num, chunk_num) assert(torch.where(lambda_mask, mask, True).all()) return mask