# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FIMTokenGenerator Module
This module offers the FIMTokenGenerator class, an extension of the
PretrainingTokenGenerator class, tailored for fill in the middle (FIM) tasks.
Usage:
from your_module_name import FIMTokenGenerator
# Initialize the token generator with the required parameters
tokenizer = FIMTokenGenerator(params, tokenizer_impl, eos_id, pad_id)
# Tokenize and encode text data
tokenized_data, stats = tokenizer.encode("Your sample text to process.")
"""
import logging
from collections import defaultdict
from typing import Any, Dict, List, Tuple
from cerebras.modelzoo.data_preparation.data_preprocessing.pretraining_token_generator import (
PretrainingTokenGenerator,
)
from cerebras.modelzoo.data_preparation.data_preprocessing.utils import (
check_fim_special_tokens,
fim,
handle_bos_token_default,
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
[docs]class FIMTokenGenerator(PretrainingTokenGenerator):
def __init__(self, params, tokenizer, eos_id, pad_id):
"""
Initialize the FIMTokenGenerator class.
Args:
params (Dict[str, Any]): Params from config file.
tokenizer: Tokenizer instance.
eos_id (int): End of sequence token ID.
pad_id (int): Padding token ID.
"""
super(FIMTokenGenerator, self).__init__(
params, tokenizer, eos_id, pad_id
)
processing_params = params["processing"]
self.fim_rate = processing_params.pop("fim_rate", None)
self.spm_rate = processing_params.pop("spm_rate", None)
# Ensures that FIM tokens are specified in config, and that
# the specified tokens are actually in the tokenizer
check_fim_special_tokens(params, self.tokenizer)
# Some tokenizers use BOS ID at the beginning and others do not.
# Here we get a flag for whether to use BOS by default
# and the BOS id if needed.
self.default_bos_token, self.opt_bos_tok_id = handle_bos_token_default(
self.tokenizer
)
self.suffix_tok_id = self.tokenizer.encode(
params['processing'].get("fim_suffix_tok")
)[-1]
self.prefix_tok_id = self.tokenizer.encode(
params['processing'].get("fim_prefix_tok")
)[-1]
self.middle_tok_id = self.tokenizer.encode(
params['processing'].get("fim_middle_tok")
)[-1]
[docs] def encode(
self, semantic_data_array: List[Dict[str, Any]]
) -> Tuple[Dict[str, Any], Dict[str, int]]:
"""
Tokenize and encode the data for auto-regressive language modeling.
Args:
semantic_data_array (Union[Dict[str, Any], List[Dict[str, Any]]]): Data to encode.
Returns:
Tuple[Dict[str, Any], Dict[str, int]]: Tuple of encoded features for auto-regressive language modeling and dataset stats.
"""
tokenized_data, data_stats = self.tokenize_data(semantic_data_array)
if not tokenized_data:
return {}, data_stats
tokenized_data = tokenized_data["data"]
result = []
# Reset the stats for pad tokens and masked tokens and recompute for FIM
num_masked_tokens = 0
num_pad_tokens = 0
loss_valid_tokens = 0
num_tokens = 0
tokenized_data_stats = defaultdict(int)
for i, sample in enumerate(tokenized_data):
if sample != []:
sample = fim(
sample,
i,
self.tokenizer,
self.fim_rate,
self.spm_rate,
self.suffix_tok_id,
self.prefix_tok_id,
self.middle_tok_id,
self.pad_id,
self.eos_id,
self.opt_bos_tok_id,
)
sample_data_stats = self.get_data_stats(sample)
for key in sample_data_stats:
tokenized_data_stats[key] += sample_data_stats[key]
result.append(sample)
if not result:
data = {}
else:
data = {"data": result}
data_stats.update(tokenized_data_stats)
return data, data_stats