mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
start working on dpo for the datasets
This commit is contained in:
@@ -652,9 +652,6 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev
|
||||
# build a random list of devices
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
index = random.randint(0, len(device_list))
|
||||
|
||||
# generate the question
|
||||
question = question_template.replace("<device_name>", chosen_device["description"])
|
||||
answer = answer_template.replace("<device_name>", chosen_device["description"])
|
||||
@@ -711,7 +708,13 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev
|
||||
answer = answer.replace("<remaining>", remaining)
|
||||
state_name = state_name.replace("<remaining>", remaining)
|
||||
|
||||
device_list.insert(index, f"{chosen_device['device_name']} = {state_name}")
|
||||
# insert our target device somewhere random in the list
|
||||
index = random.randint(0, len(device_list))
|
||||
device_list.insert(index, format_device_line(
|
||||
device_name=chosen_device["device_name"],
|
||||
friendly_name=chosen_device["description"],
|
||||
state=state_name
|
||||
))
|
||||
|
||||
# gather a list of all available services
|
||||
available_services = []
|
||||
@@ -726,6 +729,42 @@ def generate_status_request(template: dict, language: str, persona: str, max_dev
|
||||
"service_calls": []
|
||||
}
|
||||
|
||||
def generate_dpo_wrong_argument(template: dict, language: str, persona: str, max_devices: int = 32):
|
||||
example = generate_templated_example(template, language, persona, max_devices)
|
||||
rejected_example = {**example}
|
||||
|
||||
call_idx = random.randint(0, len(example["service_calls"]))
|
||||
call = example["service_calls"][call_idx]
|
||||
|
||||
random_device = random.choice(example["states"]).split(" ")[0]
|
||||
if random_device == call["target_device"]:
|
||||
random_device = random.choice(example["states"]).split(" ")[0]
|
||||
|
||||
random_service = random.choice(example["available_services"] - [ call["service"] ])
|
||||
|
||||
# need examples of hallucinated names, not incorrect names
|
||||
# need examples of hallucinated service names, not incorrect service
|
||||
# should make a csv that maps "real name" to "hallucinated name"
|
||||
|
||||
update_dict = random.choice([
|
||||
{ "service": random_service },
|
||||
{ "target_device": random_device },
|
||||
])
|
||||
|
||||
# need to replace the response text with what the incorrect response would have been
|
||||
if len(rejected_example["service_calls"]) == 1:
|
||||
pass
|
||||
|
||||
rejected_example["service_calls"][call_idx].update(update_dict)
|
||||
|
||||
return { "accepted": example, "rejected": rejected_example }
|
||||
|
||||
def generate_dpo_no_service_call():
|
||||
pass
|
||||
|
||||
def generate_dpo_incorrect_persona():
|
||||
pass
|
||||
|
||||
def format_example_raw_chatml(example, persona):
|
||||
"""Don't use this one anymore"""
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
@@ -780,6 +819,38 @@ def format_example_sharegpt(example, persona):
|
||||
|
||||
return { "conversations": conversation }
|
||||
|
||||
def format_example_dpo(example, persona):
|
||||
rejected_example = example["rejected"]
|
||||
example = example["accepted"]
|
||||
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = "Devices:\n" + "\n".join(example["states"])
|
||||
question = example["question"]
|
||||
|
||||
assistant_block = " ".join(example["answers"])
|
||||
if len(example["service_calls"]) > 0:
|
||||
json_calls = [ json.dumps(x) for x in example["service_calls"] ]
|
||||
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
|
||||
assistant_block = assistant_block + code_block
|
||||
|
||||
rejected_assistant_block = " ".join(rejected_example["answers"])
|
||||
if len(rejected_example["service_calls"]) > 0:
|
||||
json_calls = [ json.dumps(x) for x in rejected_example["service_calls"] ]
|
||||
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
|
||||
rejected_assistant_block = rejected_assistant_block + code_block
|
||||
|
||||
# replace aliases with their actual values
|
||||
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
|
||||
return {
|
||||
"system": "\n".join([ sys_prompt, services_block, states_block ]),
|
||||
"question": question,
|
||||
"chosen": assistant_block,
|
||||
"rejected": rejected_assistant_block,
|
||||
}
|
||||
|
||||
def generate_example_file(filename: str, seed: int, format_func: Callable, languages: list[str], personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
|
||||
random.seed(seed)
|
||||
@@ -871,12 +942,11 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset
|
||||
combined_dataset_train.to_json(f"home_assistant_{output_name}_merged_train.jsonl")
|
||||
combined_dataset_test.to_json(f"home_assistant_{output_name}_merged_test.jsonl")
|
||||
|
||||
|
||||
# TODO: add examples for ambiguous requests. asking a clarifying question
|
||||
# TODO: support rejection when asking to do a service that isn't exposed
|
||||
# TODO: make more randomized names for devices (random words or people's names)
|
||||
# TODO: answer questions about more than one thing in the state list at once
|
||||
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
|
||||
# TODO: add personas for responses. different system prompts should invoke different response tones (pirate, robot, and mean)
|
||||
# TODO: add time, weather, and calendar/reminders (next 3 events?)
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles")
|
||||
|
||||
Reference in New Issue
Block a user