Source code for cerebras.modelzoo.common.utils.run.write_weight_summaries

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python

import argparse

import numpy as np
import pandas as pd

from cerebras.pytorch.saver import PyTorchH5Saver


[docs]def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--input_path") parser.add_argument("--output_path") parser.add_argument("--framework", choices=["tf", "pt"], default="pt") parser.add_argument( "--output_format", choices=["text", "csv"], default="text" ) return parser.parse_args()
[docs]def read_checkpoint(input_path, framework): assert framework in ("pt"), f"Unsupported framework {framework}" if PyTorchH5Saver.is_valid_checkpoint(input_path): saver = PyTorchH5Saver() for name in saver.tensor_names(input_path): yield name, saver.load_tensor(input_path, name) else: try: state_dict = torch.load(input_path) except Exception as e: raise RuntimeError( f"{input_path} is not a valid H5 checkpoint or " f"a valid torch serialized checkpoint" ) from e state_dict = PyTorchH5Saver().flatten_state_dict(state_dict) for name, weight in state_dict: yield name, weight
[docs]def main(): args = parse_args() if args.output_format == "text": f = open(args.output_path, "w") f.write( "Layer stats layer_id: min_abs_wt, max_abs_wt, mean_wt, stdev_wt, " "mean_abs_wt, stdev_abs_wt, mean_log_wt, stdev_log_wt, wt_norm\n" ) weight_norm = 0.0 model_weight_norm = 0.0 all_data = [] for name, weight in read_checkpoint(args.input_path, args.framework): nan_check = np.isnan(weight) if np.all(nan_check) and args.output_format == "text": f.write(f"Layer {name} contains ***** ALL NaNs! *****\n") continue elif np.any(nan_check) and args.output_format == "text": num_nans = 100 * nan_check.sum() / nan_check.size f.write( f"WARNING: Layer {name} contains {num_nans:.1f}% NaNs. " "Removing them...\n" ) weight = np.where(nan_check, 0.0, weight) if isinstance(weight, int) or isinstance(weight, float): weight = np.array(weight) num_weights = weight.size num_zeros = num_weights - np.count_nonzero(weight) if num_zeros > 0 and args.output_format == "text": f.write( f"Layer {name} zero: {num_zeros} " f"({num_zeros * 100 / num_weights:.2f}%)\n" ) weight = weight.astype(float) norm_sq = np.sum(np.square(weight)) weight_norm += norm_sq if ( "Adam" not in name and "grad_accum" not in name and "step" not in name and "scale" not in name and "optimizer" not in name and "lr_scheduler" not in name ): model_weight_norm += norm_sq norm = np.sqrt(norm_sq) mean = np.mean(weight) stdev = np.std(weight) weight = np.absolute(weight) minimum = np.amin(weight) maximum = np.amax(weight) absmean = np.mean(weight) absstdev = np.std(weight) weight = np.log2(weight) logmean = np.mean(weight) logstdev = np.std(weight) if args.output_format == "text": f.write( f"Layer stats {name}: {minimum} {maximum} {mean} {stdev} " f"{absmean} {absstdev} {logmean} {logstdev} {norm}\n" ) else: all_data.append( [ name, minimum, maximum, mean, stdev, absmean, absstdev, logmean, logstdev, norm, ] ) if args.output_format == "text": weight_norm = np.sqrt(weight_norm) model_weight_norm = np.sqrt(model_weight_norm) f.write(f"Total weights norm: {weight_norm}\n") f.write(f"Non-optimizer weights norm: {model_weight_norm}\n") f.close() else: df = pd.DataFrame( all_data, columns=[ "name", "minimum", "maximum", "mean", "stdev", "absmean", "absstdev", "logmean", "logstdev", "norm", ], ) df.to_csv(args.output_path)
if __name__ == "__main__": main()