Source code for cerebras.modelzoo.data_preparation.data_preprocessing.tokenflow.launch_tokenflow

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# isort: off
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../.."))
# isort: on

import json
import os
from argparse import ArgumentParser
from copy import deepcopy

import h5py
import numpy as np
from flask import Flask, jsonify, render_template, request, send_from_directory

from cerebras.modelzoo.data_preparation.data_preprocessing.tokenflow import (
    tokenizer,
)

app = Flask(__name__)


[docs]class TokenFlowDataProcessor: def __init__(self, filepath, data_params): self.filepath = filepath self.data_params = data_params self.msl = data_params["processing"].get("max_seq_length") self.top_tokens = None self.pad_id = data_params['post-process'].get('pad_id') self.eos_id = data_params['post-process'].get('eos_id') self.sample_features = data_params['post-process'].get( 'sample_features' ) self.tokenizer = tokenizer.GenericTokenizer( deepcopy(data_params['processing']), filepath ) # image_paths and image_data_locs are used in multimodal datasets self.data, self.image_paths, image_data_locs = self.load_data() self.datadict = {'image_paths': self.image_paths} for i, feature in enumerate(self.sample_features): ## rename attention_mask to loss_mask. Since dataloader incorrectly renames loss mask to attention mask. if feature == "attention_mask": feature = "loss_mask" self.datadict[feature] = self.data[:, i, :].copy() self.nstrings = self.datadict['input_ids'].shape[0] ## Inverted attention_mask is named as key_padding_mask in multimodal in other cases as it is not a part of data if "key_padding_mask" in self.datadict: self.datadict['attention_mask'] = ( 1 - self.datadict['key_padding_mask'] ).tolist() else: # Manually construct the attention mask self.datadict['attention_mask'] = [] for i in range(self.datadict['input_ids'].shape[0]): pad_indices = np.where( self.datadict['input_ids'][i] == self.pad_id )[0] if self.eos_id != self.pad_id: non_pad_len = int( pad_indices[0] if len(pad_indices) > 0 else self.datadict['input_ids'].shape[1] ) self.datadict['attention_mask'].append( [1] * (non_pad_len) + [0] * (self.datadict['input_ids'].shape[1] - non_pad_len) ) else: ## When the eos id is same as the pad id, get the first pad index by noting that the eos will not ## have a next eos while the pad id's would be contiguous if len(pad_indices) > 0: pad_idx = 0 while ( pad_idx + 1 < len(pad_indices) and pad_indices[pad_idx] + 1 < self.datadict['input_ids'].shape[1] and self.datadict['input_ids'][i][ pad_indices[pad_idx] + 1 ] != self.pad_id ): pad_idx = pad_idx + 1 if pad_idx == len(pad_indices) - 1: ## All eos no pad non_pad_len = self.datadict['input_ids'].shape[1] else: ## the last eos just before pad if present would be chopped off from the input ids non_pad_len = pad_indices[pad_idx] else: non_pad_len = self.datadict['input_ids'].shape[1] self.datadict['attention_mask'].append( [1] * (non_pad_len) + [0] * (self.datadict['input_ids'].shape[1] - non_pad_len) ) # This should modify the dict attr self.datadict['images_bitmap'] = np.zeros( self.datadict['input_ids'].shape ) if image_data_locs.size: for i in range(image_data_locs.shape[0]): image_index = 1 for j in range(image_data_locs.shape[1]): if image_data_locs[i][j][0] == self.msl: break for k in image_data_locs[i][j]: self.datadict['images_bitmap'][i][k] = image_index image_index += 1 # Construct dummy input_strings and label_strings that are overwritten once decoded on-demand self.datadict['input_strings'] = np.full( self.datadict['input_ids'].shape, '', dtype='<U20' ) self.datadict['label_strings'] = np.full( self.datadict['labels'].shape, '', dtype='<U20' ) def load_data(self): with h5py.File(self.filepath, mode='r') as h5_file: # Multimodal data has this format if self.data_params['dataset'].get('is_multimodal', False): return ( np.array(h5_file['data']), np.array( [ [i.decode('utf-8') for i in paths] for paths in h5_file.get('img_path') ] ), np.array(h5_file.get('img_data_loc')), ) elif h5_file.get('data'): return np.array(h5_file['data']), np.array([]), np.array([]) return np.array(h5_file['data_data']), np.array([]), np.array([]) def get_top_tokens(self): # This needs on the fly calculations to get top 5 frequent tokens. if self.top_tokens: return self.top_tokens token_counts = {} for input_string in self.datadict['input_ids']: for token in input_string: if not token_counts.get(token): token_counts[token] = 0 token_counts[token] += 1 top_tokens = dict( sorted(token_counts.items(), key=lambda x: x[1], reverse=True)[:10] ) top_tokens = { self.tokenizer.decode(token): val for token, val in top_tokens.items() } top_tokens = { token.strip(): val for token, val in top_tokens.items() if token.strip() } for key in top_tokens.keys(): top_tokens[key] = int(top_tokens[key]) self.top_tokens = { key: value for i, (key, value) in enumerate(top_tokens.items()) if i < 5 } return self.top_tokens def get_stats(self): stats = self.data_params['processing'] stats.update(self.data_params['post-process']) stats.update(self.data_params['setup']) stats.update(self.data_params['dataset']) stats['multimodal'] = self.data_params['dataset'].get( 'is_multimodal', False ) # Removed empty keys for attr in list(stats.keys()): if stats[attr] is None: stats.pop(attr) # Data is trivial to the user if stats.get('data'): stats.pop('data') if stats.get('input_dir'): stats.pop('input_dir') return stats def get_datadict(self, sequence): response = {} response['input_ids'] = self.datadict['input_ids'][sequence] response['labels'] = self.datadict['labels'][sequence] response['input_strings'] = self.tokenizer.convert_ids_to_tokens( response['input_ids'] ) response['label_strings'] = self.tokenizer.convert_ids_to_tokens( response['labels'] ) response['images_bitmap'] = self.datadict['images_bitmap'][sequence] response['image_paths'] = [] if self.data_params['dataset'].get('is_multimodal', False): response['image_paths'] = self.datadict['image_paths'][sequence] if 'loss_mask' in self.datadict: response['loss_mask'] = self.datadict['loss_mask'][sequence] response['attention_mask'] = self.datadict['attention_mask'][sequence] # Store back the decoded strings for faster call next time self.datadict['input_strings'][sequence] = response['input_strings'] self.datadict['label_strings'][sequence] = response['label_strings'] response.update( {'top_tokens': self.get_top_tokens(), 'stats': self.get_stats()} ) for key, val in response.items(): if isinstance(val, np.ndarray): response[key] = val.tolist() return response
[docs]def get_data_or_error(filename): global data_processors global data_params try: if not data_processors[filename]: data_processors[filename] = TokenFlowDataProcessor( filename, data_params ) return data_processors[filename] except Exception as e: return f"File not Found or error in processing: {e}", 400
[docs]def load_params(args): try: with open(args.data_params, 'r') as json_file: return json.load(json_file) except: return
@app.route('/') def index(): global args global data_params global data_processors files = [] initial_data = None # Load data_params here to make it a one time operation data_params = load_params(args) if not data_params: return "Error in loading data_params json!", 400 if os.path.isdir(args.output_dir): files = [ os.path.join(args.output_dir, f) for f in os.listdir(args.output_dir) if f.endswith('.h5') ] if not files: return ( "There are no HDF5 files present in the directory. Please check the directory.", 404, ) elif os.path.isfile(args.output_dir) and args.output_dir.endswith('.h5'): files = [args.output_dir] if not files: return ( "The passed file is not a valid HDF5 file. Please check the file.", 404, ) data_processors = {file: None for file in files} initial_data = get_data_or_error(files[0]) if not isinstance(initial_data, TokenFlowDataProcessor): return jsonify({"error": initial_data[0], "code": initial_data[1]}) return render_template( 'index.html', files=files, initial_data=initial_data.get_datadict(0), nstrings=initial_data.nstrings, ) @app.route('/data', methods=['POST']) def data(): filename = request.form['filename'] sequence = request.form['sequence'] processor = get_data_or_error(filename) if not isinstance(processor, TokenFlowDataProcessor): return jsonify({"error": processor[0], "code": processor[1]}) response = processor.get_datadict(int(sequence)) response['nstrings'] = processor.nstrings return response @app.route('/images/<path:filename>') def serve_image(filename): return send_from_directory( os.path.join( os.path.dirname(os.path.abspath(args.data_params)), data_params['dataset']['image_dir'], ), filename, ) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument( "--output_dir", type=str, help="Directory/location of one or more HDF5 files. In case a single file is passed, data_params should also be specified", required=True, ) parser.add_argument( "--data_params", type=str, help="Location of data_params, required for loading heruistics related to the preprocessed data", ) parser.add_argument( "--port", type=int, help="Port to run the Flask app on", default=5000 ) global args args = parser.parse_args() if not args.data_params: if os.path.isdir(args.output_dir): args.data_params = os.path.join(args.output_dir, 'data_params.json') else: exit( "Use --data_params <path/to/file> to specify the path of data_params.json. Required when passing a single HDF5 file." ) app.run(debug=True, host='0.0.0.0', port=args.port)