mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
more DPO example types
This commit is contained in:
@@ -323,6 +323,9 @@ SUPPORTED_DEVICES = {
|
||||
class NoResponseAvailableException(Exception):
|
||||
pass
|
||||
|
||||
class NoServicesAvailableException(Exception):
|
||||
pass
|
||||
|
||||
def get_random_response(*, service: str, persona: str, question_template: str, short: bool) -> str:
|
||||
|
||||
required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var]))
|
||||
@@ -596,7 +599,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
"service_calls": service_calls
|
||||
}
|
||||
|
||||
def generate_status_request(template: dict, persona: str, max_devices: int = 32):
|
||||
def generate_status_request(template: dict, persona: str, max_devices: int = 32, return_target_device: bool = False):
|
||||
device_type: str = template["device_type"]
|
||||
state_name: str = template["state"]
|
||||
question_template: str = template["phrase"]
|
||||
@@ -677,53 +680,91 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32)
|
||||
for x in set(device_types + [device_type]):
|
||||
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
|
||||
|
||||
return {
|
||||
result = {
|
||||
"states": device_list,
|
||||
"available_services": list(available_services),
|
||||
"question": question.lower(),
|
||||
"answers": [ answer.lower() ],
|
||||
"service_calls": []
|
||||
}
|
||||
if return_target_device:
|
||||
return result, chosen_device
|
||||
else:
|
||||
return result
|
||||
|
||||
def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32):
|
||||
"""Generates examples of the model passing incorrect service call arguments"""
|
||||
while True:
|
||||
example = generate_templated_example(template, persona, max_devices)
|
||||
rejected_example = copy.deepcopy(example)
|
||||
|
||||
call_idx = random.randint(0, len(example["service_calls"]) - 1)
|
||||
call = example["service_calls"][call_idx]
|
||||
|
||||
target_device_type = call["target_device"].split(".")[0]
|
||||
|
||||
potential_devices = [ x for x in example["states"] if x.split(".")[0] == target_device_type]
|
||||
random_device = random.choice(potential_devices).split(" ")[0]
|
||||
|
||||
# print(f"{target_device_type} = {len(potential_devices)}")
|
||||
|
||||
if len(potential_devices) > 1:
|
||||
while random_device == call["target_device"]:
|
||||
random_device = random.choice(potential_devices).split(" ")[0]
|
||||
else:
|
||||
random_device = None
|
||||
|
||||
# random service should probably be "related"
|
||||
random_service = random.choice([ x for x in example["available_services"] if call["service"] not in x ])[:-2]
|
||||
|
||||
random_argument = None # based on the service, add arguments that might be there like rgb, temperature, etc
|
||||
|
||||
# need examples of hallucinated names, not incorrect names
|
||||
# need examples of hallucinated service names, not incorrect service
|
||||
# should make a csv that maps "real name" to "hallucinated name"
|
||||
|
||||
update_choices = []
|
||||
if random_device:
|
||||
update_choices.append({ "target_device": random_device })
|
||||
if random_service:
|
||||
update_choices.append({ "service": random_service })
|
||||
if random_argument:
|
||||
update_choices.append(random_argument)
|
||||
update_dict = random.choice(update_choices)
|
||||
|
||||
# need to replace the response text with what the incorrect response would have been
|
||||
if len(rejected_example["service_calls"]) == 1:
|
||||
pass
|
||||
|
||||
rejected_example["service_calls"][call_idx].update(update_dict)
|
||||
|
||||
return { "accepted": example, "rejected": rejected_example }
|
||||
|
||||
def generate_dpo_no_service_call(template: dict, persona: str, max_devices: int = 32):
|
||||
"""Generates examples of the model saying 'i'll do that for you' and generating no service calls"""
|
||||
example = generate_templated_example(template, persona, max_devices)
|
||||
rejected_example = copy.deepcopy(example)
|
||||
|
||||
call_idx = random.randint(0, len(example["service_calls"]) - 1)
|
||||
call = example["service_calls"][call_idx]
|
||||
|
||||
random_device = random.choice(example["states"]).split(" ")[0]
|
||||
# TODO: random device type should probably match
|
||||
while random_device == call["target_device"]:
|
||||
random_device = random.choice(example["states"]).split(" ")[0]
|
||||
|
||||
# random service should probably be "related"
|
||||
random_service = random.choice([ x for x in example["available_services"] if call["service"] not in x ])[:-2]
|
||||
|
||||
# need examples of hallucinated names, not incorrect names
|
||||
# need examples of hallucinated service names, not incorrect service
|
||||
# should make a csv that maps "real name" to "hallucinated name"
|
||||
|
||||
update_dict = random.choice([
|
||||
{ "service": random_service },
|
||||
{ "target_device": random_device },
|
||||
])
|
||||
|
||||
# need to replace the response text with what the incorrect response would have been
|
||||
if len(rejected_example["service_calls"]) == 1:
|
||||
pass
|
||||
|
||||
rejected_example["service_calls"][call_idx].update(update_dict)
|
||||
rejected_example["service_calls"] = []
|
||||
|
||||
return { "accepted": example, "rejected": rejected_example }
|
||||
|
||||
def generate_dpo_no_service_call():
|
||||
pass
|
||||
def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: int = 32):
|
||||
"""Generates examples of the model adding random service calls to the end of status requests"""
|
||||
example, target_device = generate_status_request(template, persona, max_devices, return_target_device=True)
|
||||
rejected_example = copy.deepcopy(example)
|
||||
|
||||
def generate_dpo_extra_service_call():
|
||||
pass
|
||||
device_name = target_device["device_name"]
|
||||
device_type = device_name.split(".")[0]
|
||||
random_device_services = [ x for x in example["available_services"] if x.split(".")[0] == device_type ]
|
||||
|
||||
def generate_dpo_incorrect_persona():
|
||||
if len(random_device_services) == 0:
|
||||
raise NoServicesAvailableException()
|
||||
|
||||
rejected_example["service_calls"] = [{ "service": random.choice(random_device_services), "target_device": device_name }]
|
||||
|
||||
return { "accepted": example, "rejected": rejected_example }
|
||||
|
||||
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
|
||||
pass
|
||||
|
||||
def format_example_raw_chatml(example, persona):
|
||||
@@ -880,20 +921,19 @@ def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas:
|
||||
missing_responses = set()
|
||||
|
||||
for person in personas:
|
||||
# for action in tqdm(pile_of_specific_actions):
|
||||
# try:
|
||||
# run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
|
||||
# except NoResponseAvailableException as ex:
|
||||
# missing_responses.add(str(ex))
|
||||
|
||||
for templated_action in tqdm(pile_of_templated_actions):
|
||||
try:
|
||||
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
|
||||
run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor)
|
||||
# run_factor_times(generate_dpo_incorrect_persona, generated_examples, templated_action, person, incorrect_persona_factor)
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
# for status_request in tqdm(pile_of_status_requests):
|
||||
# run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
|
||||
for status_request in tqdm(pile_of_status_requests):
|
||||
try:
|
||||
run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor)
|
||||
except NoServicesAvailableException as ex:
|
||||
pass # TODO: warn here?
|
||||
|
||||
print(f"Generated {len(generated_examples)} DPO examples. Saving...")
|
||||
|
||||
@@ -965,19 +1005,19 @@ def load_dataset_piles(language):
|
||||
pile_of_templated_actions, pile_of_specific_actions, pile_of_responses, pile_of_status_requests, \
|
||||
pile_of_system_prompts
|
||||
|
||||
with open(f"piles/{language}/pile_of_durations.csv") as f:
|
||||
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_durations = { x["duration"]: x["name"] for x in reader }
|
||||
|
||||
# media names are not translated
|
||||
with open(f"piles/english/pile_of_media_names.txt") as f:
|
||||
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||
pile_of_media_names = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
with open(f"piles/{language}/pile_of_todo_items.txt") as f:
|
||||
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||
pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
stacks_of_device_names = { x: [] for x in SUPPORTED_DEVICES.keys() }
|
||||
with open(f"piles/{language}/pile_of_device_names.csv") as f:
|
||||
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_device_names = list(reader)
|
||||
for device_dict in pile_of_device_names:
|
||||
@@ -987,7 +1027,7 @@ def load_dataset_piles(language):
|
||||
except KeyError as ex:
|
||||
print(ex)
|
||||
|
||||
with open(f"piles/{language}/pile_of_templated_actions.csv") as f:
|
||||
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_templated_actions = list(reader)
|
||||
processed_pile_of_templated_actions = []
|
||||
@@ -1001,18 +1041,18 @@ def load_dataset_piles(language):
|
||||
|
||||
pile_of_templated_actions = processed_pile_of_templated_actions
|
||||
|
||||
with open(f"piles/{language}/pile_of_specific_actions.csv") as f:
|
||||
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_specific_actions = list(reader)
|
||||
|
||||
pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
|
||||
pile_of_responses["contains_vars"] = pile_of_responses["response"].apply(get_included_vars)
|
||||
|
||||
with open(f"piles/{language}/pile_of_status_requests.csv") as f:
|
||||
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_status_requests = list(reader)
|
||||
|
||||
with open(f"piles/{language}/pile_of_system_prompts.csv") as f:
|
||||
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
|
||||
|
||||
@@ -1081,7 +1121,7 @@ def main():
|
||||
merge_languages("home_assistant_test", args.language)
|
||||
|
||||
if args.dpo:
|
||||
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=0, extra_service_call_factor=0, incorrect_persona_factor=0)
|
||||
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1)
|
||||
|
||||
if args.merge == "alpaca":
|
||||
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
|
||||
|
||||
Reference in New Issue
Block a user