Source code for cerebras.modelzoo.common.run_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for running Cerebras Pytorch Models."""

import argparse
import inspect
import logging
import os
import subprocess
import sys
from shutil import which
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import yaml

from cerebras.appliance.log import (
    collect_wsc_log_settings,
    get_level_name,
    wsc_logger,
)
from cerebras.modelzoo.common.pytorch_utils import RunConfigParamsValidator
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.common.run_cstorch_flow import run_cstorch_flow
from cerebras.modelzoo.common.utils.run.cli_parser import get_params_from_args
from cerebras.modelzoo.common.utils.run.utils import DeviceType
from cerebras.modelzoo.config_manager.config_loader import (
    convert_to_dict,
    create_config_obj_from_params,
)

DATA_FN_TYPE = Callable[[dict], torch.utils.data.DataLoader]


[docs]def torchrun(filename: str, arguments: List[str]): """Starts a distributed GPU run using torchrun.""" torchrun_cmd = [ which("torchrun", path=os.path.dirname(sys.executable)), "--nnodes=1", f"--nproc_per_node={torch.cuda.device_count()}", filename, *arguments, ] try: print( f"Starting distributed GPU run using torchrun:\n" f"{' '.join(torchrun_cmd)}" ) subprocess.run(torchrun_cmd, check=True) except Exception as e: raise RuntimeError( f"Failed to spawn distributed GPU run using torchrun" ) from e
[docs]def run( model_fn: Union[Callable[[dict], torch.nn.Module], str], train_data_fn: Optional[DATA_FN_TYPE] = None, eval_data_fn: Optional[DATA_FN_TYPE] = None, default_params_fn: Optional[Callable[[dict], dict]] = None, extra_args_parser_fn: Optional[ Callable[[], List[argparse.ArgumentParser]] ] = None, ): """ Entry point to running pytorch models including CLI argument parsing. """ parent = inspect.getouterframes(inspect.currentframe())[1] params = get_params_from_args(extra_args_parser_fn) if default_params_fn: params = default_params_fn(params) or params main(params, model_fn, train_data_fn, eval_data_fn, script=parent.filename)
[docs]def main( params: Dict[str, Any], model_fn: Union[Callable[[dict], torch.nn.Module], str], train_data_fn: Optional[DATA_FN_TYPE] = None, eval_data_fn: Optional[DATA_FN_TYPE] = None, script: Optional[str] = None, extra_args_parser_fn: Optional[ Callable[[], List[argparse.ArgumentParser]] ] = None, ): """Entry point to running pytorch models.""" if not script: parent = inspect.getouterframes(inspect.currentframe())[1] script = parent.filename if ( "runconfig" in params # If using distributed GPU with experimental API and params["runconfig"]["target_device"] == DeviceType.GPU and params["runconfig"].get("enable_distributed", False) # If this is already set, we've already launched distributed training and os.environ.get("LOCAL_RANK") is None ): # use torchrun to launch distributed training torchrun(script, sys.argv[1:]) return None return run_with_params( params, model_fn, train_data_fn, eval_data_fn, extra_args_parser_fn=extra_args_parser_fn, )
[docs]def run_with_params( params: Dict[str, Any], model_fn: Union[Callable[[dict], torch.nn.Module], str], train_data_fn: Optional[DATA_FN_TYPE] = None, eval_data_fn: Optional[DATA_FN_TYPE] = None, extra_args_parser_fn: Optional[ Callable[[], List[argparse.ArgumentParser]] ] = None, ): """ Runs a full end-to-end CS/non-CS workflow for a given model. Args: params: The parsed YAML config dictionary. model_fn: A callable that takes in a 'params' argument which it uses to configure and return a torch.nn.Module train_data_fn: A callable that takes in a 'params' argument which it uses to configure and return a PyTorch dataloader corresponding to the training dataset eval_data_fn: A callable that takes in a 'params' argument which it uses to configure and return a PyTorch dataloader corresponding to the evaluation dataset extra_args_parser_fn: An optional callable that adds any extra parser args not covered in `get_parser` fn. """ runconfig_params = params["runconfig"] RunConfigParamsValidator(extra_args_parser_fn).validate(runconfig_params) mode = runconfig_params["mode"] if not os.environ.get("RUN_LEGACY_CSTORCH_FLOW"): logging.info("Running Trainer Flow") from cerebras.modelzoo.trainer.utils import run_trainer # Recursively update the params with the runconfig if "runconfig" in params and "trainer" in params: # runconfig was injected into params by the CLI parser # and we need to merge it with the trainer params from cerebras.modelzoo.trainer.utils import ( inject_cli_args_to_trainer_params, ) params = inject_cli_args_to_trainer_params( params.pop("runconfig"), params ) # We will be getting a params_obj if we are using config classes return run_trainer( mode, params, model_fn, train_data_fn, eval_data_fn, ) # Save the params summary_dir = os.path.join(runconfig_params["model_dir"], mode) os.makedirs(summary_dir, exist_ok=True) with open(os.path.join(summary_dir, f"params_{mode}.yaml"), "w") as f: yaml.dump(params, f, default_flow_style=False, sort_keys=False) wsc_log_level = params["runconfig"].get("wsc_log_level") or {} set_wsc_log_level(wsc_log_level) if isinstance(model_fn, str): model_name = model_fn params_obj = create_config_obj_from_params(params, model_name) model_fn = registry.get_model_class(model_name) params = convert_to_dict(params_obj) else: params_obj = None return run_cstorch_flow( params, params_obj, model_fn, train_data_fn, eval_data_fn )
[docs]def set_wsc_log_level(log_levels: Union[List[str], Dict[str, str]]): """Assert the list of log levels is valid.""" if isinstance(log_levels, dict): for task, level in log_levels.items(): level = int(level) if level.isdigit() else get_level_name(level) if task: wsc_logger.getChild(task).setLevel(level) else: wsc_logger.setLevel(level) else: raise ValueError("Invalid log levels. Input must be a dict.") # validate log level setting collect_wsc_log_settings()