From f1659893d7ff10697bcc64e139f2fb2baf7c37df Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Tue, 19 Mar 2024 21:31:34 -0400 Subject: [PATCH] start working on dpo for the datasets --- TODO.md | 5 +- data/generate_home_assistant_data.py | 82 ++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 7 deletions(-) diff --git a/TODO.md b/TODO.md index 5c91a37..5732c10 100644 --- a/TODO.md +++ b/TODO.md @@ -1,5 +1,6 @@ # TODO -- [ ] setup github actions to build wheels that are optimized for RPIs +- [ ] setup github actions to build wheels that are optimized for RPIs?? +- [ ] setup github actions to publish docker images for text-gen-webui addon - [ ] detection/mitigation of too many entities being exposed & blowing out the context length - [ ] areas/room support - [ ] figure out DPO for refusals + fixing incorrect entity id @@ -7,6 +8,8 @@ - add in context learning variables to sys prompt template - add new options to setup process for setting prompt style + picking fine-tuned/ICL - [ ] prime kv cache with current "state" so that requests are faster +- [ ] support fine-tuning with RoPE for longer contexts +- [ ] support config via yaml instead of configflow - [x] ChatML format (actually need to add special tokens) - [x] Vicuna dataset merge (yahma/alpaca-cleaned) - [x] Phi-2 fine tuning diff --git a/data/generate_home_assistant_data.py b/data/generate_home_assistant_data.py index a83bc04..02225f0 100644 --- a/data/generate_home_assistant_data.py +++ b/data/generate_home_assistant_data.py @@ -652,9 +652,6 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev # build a random list of devices device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ]) - # insert our target device somewhere random in the list - index = random.randint(0, len(device_list)) - # generate the question question = question_template.replace("", chosen_device["description"]) answer = answer_template.replace("", chosen_device["description"]) @@ -711,7 +708,13 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev answer = answer.replace("", remaining) state_name = state_name.replace("", remaining) - device_list.insert(index, f"{chosen_device['device_name']} = {state_name}") + # insert our target device somewhere random in the list + index = random.randint(0, len(device_list)) + device_list.insert(index, format_device_line( + device_name=chosen_device["device_name"], + friendly_name=chosen_device["description"], + state=state_name + )) # gather a list of all available services available_services = [] @@ -726,6 +729,42 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev "service_calls": [] } +def generate_dpo_wrong_argument(template: dict, language: str, persona: str, max_devices: int = 32): + example = generate_templated_example(template, language, persona, max_devices) + rejected_example = {**example} + + call_idx = random.randint(0, len(example["service_calls"])) + call = example["service_calls"][call_idx] + + random_device = random.choice(example["states"]).split(" ")[0] + if random_device == call["target_device"]: + random_device = random.choice(example["states"]).split(" ")[0] + + random_service = random.choice(example["available_services"] - [ call["service"] ]) + + # need examples of hallucinated names, not incorrect names + # need examples of hallucinated service names, not incorrect service + # should make a csv that maps "real name" to "hallucinated name" + + update_dict = random.choice([ + { "service": random_service }, + { "target_device": random_device }, + ]) + + # need to replace the response text with what the incorrect response would have been + if len(rejected_example["service_calls"]) == 1: + pass + + rejected_example["service_calls"][call_idx].update(update_dict) + + return { "accepted": example, "rejected": rejected_example } + +def generate_dpo_no_service_call(): + pass + +def generate_dpo_incorrect_persona(): + pass + def format_example_raw_chatml(example, persona): """Don't use this one anymore""" sys_prompt = pile_of_system_prompts[persona] @@ -780,6 +819,38 @@ def format_example_sharegpt(example, persona): return { "conversations": conversation } +def format_example_dpo(example, persona): + rejected_example = example["rejected"] + example = example["accepted"] + + sys_prompt = pile_of_system_prompts[persona] + services_block = "Services: " + ", ".join(sorted(example["available_services"])) + states_block = "Devices:\n" + "\n".join(example["states"]) + question = example["question"] + + assistant_block = " ".join(example["answers"]) + if len(example["service_calls"]) > 0: + json_calls = [ json.dumps(x) for x in example["service_calls"] ] + code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```" + assistant_block = assistant_block + code_block + + rejected_assistant_block = " ".join(rejected_example["answers"]) + if len(rejected_example["service_calls"]) > 0: + json_calls = [ json.dumps(x) for x in rejected_example["service_calls"] ] + code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```" + rejected_assistant_block = rejected_assistant_block + code_block + + # replace aliases with their actual values + assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.") + states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.") + services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.") + + return { + "system": "\n".join([ sys_prompt, services_block, states_block ]), + "question": question, + "chosen": assistant_block, + "rejected": rejected_assistant_block, + } def generate_example_file(filename: str, seed: int, format_func: Callable, languages: list[str], personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int): random.seed(seed) @@ -871,12 +942,11 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset combined_dataset_train.to_json(f"home_assistant_{output_name}_merged_train.jsonl") combined_dataset_test.to_json(f"home_assistant_{output_name}_merged_test.jsonl") - # TODO: add examples for ambiguous requests. asking a clarifying question +# TODO: support rejection when asking to do a service that isn't exposed # TODO: make more randomized names for devices (random words or people's names) # TODO: answer questions about more than one thing in the state list at once # TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen" -# TODO: add personas for responses. different system prompts should invoke different response tones (pirate, robot, and mean) # TODO: add time, weather, and calendar/reminders (next 3 events?) def main(): parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles")