cerebras.modelzoo.losses.dpr_loss.DPRLoss#

class cerebras.modelzoo.losses.dpr_loss.DPRLoss(*args, **kwargs)[source]#

Bases: torch.nn.Module

Methods

forward

Args: q2c_scores: question-to-context scores (batch_size, batch_size * num_context) labels: labels for question-to-context scores (batch_size, )

forward(q2c_scores, labels, c2q_scores=None, context_labels=None)[source]#

Args: q2c_scores: question-to-context scores

(batch_size, batch_size * num_context)

labels: labels for question-to-context scores

(batch_size, )

Optional Args: c2q_scores: context_to_question scores

(batch_size, batch_size)

context_labels: labels for context_to_question scores

(batch_size, )

Please see the comment in dpr_model.py for more details.