cerebras.modelzoo.common.run_cstorch_flow.run_cstorch_flow#
- cerebras.modelzoo.common.run_cstorch_flow.run_cstorch_flow(params, params_obj, model_fn, train_data_fn, eval_data_fn)[source]#
Set up the cstorch run and call the appropriate helper based on the mode.
- Parameters
params – the params dictionary extracted from the params.yaml used
params_obj – Config object based on the params dict
model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module
train_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader
eval_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader