cerebras.modelzoo.common.utils.model.transformer_utils.get_extended_attention_mask#

cerebras.modelzoo.common.utils.model.transformer_utils.get_extended_attention_mask(attention_mask, input_shape=None, causal=False, device=None, dtype=None)[source]#

Makes broadcastable attention and causal masks so that future and masked tokens are ignored. :param attention_mask: Mask with ones indicating tokens to attend to, zeros for tokens to ignore. :type attention_mask: torch.Tensor :param input_shape: The shape of the input to the model (required for causal masks). :type input_shape: Tuple[int] :param causal: (bool): If enabled the returned mask will be causal. :param device: (torch.device):

The device of the input to the model.

Returns

torch.Tensor The extended attention mask, with a the same dtype as attention_mask.dtype.

Return type

torch.Tensor