cerebras.modelzoo.losses.LoadBalancingLoss.LoadBalancingLoss#
- class cerebras.modelzoo.losses.LoadBalancingLoss.LoadBalancingLoss(*args, **kwargs)[source]#
Bases:
torch.nn.Module
Methods
router_weights: Num hidden layers * [[batch_size, seq_len, experts]] expert_mask: Num hidden layers * [[batch_size, seq_len, experts]].