cerebras.modelzoo.losses.dpr_loss.DPRLoss#
- class cerebras.modelzoo.losses.dpr_loss.DPRLoss(*args, **kwargs)[source]#
Bases:
torch.nn.Module
Methods
Args: q2c_scores: question-to-context scores (batch_size, batch_size * num_context) labels: labels for question-to-context scores (batch_size, )
- forward(q2c_scores, labels, c2q_scores=None, context_labels=None)[source]#
Args: q2c_scores: question-to-context scores
(batch_size, batch_size * num_context)
- labels: labels for question-to-context scores
(batch_size, )
Optional Args: c2q_scores: context_to_question scores
(batch_size, batch_size)
- context_labels: labels for context_to_question scores
(batch_size, )
Please see the comment in dpr_model.py for more details.