cerebras.modelzoo.losses.T5ForConditionalGenerationLoss.T5ForConditionalGenerationLoss#

class cerebras.modelzoo.losses.T5ForConditionalGenerationLoss.T5ForConditionalGenerationLoss(*args, **kwargs)[source]#

Bases: torch.nn.Module

Methods

forward

Per-token loss is averaged across the batch by

forward(lm_logits, labels, decoder_attention_mask, loss_weight=None)[source]#
Per-token loss is averaged across the batch by
  1. Summing across all tokens in the batch

  2. Dividing by the batch size

  3. Multiplying by the provided loss weight (expected to be roughly

    equal to batch_size / num_tokens_in_batch)

The user has the option to specify this loss weight once and use the same weight for every batch (by setting self.global_loss_weight and not passing in loss_weight to the forward function) or use a different weight for every batch (by passing loss_weight to the forward function).