mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
Added full translations for all languages during generate data and creating default prompt system (#196)
This commit is contained in:
@@ -88,6 +88,12 @@ from .const import (
|
||||
DEFAULT_SSL,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
PERSONA_PROMPTS,
|
||||
CURRENT_DATE_PROMPT,
|
||||
DEVICES_PROMPT,
|
||||
SERVICES_PROMPT,
|
||||
TOOLS_PROMPT,
|
||||
AREA_PROMPT,
|
||||
USER_INSTRUCTION,
|
||||
DEFAULT_PROMPT_BASE,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
@@ -681,7 +687,20 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
selected_default_options.update(OPTIONS_OVERRIDES[key])
|
||||
|
||||
persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en"))
|
||||
current_date = CURRENT_DATE_PROMPT.get(self.selected_language, CURRENT_DATE_PROMPT.get("en"))
|
||||
devices = DEVICES_PROMPT.get(self.selected_language, DEVICES_PROMPT.get("en"))
|
||||
services = SERVICES_PROMPT.get(self.selected_language, SERVICES_PROMPT.get("en"))
|
||||
tools = TOOLS_PROMPT.get(self.selected_language, TOOLS_PROMPT.get("en"))
|
||||
area = AREA_PROMPT.get(self.selected_language, AREA_PROMPT.get("en"))
|
||||
user_instruction = USER_INSTRUCTION.get(self.selected_language, USER_INSTRUCTION.get("en"))
|
||||
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<persona>", persona)
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<current_date>", current_date)
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<devices>", devices)
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<services>", services)
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<tools>", tools)
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<area>", area)
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<user_instruction>", user_instruction)
|
||||
|
||||
schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type))
|
||||
|
||||
|
||||
@@ -12,24 +12,65 @@ PERSONA_PROMPTS = {
|
||||
"es": "Eres 'Al', un \u00fatil asistente de IA que controla los dispositivos de una casa. Complete la siguiente tarea seg\u00fan las instrucciones o responda la siguiente pregunta \u00fanicamente con la informaci\u00f3n proporcionada.",
|
||||
"pl": "Jeste\u015b 'Al', pomocnym asystentem AI, kt\u00f3ry kontroluje urz\u0105dzenia w domu. Wykonaj poni\u017csze zadanie zgodnie z instrukcj\u0105 lub odpowiedz na poni\u017csze pytanie, korzystaj\u0105c wy\u0142\u0105cznie z podanych informacji."
|
||||
}
|
||||
|
||||
CURRENT_DATE_PROMPT = {
|
||||
"en": """The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}""",
|
||||
"de": """{% set day_name = ["Montag", "Dienstag", "Mittwoch", "Donnerstag", "Freitag", "Samstag", "Sonntag"] %}{% set month_name = ["Januar", "Februar", "März", "April", "Mai", "Juni", "Juli", "August", "September", "Oktober", "November", "Dezember"] %}Die aktuelle Uhrzeit und das aktuelle Datum sind {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""",
|
||||
"fr": """{% set day_name = ["lundi", "mardi", "mercredi", "jeudi", "vendredi", "samedi", "dimanche"] %}{% set month_name = ["janvier", "février", "mars", "avril", "mai", "juin", "juillet", "août", "septembre", "octobre", "novembre", "décembre"] %} L'heure et la date actuelles sont {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""",
|
||||
"es": """{% set day_name = ["lunes", "martes", "miércoles", "jueves", "viernes", "sábado", "domingo"] %}{% set month_name = ["enero", "febrero", "marzo", "abril", "mayo", "junio", "julio", "agosto", "septiembre", "octubre", "noviembre", "diciembre"] %}La hora y fecha actuales son {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} de {{ month_name[now().month -1]}} de {{ now().year }}.""",
|
||||
"pl": """{% set day_name = ["poniedziałek", "wtorek", "środę", "czwartek", "piątek", "sobotę", "niedzielę"] %}{% set month_name = ["styczeń", "luty", "marzec", "kwiecień", "maj", "czerwiec", "lipiec", "sierpień", "wrzesień", "październik", "listopad", "grudzień"] %}Aktualna godzina i data to {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} w {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}."""
|
||||
}
|
||||
DEVICES_PROMPT = {
|
||||
"en": "Devices",
|
||||
"de": "Ger\u00e4te",
|
||||
"fr": "Appareils",
|
||||
"es": "Dispositivos",
|
||||
"pl": "Urządzenia",
|
||||
}
|
||||
SERVICES_PROMPT = {
|
||||
"en": "Services",
|
||||
"de": "Dienste",
|
||||
"fr": "Services",
|
||||
"es": "Servicios",
|
||||
"pl": "Usługi",
|
||||
}
|
||||
TOOLS_PROMPT = {
|
||||
"en": "Tools",
|
||||
"de": "Werkzeuge",
|
||||
"fr": "Outils",
|
||||
"es": "Herramientas",
|
||||
"pl": "Narzędzia",
|
||||
}
|
||||
AREA_PROMPT = {
|
||||
"en": "Area",
|
||||
"de": "Bereich",
|
||||
"fr": "Zone",
|
||||
"es": "Área",
|
||||
"pl": "Obszar",
|
||||
}
|
||||
USER_INSTRUCTION = {
|
||||
"en": "User instruction",
|
||||
"de": "Benutzeranweisung",
|
||||
"fr": "Instruction de l'utilisateur ",
|
||||
"es": "Instrucción del usuario",
|
||||
"pl": "Instrukcja użytkownika"
|
||||
}
|
||||
DEFAULT_PROMPT_BASE = """<persona>
|
||||
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}
|
||||
Tools: {{ tools | to_json }}
|
||||
Devices:
|
||||
<current_date>
|
||||
<tools>: {{ tools | to_json }}
|
||||
<devices>:
|
||||
{% for device in devices | selectattr('area_id', 'none'): %}
|
||||
{{ device.entity_id }} '{{ device.name }}' = {{ device.state }}{{ ([""] + device.attributes) | join(";") }}
|
||||
{% endfor %}
|
||||
{% for area in devices | rejectattr('area_id', 'none') | groupby('area_name') %}
|
||||
## Area: {{ area.grouper }}
|
||||
## <area>: {{ area.grouper }}
|
||||
{% for device in area.list %}
|
||||
{{ device.entity_id }} '{{ device.name }}' = {{ device.state }};{{ device.attributes | join(";") }}
|
||||
{% endfor %}
|
||||
{% endfor %}"""
|
||||
DEFAULT_PROMPT_BASE_LEGACY = """<persona>
|
||||
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}
|
||||
Services: {{ formatted_tools }}
|
||||
Devices:
|
||||
<current_date>
|
||||
<services>: {{ formatted_tools }}
|
||||
<devices>:
|
||||
{{ formatted_devices }}"""
|
||||
ICL_EXTRAS = """
|
||||
{% for item in response_examples %}
|
||||
@@ -43,7 +84,7 @@ ICL_NO_SYSTEM_PROMPT_EXTRAS = """
|
||||
{{ item.response }}
|
||||
<functioncall> {{ item.tool | to_json }}
|
||||
{% endfor %}
|
||||
User instruction:"""
|
||||
<user_instruction>:"""
|
||||
DEFAULT_PROMPT = DEFAULT_PROMPT_BASE + ICL_EXTRAS
|
||||
CONF_CHAT_MODEL = "huggingface_model"
|
||||
DEFAULT_CHAT_MODEL = "acon96/Home-3B-v3-GGUF"
|
||||
|
||||
@@ -29,7 +29,7 @@ Then create a Python virtual environment and install all necessary library:
|
||||
```
|
||||
python3 -m venv .generate_data
|
||||
source ./.generate_data/bin/activate
|
||||
pip3 install pandas=2.2.2 datasets==2.20.0 webcolors==1.13
|
||||
pip3 install pandas==2.2.2 datasets==2.20.0 webcolors==1.13 babel==2.15.0
|
||||
```
|
||||
|
||||
## Generating the dataset from piles
|
||||
|
||||
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import random
|
||||
import re
|
||||
import copy
|
||||
import babel.dates
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
@@ -51,7 +52,7 @@ pile_of_responses = None
|
||||
pile_of_status_requests = None
|
||||
pile_of_system_prompts = None
|
||||
pile_of_hallucinated_service_names = None
|
||||
and_word = None
|
||||
and_words = None
|
||||
|
||||
def closest_color(requested_color):
|
||||
min_colors = {}
|
||||
@@ -63,7 +64,7 @@ def closest_color(requested_color):
|
||||
min_colors[(rd + gd + bd)] = name
|
||||
return min_colors[min(min_colors.keys())]
|
||||
|
||||
def generate_random_date():
|
||||
def generate_random_datetime():
|
||||
start_date = datetime(2022, 1, 1)
|
||||
end_date = datetime(2030, 12, 31)
|
||||
delta = end_date - start_date
|
||||
@@ -330,6 +331,47 @@ SUPPORTED_DEVICES = {
|
||||
),
|
||||
}
|
||||
|
||||
CURRENT_DATE_PROMPT = {
|
||||
"english": "The current time and date is",
|
||||
"polish": "Aktualna godzina i data to",
|
||||
"german": "Die aktuelle Uhrzeit und das aktuelle Datum sind",
|
||||
"french": "L'heure et la date actuelles sont",
|
||||
"spanish": "La hora y fecha actuales son"
|
||||
}
|
||||
|
||||
DEVICES_PROMPT = {
|
||||
"english": "Devices",
|
||||
"polish": "Urządzenia",
|
||||
"german": "Ger\u00e4te",
|
||||
"french": "Appareils",
|
||||
"spanish": "Dispositivos"
|
||||
}
|
||||
|
||||
SERVICES_PROMPT = {
|
||||
"english": "Services",
|
||||
"polish": "Usługi",
|
||||
"german": "Dienste",
|
||||
"french": "Services",
|
||||
"spanish": "Servicios"
|
||||
}
|
||||
|
||||
BABEL_LOCALE = {
|
||||
"english": "en_US",
|
||||
"polish": "pl_PL",
|
||||
"german": "de_DE",
|
||||
"french": "fr_FR",
|
||||
"spanish": "es_ES"
|
||||
}
|
||||
|
||||
BABEL_FORMAT = {
|
||||
"english": "h:m a 'on' EEEE, MMMM d yyyy",
|
||||
"polish": "H:m 'w' EEEE, d MMMM yyyy",
|
||||
"german": "H:m EEEE, d MMMM yyyy",
|
||||
"french": "H:m EEEE, d MMMM yyyy",
|
||||
"spanish": "H:m EEEE, d 'de' MMMM 'de' yyyy"
|
||||
}
|
||||
|
||||
|
||||
class NoResponseAvailableException(Exception):
|
||||
pass
|
||||
|
||||
@@ -456,6 +498,12 @@ def generate_static_example(action: dict, persona: str, max_devices: int = 32):
|
||||
"service_calls": [ { "service": service_name, "target_device": target_device } ]
|
||||
}
|
||||
|
||||
def replace_answer(list_of_answer, var, value):
|
||||
new_list = []
|
||||
for answer in list_of_answer:
|
||||
new_list.append(answer.replace(var, value))
|
||||
return new_list
|
||||
|
||||
def generate_templated_example(template: dict, persona: str, max_devices: int = 32):
|
||||
template_device_types: list[str] = template["device_type"].split("|")
|
||||
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
||||
@@ -527,7 +575,9 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
)
|
||||
answers.append(answer.replace(f"<device_name>", chosen_devices[i]["description"]))
|
||||
|
||||
answer = f" {and_word} ".join(answers)
|
||||
answer = []
|
||||
for word in and_words:
|
||||
answer.append(f" {word} ".join(answers))
|
||||
|
||||
# generate the list of service calls and answers
|
||||
service_calls = []
|
||||
@@ -539,31 +589,31 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
if "<hvac_mode>" in question:
|
||||
hvac_mode = climate_device_type.get_random_parameter("hvac_mode")
|
||||
question = question.replace("<hvac_mode>", hvac_mode)
|
||||
answer = answer.replace("<hvac_mode>", hvac_mode)
|
||||
answer = replace_answer(answer, "<hvac_mode>", hvac_mode)
|
||||
service_calls = [ { **call, "hvac_mode": hvac_mode} for call in service_calls ]
|
||||
|
||||
if "<fan_mode>" in question:
|
||||
fan_mode = climate_device_type.get_random_parameter("fan_mode")
|
||||
question = question.replace("<fan_mode>", fan_mode)
|
||||
answer = answer.replace("<fan_mode>", fan_mode)
|
||||
answer = replace_answer(answer, "<fan_mode>", fan_mode)
|
||||
service_calls = [ { **call, "fan_mode": fan_mode} for call in service_calls ]
|
||||
|
||||
if "<temp_f>" in question:
|
||||
temp_f = climate_device_type.get_random_parameter("temp_f")
|
||||
question = question.replace("<temp_f>", str(temp_f))
|
||||
answer = answer.replace("<temp_f>", str(temp_f))
|
||||
answer = replace_answer(answer, "<temp_f>", str(temp_f))
|
||||
service_calls = [ { **call, "temperature": temp_f} for call in service_calls ]
|
||||
|
||||
if "<temp_c>" in question:
|
||||
temp_c = climate_device_type.get_random_parameter("temp_c")
|
||||
question = question.replace("<temp_c>", str(temp_c))
|
||||
answer = answer.replace("<temp_c>", str(temp_c))
|
||||
answer = replace_answer(answer, "<temp_c>", str(temp_c))
|
||||
service_calls = [ { **call, "temperature": temp_c} for call in service_calls ]
|
||||
|
||||
if "<humidity>" in question:
|
||||
humidity = climate_device_type.get_random_parameter("humidity")
|
||||
question = question.replace("<humidity>", str(humidity))
|
||||
answer = answer.replace("<humidity>", str(humidity))
|
||||
answer = replace_answer(answer, "<humidity>", str(humidity))
|
||||
service_calls = [ { **call, "humidity": humidity} for call in service_calls ]
|
||||
|
||||
if any(["light" in service for service in service_names ]):
|
||||
@@ -571,7 +621,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
if "<brightness>" in question:
|
||||
brightness = light_device_type.get_random_parameter("brightness")
|
||||
question = question.replace("<brightness>", str(brightness))
|
||||
answer = answer.replace("<brightness>", str(brightness))
|
||||
answer = replace_answer(answer, "<brightness>", str(brightness))
|
||||
service_calls = [ { **call, "brightness": round(brightness / 100, 2) } for call in service_calls ]
|
||||
|
||||
if "<color>" in question:
|
||||
@@ -580,7 +630,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
actual_random_rgb = webcolors.name_to_rgb(random_rgb_name)
|
||||
actual_random_rgb = (actual_random_rgb.red, actual_random_rgb.green, actual_random_rgb.blue)
|
||||
question = question.replace("<color>", str(random_rgb_name))
|
||||
answer = answer.replace("<color>", str(random_rgb_name))
|
||||
answer = replace_answer(answer, "<color>", str(random_rgb_name))
|
||||
service_calls = [ { **call, "rgb_color": str(actual_random_rgb) } for call in service_calls ]
|
||||
|
||||
if any(["timer" in service for service in service_names ]):
|
||||
@@ -589,7 +639,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
duration = timer_device_type.get_random_parameter("duration")
|
||||
duration_name = pile_of_durations[duration]
|
||||
question = question.replace("<duration>", duration_name)
|
||||
answer = answer.replace("<duration>", duration_name)
|
||||
answer = replace_answer(answer, "<duration>", duration_name)
|
||||
service_calls = [ { **call, "duration": str(duration) } for call in service_calls ]
|
||||
|
||||
if any(["todo" in service for service in service_names ]):
|
||||
@@ -597,14 +647,14 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
if "<todo>" in question:
|
||||
todo = todo_device_type.get_random_parameter("todo")
|
||||
question = question.replace("<todo>", todo)
|
||||
answer = answer.replace("<todo>", todo)
|
||||
answer = replace_answer(answer, "<todo>", todo)
|
||||
service_calls = [ { **call, "item": todo } for call in service_calls ]
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_services": list(available_services),
|
||||
"question": question.lower(),
|
||||
"answers": [ answer.lower() ],
|
||||
"answers": [ sentence.lower() for sentence in answer ],
|
||||
"service_calls": service_calls
|
||||
}
|
||||
|
||||
@@ -777,11 +827,11 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i
|
||||
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
|
||||
pass
|
||||
|
||||
def format_example_raw_chatml(example, persona):
|
||||
def format_example_raw_chatml(example, persona, language):
|
||||
"""Don't use this one anymore"""
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = "Devices:\n" + "\n".join(example["states"])
|
||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||
question = example["question"]
|
||||
answers = " ".join(example["answers"])
|
||||
|
||||
@@ -805,11 +855,13 @@ def format_example_raw_chatml(example, persona):
|
||||
result = result.replace("garage_door.", "cover.")
|
||||
return { "text": result }
|
||||
|
||||
def format_example_sharegpt(example, persona):
|
||||
def format_example_sharegpt(example, persona, language):
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
time_block = "The current time and date is " + generate_random_date().strftime("%I:%M %p on %A %B %d, %Y")
|
||||
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = "Devices:\n" + "\n".join(example["states"])
|
||||
random_datetime = generate_random_datetime()
|
||||
translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])
|
||||
time_block = f"{CURRENT_DATE_PROMPT[language]} {translate_datetime}"
|
||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||
question = example["question"]
|
||||
answers = " ".join(example["answers"])
|
||||
|
||||
@@ -832,13 +884,13 @@ def format_example_sharegpt(example, persona):
|
||||
|
||||
return { "conversations": conversation }
|
||||
|
||||
def format_example_dpo(example, persona):
|
||||
def format_example_dpo(example, persona, language):
|
||||
rejected_example = example["rejected"]
|
||||
example = example["accepted"]
|
||||
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = "Devices:\n" + "\n".join(example["states"])
|
||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||
question = example["question"]
|
||||
|
||||
assistant_block = " ".join(example["answers"])
|
||||
@@ -866,19 +918,19 @@ def format_example_dpo(example, persona):
|
||||
"rejected": rejected_assistant_block,
|
||||
}
|
||||
|
||||
def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
|
||||
def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], language: str, *, static_factor: int, template_factor: int, status_request_factor: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
print("Generating...")
|
||||
|
||||
def run_factor_times(func, examples, data, persona, factor):
|
||||
def run_factor_times(func, examples, data, persona, factor, language):
|
||||
if factor >= 1:
|
||||
for i in range(factor):
|
||||
examples.append(format_func(func(data, persona), persona))
|
||||
examples.append(format_func(func(data, persona), persona, language))
|
||||
else:
|
||||
if random.random() < factor:
|
||||
examples.append(format_func(func(data, persona), persona))
|
||||
examples.append(format_func(func(data, persona), persona, language))
|
||||
|
||||
generated_examples = []
|
||||
|
||||
@@ -887,18 +939,18 @@ def generate_sft_file(filename: str, seed: int, format_func: Callable, personas:
|
||||
for person in personas:
|
||||
for action in tqdm(pile_of_specific_actions):
|
||||
try:
|
||||
run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
|
||||
run_factor_times(generate_static_example, generated_examples, action, person, static_factor, language)
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
for templated_action in tqdm(pile_of_templated_actions):
|
||||
try:
|
||||
run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor)
|
||||
run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor, language)
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
for status_request in tqdm(pile_of_status_requests):
|
||||
run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
|
||||
run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor, language)
|
||||
|
||||
print(f"Generated {len(generated_examples)} examples. Saving...")
|
||||
|
||||
@@ -912,20 +964,19 @@ def generate_sft_file(filename: str, seed: int, format_func: Callable, personas:
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
|
||||
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], language: str, *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
print("Generating...")
|
||||
|
||||
def run_factor_times(func, examples, data, persona, factor):
|
||||
def run_factor_times(func, examples, data, persona, factor, language):
|
||||
if factor >= 1:
|
||||
for i in range(factor):
|
||||
examples.append(format_func(func(data, persona), persona))
|
||||
examples.append(format_func(func(data, persona), persona, language))
|
||||
else:
|
||||
if random.random() < factor:
|
||||
examples.append(format_func(func(data, persona), persona))
|
||||
examples.append(format_func(func(data, persona), persona, language))
|
||||
|
||||
generated_examples = []
|
||||
|
||||
@@ -934,15 +985,15 @@ def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas:
|
||||
for person in personas:
|
||||
for templated_action in tqdm(pile_of_templated_actions):
|
||||
try:
|
||||
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
|
||||
run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor)
|
||||
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor, language)
|
||||
run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor, language)
|
||||
# run_factor_times(generate_dpo_incorrect_persona, generated_examples, templated_action, person, incorrect_persona_factor)
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
for status_request in tqdm(pile_of_status_requests):
|
||||
try:
|
||||
run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor)
|
||||
run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor, language)
|
||||
except NoServicesAvailableException as ex:
|
||||
pass # TODO: warn here?
|
||||
|
||||
@@ -1010,23 +1061,13 @@ def merge_languages(filename_prefix: str, languages: list):
|
||||
with open(f"{filename_prefix}.jsonl", "w") as f:
|
||||
f.writelines(all_examples)
|
||||
|
||||
|
||||
def load_dataset_piles(language):
|
||||
global pile_of_durations, pile_of_media_names, pile_of_todo_items, stacks_of_device_names, \
|
||||
pile_of_templated_actions, pile_of_specific_actions, pile_of_responses, pile_of_status_requests, \
|
||||
pile_of_system_prompts, pile_of_hallucinated_service_names, and_word
|
||||
pile_of_system_prompts, pile_of_hallucinated_service_names, and_words
|
||||
|
||||
# TODO: make this properly dynamic
|
||||
if language == "english":
|
||||
and_word = "and"
|
||||
elif language == "german":
|
||||
and_word = "und"
|
||||
elif language == "french":
|
||||
and_word = "et"
|
||||
elif language == "spanish":
|
||||
and_word = "y"
|
||||
elif language == "polish":
|
||||
and_word = "i"
|
||||
with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
||||
and_words = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
@@ -1130,20 +1171,20 @@ def main():
|
||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||
|
||||
if args.sample:
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, personas, language, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
if args.train:
|
||||
if args.size == "small":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
elif args.size == "medium":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
elif args.size == "large":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
elif args.size == "xl":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, language, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
else:
|
||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||
if args.test:
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, language, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
|
||||
if len(args.language) > 1:
|
||||
if args.sample:
|
||||
@@ -1154,7 +1195,7 @@ def main():
|
||||
merge_languages("home_assistant_test", args.language)
|
||||
|
||||
if args.dpo:
|
||||
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1)
|
||||
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, language, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1)
|
||||
|
||||
if args.merge == "alpaca":
|
||||
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
|
||||
|
||||
1
data/piles/english/pile_of_and_words.csv
Normal file
1
data/piles/english/pile_of_and_words.csv
Normal file
@@ -0,0 +1 @@
|
||||
and
|
||||
|
1
data/piles/french/pile_of_and_words.csv
Normal file
1
data/piles/french/pile_of_and_words.csv
Normal file
@@ -0,0 +1 @@
|
||||
et
|
||||
|
1
data/piles/german/pile_of_and_words.csv
Normal file
1
data/piles/german/pile_of_and_words.csv
Normal file
@@ -0,0 +1 @@
|
||||
und
|
||||
|
4
data/piles/polish/pile_of_and_words.csv
Normal file
4
data/piles/polish/pile_of_and_words.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
i
|
||||
oraz
|
||||
a także
|
||||
również
|
||||
|
1
data/piles/spanish/pile_of_and_words.csv
Normal file
1
data/piles/spanish/pile_of_and_words.csv
Normal file
@@ -0,0 +1 @@
|
||||
y
|
||||
|
Reference in New Issue
Block a user