cerebras.modelzoo.common.run_cstorch_flow.run_cstorch_train#

cerebras.modelzoo.common.run_cstorch_flow.run_cstorch_train(params, params_obj, model_fn, input_fn, cluster_config, artifact_dir)[source]#

Runs the training workflow built using the cstorch API.

Parameters
  • params – the params dictionary extracted from the params.yaml used

  • model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module

  • input_data – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader