diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index 34bf153..0eb7953 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -88,6 +88,12 @@ from .const import ( DEFAULT_SSL, DEFAULT_MAX_TOKENS, PERSONA_PROMPTS, + CURRENT_DATE_PROMPT, + DEVICES_PROMPT, + SERVICES_PROMPT, + TOOLS_PROMPT, + AREA_PROMPT, + USER_INSTRUCTION, DEFAULT_PROMPT_BASE, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, @@ -681,7 +687,20 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom selected_default_options.update(OPTIONS_OVERRIDES[key]) persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en")) + current_date = CURRENT_DATE_PROMPT.get(self.selected_language, CURRENT_DATE_PROMPT.get("en")) + devices = DEVICES_PROMPT.get(self.selected_language, DEVICES_PROMPT.get("en")) + services = SERVICES_PROMPT.get(self.selected_language, SERVICES_PROMPT.get("en")) + tools = TOOLS_PROMPT.get(self.selected_language, TOOLS_PROMPT.get("en")) + area = AREA_PROMPT.get(self.selected_language, AREA_PROMPT.get("en")) + user_instruction = USER_INSTRUCTION.get(self.selected_language, USER_INSTRUCTION.get("en")) + selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", persona) + selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", current_date) + selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", devices) + selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", services) + selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", tools) + selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", area) + selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", user_instruction) schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type)) diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index b10f792..f852fd8 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -12,24 +12,65 @@ PERSONA_PROMPTS = { "es": "Eres 'Al', un \u00fatil asistente de IA que controla los dispositivos de una casa. Complete la siguiente tarea seg\u00fan las instrucciones o responda la siguiente pregunta \u00fanicamente con la informaci\u00f3n proporcionada.", "pl": "Jeste\u015b 'Al', pomocnym asystentem AI, kt\u00f3ry kontroluje urz\u0105dzenia w domu. Wykonaj poni\u017csze zadanie zgodnie z instrukcj\u0105 lub odpowiedz na poni\u017csze pytanie, korzystaj\u0105c wy\u0142\u0105cznie z podanych informacji." } - +CURRENT_DATE_PROMPT = { + "en": """The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}""", + "de": """{% set day_name = ["Montag", "Dienstag", "Mittwoch", "Donnerstag", "Freitag", "Samstag", "Sonntag"] %}{% set month_name = ["Januar", "Februar", "März", "April", "Mai", "Juni", "Juli", "August", "September", "Oktober", "November", "Dezember"] %}Die aktuelle Uhrzeit und das aktuelle Datum sind {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""", + "fr": """{% set day_name = ["lundi", "mardi", "mercredi", "jeudi", "vendredi", "samedi", "dimanche"] %}{% set month_name = ["janvier", "février", "mars", "avril", "mai", "juin", "juillet", "août", "septembre", "octobre", "novembre", "décembre"] %} L'heure et la date actuelles sont {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""", + "es": """{% set day_name = ["lunes", "martes", "miércoles", "jueves", "viernes", "sábado", "domingo"] %}{% set month_name = ["enero", "febrero", "marzo", "abril", "mayo", "junio", "julio", "agosto", "septiembre", "octubre", "noviembre", "diciembre"] %}La hora y fecha actuales son {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} de {{ month_name[now().month -1]}} de {{ now().year }}.""", + "pl": """{% set day_name = ["poniedziałek", "wtorek", "środę", "czwartek", "piątek", "sobotę", "niedzielę"] %}{% set month_name = ["styczeń", "luty", "marzec", "kwiecień", "maj", "czerwiec", "lipiec", "sierpień", "wrzesień", "październik", "listopad", "grudzień"] %}Aktualna godzina i data to {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} w {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""" +} +DEVICES_PROMPT = { + "en": "Devices", + "de": "Ger\u00e4te", + "fr": "Appareils", + "es": "Dispositivos", + "pl": "Urządzenia", +} +SERVICES_PROMPT = { + "en": "Services", + "de": "Dienste", + "fr": "Services", + "es": "Servicios", + "pl": "Usługi", +} +TOOLS_PROMPT = { + "en": "Tools", + "de": "Werkzeuge", + "fr": "Outils", + "es": "Herramientas", + "pl": "Narzędzia", +} +AREA_PROMPT = { + "en": "Area", + "de": "Bereich", + "fr": "Zone", + "es": "Área", + "pl": "Obszar", +} +USER_INSTRUCTION = { + "en": "User instruction", + "de": "Benutzeranweisung", + "fr": "Instruction de l'utilisateur ", + "es": "Instrucción del usuario", + "pl": "Instrukcja użytkownika" +} DEFAULT_PROMPT_BASE = """ -The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }} -Tools: {{ tools | to_json }} -Devices: + +: {{ tools | to_json }} +: {% for device in devices | selectattr('area_id', 'none'): %} {{ device.entity_id }} '{{ device.name }}' = {{ device.state }}{{ ([""] + device.attributes) | join(";") }} {% endfor %} {% for area in devices | rejectattr('area_id', 'none') | groupby('area_name') %} -## Area: {{ area.grouper }} +## : {{ area.grouper }} {% for device in area.list %} {{ device.entity_id }} '{{ device.name }}' = {{ device.state }};{{ device.attributes | join(";") }} {% endfor %} {% endfor %}""" DEFAULT_PROMPT_BASE_LEGACY = """ -The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }} -Services: {{ formatted_tools }} -Devices: + +: {{ formatted_tools }} +: {{ formatted_devices }}""" ICL_EXTRAS = """ {% for item in response_examples %} @@ -43,7 +84,7 @@ ICL_NO_SYSTEM_PROMPT_EXTRAS = """ {{ item.response }} {{ item.tool | to_json }} {% endfor %} -User instruction:""" +:""" DEFAULT_PROMPT = DEFAULT_PROMPT_BASE + ICL_EXTRAS CONF_CHAT_MODEL = "huggingface_model" DEFAULT_CHAT_MODEL = "acon96/Home-3B-v3-GGUF" diff --git a/data/README.md b/data/README.md index b41d985..e8e799c 100644 --- a/data/README.md +++ b/data/README.md @@ -29,7 +29,7 @@ Then create a Python virtual environment and install all necessary library: ``` python3 -m venv .generate_data source ./.generate_data/bin/activate -pip3 install pandas=2.2.2 datasets==2.20.0 webcolors==1.13 +pip3 install pandas==2.2.2 datasets==2.20.0 webcolors==1.13 babel==2.15.0 ``` ## Generating the dataset from piles diff --git a/data/generate_home_assistant_data.py b/data/generate_home_assistant_data.py index b2546bf..daeec4b 100644 --- a/data/generate_home_assistant_data.py +++ b/data/generate_home_assistant_data.py @@ -6,6 +6,7 @@ import numpy as np import random import re import copy +import babel.dates from dataclasses import dataclass from datetime import datetime, timedelta from datasets import load_dataset, concatenate_datasets @@ -51,7 +52,7 @@ pile_of_responses = None pile_of_status_requests = None pile_of_system_prompts = None pile_of_hallucinated_service_names = None -and_word = None +and_words = None def closest_color(requested_color): min_colors = {} @@ -63,7 +64,7 @@ def closest_color(requested_color): min_colors[(rd + gd + bd)] = name return min_colors[min(min_colors.keys())] -def generate_random_date(): +def generate_random_datetime(): start_date = datetime(2022, 1, 1) end_date = datetime(2030, 12, 31) delta = end_date - start_date @@ -330,6 +331,47 @@ SUPPORTED_DEVICES = { ), } +CURRENT_DATE_PROMPT = { + "english": "The current time and date is", + "polish": "Aktualna godzina i data to", + "german": "Die aktuelle Uhrzeit und das aktuelle Datum sind", + "french": "L'heure et la date actuelles sont", + "spanish": "La hora y fecha actuales son" +} + +DEVICES_PROMPT = { + "english": "Devices", + "polish": "Urządzenia", + "german": "Ger\u00e4te", + "french": "Appareils", + "spanish": "Dispositivos" +} + +SERVICES_PROMPT = { + "english": "Services", + "polish": "Usługi", + "german": "Dienste", + "french": "Services", + "spanish": "Servicios" +} + +BABEL_LOCALE = { + "english": "en_US", + "polish": "pl_PL", + "german": "de_DE", + "french": "fr_FR", + "spanish": "es_ES" +} + +BABEL_FORMAT = { + "english": "h:m a 'on' EEEE, MMMM d yyyy", + "polish": "H:m 'w' EEEE, d MMMM yyyy", + "german": "H:m EEEE, d MMMM yyyy", + "french": "H:m EEEE, d MMMM yyyy", + "spanish": "H:m EEEE, d 'de' MMMM 'de' yyyy" +} + + class NoResponseAvailableException(Exception): pass @@ -456,6 +498,12 @@ def generate_static_example(action: dict, persona: str, max_devices: int = 32): "service_calls": [ { "service": service_name, "target_device": target_device } ] } +def replace_answer(list_of_answer, var, value): + new_list = [] + for answer in list_of_answer: + new_list.append(answer.replace(var, value)) + return new_list + def generate_templated_example(template: dict, persona: str, max_devices: int = 32): template_device_types: list[str] = template["device_type"].split("|") service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ] @@ -527,7 +575,9 @@ def generate_templated_example(template: dict, persona: str, max_devices: int = ) answers.append(answer.replace(f"", chosen_devices[i]["description"])) - answer = f" {and_word} ".join(answers) + answer = [] + for word in and_words: + answer.append(f" {word} ".join(answers)) # generate the list of service calls and answers service_calls = [] @@ -539,31 +589,31 @@ def generate_templated_example(template: dict, persona: str, max_devices: int = if "" in question: hvac_mode = climate_device_type.get_random_parameter("hvac_mode") question = question.replace("", hvac_mode) - answer = answer.replace("", hvac_mode) + answer = replace_answer(answer, "", hvac_mode) service_calls = [ { **call, "hvac_mode": hvac_mode} for call in service_calls ] if "" in question: fan_mode = climate_device_type.get_random_parameter("fan_mode") question = question.replace("", fan_mode) - answer = answer.replace("", fan_mode) + answer = replace_answer(answer, "", fan_mode) service_calls = [ { **call, "fan_mode": fan_mode} for call in service_calls ] if "" in question: temp_f = climate_device_type.get_random_parameter("temp_f") question = question.replace("", str(temp_f)) - answer = answer.replace("", str(temp_f)) + answer = replace_answer(answer, "", str(temp_f)) service_calls = [ { **call, "temperature": temp_f} for call in service_calls ] if "" in question: temp_c = climate_device_type.get_random_parameter("temp_c") question = question.replace("", str(temp_c)) - answer = answer.replace("", str(temp_c)) + answer = replace_answer(answer, "", str(temp_c)) service_calls = [ { **call, "temperature": temp_c} for call in service_calls ] if "" in question: humidity = climate_device_type.get_random_parameter("humidity") question = question.replace("", str(humidity)) - answer = answer.replace("", str(humidity)) + answer = replace_answer(answer, "", str(humidity)) service_calls = [ { **call, "humidity": humidity} for call in service_calls ] if any(["light" in service for service in service_names ]): @@ -571,7 +621,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int = if "" in question: brightness = light_device_type.get_random_parameter("brightness") question = question.replace("", str(brightness)) - answer = answer.replace("", str(brightness)) + answer = replace_answer(answer, "", str(brightness)) service_calls = [ { **call, "brightness": round(brightness / 100, 2) } for call in service_calls ] if "" in question: @@ -580,7 +630,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int = actual_random_rgb = webcolors.name_to_rgb(random_rgb_name) actual_random_rgb = (actual_random_rgb.red, actual_random_rgb.green, actual_random_rgb.blue) question = question.replace("", str(random_rgb_name)) - answer = answer.replace("", str(random_rgb_name)) + answer = replace_answer(answer, "", str(random_rgb_name)) service_calls = [ { **call, "rgb_color": str(actual_random_rgb) } for call in service_calls ] if any(["timer" in service for service in service_names ]): @@ -589,7 +639,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int = duration = timer_device_type.get_random_parameter("duration") duration_name = pile_of_durations[duration] question = question.replace("", duration_name) - answer = answer.replace("", duration_name) + answer = replace_answer(answer, "", duration_name) service_calls = [ { **call, "duration": str(duration) } for call in service_calls ] if any(["todo" in service for service in service_names ]): @@ -597,14 +647,14 @@ def generate_templated_example(template: dict, persona: str, max_devices: int = if "" in question: todo = todo_device_type.get_random_parameter("todo") question = question.replace("", todo) - answer = answer.replace("", todo) + answer = replace_answer(answer, "", todo) service_calls = [ { **call, "item": todo } for call in service_calls ] return { "states": device_list, "available_services": list(available_services), "question": question.lower(), - "answers": [ answer.lower() ], + "answers": [ sentence.lower() for sentence in answer ], "service_calls": service_calls } @@ -777,11 +827,11 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32): pass -def format_example_raw_chatml(example, persona): +def format_example_raw_chatml(example, persona, language): """Don't use this one anymore""" sys_prompt = pile_of_system_prompts[persona] - services_block = "Services: " + ", ".join(sorted(example["available_services"])) - states_block = "Devices:\n" + "\n".join(example["states"]) + services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"])) + states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"]) question = example["question"] answers = " ".join(example["answers"]) @@ -805,11 +855,13 @@ def format_example_raw_chatml(example, persona): result = result.replace("garage_door.", "cover.") return { "text": result } -def format_example_sharegpt(example, persona): +def format_example_sharegpt(example, persona, language): sys_prompt = pile_of_system_prompts[persona] - time_block = "The current time and date is " + generate_random_date().strftime("%I:%M %p on %A %B %d, %Y") - services_block = "Services: " + ", ".join(sorted(example["available_services"])) - states_block = "Devices:\n" + "\n".join(example["states"]) + random_datetime = generate_random_datetime() + translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language]) + time_block = f"{CURRENT_DATE_PROMPT[language]} {translate_datetime}" + services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"])) + states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"]) question = example["question"] answers = " ".join(example["answers"]) @@ -832,13 +884,13 @@ def format_example_sharegpt(example, persona): return { "conversations": conversation } -def format_example_dpo(example, persona): +def format_example_dpo(example, persona, language): rejected_example = example["rejected"] example = example["accepted"] sys_prompt = pile_of_system_prompts[persona] - services_block = "Services: " + ", ".join(sorted(example["available_services"])) - states_block = "Devices:\n" + "\n".join(example["states"]) + services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"])) + states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"]) question = example["question"] assistant_block = " ".join(example["answers"]) @@ -866,19 +918,19 @@ def format_example_dpo(example, persona): "rejected": rejected_assistant_block, } -def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int): +def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], language: str, *, static_factor: int, template_factor: int, status_request_factor: int): random.seed(seed) np.random.seed(seed) print("Generating...") - def run_factor_times(func, examples, data, persona, factor): + def run_factor_times(func, examples, data, persona, factor, language): if factor >= 1: for i in range(factor): - examples.append(format_func(func(data, persona), persona)) + examples.append(format_func(func(data, persona), persona, language)) else: if random.random() < factor: - examples.append(format_func(func(data, persona), persona)) + examples.append(format_func(func(data, persona), persona, language)) generated_examples = [] @@ -887,18 +939,18 @@ def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: for person in personas: for action in tqdm(pile_of_specific_actions): try: - run_factor_times(generate_static_example, generated_examples, action, person, static_factor) + run_factor_times(generate_static_example, generated_examples, action, person, static_factor, language) except NoResponseAvailableException as ex: missing_responses.add(str(ex)) for templated_action in tqdm(pile_of_templated_actions): try: - run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor) + run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor, language) except NoResponseAvailableException as ex: missing_responses.add(str(ex)) for status_request in tqdm(pile_of_status_requests): - run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor) + run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor, language) print(f"Generated {len(generated_examples)} examples. Saving...") @@ -912,20 +964,19 @@ def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: print("Done!") - -def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int): +def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], language: str, *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int): random.seed(seed) np.random.seed(seed) print("Generating...") - def run_factor_times(func, examples, data, persona, factor): + def run_factor_times(func, examples, data, persona, factor, language): if factor >= 1: for i in range(factor): - examples.append(format_func(func(data, persona), persona)) + examples.append(format_func(func(data, persona), persona, language)) else: if random.random() < factor: - examples.append(format_func(func(data, persona), persona)) + examples.append(format_func(func(data, persona), persona, language)) generated_examples = [] @@ -934,15 +985,15 @@ def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: for person in personas: for templated_action in tqdm(pile_of_templated_actions): try: - run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor) - run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor) + run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor, language) + run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor, language) # run_factor_times(generate_dpo_incorrect_persona, generated_examples, templated_action, person, incorrect_persona_factor) except NoResponseAvailableException as ex: missing_responses.add(str(ex)) for status_request in tqdm(pile_of_status_requests): try: - run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor) + run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor, language) except NoServicesAvailableException as ex: pass # TODO: warn here? @@ -1010,23 +1061,13 @@ def merge_languages(filename_prefix: str, languages: list): with open(f"{filename_prefix}.jsonl", "w") as f: f.writelines(all_examples) - def load_dataset_piles(language): global pile_of_durations, pile_of_media_names, pile_of_todo_items, stacks_of_device_names, \ pile_of_templated_actions, pile_of_specific_actions, pile_of_responses, pile_of_status_requests, \ - pile_of_system_prompts, pile_of_hallucinated_service_names, and_word + pile_of_system_prompts, pile_of_hallucinated_service_names, and_words - # TODO: make this properly dynamic - if language == "english": - and_word = "and" - elif language == "german": - and_word = "und" - elif language == "french": - and_word = "et" - elif language == "spanish": - and_word = "y" - elif language == "polish": - and_word = "i" + with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f: + and_words = [ x.strip() for x in f.readlines() ] with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f: reader = csv.DictReader(f) @@ -1130,20 +1171,20 @@ def main(): suffix = f"_{language}" if len(args.language) > 1 else "" if args.sample: - generate_sft_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1) + generate_sft_file(f"sample{suffix}", 42, format_func, personas, language, static_factor=1, template_factor=1, status_request_factor=1) if args.train: if args.size == "small": - generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8) + generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=1, template_factor=10, status_request_factor=8) elif args.size == "medium": - generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12) + generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=5, template_factor=15, status_request_factor=12) elif args.size == "large": - generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15) + generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=5, template_factor=20, status_request_factor=15) elif args.size == "xl": - generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18) + generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=7, template_factor=25, status_request_factor=18) else: raise Exception(f"Unrecognized dataset size: {args.size}") if args.test: - generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2) + generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, language, static_factor=0.25, template_factor=1, status_request_factor=2) if len(args.language) > 1: if args.sample: @@ -1154,7 +1195,7 @@ def main(): merge_languages("home_assistant_test", args.language) if args.dpo: - generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1) + generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, language, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1) if args.merge == "alpaca": merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func) diff --git a/data/piles/english/pile_of_and_words.csv b/data/piles/english/pile_of_and_words.csv new file mode 100644 index 0000000..c51107c --- /dev/null +++ b/data/piles/english/pile_of_and_words.csv @@ -0,0 +1 @@ +and \ No newline at end of file diff --git a/data/piles/french/pile_of_and_words.csv b/data/piles/french/pile_of_and_words.csv new file mode 100644 index 0000000..d8a47f8 --- /dev/null +++ b/data/piles/french/pile_of_and_words.csv @@ -0,0 +1 @@ +et \ No newline at end of file diff --git a/data/piles/german/pile_of_and_words.csv b/data/piles/german/pile_of_and_words.csv new file mode 100644 index 0000000..2babcce --- /dev/null +++ b/data/piles/german/pile_of_and_words.csv @@ -0,0 +1 @@ +und \ No newline at end of file diff --git a/data/piles/polish/pile_of_and_words.csv b/data/piles/polish/pile_of_and_words.csv new file mode 100644 index 0000000..0bda13b --- /dev/null +++ b/data/piles/polish/pile_of_and_words.csv @@ -0,0 +1,4 @@ +i +oraz +a także +również \ No newline at end of file diff --git a/data/piles/spanish/pile_of_and_words.csv b/data/piles/spanish/pile_of_and_words.csv new file mode 100644 index 0000000..e25f181 --- /dev/null +++ b/data/piles/spanish/pile_of_and_words.csv @@ -0,0 +1 @@ +y \ No newline at end of file