# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Any, Dict, List
from cerebras.modelzoo.data_preparation.data_preprocessing.utils import (
SYSTEM_PROMPT_REGISTRY,
)
logger = logging.getLogger(__name__)
[docs]def finetuning_llava_hook(
example: Dict[str, Any], **read_hook_kwargs: Any
) -> List[Dict[str, Any]]:
"""
Transforms conversation data for finetuning LLaVA.
Args:
example (Dict[str, Any]): The input data containing conversation and image paths.
**read_hook_kwargs (Any): Additional keyword arguments containing data_keys, system_prompt, image_token, multi_turn_content_key, and phase.
Returns:
List[Dict[str, Any]]: Transformed data suitable for finetuning LLaVA.
Raises:
AssertionError: If required keys are not provided in read_hook_kwargs.
ValueError: If image_token is not provided, or if there are multiple image tokens in the user's role, or if image tokens are found in the assistant's response.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
multi_turn_key = data_keys.get("multi_turn_key")
image_key = data_keys.get("image_key")
system_prompt = read_hook_kwargs.get("system_prompt")
image_token = read_hook_kwargs.get("image_token", None)
multi_turn_content_key = read_hook_kwargs.get(
"multi_turn_content_key", "value"
)
phase = read_hook_kwargs.get("phase")
assert (
phase != None
), "phase should be provided in the read_hook_kwargs section for llava"
assert (
image_token != None
), "image_token should be provided in the read_hook_kwargs section for llava"
conversation_data = example.get(multi_turn_key, [])
if conversation_data is None:
conversation_data = []
image_path = example.get(image_key, None)
transformed_data = []
if system_prompt:
system_prompt_text = SYSTEM_PROMPT_REGISTRY.get(system_prompt, "")
system_data = {
"type": "system",
"content": [{"text": system_prompt_text.strip()}],
}
transformed_data.append(system_data)
if image_path and not image_token:
raise ValueError(
"Image token has not been provided inside read_hook_kwargs within the processing section in the config file for llava finetuning datasets."
)
for i, turn in enumerate(conversation_data):
semantic_drop_mask = []
role = "user" if i % 2 == 0 else "assistant"
content_parts = []
if role == "user":
# Check for multiple image tokens in the user's role
if turn[multi_turn_content_key].count(image_token) > 1:
raise ValueError(
"Multiple image tokens found in user's role. Only one image token is allowed."
)
# Assume there's only one image token in the user's role
parts = re.split(
re.escape(image_token), turn[multi_turn_content_key]
)
if len(parts) == 2:
# Add the image part before the text
content_parts.append({"image": image_path})
parts = [part.strip() for part in parts]
if parts[0] != '' and parts[1] != '':
content_parts.append({"text": parts[0] + parts[1]})
if phase == 1:
semantic_drop_mask.extend([False, True])
else:
semantic_drop_mask.extend([False, False])
else:
semantic_drop_mask.extend([False])
else:
# No image token found, add the text as is
content_parts.append({"text": turn[multi_turn_content_key]})
if phase == 1:
semantic_drop_mask.extend([True])
else:
semantic_drop_mask.extend([False])
elif role == "assistant":
# Check that no image tokens are present in the assistant's response
if image_token and image_token in turn[multi_turn_content_key]:
raise ValueError(
"Image tokens are not allowed in the assistant's response."
)
content_parts.append({"text": turn[multi_turn_content_key]})
semantic_drop_mask.extend([False])
transformed_data.append(
{
"type": role,
"content": content_parts,
"semantic_drop_mask": semantic_drop_mask,
}
)
return transformed_data
[docs]def pretraining_image_captions_hook(
example: Dict[str, Any], **read_hook_kwargs: Any
) -> List[Dict[str, Any]]:
"""
Transforms image and caption data for pretraining.
Args:
example (Dict[str, Any]): The input data containing image and caption information.
**read_hook_kwargs (Any): Additional keyword arguments containing data_keys.
Returns:
List[Dict[str, Any]]: Transformed data suitable for pretraining.
Raises:
AssertionError: If required keys are not provided in read_hook_kwargs.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
image_key = data_keys.get('image_key', None)
caption_key = data_keys.get('caption_key', None)
assert (
image_key != None
), "pretraining_image_captions_hook requires a image_key"
if isinstance(example.get(image_key), dict):
## datasets downloaded directly from huggingface come in this format
return [
{
"content": [
{"image": example.get(image_key).get("path")},
{"text": example.get(caption_key)},
],
}
]
else:
return [
{
"content": [
{"image": example.get(image_key)},
{"text": example.get(caption_key)},
],
}
]
[docs]def text_read_hook(
example: Dict[str, Any], **read_hook_kwargs: Any
) -> List[Dict[str, Any]]:
"""
Transforms text data for reading.
Args:
example (Dict[str, Any]): The input data containing text information.
**read_hook_kwargs (Any): Additional keyword arguments containing data_keys.
Returns:
List[Dict[str, Any]]: Transformed data suitable for reading.
Raises:
AssertionError: If required keys are not provided in read_hook_kwargs.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
text_key = data_keys.get('text_key', None)
assert text_key != None, "text_read_hook requires a text_key"
return [
{
"content": [
{"text": example.get(text_key, "")},
],
}
]
[docs]def nlg_read_hook(
example: Dict[str, Any], **read_hook_kwargs: Any
) -> List[Dict[str, Any]]:
"""
Transforms natural language generation (NLG) data for reading.
Args:
example (Dict[str, Any]): The input data containing NLG information.
**read_hook_kwargs (Any): Additional keyword arguments containing data_keys.
Returns:
List[Dict[str, Any]]: Transformed data suitable for reading.
Raises:
AssertionError: If required keys are not provided in read_hook_kwargs.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
context_key = data_keys.get('context_key', None)
completion_key = data_keys.get('completion_key', None)
assert (
context_key is not None and completion_key is not None
), "nlg_read_hook requires a context_key and a completion_key"
return [
{
"type": "context",
"content": [
{"text": example.get(context_key, "")},
],
},
{
"type": "completion",
"content": [
{"text": example.get(completion_key, "")},
],
},
]
[docs]def prompt_completion_text_read_hook(
example: Dict[str, Any], **read_hook_kwargs
) -> List[Dict[str, Any]]:
"""
Process prompt and completion text data into a semantic_data_array format.
Args:
example (Dict[str, Any]): The example data to process.
**read_hook_kwargs: Additional keyword arguments for processing.
Returns:
List[Dict[str, Any]]: A list of dictionaries in semantic_data_array format.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
prompt_key = data_keys.get('prompt_key', None)
completion_key = data_keys.get('completion_key', None)
assert (
prompt_key is not None and completion_key is not None
), "prompt_completion_read_hook requires a prompt_key and a completion_key"
return [
{
"type": "prompt",
"content": [
{"text": example.get(prompt_key)},
],
},
{
"type": "completion",
"content": [
{"text": example.get(completion_key)},
],
},
]
[docs]def chat_read_hook(
example: Dict[str, Any], **read_hook_kwargs: Any
) -> List[Dict[str, Any]]:
"""
Transforms chat data for reading.
Args:
example (Dict[str, Any]): The input data containing chat messages.
**read_hook_kwargs (Any): Additional keyword arguments containing data_keys.
Returns:
List[Dict[str, Any]]: Transformed data into semantic data array format.
Raises:
AssertionError: If required keys are not provided in read_hook_kwargs.
"""
## This api assumes dataset is in ChatML format
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
multi_turn_key = data_keys.get('multi_turn_key')
assert (
multi_turn_key is not None
), "multi_turn_chat_read_hook requires a multi_turn_key"
conversation_data = example.get(multi_turn_key, [])
content_key = read_hook_kwargs.get('multi_turn_content_key', "content")
has_system_prompt = read_hook_kwargs.get('has_system_prompt', False)
semantic_data_array = []
if has_system_prompt:
system_prompt = conversation_data.pop(0)
semantic_data_array.append(
{"type": "system", "content": [{"text": system_prompt}]}
)
for i, turn in enumerate(conversation_data):
role = "user" if i % 2 == 0 else "assistant"
content = turn.get(content_key)
if content:
## Some tokenizer's like LLaMa 3 when applying chat template strip the user and assistant.
## The semantic region content should be in sync with the string obtained after applying chat template.
content = content.strip()
semantic_data_array.append(
{"type": role, "content": [{"text": content}]}
)
return semantic_data_array
[docs]def dpo_read_hook(
example: Dict[str, Any],
**read_hook_kwargs: Any,
) -> List[Dict[str, Any]]:
"""
Transforms data for the Direct Preference Optimization (DPO) task into a semantic data array format.
Args:
example (Dict[str, Any]): The input example data.
**read_hook_kwargs (Any): Additional keyword arguments containing data_keys.
Returns:
List[Dict[str, Any]]: Transformed data suitable for the DPO task.
Raises:
AssertionError: If required keys are not provided in read_hook_kwargs.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
prompt_key = data_keys.get("prompt_key", None)
chosen_key = data_keys.get("chosen_key", None)
rejected_key = data_keys.get("rejected_key", None)
assistant_role = read_hook_kwargs.get("assistant_role", "assistant:")
input = []
if isinstance(example, dict) and all(
isinstance(k, str) and isinstance(v, str) for k, v in example.items()
):
if prompt_key:
prompt = {}
prompt['content'] = [{"text": example.get(prompt_key, "")}]
prompt['type'] = "prompt"
chosen = {}
chosen['content'] = [{"text": example.get(chosen_key, "")}]
chosen['type'] = "chosen"
rejected = {}
rejected['content'] = [{"text": example.get(rejected_key, "")}]
rejected['type'] = "rejected"
input.append(prompt)
input.append(chosen)
input.append(rejected)
else:
chosen_str = example.get(chosen_key, "")
rejected_str = example.get(rejected_key, "")
last_assistant_index = chosen_str.lower().rfind(assistant_role)
if last_assistant_index == -1:
logger.warning(
f"Can't determine prompt from the chosen string. No demarcation found. Skipping this doc..."
)
return []
prompt_str = chosen_str[
: last_assistant_index + len(assistant_role)
]
chosen_str = chosen_str[
last_assistant_index + len(assistant_role) :
]
rejected_str = rejected_str[
last_assistant_index + len(assistant_role) :
]
prompt = {}
prompt['content'] = [{"text": prompt_str}]
prompt['type'] = "prompt"
chosen = {}
chosen['content'] = [{"text": chosen_str}]
chosen['type'] = "chosen"
rejected = {}
rejected['content'] = [{"text": rejected_str}]
rejected['type'] = "rejected"
input.append(prompt)
input.append(chosen)
input.append(rejected)
elif isinstance(example, dict) and all(
isinstance(k, str) and isinstance(v, list) for k, v in example.items()
):
chosen_list = example.get(chosen_key, None)
assert chosen_list, "chosen list must be provided"
rejected_list = example.get(rejected_key, None)
assert rejected_list, "rejected list must be provided"
# The only dataset available with list of dict has only
# prompt and response entries hence the size is assumed
# to be 2
prompt_str = chosen_list[0]['content']
chosen_str = chosen_list[1]['content']
rejected_str = rejected_list[1]['content']
prompt = {}
prompt['content'] = [{"text": prompt_str}]
prompt['type'] = "prompt"
chosen = {}
chosen['content'] = [{"text": chosen_str}]
chosen['type'] = "chosen"
rejected = {}
rejected['content'] = [{"text": rejected_str}]
rejected['type'] = "rejected"
input.append(prompt)
input.append(chosen)
input.append(rejected)
return input
[docs]def prompt_completion_chat_read_hook(
example: Dict[str, Any], **read_hook_kwargs: Any
) -> List[Dict[str, Any]]:
"""
Process prompt and completion data from a chat into a semantic_data_array format.
Args:
example (Dict[str, Any]): The example data to process.
**read_hook_kwargs: Additional keyword arguments for processing.
Returns:
List[Dict[str, Any]]: A list of dictionaries in semantic_data_array format.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
prompt_key = data_keys.get('prompt_key', None)
completion_key = data_keys.get('completion_key', None)
assert (
prompt_key is not None and completion_key is not None
), "prompt_completion_chat_read_hook requires a prompt_key and a completion_key"
return [
{
"type": "user",
"content": [
{
"text": (
example.get(prompt_key).strip()
if example.get(prompt_key)
else None
)
},
],
},
{
"type": "assistant",
"content": [
{
"text": (
example.get(completion_key).strip()
if example.get(completion_key)
else None
)
},
],
},
]
[docs]def finetuning_image_captions_hook(
example: Dict[str, Any], **read_hook_kwargs
) -> List[Dict[str, Any]]:
"""
Process finetuning image captions data into a semantic_data_array format.
Args:
example (Dict[str, Any]): The example data to process.
**read_hook_kwargs: Additional keyword arguments for processing.
Returns:
List[Dict[str, Any]]: A list of dictionaries in semantic_data_array format.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
image_key = data_keys.get('image_key', None)
caption_key = data_keys.get('caption_key', None)
assert (
image_key != None
), "pretraining_image_captions_hook requires a image_key"
if isinstance(example.get(image_key), dict):
## datasets downloaded directly from huggingface come in this format
return [
{
"type": "prompt",
"content": [
{"image": example.get(image_key).get("path")},
],
},
{
"type": "completion",
"content": [
{"text": example.get(caption_key)},
],
},
]
else:
return [
{
"type": "prompt",
"content": [
{"image": example.get(image_key)},
],
},
{
"type": "completion",
"content": [
{"text": example.get(caption_key)},
],
},
]
[docs]def finetuning_llava_hook_prompt_completion(
example: Dict[str, Any], **read_hook_kwargs: Any
) -> List[Dict[str, Any]]:
"""
Transforms conversation data for finetuning LLaVA.
Args:
example (Dict[str, Any]): The input data containing conversation and image paths.
**read_hook_kwargs (Any): Additional keyword arguments containing data_keys, system_prompt, image_token, multi_turn_content_key, and phase.
Returns:
List[Dict[str, Any]]: Transformed data suitable for finetuning LLaVA.
Raises:
AssertionError: If required keys are not provided in read_hook_kwargs.
ValueError: If image_token is not provided, or if there are multiple image tokens in the user's role, or if image tokens are found in the assistant's response.
"""
data_keys = read_hook_kwargs.get("data_keys")
assert (
data_keys != None
), "data_keys should be provided in the read_hook_kwargs section"
multi_turn_key = data_keys.get("multi_turn_key")
image_key = data_keys.get("image_key")
system_prompt = read_hook_kwargs.get("system_prompt")
image_token = read_hook_kwargs.get("image_token", None)
multi_turn_content_key = read_hook_kwargs.get(
"multi_turn_content_key", "value"
)
phase = read_hook_kwargs.get("phase", 1)
assert (
image_token != None
), "image_token should be provided in the read_hook_kwargs section for llava"
conversation_data = example.get(multi_turn_key, [])
if conversation_data is None:
conversation_data = []
image_path = example.get(image_key, None)
transformed_data = []
if system_prompt:
system_prompt_text = SYSTEM_PROMPT_REGISTRY.get(system_prompt, "")
system_data = {
"type": "system",
"content": [{"text": system_prompt_text.strip()}],
}
transformed_data.append(system_data)
if image_path and not image_token:
raise ValueError(
"Image token has not been provided inside read_hook_kwargs within the processing section in the config file for llava finetuning datasets."
)
for i, turn in enumerate(conversation_data):
semantic_drop_mask = []
role = "prompt" if i % 2 == 0 else "completion"
content_parts = []
if role == "prompt":
# Check for multiple image tokens in the user's role
if turn[multi_turn_content_key].count(image_token) > 1:
raise ValueError(
"Multiple image tokens found in user's role. Only one image token is allowed."
)
# Assume there's only one image token in the user's role
parts = re.split(
re.escape(image_token), turn[multi_turn_content_key]
)
if len(parts) == 2:
# Add the image part before the text
content_parts.append({"image": image_path})
parts = [part.strip() for part in parts]
if parts[0] != '' and parts[1] != '':
content_parts.append({"text": parts[0] + parts[1]})
if phase == 1:
semantic_drop_mask.extend([False, True])
else:
semantic_drop_mask.extend([False, False])
else:
semantic_drop_mask.extend([False])
else:
# No image token found, add the text as is
content_parts.append({"text": turn[multi_turn_content_key]})
if phase == 1:
semantic_drop_mask.extend([True])
else:
semantic_drop_mask.extend([False])
elif role == "completion":
# Check that no image tokens are present in the assistant's response
if image_token and image_token in turn[multi_turn_content_key]:
raise ValueError(
"Image tokens are not allowed in the completion's response."
)
content_parts.append({"text": turn[multi_turn_content_key]})
semantic_drop_mask.extend([False])
transformed_data.append(
{
"type": role,
"content": content_parts,
"semantic_drop_mask": semantic_drop_mask,
}
)
return transformed_data