mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
very basic DPO data generator is working now
This commit is contained in:
@@ -18,7 +18,7 @@ size_categories:
|
|||||||
|
|
||||||
This dataset contains a list of requests and responses for a user interacting with a personal assistant that controls an instance of [Home Assistant](https://www.home-assistant.io/).
|
This dataset contains a list of requests and responses for a user interacting with a personal assistant that controls an instance of [Home Assistant](https://www.home-assistant.io/).
|
||||||
|
|
||||||
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/pile_of_templated_actions.csv` and `piles/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py).
|
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/<language>/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/<language>/pile_of_templated_actions.csv` and `piles/<language>/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py).
|
||||||
|
|
||||||
## Generating the dataset from piles
|
## Generating the dataset from piles
|
||||||
|
|
||||||
@@ -27,6 +27,7 @@ The dataset is generated from the different CSV "piles". The "piles" contain dif
|
|||||||
Supported dataset splits are `--test`, `--train`, & `--sample`
|
Supported dataset splits are `--test`, `--train`, & `--sample`
|
||||||
Arguments to set the train dataset size are `--small`, `--medium`, `--large`, & `--xl`.
|
Arguments to set the train dataset size are `--small`, `--medium`, `--large`, & `--xl`.
|
||||||
Supported formats are `--raw_corpus` (chatml formatted) & `--sharegpt`
|
Supported formats are `--raw_corpus` (chatml formatted) & `--sharegpt`
|
||||||
|
Languages can be enabled using `--language english german french spanish`
|
||||||
|
|
||||||
## Merging with other instruct-datasets for training
|
## Merging with other instruct-datasets for training
|
||||||
|
|
||||||
@@ -48,7 +49,6 @@ There are 2 columns in `pile_of_system_prompts.csv`:
|
|||||||
The response pile is a CSV with the following headers: `service,response,language,persona,short`
|
The response pile is a CSV with the following headers: `service,response,language,persona,short`
|
||||||
- `service`: the service name that we are responding to. Make sure you cover enough different services so that the model can learn how to respond in all situations.
|
- `service`: the service name that we are responding to. Make sure you cover enough different services so that the model can learn how to respond in all situations.
|
||||||
- `response`: the text of the response. Recommended to put this in quotes in case the response also has commas in it
|
- `response`: the text of the response. Recommended to put this in quotes in case the response also has commas in it
|
||||||
- `language`: the language code of the response (currently only `en` is supported)
|
|
||||||
- `persona`: the name of the persona the response belongs to. Use the name of your persona here
|
- `persona`: the name of the persona the response belongs to. Use the name of your persona here
|
||||||
- `short`: either 0 or 1. If it is 1 then the response is considered "short', and can be combined together with other "short" responses using "and". These are used for examples where there are multiple service calls
|
- `short`: either 0 or 1. If it is 1 then the response is considered "short', and can be combined together with other "short" responses using "and". These are used for examples where there are multiple service calls
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import pandas
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datasets import load_dataset, concatenate_datasets
|
from datasets import load_dataset, concatenate_datasets
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
@@ -686,16 +687,18 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32)
|
|||||||
|
|
||||||
def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32):
|
def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32):
|
||||||
example = generate_templated_example(template, persona, max_devices)
|
example = generate_templated_example(template, persona, max_devices)
|
||||||
rejected_example = {**example}
|
rejected_example = copy.deepcopy(example)
|
||||||
|
|
||||||
call_idx = random.randint(0, len(example["service_calls"]))
|
call_idx = random.randint(0, len(example["service_calls"]) - 1)
|
||||||
call = example["service_calls"][call_idx]
|
call = example["service_calls"][call_idx]
|
||||||
|
|
||||||
random_device = random.choice(example["states"]).split(" ")[0]
|
random_device = random.choice(example["states"]).split(" ")[0]
|
||||||
if random_device == call["target_device"]:
|
# TODO: random device type should probably match
|
||||||
|
while random_device == call["target_device"]:
|
||||||
random_device = random.choice(example["states"]).split(" ")[0]
|
random_device = random.choice(example["states"]).split(" ")[0]
|
||||||
|
|
||||||
random_service = random.choice(example["available_services"] - [ call["service"] ])
|
# random service should probably be "related"
|
||||||
|
random_service = random.choice([ x for x in example["available_services"] if call["service"] not in x ])[:-2]
|
||||||
|
|
||||||
# need examples of hallucinated names, not incorrect names
|
# need examples of hallucinated names, not incorrect names
|
||||||
# need examples of hallucinated service names, not incorrect service
|
# need examples of hallucinated service names, not incorrect service
|
||||||
@@ -717,6 +720,9 @@ def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int =
|
|||||||
def generate_dpo_no_service_call():
|
def generate_dpo_no_service_call():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def generate_dpo_extra_service_call():
|
||||||
|
pass
|
||||||
|
|
||||||
def generate_dpo_incorrect_persona():
|
def generate_dpo_incorrect_persona():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -797,6 +803,7 @@ def format_example_dpo(example, persona):
|
|||||||
|
|
||||||
# replace aliases with their actual values
|
# replace aliases with their actual values
|
||||||
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||||
|
rejected_assistant_block = rejected_assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||||
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||||
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||||
|
|
||||||
@@ -807,7 +814,7 @@ def format_example_dpo(example, persona):
|
|||||||
"rejected": rejected_assistant_block,
|
"rejected": rejected_assistant_block,
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate_example_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
|
def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
@@ -853,6 +860,53 @@ def generate_example_file(filename: str, seed: int, format_func: Callable, perso
|
|||||||
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
print("Generating...")
|
||||||
|
|
||||||
|
def run_factor_times(func, examples, data, persona, factor):
|
||||||
|
if factor >= 1:
|
||||||
|
for i in range(factor):
|
||||||
|
examples.append(format_func(func(data, persona), persona))
|
||||||
|
else:
|
||||||
|
if random.random() < factor:
|
||||||
|
examples.append(format_func(func(data, persona), persona))
|
||||||
|
|
||||||
|
generated_examples = []
|
||||||
|
|
||||||
|
missing_responses = set()
|
||||||
|
|
||||||
|
for person in personas:
|
||||||
|
# for action in tqdm(pile_of_specific_actions):
|
||||||
|
# try:
|
||||||
|
# run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
|
||||||
|
# except NoResponseAvailableException as ex:
|
||||||
|
# missing_responses.add(str(ex))
|
||||||
|
|
||||||
|
for templated_action in tqdm(pile_of_templated_actions):
|
||||||
|
try:
|
||||||
|
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
|
||||||
|
except NoResponseAvailableException as ex:
|
||||||
|
missing_responses.add(str(ex))
|
||||||
|
|
||||||
|
# for status_request in tqdm(pile_of_status_requests):
|
||||||
|
# run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
|
||||||
|
|
||||||
|
print(f"Generated {len(generated_examples)} DPO examples. Saving...")
|
||||||
|
|
||||||
|
for missing in sorted(missing_responses):
|
||||||
|
print(missing)
|
||||||
|
|
||||||
|
with open(f"{filename}.jsonl", "w") as f:
|
||||||
|
for item in generated_examples:
|
||||||
|
json_record = json.dumps(item)
|
||||||
|
f.write(json_record + '\n')
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
def format_alpaca(example, format_func: Callable):
|
def format_alpaca(example, format_func: Callable):
|
||||||
question = example["instruction"]
|
question = example["instruction"]
|
||||||
if "input" in example and example["input"]:
|
if "input" in example and example["input"]:
|
||||||
@@ -973,6 +1027,7 @@ def main():
|
|||||||
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||||
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..")
|
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..")
|
||||||
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||||
|
parser.add_argument("--dpo", action="store_true", help="Set this flag to enable generation of the DPO dataset.")
|
||||||
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
|
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
|
||||||
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate")
|
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate")
|
||||||
|
|
||||||
@@ -988,7 +1043,7 @@ def main():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.sample and not args.train and not args.test and not args.merge:
|
if not args.sample and not args.train and not args.test and not args.merge and not args.dpo:
|
||||||
parser.print_usage()
|
parser.print_usage()
|
||||||
|
|
||||||
if not args.format or args.format == "raw":
|
if not args.format or args.format == "raw":
|
||||||
@@ -1002,20 +1057,20 @@ def main():
|
|||||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||||
|
|
||||||
if args.sample:
|
if args.sample:
|
||||||
generate_example_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
|
generate_sft_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
|
||||||
if args.train:
|
if args.train:
|
||||||
if args.size == "small":
|
if args.size == "small":
|
||||||
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
|
||||||
elif args.size == "medium":
|
elif args.size == "medium":
|
||||||
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
|
||||||
elif args.size == "large":
|
elif args.size == "large":
|
||||||
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
|
||||||
elif args.size == "xl":
|
elif args.size == "xl":
|
||||||
generate_example_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||||
if args.test:
|
if args.test:
|
||||||
generate_example_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
|
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||||
|
|
||||||
if len(args.language) > 1:
|
if len(args.language) > 1:
|
||||||
if args.sample:
|
if args.sample:
|
||||||
@@ -1025,6 +1080,9 @@ def main():
|
|||||||
if args.test:
|
if args.test:
|
||||||
merge_languages("home_assistant_test", args.language)
|
merge_languages("home_assistant_test", args.language)
|
||||||
|
|
||||||
|
if args.dpo:
|
||||||
|
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=0, extra_service_call_factor=0, incorrect_persona_factor=0)
|
||||||
|
|
||||||
if args.merge == "alpaca":
|
if args.merge == "alpaca":
|
||||||
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
|
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
|
||||||
elif args.merge == "wizardlm70k":
|
elif args.merge == "wizardlm70k":
|
||||||
|
|||||||
Reference in New Issue
Block a user