cerebras.modelzoo.losses.T5ForConditionalGenerationLoss.T5ForConditionalGenerationLoss#
- class cerebras.modelzoo.losses.T5ForConditionalGenerationLoss.T5ForConditionalGenerationLoss(*args, **kwargs)[source]#
Bases:
torch.nn.Module
Methods
Per-token loss is averaged across the batch by
- forward(lm_logits, labels, decoder_attention_mask, loss_weight=None)[source]#
- Per-token loss is averaged across the batch by
Summing across all tokens in the batch
Dividing by the batch size
- Multiplying by the provided loss weight (expected to be roughly
equal to batch_size / num_tokens_in_batch)
The user has the option to specify this loss weight once and use the same weight for every batch (by setting self.global_loss_weight and not passing in loss_weight to the forward function) or use a different weight for every batch (by passing loss_weight to the forward function).