start working on dpo for the datasets

This commit is contained in:
Alex O'Connell
2024-03-19 21:31:34 -04:00
parent b9d394f860
commit f1659893d7
2 changed files with 80 additions and 7 deletions

View File

@@ -652,9 +652,6 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev
# build a random list of devices
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
# insert our target device somewhere random in the list
index = random.randint(0, len(device_list))
# generate the question
question = question_template.replace("<device_name>", chosen_device["description"])
answer = answer_template.replace("<device_name>", chosen_device["description"])
@@ -711,7 +708,13 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev
answer = answer.replace("<remaining>", remaining)
state_name = state_name.replace("<remaining>", remaining)
device_list.insert(index, f"{chosen_device['device_name']} = {state_name}")
# insert our target device somewhere random in the list
index = random.randint(0, len(device_list))
device_list.insert(index, format_device_line(
device_name=chosen_device["device_name"],
friendly_name=chosen_device["description"],
state=state_name
))
# gather a list of all available services
available_services = []
@@ -726,6 +729,42 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev
"service_calls": []
}
def generate_dpo_wrong_argument(template: dict, language: str, persona: str, max_devices: int = 32):
example = generate_templated_example(template, language, persona, max_devices)
rejected_example = {**example}
call_idx = random.randint(0, len(example["service_calls"]))
call = example["service_calls"][call_idx]
random_device = random.choice(example["states"]).split(" ")[0]
if random_device == call["target_device"]:
random_device = random.choice(example["states"]).split(" ")[0]
random_service = random.choice(example["available_services"] - [ call["service"] ])
# need examples of hallucinated names, not incorrect names
# need examples of hallucinated service names, not incorrect service
# should make a csv that maps "real name" to "hallucinated name"
update_dict = random.choice([
{ "service": random_service },
{ "target_device": random_device },
])
# need to replace the response text with what the incorrect response would have been
if len(rejected_example["service_calls"]) == 1:
pass
rejected_example["service_calls"][call_idx].update(update_dict)
return { "accepted": example, "rejected": rejected_example }
def generate_dpo_no_service_call():
pass
def generate_dpo_incorrect_persona():
pass
def format_example_raw_chatml(example, persona):
"""Don't use this one anymore"""
sys_prompt = pile_of_system_prompts[persona]
@@ -780,6 +819,38 @@ def format_example_sharegpt(example, persona):
return { "conversations": conversation }
def format_example_dpo(example, persona):
rejected_example = example["rejected"]
example = example["accepted"]
sys_prompt = pile_of_system_prompts[persona]
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
states_block = "Devices:\n" + "\n".join(example["states"])
question = example["question"]
assistant_block = " ".join(example["answers"])
if len(example["service_calls"]) > 0:
json_calls = [ json.dumps(x) for x in example["service_calls"] ]
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
assistant_block = assistant_block + code_block
rejected_assistant_block = " ".join(rejected_example["answers"])
if len(rejected_example["service_calls"]) > 0:
json_calls = [ json.dumps(x) for x in rejected_example["service_calls"] ]
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
rejected_assistant_block = rejected_assistant_block + code_block
# replace aliases with their actual values
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
return {
"system": "\n".join([ sys_prompt, services_block, states_block ]),
"question": question,
"chosen": assistant_block,
"rejected": rejected_assistant_block,
}
def generate_example_file(filename: str, seed: int, format_func: Callable, languages: list[str], personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
random.seed(seed)
@@ -871,12 +942,11 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset
combined_dataset_train.to_json(f"home_assistant_{output_name}_merged_train.jsonl")
combined_dataset_test.to_json(f"home_assistant_{output_name}_merged_test.jsonl")
# TODO: add examples for ambiguous requests. asking a clarifying question
# TODO: support rejection when asking to do a service that isn't exposed
# TODO: make more randomized names for devices (random words or people's names)
# TODO: answer questions about more than one thing in the state list at once
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
# TODO: add personas for responses. different system prompts should invoke different response tones (pirate, robot, and mean)
# TODO: add time, weather, and calendar/reminders (next 3 events?)
def main():
parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles")