more DPO example types

This commit is contained in:
Alex O'Connell
2024-04-14 08:02:20 -04:00
parent ce75bf0d7c
commit 85cd5ec036

View File

@@ -323,6 +323,9 @@ SUPPORTED_DEVICES = {
class NoResponseAvailableException(Exception):
pass
class NoServicesAvailableException(Exception):
pass
def get_random_response(*, service: str, persona: str, question_template: str, short: bool) -> str:
required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var]))
@@ -596,7 +599,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
"service_calls": service_calls
}
def generate_status_request(template: dict, persona: str, max_devices: int = 32):
def generate_status_request(template: dict, persona: str, max_devices: int = 32, return_target_device: bool = False):
device_type: str = template["device_type"]
state_name: str = template["state"]
question_template: str = template["phrase"]
@@ -677,53 +680,91 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32)
for x in set(device_types + [device_type]):
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
return {
result = {
"states": device_list,
"available_services": list(available_services),
"question": question.lower(),
"answers": [ answer.lower() ],
"service_calls": []
}
if return_target_device:
return result, chosen_device
else:
return result
def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32):
"""Generates examples of the model passing incorrect service call arguments"""
while True:
example = generate_templated_example(template, persona, max_devices)
rejected_example = copy.deepcopy(example)
call_idx = random.randint(0, len(example["service_calls"]) - 1)
call = example["service_calls"][call_idx]
target_device_type = call["target_device"].split(".")[0]
potential_devices = [ x for x in example["states"] if x.split(".")[0] == target_device_type]
random_device = random.choice(potential_devices).split(" ")[0]
# print(f"{target_device_type} = {len(potential_devices)}")
if len(potential_devices) > 1:
while random_device == call["target_device"]:
random_device = random.choice(potential_devices).split(" ")[0]
else:
random_device = None
# random service should probably be "related"
random_service = random.choice([ x for x in example["available_services"] if call["service"] not in x ])[:-2]
random_argument = None # based on the service, add arguments that might be there like rgb, temperature, etc
# need examples of hallucinated names, not incorrect names
# need examples of hallucinated service names, not incorrect service
# should make a csv that maps "real name" to "hallucinated name"
update_choices = []
if random_device:
update_choices.append({ "target_device": random_device })
if random_service:
update_choices.append({ "service": random_service })
if random_argument:
update_choices.append(random_argument)
update_dict = random.choice(update_choices)
# need to replace the response text with what the incorrect response would have been
if len(rejected_example["service_calls"]) == 1:
pass
rejected_example["service_calls"][call_idx].update(update_dict)
return { "accepted": example, "rejected": rejected_example }
def generate_dpo_no_service_call(template: dict, persona: str, max_devices: int = 32):
"""Generates examples of the model saying 'i'll do that for you' and generating no service calls"""
example = generate_templated_example(template, persona, max_devices)
rejected_example = copy.deepcopy(example)
call_idx = random.randint(0, len(example["service_calls"]) - 1)
call = example["service_calls"][call_idx]
random_device = random.choice(example["states"]).split(" ")[0]
# TODO: random device type should probably match
while random_device == call["target_device"]:
random_device = random.choice(example["states"]).split(" ")[0]
# random service should probably be "related"
random_service = random.choice([ x for x in example["available_services"] if call["service"] not in x ])[:-2]
# need examples of hallucinated names, not incorrect names
# need examples of hallucinated service names, not incorrect service
# should make a csv that maps "real name" to "hallucinated name"
update_dict = random.choice([
{ "service": random_service },
{ "target_device": random_device },
])
# need to replace the response text with what the incorrect response would have been
if len(rejected_example["service_calls"]) == 1:
pass
rejected_example["service_calls"][call_idx].update(update_dict)
rejected_example["service_calls"] = []
return { "accepted": example, "rejected": rejected_example }
def generate_dpo_no_service_call():
pass
def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: int = 32):
"""Generates examples of the model adding random service calls to the end of status requests"""
example, target_device = generate_status_request(template, persona, max_devices, return_target_device=True)
rejected_example = copy.deepcopy(example)
def generate_dpo_extra_service_call():
pass
device_name = target_device["device_name"]
device_type = device_name.split(".")[0]
random_device_services = [ x for x in example["available_services"] if x.split(".")[0] == device_type ]
def generate_dpo_incorrect_persona():
if len(random_device_services) == 0:
raise NoServicesAvailableException()
rejected_example["service_calls"] = [{ "service": random.choice(random_device_services), "target_device": device_name }]
return { "accepted": example, "rejected": rejected_example }
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
pass
def format_example_raw_chatml(example, persona):
@@ -880,20 +921,19 @@ def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas:
missing_responses = set()
for person in personas:
# for action in tqdm(pile_of_specific_actions):
# try:
# run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
# except NoResponseAvailableException as ex:
# missing_responses.add(str(ex))
for templated_action in tqdm(pile_of_templated_actions):
try:
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor)
# run_factor_times(generate_dpo_incorrect_persona, generated_examples, templated_action, person, incorrect_persona_factor)
except NoResponseAvailableException as ex:
missing_responses.add(str(ex))
# for status_request in tqdm(pile_of_status_requests):
# run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
for status_request in tqdm(pile_of_status_requests):
try:
run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor)
except NoServicesAvailableException as ex:
pass # TODO: warn here?
print(f"Generated {len(generated_examples)} DPO examples. Saving...")
@@ -965,19 +1005,19 @@ def load_dataset_piles(language):
pile_of_templated_actions, pile_of_specific_actions, pile_of_responses, pile_of_status_requests, \
pile_of_system_prompts
with open(f"piles/{language}/pile_of_durations.csv") as f:
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_durations = { x["duration"]: x["name"] for x in reader }
# media names are not translated
with open(f"piles/english/pile_of_media_names.txt") as f:
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
pile_of_media_names = [ x.strip() for x in f.readlines() ]
with open(f"piles/{language}/pile_of_todo_items.txt") as f:
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
pile_of_todo_items = [ x.strip() for x in f.readlines() ]
stacks_of_device_names = { x: [] for x in SUPPORTED_DEVICES.keys() }
with open(f"piles/{language}/pile_of_device_names.csv") as f:
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_device_names = list(reader)
for device_dict in pile_of_device_names:
@@ -987,7 +1027,7 @@ def load_dataset_piles(language):
except KeyError as ex:
print(ex)
with open(f"piles/{language}/pile_of_templated_actions.csv") as f:
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_templated_actions = list(reader)
processed_pile_of_templated_actions = []
@@ -1001,18 +1041,18 @@ def load_dataset_piles(language):
pile_of_templated_actions = processed_pile_of_templated_actions
with open(f"piles/{language}/pile_of_specific_actions.csv") as f:
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_specific_actions = list(reader)
pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
pile_of_responses["contains_vars"] = pile_of_responses["response"].apply(get_included_vars)
with open(f"piles/{language}/pile_of_status_requests.csv") as f:
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_status_requests = list(reader)
with open(f"piles/{language}/pile_of_system_prompts.csv") as f:
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
@@ -1081,7 +1121,7 @@ def main():
merge_languages("home_assistant_test", args.language)
if args.dpo:
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=0, extra_service_call_factor=0, incorrect_persona_factor=0)
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1)
if args.merge == "alpaca":
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)