very basic DPO data generator is working now

This commit is contained in:
Alex O'Connell
2024-04-13 22:47:07 -04:00
parent 547d2d9989
commit ce75bf0d7c
2 changed files with 72 additions and 14 deletions

View File

@@ -18,7 +18,7 @@ size_categories:
This dataset contains a list of requests and responses for a user interacting with a personal assistant that controls an instance of [Home Assistant](https://www.home-assistant.io/).
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/pile_of_templated_actions.csv` and `piles/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py).
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/<language>/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/<language>/pile_of_templated_actions.csv` and `piles/<language>/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py).
## Generating the dataset from piles
@@ -27,6 +27,7 @@ The dataset is generated from the different CSV "piles". The "piles" contain dif
Supported dataset splits are `--test`, `--train`, & `--sample`
Arguments to set the train dataset size are `--small`, `--medium`, `--large`, & `--xl`.
Supported formats are `--raw_corpus` (chatml formatted) & `--sharegpt`
Languages can be enabled using `--language english german french spanish`
## Merging with other instruct-datasets for training
@@ -48,7 +49,6 @@ There are 2 columns in `pile_of_system_prompts.csv`:
The response pile is a CSV with the following headers: `service,response,language,persona,short`
- `service`: the service name that we are responding to. Make sure you cover enough different services so that the model can learn how to respond in all situations.
- `response`: the text of the response. Recommended to put this in quotes in case the response also has commas in it
- `language`: the language code of the response (currently only `en` is supported)
- `persona`: the name of the persona the response belongs to. Use the name of your persona here
- `short`: either 0 or 1. If it is 1 then the response is considered "short', and can be combined together with other "short" responses using "and". These are used for examples where there are multiple service calls

View File

@@ -5,6 +5,7 @@ import pandas
import numpy as np
import random
import re
import copy
from dataclasses import dataclass
from datasets import load_dataset, concatenate_datasets
from difflib import SequenceMatcher
@@ -686,16 +687,18 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32)
def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32):
example = generate_templated_example(template, persona, max_devices)
rejected_example = {**example}
rejected_example = copy.deepcopy(example)
call_idx = random.randint(0, len(example["service_calls"]))
call_idx = random.randint(0, len(example["service_calls"]) - 1)
call = example["service_calls"][call_idx]
random_device = random.choice(example["states"]).split(" ")[0]
if random_device == call["target_device"]:
# TODO: random device type should probably match
while random_device == call["target_device"]:
random_device = random.choice(example["states"]).split(" ")[0]
random_service = random.choice(example["available_services"] - [ call["service"] ])
# random service should probably be "related"
random_service = random.choice([ x for x in example["available_services"] if call["service"] not in x ])[:-2]
# need examples of hallucinated names, not incorrect names
# need examples of hallucinated service names, not incorrect service
@@ -717,6 +720,9 @@ def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int =
def generate_dpo_no_service_call():
pass
def generate_dpo_extra_service_call():
pass
def generate_dpo_incorrect_persona():
pass
@@ -797,6 +803,7 @@ def format_example_dpo(example, persona):
# replace aliases with their actual values
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
rejected_assistant_block = rejected_assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
@@ -807,7 +814,7 @@ def format_example_dpo(example, persona):
"rejected": rejected_assistant_block,
}
def generate_example_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
random.seed(seed)
np.random.seed(seed)
@@ -853,6 +860,53 @@ def generate_example_file(filename: str, seed: int, format_func: Callable, perso
print("Done!")
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
random.seed(seed)
np.random.seed(seed)
print("Generating...")
def run_factor_times(func, examples, data, persona, factor):
if factor >= 1:
for i in range(factor):
examples.append(format_func(func(data, persona), persona))
else:
if random.random() < factor:
examples.append(format_func(func(data, persona), persona))
generated_examples = []
missing_responses = set()
for person in personas:
# for action in tqdm(pile_of_specific_actions):
# try:
# run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
# except NoResponseAvailableException as ex:
# missing_responses.add(str(ex))
for templated_action in tqdm(pile_of_templated_actions):
try:
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
except NoResponseAvailableException as ex:
missing_responses.add(str(ex))
# for status_request in tqdm(pile_of_status_requests):
# run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
print(f"Generated {len(generated_examples)} DPO examples. Saving...")
for missing in sorted(missing_responses):
print(missing)
with open(f"{filename}.jsonl", "w") as f:
for item in generated_examples:
json_record = json.dumps(item)
f.write(json_record + '\n')
print("Done!")
def format_alpaca(example, format_func: Callable):
question = example["instruction"]
if "input" in example and example["input"]:
@@ -973,6 +1027,7 @@ def main():
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..")
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--dpo", action="store_true", help="Set this flag to enable generation of the DPO dataset.")
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate")
@@ -988,7 +1043,7 @@ def main():
args = parser.parse_args()
if not args.sample and not args.train and not args.test and not args.merge:
if not args.sample and not args.train and not args.test and not args.merge and not args.dpo:
parser.print_usage()
if not args.format or args.format == "raw":
@@ -1002,20 +1057,20 @@ def main():
suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample:
generate_example_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
generate_sft_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
if args.train:
if args.size == "small":
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
elif args.size == "medium":
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
elif args.size == "large":
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
elif args.size == "xl":
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
else:
raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test:
generate_example_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
if len(args.language) > 1:
if args.sample:
@@ -1025,6 +1080,9 @@ def main():
if args.test:
merge_languages("home_assistant_test", args.language)
if args.dpo:
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=0, extra_service_call_factor=0, incorrect_persona_factor=0)
if args.merge == "alpaca":
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
elif args.merge == "wizardlm70k":