# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to shard into separate train and test dataset files
Reference: https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT
"""
import multiprocessing
import os
import statistics
from collections import defaultdict
from itertools import islice
import nltk
nltk.download('punkt')
[docs]class Sharding:
def __init__(
self,
input_files,
output_name_prefix,
n_training_shards,
n_test_shards,
fraction_test_set,
):
assert (
len(input_files) > 0
), 'The input file list must contain at least one file.'
assert n_training_shards > 0, 'There must be at least one output shard.'
assert n_test_shards > 0, 'There must be at least one output shard.'
self.n_training_shards = n_training_shards
self.n_test_shards = n_test_shards
self.fraction_test_set = fraction_test_set
self.input_files = input_files
self.output_name_prefix = output_name_prefix
self.output_training_identifier = '_training'
self.output_test_identifier = '_test'
self.output_file_extension = '.txt'
self.articles = {} # key: integer identifier, value: list of articles
self.sentences = {} # key: integer identifier, value: list of sentences
self.output_training_files = (
{}
) # key: filename, value: list of articles to go into file
self.output_test_files = (
{}
) # key: filename, value: list of articles to go into file
self.init_output_files()
# Remember, the input files contain one article per line (the whitespace check is to skip extraneous blank lines)
def load_articles(self):
print(f"Start: Loading Articles")
global_article_count = 0
for input_file in self.input_files:
print('input file:', input_file)
with open(input_file, mode='r', newline='\n') as f:
for i, line in enumerate(f):
if line.strip():
self.articles[global_article_count] = line.rstrip()
global_article_count += 1
print(
f"End: Loading Articles: There are {len(self.articles)} articles."
)
def segment_articles_into_sentences(self, segmenter):
print(f"Start: Sentence Segmentation")
if len(self.articles) == 0:
self.load_articles()
assert (
len(self.articles) != 0
), 'Please check that input files are present and contain data.'
use_multiprocessing = 'serial'
def chunks(data, size=len(self.articles)):
it = iter(data)
for i in range(0, len(data), size):
yield {k: data[k] for k in islice(it, size)}
if use_multiprocessing == 'manager':
manager = multiprocessing.Manager()
return_dict = manager.dict()
jobs = []
n_processes = 7 # in addition to the main process, total = n_proc+1
def work(articles, return_dict):
sentences = {}
for i, article in enumerate(articles):
sentences[i] = segmenter.segment_string(articles[article])
if i % 5000 == 0:
print(f"Segmenting article {i}")
return_dict.update(sentences)
for item in chunks(self.articles, len(self.articles)):
p = multiprocessing.Process(
target=work, args=(item, return_dict)
)
# Busy wait
while len(jobs) >= n_processes:
pass
jobs.append(p)
p.start()
for proc in jobs:
proc.join()
elif use_multiprocessing == 'queue':
work_queue = multiprocessing.Queue()
jobs = []
for item in chunks(self.articles, len(self.articles)):
pass
else: # serial option
for i, article in enumerate(self.articles):
self.sentences[i] = segmenter.segment_string(
self.articles[article]
)
if i % 5000 == 0:
print(f"Segmenting article {i}")
print(f"End: Sentence Segmentation")
def init_output_files(self):
print(f"Start: Init Output Files")
assert (
len(self.output_training_files) == 0
), 'Internal storage \
self.output_files already contains data. This function is \
intended to be used by the constructor only.'
assert (
len(self.output_test_files) == 0
), 'Internal storage \
self.output_files already contains data. \
This function is intended to be used by the constructor only.'
for i in range(self.n_training_shards):
name = (
self.output_name_prefix
+ self.output_training_identifier
+ '_'
+ str(i)
+ self.output_file_extension
)
self.output_training_files[name] = []
for i in range(self.n_test_shards):
name = (
self.output_name_prefix
+ self.output_test_identifier
+ '_'
+ str(i)
+ self.output_file_extension
)
self.output_test_files[name] = []
print('End: Init Output Files')
def get_sentences_per_shard(self, shard):
result = 0
for article_id in shard:
result += len(self.sentences[article_id])
return result
def distribute_articles_over_shards(self):
print(f"Start: Distribute Articles Over Shards")
assert (
len(self.articles) >= self.n_training_shards + self.n_test_shards
), 'There are fewer articles than shards. \
Please add more data or reduce the number of shards requested.'
# Create dictionary with - key: sentence count per article, value: article id number
sentence_counts = defaultdict(lambda: [])
max_sentences = 0
total_sentences = 0
for article_id in self.sentences:
current_length = len(self.sentences[article_id])
sentence_counts[current_length].append(article_id)
max_sentences = max(max_sentences, current_length)
total_sentences += current_length
n_sentences_assigned_to_training = int(
(1 - self.fraction_test_set) * total_sentences
)
nominal_sentences_per_training_shard = (
n_sentences_assigned_to_training // self.n_training_shards
)
nominal_sentences_per_test_shard = (
total_sentences - n_sentences_assigned_to_training
) // self.n_test_shards
consumed_article_set = set({})
unused_article_set = set(self.articles.keys())
# Make first pass and add one article worth of lines per file
for file in self.output_training_files:
current_article_id = sentence_counts[max_sentences][-1]
sentence_counts[max_sentences].pop(-1)
self.output_training_files[file].append(current_article_id)
consumed_article_set.add(current_article_id)
unused_article_set.remove(current_article_id)
# Maintain the max sentence count
while (
len(sentence_counts[max_sentences]) == 0 and max_sentences > 0
):
max_sentences -= 1
if (
len(self.sentences[current_article_id])
> nominal_sentences_per_training_shard
):
nominal_sentences_per_training_shard = len(
self.sentences[current_article_id]
)
print(
f"Warning: A single article contains more"
f" than the nominal number of sentences per training shard."
)
for file in self.output_test_files:
current_article_id = sentence_counts[max_sentences][-1]
sentence_counts[max_sentences].pop(-1)
self.output_test_files[file].append(current_article_id)
consumed_article_set.add(current_article_id)
unused_article_set.remove(current_article_id)
# Maintain the max sentence count
while (
len(sentence_counts[max_sentences]) == 0 and max_sentences > 0
):
max_sentences -= 1
if (
len(self.sentences[current_article_id])
> nominal_sentences_per_test_shard
):
nominal_sentences_per_test_shard = len(
self.sentences[current_article_id]
)
print(
f"Warning: A single article contains more \
than the nominal number of sentences per test shard."
)
training_counts = []
test_counts = []
for shard in self.output_training_files:
training_counts.append(
self.get_sentences_per_shard(self.output_training_files[shard])
)
for shard in self.output_test_files:
test_counts.append(
self.get_sentences_per_shard(self.output_test_files[shard])
)
training_median = statistics.median(training_counts)
test_median = statistics.median(test_counts)
# Make subsequent passes over files to find articles to add without going over limit
history_remaining = []
n_history_remaining = 4
while len(consumed_article_set) < len(self.articles):
for fidx, file in enumerate(self.output_training_files):
nominal_next_article_size = min(
nominal_sentences_per_training_shard
- training_counts[fidx],
max_sentences,
)
# Maintain the max sentence count
while (
len(sentence_counts[max_sentences]) == 0
and max_sentences > 0
):
max_sentences -= 1
while (
len(sentence_counts[nominal_next_article_size]) == 0
and nominal_next_article_size > 0
):
nominal_next_article_size -= 1
if (
nominal_next_article_size not in sentence_counts
or nominal_next_article_size == 0
or training_counts[fidx] > training_median
):
continue
# skip adding to this file,
# will come back later if no file can accept unused articles
current_article_id = sentence_counts[nominal_next_article_size][
-1
]
sentence_counts[nominal_next_article_size].pop(-1)
self.output_training_files[file].append(current_article_id)
consumed_article_set.add(current_article_id)
unused_article_set.remove(current_article_id)
for fidx, file in enumerate(self.output_test_files):
nominal_next_article_size = min(
nominal_sentences_per_test_shard - test_counts[fidx],
max_sentences,
)
# Maintain the max sentence count
while (
len(sentence_counts[max_sentences]) == 0
and max_sentences > 0
):
max_sentences -= 1
while (
len(sentence_counts[nominal_next_article_size]) == 0
and nominal_next_article_size > 0
):
nominal_next_article_size -= 1
if (
nominal_next_article_size not in sentence_counts
or nominal_next_article_size == 0
or test_counts[fidx] > test_median
):
continue
# skip adding to this file,
# will come back later if no file can accept unused articles
current_article_id = sentence_counts[nominal_next_article_size][
-1
]
sentence_counts[nominal_next_article_size].pop(-1)
self.output_test_files[file].append(current_article_id)
consumed_article_set.add(current_article_id)
unused_article_set.remove(current_article_id)
# If unable to place articles a few times,
# bump up nominal sizes by fraction until articles get placed
if len(history_remaining) == n_history_remaining:
history_remaining.pop(0)
history_remaining.append(len(unused_article_set))
history_same = True
for i in range(1, len(history_remaining)):
history_same = history_same and (
history_remaining[i - 1] == history_remaining[i]
)
if history_same:
nominal_sentences_per_training_shard += 1
# nominal_sentences_per_test_shard += 1
training_counts = []
test_counts = []
for shard in self.output_training_files:
training_counts.append(
self.get_sentences_per_shard(
self.output_training_files[shard]
)
)
for shard in self.output_test_files:
test_counts.append(
self.get_sentences_per_shard(self.output_test_files[shard])
)
training_median = statistics.median(training_counts)
test_median = statistics.median(test_counts)
print(
f"Distributing data over shards: {len(unused_article_set)} articles remaining."
)
if len(unused_article_set) != 0:
print(f"Warning: Some articles did not make it into output files.")
for shard in self.output_training_files:
print(
f"Training shard sentences: {self.get_sentences_per_shard(self.output_training_files[shard])}"
)
for shard in self.output_test_files:
print(
f"Test shard sentences:{self.get_sentences_per_shard(self.output_test_files[shard])}"
)
print(f"End: Distribute Articles Over Shards")
def write_shards_to_disk(self):
print('Start: Write Shards to Disk')
for shard in self.output_training_files:
self.write_single_shard(
shard, self.output_training_files[shard], 'training'
)
for shard in self.output_test_files:
self.write_single_shard(
shard, self.output_test_files[shard], 'test'
)
print(f"End: Write Shards to Disk")
def write_single_shard(self, shard_name, shard, split):
shard_split = os.path.split(shard_name)
shard_name = shard_split[0] + '/' + split + '/' + shard_split[1]
with open(shard_name, mode='w', newline='\n') as f:
for article_id in shard:
for line in self.sentences[article_id]:
f.write(line + '\n')
f.write('\n') # Line break between articles
[docs]class NLTKSegmenter:
def __init__(self):
pass
def segment_string(self, article):
return nltk.tokenize.sent_tokenize(article)