|
|
|
|
@@ -6,6 +6,7 @@ import numpy as np
|
|
|
|
|
import random
|
|
|
|
|
import re
|
|
|
|
|
import copy
|
|
|
|
|
import babel.dates
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
from datasets import load_dataset, concatenate_datasets
|
|
|
|
|
@@ -51,7 +52,7 @@ pile_of_responses = None
|
|
|
|
|
pile_of_status_requests = None
|
|
|
|
|
pile_of_system_prompts = None
|
|
|
|
|
pile_of_hallucinated_service_names = None
|
|
|
|
|
and_word = None
|
|
|
|
|
and_words = None
|
|
|
|
|
|
|
|
|
|
def closest_color(requested_color):
|
|
|
|
|
min_colors = {}
|
|
|
|
|
@@ -63,7 +64,7 @@ def closest_color(requested_color):
|
|
|
|
|
min_colors[(rd + gd + bd)] = name
|
|
|
|
|
return min_colors[min(min_colors.keys())]
|
|
|
|
|
|
|
|
|
|
def generate_random_date():
|
|
|
|
|
def generate_random_datetime():
|
|
|
|
|
start_date = datetime(2022, 1, 1)
|
|
|
|
|
end_date = datetime(2030, 12, 31)
|
|
|
|
|
delta = end_date - start_date
|
|
|
|
|
@@ -330,6 +331,47 @@ SUPPORTED_DEVICES = {
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CURRENT_DATE_PROMPT = {
|
|
|
|
|
"english": "The current time and date is",
|
|
|
|
|
"polish": "Aktualna godzina i data to",
|
|
|
|
|
"german": "Die aktuelle Uhrzeit und das aktuelle Datum sind",
|
|
|
|
|
"french": "L'heure et la date actuelles sont",
|
|
|
|
|
"spanish": "La hora y fecha actuales son"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DEVICES_PROMPT = {
|
|
|
|
|
"english": "Devices",
|
|
|
|
|
"polish": "Urządzenia",
|
|
|
|
|
"german": "Ger\u00e4te",
|
|
|
|
|
"french": "Appareils",
|
|
|
|
|
"spanish": "Dispositivos"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SERVICES_PROMPT = {
|
|
|
|
|
"english": "Services",
|
|
|
|
|
"polish": "Usługi",
|
|
|
|
|
"german": "Dienste",
|
|
|
|
|
"french": "Services",
|
|
|
|
|
"spanish": "Servicios"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BABEL_LOCALE = {
|
|
|
|
|
"english": "en_US",
|
|
|
|
|
"polish": "pl_PL",
|
|
|
|
|
"german": "de_DE",
|
|
|
|
|
"french": "fr_FR",
|
|
|
|
|
"spanish": "es_ES"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BABEL_FORMAT = {
|
|
|
|
|
"english": "h:m a 'on' EEEE, MMMM d yyyy",
|
|
|
|
|
"polish": "H:m 'w' EEEE, d MMMM yyyy",
|
|
|
|
|
"german": "H:m EEEE, d MMMM yyyy",
|
|
|
|
|
"french": "H:m EEEE, d MMMM yyyy",
|
|
|
|
|
"spanish": "H:m EEEE, d 'de' MMMM 'de' yyyy"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NoResponseAvailableException(Exception):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@@ -456,6 +498,12 @@ def generate_static_example(action: dict, persona: str, max_devices: int = 32):
|
|
|
|
|
"service_calls": [ { "service": service_name, "target_device": target_device } ]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def replace_answer(list_of_answer, var, value):
|
|
|
|
|
new_list = []
|
|
|
|
|
for answer in list_of_answer:
|
|
|
|
|
new_list.append(answer.replace(var, value))
|
|
|
|
|
return new_list
|
|
|
|
|
|
|
|
|
|
def generate_templated_example(template: dict, persona: str, max_devices: int = 32):
|
|
|
|
|
template_device_types: list[str] = template["device_type"].split("|")
|
|
|
|
|
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
|
|
|
|
@@ -527,7 +575,9 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
|
|
|
|
)
|
|
|
|
|
answers.append(answer.replace(f"<device_name>", chosen_devices[i]["description"]))
|
|
|
|
|
|
|
|
|
|
answer = f" {and_word} ".join(answers)
|
|
|
|
|
answer = []
|
|
|
|
|
for word in and_words:
|
|
|
|
|
answer.append(f" {word} ".join(answers))
|
|
|
|
|
|
|
|
|
|
# generate the list of service calls and answers
|
|
|
|
|
service_calls = []
|
|
|
|
|
@@ -539,31 +589,31 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
|
|
|
|
if "<hvac_mode>" in question:
|
|
|
|
|
hvac_mode = climate_device_type.get_random_parameter("hvac_mode")
|
|
|
|
|
question = question.replace("<hvac_mode>", hvac_mode)
|
|
|
|
|
answer = answer.replace("<hvac_mode>", hvac_mode)
|
|
|
|
|
answer = replace_answer(answer, "<hvac_mode>", hvac_mode)
|
|
|
|
|
service_calls = [ { **call, "hvac_mode": hvac_mode} for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
if "<fan_mode>" in question:
|
|
|
|
|
fan_mode = climate_device_type.get_random_parameter("fan_mode")
|
|
|
|
|
question = question.replace("<fan_mode>", fan_mode)
|
|
|
|
|
answer = answer.replace("<fan_mode>", fan_mode)
|
|
|
|
|
answer = replace_answer(answer, "<fan_mode>", fan_mode)
|
|
|
|
|
service_calls = [ { **call, "fan_mode": fan_mode} for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
if "<temp_f>" in question:
|
|
|
|
|
temp_f = climate_device_type.get_random_parameter("temp_f")
|
|
|
|
|
question = question.replace("<temp_f>", str(temp_f))
|
|
|
|
|
answer = answer.replace("<temp_f>", str(temp_f))
|
|
|
|
|
answer = replace_answer(answer, "<temp_f>", str(temp_f))
|
|
|
|
|
service_calls = [ { **call, "temperature": temp_f} for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
if "<temp_c>" in question:
|
|
|
|
|
temp_c = climate_device_type.get_random_parameter("temp_c")
|
|
|
|
|
question = question.replace("<temp_c>", str(temp_c))
|
|
|
|
|
answer = answer.replace("<temp_c>", str(temp_c))
|
|
|
|
|
answer = replace_answer(answer, "<temp_c>", str(temp_c))
|
|
|
|
|
service_calls = [ { **call, "temperature": temp_c} for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
if "<humidity>" in question:
|
|
|
|
|
humidity = climate_device_type.get_random_parameter("humidity")
|
|
|
|
|
question = question.replace("<humidity>", str(humidity))
|
|
|
|
|
answer = answer.replace("<humidity>", str(humidity))
|
|
|
|
|
answer = replace_answer(answer, "<humidity>", str(humidity))
|
|
|
|
|
service_calls = [ { **call, "humidity": humidity} for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
if any(["light" in service for service in service_names ]):
|
|
|
|
|
@@ -571,7 +621,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
|
|
|
|
if "<brightness>" in question:
|
|
|
|
|
brightness = light_device_type.get_random_parameter("brightness")
|
|
|
|
|
question = question.replace("<brightness>", str(brightness))
|
|
|
|
|
answer = answer.replace("<brightness>", str(brightness))
|
|
|
|
|
answer = replace_answer(answer, "<brightness>", str(brightness))
|
|
|
|
|
service_calls = [ { **call, "brightness": round(brightness / 100, 2) } for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
if "<color>" in question:
|
|
|
|
|
@@ -580,7 +630,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
|
|
|
|
actual_random_rgb = webcolors.name_to_rgb(random_rgb_name)
|
|
|
|
|
actual_random_rgb = (actual_random_rgb.red, actual_random_rgb.green, actual_random_rgb.blue)
|
|
|
|
|
question = question.replace("<color>", str(random_rgb_name))
|
|
|
|
|
answer = answer.replace("<color>", str(random_rgb_name))
|
|
|
|
|
answer = replace_answer(answer, "<color>", str(random_rgb_name))
|
|
|
|
|
service_calls = [ { **call, "rgb_color": str(actual_random_rgb) } for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
if any(["timer" in service for service in service_names ]):
|
|
|
|
|
@@ -589,7 +639,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
|
|
|
|
duration = timer_device_type.get_random_parameter("duration")
|
|
|
|
|
duration_name = pile_of_durations[duration]
|
|
|
|
|
question = question.replace("<duration>", duration_name)
|
|
|
|
|
answer = answer.replace("<duration>", duration_name)
|
|
|
|
|
answer = replace_answer(answer, "<duration>", duration_name)
|
|
|
|
|
service_calls = [ { **call, "duration": str(duration) } for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
if any(["todo" in service for service in service_names ]):
|
|
|
|
|
@@ -597,14 +647,14 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
|
|
|
|
if "<todo>" in question:
|
|
|
|
|
todo = todo_device_type.get_random_parameter("todo")
|
|
|
|
|
question = question.replace("<todo>", todo)
|
|
|
|
|
answer = answer.replace("<todo>", todo)
|
|
|
|
|
answer = replace_answer(answer, "<todo>", todo)
|
|
|
|
|
service_calls = [ { **call, "item": todo } for call in service_calls ]
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"states": device_list,
|
|
|
|
|
"available_services": list(available_services),
|
|
|
|
|
"question": question.lower(),
|
|
|
|
|
"answers": [ answer.lower() ],
|
|
|
|
|
"answers": [ sentence.lower() for sentence in answer ],
|
|
|
|
|
"service_calls": service_calls
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -777,11 +827,11 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i
|
|
|
|
|
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def format_example_raw_chatml(example, persona):
|
|
|
|
|
def format_example_raw_chatml(example, persona, language):
|
|
|
|
|
"""Don't use this one anymore"""
|
|
|
|
|
sys_prompt = pile_of_system_prompts[persona]
|
|
|
|
|
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
|
|
|
|
|
states_block = "Devices:\n" + "\n".join(example["states"])
|
|
|
|
|
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
|
|
|
|
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
|
|
|
|
question = example["question"]
|
|
|
|
|
answers = " ".join(example["answers"])
|
|
|
|
|
|
|
|
|
|
@@ -805,11 +855,13 @@ def format_example_raw_chatml(example, persona):
|
|
|
|
|
result = result.replace("garage_door.", "cover.")
|
|
|
|
|
return { "text": result }
|
|
|
|
|
|
|
|
|
|
def format_example_sharegpt(example, persona):
|
|
|
|
|
def format_example_sharegpt(example, persona, language):
|
|
|
|
|
sys_prompt = pile_of_system_prompts[persona]
|
|
|
|
|
time_block = "The current time and date is " + generate_random_date().strftime("%I:%M %p on %A %B %d, %Y")
|
|
|
|
|
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
|
|
|
|
|
states_block = "Devices:\n" + "\n".join(example["states"])
|
|
|
|
|
random_datetime = generate_random_datetime()
|
|
|
|
|
translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])
|
|
|
|
|
time_block = f"{CURRENT_DATE_PROMPT[language]} {translate_datetime}"
|
|
|
|
|
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
|
|
|
|
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
|
|
|
|
question = example["question"]
|
|
|
|
|
answers = " ".join(example["answers"])
|
|
|
|
|
|
|
|
|
|
@@ -832,13 +884,13 @@ def format_example_sharegpt(example, persona):
|
|
|
|
|
|
|
|
|
|
return { "conversations": conversation }
|
|
|
|
|
|
|
|
|
|
def format_example_dpo(example, persona):
|
|
|
|
|
def format_example_dpo(example, persona, language):
|
|
|
|
|
rejected_example = example["rejected"]
|
|
|
|
|
example = example["accepted"]
|
|
|
|
|
|
|
|
|
|
sys_prompt = pile_of_system_prompts[persona]
|
|
|
|
|
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
|
|
|
|
|
states_block = "Devices:\n" + "\n".join(example["states"])
|
|
|
|
|
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
|
|
|
|
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
|
|
|
|
question = example["question"]
|
|
|
|
|
|
|
|
|
|
assistant_block = " ".join(example["answers"])
|
|
|
|
|
@@ -866,19 +918,19 @@ def format_example_dpo(example, persona):
|
|
|
|
|
"rejected": rejected_assistant_block,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
|
|
|
|
|
def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], language: str, *, static_factor: int, template_factor: int, status_request_factor: int):
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
|
|
|
|
|
print("Generating...")
|
|
|
|
|
|
|
|
|
|
def run_factor_times(func, examples, data, persona, factor):
|
|
|
|
|
def run_factor_times(func, examples, data, persona, factor, language):
|
|
|
|
|
if factor >= 1:
|
|
|
|
|
for i in range(factor):
|
|
|
|
|
examples.append(format_func(func(data, persona), persona))
|
|
|
|
|
examples.append(format_func(func(data, persona), persona, language))
|
|
|
|
|
else:
|
|
|
|
|
if random.random() < factor:
|
|
|
|
|
examples.append(format_func(func(data, persona), persona))
|
|
|
|
|
examples.append(format_func(func(data, persona), persona, language))
|
|
|
|
|
|
|
|
|
|
generated_examples = []
|
|
|
|
|
|
|
|
|
|
@@ -887,18 +939,18 @@ def generate_sft_file(filename: str, seed: int, format_func: Callable, personas:
|
|
|
|
|
for person in personas:
|
|
|
|
|
for action in tqdm(pile_of_specific_actions):
|
|
|
|
|
try:
|
|
|
|
|
run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
|
|
|
|
|
run_factor_times(generate_static_example, generated_examples, action, person, static_factor, language)
|
|
|
|
|
except NoResponseAvailableException as ex:
|
|
|
|
|
missing_responses.add(str(ex))
|
|
|
|
|
|
|
|
|
|
for templated_action in tqdm(pile_of_templated_actions):
|
|
|
|
|
try:
|
|
|
|
|
run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor)
|
|
|
|
|
run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor, language)
|
|
|
|
|
except NoResponseAvailableException as ex:
|
|
|
|
|
missing_responses.add(str(ex))
|
|
|
|
|
|
|
|
|
|
for status_request in tqdm(pile_of_status_requests):
|
|
|
|
|
run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
|
|
|
|
|
run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor, language)
|
|
|
|
|
|
|
|
|
|
print(f"Generated {len(generated_examples)} examples. Saving...")
|
|
|
|
|
|
|
|
|
|
@@ -912,20 +964,19 @@ def generate_sft_file(filename: str, seed: int, format_func: Callable, personas:
|
|
|
|
|
|
|
|
|
|
print("Done!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
|
|
|
|
|
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], language: str, *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
|
|
|
|
|
print("Generating...")
|
|
|
|
|
|
|
|
|
|
def run_factor_times(func, examples, data, persona, factor):
|
|
|
|
|
def run_factor_times(func, examples, data, persona, factor, language):
|
|
|
|
|
if factor >= 1:
|
|
|
|
|
for i in range(factor):
|
|
|
|
|
examples.append(format_func(func(data, persona), persona))
|
|
|
|
|
examples.append(format_func(func(data, persona), persona, language))
|
|
|
|
|
else:
|
|
|
|
|
if random.random() < factor:
|
|
|
|
|
examples.append(format_func(func(data, persona), persona))
|
|
|
|
|
examples.append(format_func(func(data, persona), persona, language))
|
|
|
|
|
|
|
|
|
|
generated_examples = []
|
|
|
|
|
|
|
|
|
|
@@ -934,15 +985,15 @@ def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas:
|
|
|
|
|
for person in personas:
|
|
|
|
|
for templated_action in tqdm(pile_of_templated_actions):
|
|
|
|
|
try:
|
|
|
|
|
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
|
|
|
|
|
run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor)
|
|
|
|
|
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor, language)
|
|
|
|
|
run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor, language)
|
|
|
|
|
# run_factor_times(generate_dpo_incorrect_persona, generated_examples, templated_action, person, incorrect_persona_factor)
|
|
|
|
|
except NoResponseAvailableException as ex:
|
|
|
|
|
missing_responses.add(str(ex))
|
|
|
|
|
|
|
|
|
|
for status_request in tqdm(pile_of_status_requests):
|
|
|
|
|
try:
|
|
|
|
|
run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor)
|
|
|
|
|
run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor, language)
|
|
|
|
|
except NoServicesAvailableException as ex:
|
|
|
|
|
pass # TODO: warn here?
|
|
|
|
|
|
|
|
|
|
@@ -1010,23 +1061,13 @@ def merge_languages(filename_prefix: str, languages: list):
|
|
|
|
|
with open(f"{filename_prefix}.jsonl", "w") as f:
|
|
|
|
|
f.writelines(all_examples)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_dataset_piles(language):
|
|
|
|
|
global pile_of_durations, pile_of_media_names, pile_of_todo_items, stacks_of_device_names, \
|
|
|
|
|
pile_of_templated_actions, pile_of_specific_actions, pile_of_responses, pile_of_status_requests, \
|
|
|
|
|
pile_of_system_prompts, pile_of_hallucinated_service_names, and_word
|
|
|
|
|
pile_of_system_prompts, pile_of_hallucinated_service_names, and_words
|
|
|
|
|
|
|
|
|
|
# TODO: make this properly dynamic
|
|
|
|
|
if language == "english":
|
|
|
|
|
and_word = "and"
|
|
|
|
|
elif language == "german":
|
|
|
|
|
and_word = "und"
|
|
|
|
|
elif language == "french":
|
|
|
|
|
and_word = "et"
|
|
|
|
|
elif language == "spanish":
|
|
|
|
|
and_word = "y"
|
|
|
|
|
elif language == "polish":
|
|
|
|
|
and_word = "i"
|
|
|
|
|
with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
|
|
|
|
and_words = [ x.strip() for x in f.readlines() ]
|
|
|
|
|
|
|
|
|
|
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
|
|
|
|
reader = csv.DictReader(f)
|
|
|
|
|
@@ -1130,20 +1171,20 @@ def main():
|
|
|
|
|
suffix = f"_{language}" if len(args.language) > 1 else ""
|
|
|
|
|
|
|
|
|
|
if args.sample:
|
|
|
|
|
generate_sft_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
|
|
|
|
|
generate_sft_file(f"sample{suffix}", 42, format_func, personas, language, static_factor=1, template_factor=1, status_request_factor=1)
|
|
|
|
|
if args.train:
|
|
|
|
|
if args.size == "small":
|
|
|
|
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
|
|
|
|
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=1, template_factor=10, status_request_factor=8)
|
|
|
|
|
elif args.size == "medium":
|
|
|
|
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
|
|
|
|
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=5, template_factor=15, status_request_factor=12)
|
|
|
|
|
elif args.size == "large":
|
|
|
|
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
|
|
|
|
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=5, template_factor=20, status_request_factor=15)
|
|
|
|
|
elif args.size == "xl":
|
|
|
|
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
|
|
|
|
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=7, template_factor=25, status_request_factor=18)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception(f"Unrecognized dataset size: {args.size}")
|
|
|
|
|
if args.test:
|
|
|
|
|
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
|
|
|
|
|
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, language, static_factor=0.25, template_factor=1, status_request_factor=2)
|
|
|
|
|
|
|
|
|
|
if len(args.language) > 1:
|
|
|
|
|
if args.sample:
|
|
|
|
|
@@ -1154,7 +1195,7 @@ def main():
|
|
|
|
|
merge_languages("home_assistant_test", args.language)
|
|
|
|
|
|
|
|
|
|
if args.dpo:
|
|
|
|
|
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1)
|
|
|
|
|
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, language, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1)
|
|
|
|
|
|
|
|
|
|
if args.merge == "alpaca":
|
|
|
|
|
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
|
|
|
|
|
|