cerebras.modelzoo.losses.LoadBalancingLoss.LoadBalancingLoss#

class cerebras.modelzoo.losses.LoadBalancingLoss.LoadBalancingLoss(*args, **kwargs)[source]#

Bases: torch.nn.Module

Methods

forward

router_weights: Num hidden layers * [[batch_size, seq_len, experts]] expert_mask: Num hidden layers * [[batch_size, seq_len, experts]].

forward(router_weights_list, expert_mask_list, attention_mask=None)[source]#

router_weights: Num hidden layers * [[batch_size, seq_len, experts]] expert_mask: Num hidden layers * [[batch_size, seq_len, experts]].