mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
very basic DPO data generator is working now
This commit is contained in:
@@ -18,7 +18,7 @@ size_categories:
|
||||
|
||||
This dataset contains a list of requests and responses for a user interacting with a personal assistant that controls an instance of [Home Assistant](https://www.home-assistant.io/).
|
||||
|
||||
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/pile_of_templated_actions.csv` and `piles/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py).
|
||||
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/<language>/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/<language>/pile_of_templated_actions.csv` and `piles/<language>/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py).
|
||||
|
||||
## Generating the dataset from piles
|
||||
|
||||
@@ -27,6 +27,7 @@ The dataset is generated from the different CSV "piles". The "piles" contain dif
|
||||
Supported dataset splits are `--test`, `--train`, & `--sample`
|
||||
Arguments to set the train dataset size are `--small`, `--medium`, `--large`, & `--xl`.
|
||||
Supported formats are `--raw_corpus` (chatml formatted) & `--sharegpt`
|
||||
Languages can be enabled using `--language english german french spanish`
|
||||
|
||||
## Merging with other instruct-datasets for training
|
||||
|
||||
@@ -48,7 +49,6 @@ There are 2 columns in `pile_of_system_prompts.csv`:
|
||||
The response pile is a CSV with the following headers: `service,response,language,persona,short`
|
||||
- `service`: the service name that we are responding to. Make sure you cover enough different services so that the model can learn how to respond in all situations.
|
||||
- `response`: the text of the response. Recommended to put this in quotes in case the response also has commas in it
|
||||
- `language`: the language code of the response (currently only `en` is supported)
|
||||
- `persona`: the name of the persona the response belongs to. Use the name of your persona here
|
||||
- `short`: either 0 or 1. If it is 1 then the response is considered "short', and can be combined together with other "short" responses using "and". These are used for examples where there are multiple service calls
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import pandas
|
||||
import numpy as np
|
||||
import random
|
||||
import re
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
from difflib import SequenceMatcher
|
||||
@@ -686,16 +687,18 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32)
|
||||
|
||||
def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32):
|
||||
example = generate_templated_example(template, persona, max_devices)
|
||||
rejected_example = {**example}
|
||||
rejected_example = copy.deepcopy(example)
|
||||
|
||||
call_idx = random.randint(0, len(example["service_calls"]))
|
||||
call_idx = random.randint(0, len(example["service_calls"]) - 1)
|
||||
call = example["service_calls"][call_idx]
|
||||
|
||||
random_device = random.choice(example["states"]).split(" ")[0]
|
||||
if random_device == call["target_device"]:
|
||||
# TODO: random device type should probably match
|
||||
while random_device == call["target_device"]:
|
||||
random_device = random.choice(example["states"]).split(" ")[0]
|
||||
|
||||
random_service = random.choice(example["available_services"] - [ call["service"] ])
|
||||
# random service should probably be "related"
|
||||
random_service = random.choice([ x for x in example["available_services"] if call["service"] not in x ])[:-2]
|
||||
|
||||
# need examples of hallucinated names, not incorrect names
|
||||
# need examples of hallucinated service names, not incorrect service
|
||||
@@ -717,6 +720,9 @@ def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int =
|
||||
def generate_dpo_no_service_call():
|
||||
pass
|
||||
|
||||
def generate_dpo_extra_service_call():
|
||||
pass
|
||||
|
||||
def generate_dpo_incorrect_persona():
|
||||
pass
|
||||
|
||||
@@ -797,6 +803,7 @@ def format_example_dpo(example, persona):
|
||||
|
||||
# replace aliases with their actual values
|
||||
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
rejected_assistant_block = rejected_assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
|
||||
@@ -807,7 +814,7 @@ def format_example_dpo(example, persona):
|
||||
"rejected": rejected_assistant_block,
|
||||
}
|
||||
|
||||
def generate_example_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
|
||||
def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
@@ -853,6 +860,53 @@ def generate_example_file(filename: str, seed: int, format_func: Callable, perso
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
print("Generating...")
|
||||
|
||||
def run_factor_times(func, examples, data, persona, factor):
|
||||
if factor >= 1:
|
||||
for i in range(factor):
|
||||
examples.append(format_func(func(data, persona), persona))
|
||||
else:
|
||||
if random.random() < factor:
|
||||
examples.append(format_func(func(data, persona), persona))
|
||||
|
||||
generated_examples = []
|
||||
|
||||
missing_responses = set()
|
||||
|
||||
for person in personas:
|
||||
# for action in tqdm(pile_of_specific_actions):
|
||||
# try:
|
||||
# run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
|
||||
# except NoResponseAvailableException as ex:
|
||||
# missing_responses.add(str(ex))
|
||||
|
||||
for templated_action in tqdm(pile_of_templated_actions):
|
||||
try:
|
||||
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
# for status_request in tqdm(pile_of_status_requests):
|
||||
# run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
|
||||
|
||||
print(f"Generated {len(generated_examples)} DPO examples. Saving...")
|
||||
|
||||
for missing in sorted(missing_responses):
|
||||
print(missing)
|
||||
|
||||
with open(f"{filename}.jsonl", "w") as f:
|
||||
for item in generated_examples:
|
||||
json_record = json.dumps(item)
|
||||
f.write(json_record + '\n')
|
||||
|
||||
print("Done!")
|
||||
|
||||
def format_alpaca(example, format_func: Callable):
|
||||
question = example["instruction"]
|
||||
if "input" in example and example["input"]:
|
||||
@@ -973,6 +1027,7 @@ def main():
|
||||
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..")
|
||||
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--dpo", action="store_true", help="Set this flag to enable generation of the DPO dataset.")
|
||||
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
|
||||
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate")
|
||||
|
||||
@@ -988,7 +1043,7 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.sample and not args.train and not args.test and not args.merge:
|
||||
if not args.sample and not args.train and not args.test and not args.merge and not args.dpo:
|
||||
parser.print_usage()
|
||||
|
||||
if not args.format or args.format == "raw":
|
||||
@@ -1002,20 +1057,20 @@ def main():
|
||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||
|
||||
if args.sample:
|
||||
generate_example_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
if args.train:
|
||||
if args.size == "small":
|
||||
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
elif args.size == "medium":
|
||||
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
elif args.size == "large":
|
||||
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
elif args.size == "xl":
|
||||
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
else:
|
||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||
if args.test:
|
||||
generate_example_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
|
||||
if len(args.language) > 1:
|
||||
if args.sample:
|
||||
@@ -1025,6 +1080,9 @@ def main():
|
||||
if args.test:
|
||||
merge_languages("home_assistant_test", args.language)
|
||||
|
||||
if args.dpo:
|
||||
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=0, extra_service_call_factor=0, incorrect_persona_factor=0)
|
||||
|
||||
if args.merge == "alpaca":
|
||||
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
|
||||
elif args.merge == "wizardlm70k":
|
||||
|
||||
Reference in New Issue
Block a user