very basic DPO data generator is working now

This commit is contained in:
Alex O'Connell
2024-04-13 22:47:07 -04:00
parent 547d2d9989
commit ce75bf0d7c
2 changed files with 72 additions and 14 deletions

View File

@@ -18,7 +18,7 @@ size_categories:
This dataset contains a list of requests and responses for a user interacting with a personal assistant that controls an instance of [Home Assistant](https://www.home-assistant.io/). This dataset contains a list of requests and responses for a user interacting with a personal assistant that controls an instance of [Home Assistant](https://www.home-assistant.io/).
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/pile_of_templated_actions.csv` and `piles/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py). The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/<language>/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/<language>/pile_of_templated_actions.csv` and `piles/<language>/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py).
## Generating the dataset from piles ## Generating the dataset from piles
@@ -27,6 +27,7 @@ The dataset is generated from the different CSV "piles". The "piles" contain dif
Supported dataset splits are `--test`, `--train`, & `--sample` Supported dataset splits are `--test`, `--train`, & `--sample`
Arguments to set the train dataset size are `--small`, `--medium`, `--large`, & `--xl`. Arguments to set the train dataset size are `--small`, `--medium`, `--large`, & `--xl`.
Supported formats are `--raw_corpus` (chatml formatted) & `--sharegpt` Supported formats are `--raw_corpus` (chatml formatted) & `--sharegpt`
Languages can be enabled using `--language english german french spanish`
## Merging with other instruct-datasets for training ## Merging with other instruct-datasets for training
@@ -48,7 +49,6 @@ There are 2 columns in `pile_of_system_prompts.csv`:
The response pile is a CSV with the following headers: `service,response,language,persona,short` The response pile is a CSV with the following headers: `service,response,language,persona,short`
- `service`: the service name that we are responding to. Make sure you cover enough different services so that the model can learn how to respond in all situations. - `service`: the service name that we are responding to. Make sure you cover enough different services so that the model can learn how to respond in all situations.
- `response`: the text of the response. Recommended to put this in quotes in case the response also has commas in it - `response`: the text of the response. Recommended to put this in quotes in case the response also has commas in it
- `language`: the language code of the response (currently only `en` is supported)
- `persona`: the name of the persona the response belongs to. Use the name of your persona here - `persona`: the name of the persona the response belongs to. Use the name of your persona here
- `short`: either 0 or 1. If it is 1 then the response is considered "short', and can be combined together with other "short" responses using "and". These are used for examples where there are multiple service calls - `short`: either 0 or 1. If it is 1 then the response is considered "short', and can be combined together with other "short" responses using "and". These are used for examples where there are multiple service calls

View File

