mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
mostly working gemma implementation
This commit is contained in:
@@ -140,6 +140,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
tool_args = {}
|
||||
|
||||
question = question.replace("<device_name>", target_device)
|
||||
response_starting = response_starting.replace("<device_name>", target_device)
|
||||
answer_list = replace_answer(answer_list, "<device_name>", target_device)
|
||||
|
||||
if "climate" in service_action:
|
||||
@@ -520,7 +521,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
else:
|
||||
return result
|
||||
|
||||
def format_example_sharegpt(example, persona, language, use_system_role, use_service_names, tool_response_format):
|
||||
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
|
||||
piles = get_dataset_piles(language)
|
||||
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
||||
question = example["question"]
|
||||
@@ -546,6 +547,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
|
||||
"tool_result": "Success"
|
||||
})
|
||||
|
||||
if append_user_instruction_prompt:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
sys_prompt = "\n".join([ sys_prompt, user_instruction_words ])
|
||||
|
||||
if use_system_role:
|
||||
conversation = [
|
||||
{
|
||||
@@ -558,11 +563,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
|
||||
}
|
||||
]
|
||||
else:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, user_instruction_words, question ]) }]
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -605,6 +609,7 @@ def generate_sft_file(
|
||||
seed: int,
|
||||
format_func: Callable,
|
||||
use_system_role: bool,
|
||||
append_user_instruction_prompt: bool,
|
||||
use_service_names: bool,
|
||||
personas: list[str],
|
||||
language: str,
|
||||
@@ -622,10 +627,10 @@ def generate_sft_file(
|
||||
def run_factor_times(func, examples, data, persona, factor, language):
|
||||
if factor >= 1:
|
||||
for i in range(factor):
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||
else:
|
||||
if random.random() < factor:
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||
|
||||
generated_examples = []
|
||||
|
||||
@@ -652,7 +657,8 @@ def generate_sft_file(
|
||||
for missing in sorted(missing_responses):
|
||||
print(missing)
|
||||
|
||||
with open(f"output/{filename}.jsonl", "w") as f:
|
||||
cwd = os.path.dirname(__file__)
|
||||
with open(f"{cwd}/output/{filename}.jsonl", "w") as f:
|
||||
for item in generated_examples:
|
||||
json_record = json.dumps(item)
|
||||
f.write(json_record + '\n')
|
||||
@@ -676,11 +682,13 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset
|
||||
|
||||
def merge_languages(filename_prefix: str, languages: list):
|
||||
all_examples = []
|
||||
cwd = os.path.dirname(__file__)
|
||||
|
||||
for language in languages:
|
||||
with open(f"output/{filename_prefix}_{language}.jsonl") as f:
|
||||
with open(f"{cwd}/output/{filename_prefix}_{language}.jsonl") as f:
|
||||
all_examples.extend(f.readlines())
|
||||
|
||||
with open(f"output/{filename_prefix}.jsonl", "w") as f:
|
||||
with open(f"{cwd}/output/{filename_prefix}.jsonl", "w") as f:
|
||||
f.writelines(all_examples)
|
||||
|
||||
|
||||
@@ -696,9 +704,12 @@ def main(args=None):
|
||||
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
|
||||
parser.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. It will be combined with the user role")
|
||||
parser.add_argument("--tool-response-format", default="text", choices=["text", "functiongemma"], help="Format to use for tool responses.")
|
||||
|
||||
role_tweaks = parser.add_mutually_exclusive_group()
|
||||
role_tweaks.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. The house context will be combined with the user role")
|
||||
role_tweaks.add_argument("--merged-system-role", action="store_true", help="Set this flag to still emit a system role, but assume it will be merged by the chat template into the user role.")
|
||||
|
||||
train_size_group = parser.add_mutually_exclusive_group()
|
||||
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
|
||||
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
|
||||
@@ -721,6 +732,7 @@ def main(args=None):
|
||||
format_func = format_example_sharegpt
|
||||
|
||||
use_system_role = not args.no_system_role
|
||||
append_user_instruction_prompt = args.merged_system_role or not args.no_system_role
|
||||
use_service_names = args.use_service_names
|
||||
tool_response_format = args.tool_response_format
|
||||
|
||||
@@ -730,21 +742,20 @@ def main(args=None):
|
||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||
|
||||
if args.sample:
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
if args.train:
|
||||
if args.size == "small":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
elif args.size == "medium":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
elif args.size == "large":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
elif args.size == "xl":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
else:
|
||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||
if args.test:
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
if len(args.language) > 1:
|
||||
if args.sample:
|
||||
merge_languages("sample", args.language)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import random
|
||||
import re
|
||||
import os
|
||||
import csv
|
||||
import pandas
|
||||
from datetime import datetime, timedelta
|
||||
@@ -84,23 +85,25 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
|
||||
class DatasetPiles:
|
||||
def __init__(self, supported_devices, language="english"):
|
||||
self.language = language
|
||||
|
||||
cwd = os.path.dirname(__file__)
|
||||
|
||||
with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
||||
self.and_words = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
|
||||
|
||||
# media names are not translated
|
||||
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||
self.pile_of_media_names = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
self.stacks_of_device_names = { x: [] for x in supported_devices }
|
||||
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_device_names = list(reader)
|
||||
for device_dict in pile_of_device_names:
|
||||
@@ -110,7 +113,7 @@ class DatasetPiles:
|
||||
except KeyError as ex:
|
||||
print(ex)
|
||||
|
||||
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_templated_actions = list(reader)
|
||||
processed_pile_of_templated_actions = []
|
||||
@@ -124,23 +127,23 @@ class DatasetPiles:
|
||||
|
||||
self.pile_of_templated_actions = processed_pile_of_templated_actions
|
||||
|
||||
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_specific_actions = list(reader)
|
||||
|
||||
self.pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
|
||||
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
|
||||
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
|
||||
|
||||
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_status_requests = list(reader)
|
||||
|
||||
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
|
||||
|
||||
# service names are not translated
|
||||
with open(f"piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_hallucinated_service_names = list(reader)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user