# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor for PyTorch BERT training.
"""
import csv
import random
import numpy as np
import torch
from cerebras.modelzoo.common.input_utils import (
bucketed_batch,
get_streaming_batch_size,
)
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.common.input_utils import (
get_data_for_task,
num_tasks,
shard_list_interleaved,
task_id,
)
from cerebras.modelzoo.data.nlp.bert.bert_utils import (
build_vocab,
create_masked_lm_predictions,
get_meta_data,
parse_text,
)
[docs]@registry.register_datasetprocessor("BertCSVDynamicMaskDataProcessor")
class BertCSVDynamicMaskDataProcessor(torch.utils.data.IterableDataset):
"""
Reads csv files containing the input text tokens, adds MLM features
on the fly.
:param <dict> params: dict containing input parameters for creating dataset.
Expects the following fields:
- "data_dir" (string): path to the data files to use.
- "batch_size" (int): Batch size.
- "shuffle" (bool): Flag to enable data shuffling.
- "shuffle_seed" (int): Shuffle seed.
- "shuffle_buffer" (int): Shuffle buffer size.
- "mask_whole_word" (bool): Flag to whether mask the entire word.
- "do_lower" (bool): Flag to lower case the texts.
- "dynamic_mlm_scale" (bool): Flag to dynamically scale the loss.
- "num_workers" (int): How many subprocesses to use for data loading.
- "drop_last" (bool): If True and the dataset size is not divisible
by the batch size, the last incomplete batch will be dropped.
- "prefetch_factor" (int): Number of samples loaded in advance by each worker.
- "persistent_workers" (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
- "oov_token" (string): Out of vocabulary token.
- "mask_token" (string): Mask token.
- "document_separator_token" (string): Seperator token.
- "exclude_from_masking" list(string): tokens that should be excluded from being masked.
- "max_sequence_length" (int): Maximum length of the sequence to generate.
- "max_predictions_per_seq" (int): Maximum number of masked tokens per sequence.
- "masked_lm_prob" (float): Ratio of the masked tokens over the sequence length.
- "gather_mlm_labels" (bool): Flag to gather mlm labels.
- "mixed_precision" (bool): Casts input mask to fp16 if set to True.
Otherwise, the generated mask is float32.
"""
def __init__(self, params):
super(BertCSVDynamicMaskDataProcessor, self).__init__()
# Input params.
self.meta_data = get_meta_data(params["data_dir"])
self.meta_data_values = list(self.meta_data.values())
self.meta_data_filenames = list(self.meta_data.keys())
# Please note the appending of [0]
self.meta_data_values_cum_sum = np.cumsum([0] + self.meta_data_values)
self.num_examples = sum(map(int, self.meta_data.values()))
self.disable_nsp = params.get("disable_nsp", False)
self.batch_size = get_streaming_batch_size(params["batch_size"])
self.num_batches = self.num_examples // self.batch_size
assert (
self.num_batches > 0
), "Dataset does not contain enough samples for one batch. Please choose a smaller batch size"
self.num_tasks = num_tasks()
self.task_id = task_id()
self.num_batch_per_task = self.num_batches // self.num_tasks
assert (
self.num_batch_per_task > 0
), "Dataset cannot be evenly distributed across the given tasks. Please choose fewer tasks to run with"
self.num_examples_per_task = self.num_batch_per_task * self.batch_size
self.files_in_task = get_data_for_task(
self.task_id,
self.meta_data_values_cum_sum,
self.num_examples_per_task,
self.meta_data_values,
self.meta_data_filenames,
)
self.shuffle = params.get("shuffle", True)
self.shuffle_seed = params.get("shuffle_seed", None)
self.shuffle_buffer = params.get("shuffle_buffer", 10 * self.batch_size)
self.mask_whole_word = params.get("mask_whole_word", False)
self.do_lower = params.get("do_lower", False)
self.dynamic_mlm_scale = params.get("dynamic_mlm_scale", False)
self.buckets = params.get("buckets", None)
# Multi-processing params.
self.num_workers = params.get("num_workers", 0)
self.drop_last = params.get("drop_last", True)
self.prefetch_factor = params.get("prefetch_factor", 10)
self.persistent_workers = params.get("persistent_workers", True)
# Get special tokens and tokens that should not be masked.
self.special_tokens = {
"oov_token": params.get("oov_token", "[UNK]"),
"mask_token": params.get("mask_token", "[MASK]"),
"document_separator_token": params.get(
"document_separator_token", "[SEP]"
),
}
self.exclude_from_masking = params.get(
"exclude_from_masking", ["[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
if self.do_lower:
self.special_tokens = {
key: value.lower() for key, value in self.special_tokens.items()
}
self.exclude_from_masking = list(
map(lambda token: token.lower(), self.exclude_from_masking)
)
# Get vocab file and size.
self.vocab_file = params["vocab_file"]
self.vocab, self.vocab_size = build_vocab(
self.vocab_file, self.do_lower, self.special_tokens["oov_token"]
)
# Init tokenizer.
self.tokenize = self.vocab.forward
# Getting indices for special tokens.
self.special_tokens_indices = {
key: self.tokenize([value])[0]
for key, value in self.special_tokens.items()
}
self.exclude_from_masking_ids = [
self.tokenize([token])[0] for token in self.exclude_from_masking
]
# We create a pool with tokens that can be used to randomly replace input tokens
# for BERT MLM task.
self.replacement_pool = list(
set(range(self.vocab_size)) - set(self.exclude_from_masking_ids)
)
# Padding indices.
# See https://huggingface.co/transformers/glossary.html#labels.
self.labels_pad_id = params.get("labels_pad_id", 0)
self.input_pad_id = params.get("input_pad_id", 0)
self.attn_mask_pad_id = params.get("attn_mask_pad_id", 0)
if not self.disable_nsp:
self.segment_pad_id = params.get("segment_pad_id", 0)
# Max sequence lengths size params.
self.max_sequence_length = params["max_sequence_length"]
self.max_predictions_per_seq = params["max_predictions_per_seq"]
self.masked_lm_prob = params.get("masked_lm_prob", 0.15)
self.gather_mlm_labels = params.get("gather_mlm_labels", True)
# Store params.
self.data_buffer = []
self.csv_files_per_task_per_worker = []
self.processed_buffers = 0
[docs] def load_buffer(self):
"""
Generator to read the data in chunks of size of `data_buffer`.
:returns: Yields the data stored in the `data_buffer`.
"""
self.processed_buffers = 0
self.data_buffer = []
while self.processed_buffers < len(self.csv_files_per_task_per_worker):
(
current_file_path,
num_examples,
start_id,
) = self.csv_files_per_task_per_worker[self.processed_buffers]
with open(current_file_path, "r", newline="") as fin:
data_reader = csv.DictReader(fin)
for row_id, row in enumerate(data_reader):
if start_id <= row_id < start_id + num_examples:
self.data_buffer.append(row)
else:
continue
if len(self.data_buffer) == self.shuffle_buffer:
if self.shuffle:
self.rng.shuffle(self.data_buffer)
for ind in range(len(self.data_buffer)):
yield self.data_buffer[ind]
self.data_buffer = []
self.processed_buffers += 1
if self.shuffle:
self.rng.shuffle(self.data_buffer)
for ind in range(len(self.data_buffer)):
yield self.data_buffer[ind]
self.data_buffer = []
def __len__(self):
# Returns the len of dataset on the task process
if not self.drop_last:
return (
self.num_examples_per_task + self.batch_size - 1
) // self.batch_size
elif self.buckets is None:
return self.num_examples_per_task // self.batch_size
else:
# give an under-estimate in case we don't fully fill some buckets
length = self.num_examples_per_task // self.batch_size
length -= len(self.buckets)
return length
[docs] def get_single_item(self):
"""
Iterating over the data to construct input features.
:return: A tuple with training features:
* np.array[int.32] input_ids: Numpy array with input token indices.
Shape: (`max_sequence_length`).
* np.array[int.32] labels: Numpy array with labels.
Shape: (`max_sequence_length`).
* np.array[int.32] attention_mask
Shape: (`max_sequence_length`).
* np.array[int.32] token_type_ids: Numpy array with segment indices.
Shape: (`max_sequence_length`).
* np.array[int.32] next_sentence_label: Numpy array with labels for NSP task.
Shape: (1).
* np.array[int.32] masked_lm_mask: Numpy array with a mask of
predicted tokens.
Shape: (`max_predictions`)
`0` indicates the non masked token, and `1` indicates the masked token.
"""
# Iterate over the data rows to create input features.
for data_row in self.load_buffer():
# `data_row` is a dict with keys:
# ["tokens", "segment_ids", "is_random_next"].
tokens = parse_text(data_row["tokens"], do_lower=self.do_lower)
if self.disable_nsp:
# truncate tokens to MSL
tokens = tokens[: self.max_sequence_length]
else:
assert (
len(tokens) <= self.max_sequence_length
), "When using NSP head, make sure that len(tokens) <= MSL."
(
input_ids,
labels,
attention_mask,
masked_lm_mask,
) = create_masked_lm_predictions(
tokens,
self.max_sequence_length,
self.special_tokens_indices["mask_token"],
self.max_predictions_per_seq,
self.input_pad_id,
self.attn_mask_pad_id,
self.labels_pad_id,
self.tokenize,
self.vocab_size,
self.masked_lm_prob,
self.rng,
self.exclude_from_masking,
self.mask_whole_word,
self.replacement_pool,
)
features = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if self.gather_mlm_labels:
# Gather MLM positions
_mlm_positions = np.nonzero(masked_lm_mask)[0]
_num_preds = len(_mlm_positions)
gathered_mlm_positions = np.zeros(
(self.max_predictions_per_seq,), dtype=np.int32
)
gathered_mlm_positions[:_num_preds] = _mlm_positions
gathered_labels = np.zeros(
(self.max_predictions_per_seq,), dtype=np.int32
)
gathered_labels[:_num_preds] = labels[_mlm_positions]
gathered_mlm_mask = np.zeros(
(self.max_predictions_per_seq,), dtype=np.int32
)
gathered_mlm_mask[:_num_preds] = masked_lm_mask[_mlm_positions]
features["labels"] = gathered_labels
features["masked_lm_mask"] = gathered_mlm_mask
features["masked_lm_positions"] = gathered_mlm_positions
else:
features["labels"] = labels
features["masked_lm_mask"] = masked_lm_mask
if not self.disable_nsp:
next_sentence_label = np.zeros((1,), dtype=np.int32)
token_type_ids = (
np.ones((self.max_sequence_length,), dtype=np.int32)
* self.segment_pad_id
)
segment_ids = data_row["segment_ids"].strip("[]").split(", ")
token_type_ids[: len(segment_ids)] = list(map(int, segment_ids))
next_sentence_label[0] = int(data_row["is_random_next"])
features["token_type_ids"] = token_type_ids
features["next_sentence_label"] = next_sentence_label
yield features
def __iter__(self):
batched_dataset = bucketed_batch(
self.get_single_item(),
self.batch_size,
buckets=self.buckets,
element_length_fn=lambda feats: np.sum(feats["attention_mask"]),
drop_last=self.drop_last,
seed=self.shuffle_seed,
)
for batch in batched_dataset:
if self.dynamic_mlm_scale:
scale = self.batch_size / torch.sum(batch["masked_lm_mask"])
batch["mlm_loss_scale"] = scale.expand(self.batch_size, 1)
yield batch
def _worker_init_fn(self, worker_id):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = worker_info.num_workers
else:
# Single-process
worker_id = 0
num_workers = 1
self.processed_buffers = 0
if self.shuffle_seed is not None:
self.shuffle_seed += worker_id + 1
self.rng = random.Random(self.shuffle_seed)
# Shard the data across multiple processes.
self.csv_files_per_task_per_worker = shard_list_interleaved(
self.files_in_task, worker_id, num_workers
)
if self.shuffle:
self.rng.shuffle(self.csv_files_per_task_per_worker)
[docs] def create_dataloader(self):
"""
Classmethod to create the dataloader object.
"""
if self.num_workers:
dataloader = torch.utils.data.DataLoader(
self,
batch_size=None,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
worker_init_fn=self._worker_init_fn,
)
else:
dataloader = torch.utils.data.DataLoader(self, batch_size=None)
self._worker_init_fn(0)
return dataloader