@@ -5,6 +5,7 @@ import pandas
import numpy as np import numpy as np
import random import random
import re import re
import copy
from dataclasses import dataclass from dataclasses import dataclass
from datasets import load_dataset, concatenate_datasets from datasets import load_dataset, concatenate_datasets
from difflib import SequenceMatcher from difflib import SequenceMatcher
@@ -686,16 +687,18 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32)
def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32): def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32):
example = generate_templated_example(template, persona, max_devices) example = generate_templated_example(template, persona, max_devices)
rejected_example = {**example} rejected_example = copy.deepcopy(example)
call_idx = random.randint(0, len(example["service_calls"])) call_idx = random.randint(0, len(example["service_calls"]) - 1)
call = example["service_calls"][call_idx] call = example["service_calls"][call_idx]
random_device = random.choice(example["states"]).split(" ")[0] random_device = random.choice(example["states"]).split(" ")[0]
if random_device == call["target_device"]: # TODO: random device type should probably match
while random_device == call["target_device"]:
random_device = random.choice(example["states"]).split(" ")[0] random_device = random.choice(example["states"]).split(" ")[0]
random_service = random.choice(example["available_services"] - [ call["service"] ]) # random service should probably be "related"
random_service = random.choice([ x for x in example["available_services"] if call["service"] not in x ])[:-2]
# need examples of hallucinated names, not incorrect names # need examples of hallucinated names, not incorrect names
# need examples of hallucinated service names, not incorrect service # need examples of hallucinated service names, not incorrect service
@@ -717,6 +720,9 @@ def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int =
def generate_dpo_no_service_call(): def generate_dpo_no_service_call():
pass pass
def generate_dpo_extra_service_call():
pass
def generate_dpo_incorrect_persona(): def generate_dpo_incorrect_persona():
pass pass
@@ -797,6 +803,7 @@ def format_example_dpo(example, persona):
# replace aliases with their actual values # replace aliases with their actual values
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.") assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
rejected_assistant_block = rejected_assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.") states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.") services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
@@ -807,7 +814,7 @@ def format_example_dpo(example, persona):
"rejected": rejected_assistant_block, "rejected": rejected_assistant_block,
} }
def generate_example_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int): def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
@@ -853,6 +860,53 @@ def generate_example_file(filename: str, seed: int, format_func: Callable, perso
print("Done!") print("Done!")
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
random.seed(seed)
np.random.seed(seed)
print("Generating...")
def run_factor_times(func, examples, data, persona, factor):
if factor >= 1:
for i in range(factor):
examples.append(format_func(func(data, persona), persona))
else:
if random.random() < factor:
examples.append(format_func(func(data, persona), persona))
generated_examples = []
missing_responses = set()
for person in personas:
# for action in tqdm(pile_of_specific_actions):
# try:
# run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
# except NoResponseAvailableException as ex:
# missing_responses.add(str(ex))
for templated_action in tqdm(pile_of_templated_actions):
try:
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
except NoResponseAvailableException as ex:
missing_responses.add(str(ex))
# for status_request in tqdm(pile_of_status_requests):
# run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
print(f"Generated {len(generated_examples)} DPO examples. Saving...")
for missing in sorted(missing_responses):
print(missing)
with open(f"{filename}.jsonl", "w") as f:
for item in generated_examples:
json_record = json.dumps(item)
f.write(json_record + '\n')
print("Done!")
def format_alpaca(example, format_func: Callable): def format_alpaca(example, format_func: Callable):
question = example["instruction"] question = example["instruction"]
if "input" in example and example["input"]: if "input" in example and example["input"]:
@@ -973,6 +1027,7 @@ def main():
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.") parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..") parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..")
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.") parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--dpo", action="store_true", help="Set this flag to enable generation of the DPO dataset.")
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.") parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate") parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate")
@@ -988,7 +1043,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if not args.sample and not args.train and not args.test and not args.merge: if not args.sample and not args.train and not args.test and not args.merge and not args.dpo:
parser.print_usage() parser.print_usage()
if not args.format or args.format == "raw": if not args.format or args.format == "raw":
@@ -1002,20 +1057,20 @@ def main():
suffix = f"_{language}" if len(args.language) > 1 else "" suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample: if args.sample:
generate_example_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1) generate_sft_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
if args.train: if args.train:
if args.size == "small": if args.size == "small":
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8) generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
elif args.size == "medium": elif args.size == "medium":
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12) generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
elif args.size == "large": elif args.size == "large":
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15) generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
elif args.size == "xl": elif args.size == "xl":
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18) generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
else: else:
raise Exception(f"Unrecognized dataset size: {args.size}") raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test: if args.test:
generate_example_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2) generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
if len(args.language) > 1: if len(args.language) > 1:
if args.sample: if args.sample:
@@ -1025,6 +1080,9 @@ def main():
if args.test: if args.test:
merge_languages("home_assistant_test", args.language) merge_languages("home_assistant_test", args.language)
if args.dpo:
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=0, extra_service_call_factor=0, incorrect_persona_factor=0)
if args.merge == "alpaca": if args.merge == "alpaca":
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func) merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
elif args.merge == "wizardlm70k": elif args.merge == "wizardlm70k":