mostly working gemma implementation

This commit is contained in:
Alex O'Connell
2025-12-20 20:29:09 -05:00
parent 672a9de65c
commit 29d839eea8
8 changed files with 694 additions and 38 deletions

View File

@@ -140,6 +140,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
tool_args = {}
question = question.replace("<device_name>", target_device)
response_starting = response_starting.replace("<device_name>", target_device)
answer_list = replace_answer(answer_list, "<device_name>", target_device)
if "climate" in service_action:
@@ -520,7 +521,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
else:
return result
def format_example_sharegpt(example, persona, language, use_system_role, use_service_names, tool_response_format):
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
piles = get_dataset_piles(language)
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
question = example["question"]
@@ -546,6 +547,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
"tool_result": "Success"
})
if append_user_instruction_prompt:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
sys_prompt = "\n".join([ sys_prompt, user_instruction_words ])
if use_system_role:
conversation = [
{
@@ -558,11 +563,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
}
]
else:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
conversation = [
{
"role": "user",
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, user_instruction_words, question ]) }]
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
}
]
@@ -605,6 +609,7 @@ def generate_sft_file(
seed: int,
format_func: Callable,
use_system_role: bool,
append_user_instruction_prompt: bool,
use_service_names: bool,
personas: list[str],
language: str,
@@ -622,10 +627,10 @@ def generate_sft_file(
def run_factor_times(func, examples, data, persona, factor, language):
if factor >= 1:
for i in range(factor):
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
else:
if random.random() < factor:
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
generated_examples = []
@@ -652,7 +657,8 @@ def generate_sft_file(
for missing in sorted(missing_responses):
print(missing)
with open(f"output/{filename}.jsonl", "w") as f:
cwd = os.path.dirname(__file__)
with open(f"{cwd}/output/{filename}.jsonl", "w") as f:
for item in generated_examples:
json_record = json.dumps(item)
f.write(json_record + '\n')
@@ -676,11 +682,13 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset
def merge_languages(filename_prefix: str, languages: list):
all_examples = []
cwd = os.path.dirname(__file__)
for language in languages:
with open(f"output/{filename_prefix}_{language}.jsonl") as f:
with open(f"{cwd}/output/{filename_prefix}_{language}.jsonl") as f:
all_examples.extend(f.readlines())
with open(f"output/{filename_prefix}.jsonl", "w") as f:
with open(f"{cwd}/output/{filename_prefix}.jsonl", "w") as f:
f.writelines(all_examples)
@@ -696,9 +704,12 @@ def main(args=None):
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
parser.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. It will be combined with the user role")
parser.add_argument("--tool-response-format", default="text", choices=["text", "functiongemma"], help="Format to use for tool responses.")
role_tweaks = parser.add_mutually_exclusive_group()
role_tweaks.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. The house context will be combined with the user role")
role_tweaks.add_argument("--merged-system-role", action="store_true", help="Set this flag to still emit a system role, but assume it will be merged by the chat template into the user role.")
train_size_group = parser.add_mutually_exclusive_group()
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
@@ -721,6 +732,7 @@ def main(args=None):
format_func = format_example_sharegpt
use_system_role = not args.no_system_role
append_user_instruction_prompt = args.merged_system_role or not args.no_system_role
use_service_names = args.use_service_names
tool_response_format = args.tool_response_format
@@ -730,21 +742,20 @@ def main(args=None):
suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample:
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
if args.train:
if args.size == "small":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
elif args.size == "medium":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
elif args.size == "large":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
elif args.size == "xl":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
else:
raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test:
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
if len(args.language) > 1:
if args.sample:
merge_languages("sample", args.language)

View File

@@ -1,5 +1,6 @@
import random
import re
import os
import csv
import pandas
from datetime import datetime, timedelta
@@ -84,23 +85,25 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
class DatasetPiles:
def __init__(self, supported_devices, language="english"):
self.language = language
cwd = os.path.dirname(__file__)
with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
self.and_words = [ x.strip() for x in f.readlines() ]
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
# media names are not translated
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
self.pile_of_media_names = [ x.strip() for x in f.readlines() ]
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
self.stacks_of_device_names = { x: [] for x in supported_devices }
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_device_names = list(reader)
for device_dict in pile_of_device_names:
@@ -110,7 +113,7 @@ class DatasetPiles:
except KeyError as ex:
print(ex)
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_templated_actions = list(reader)
processed_pile_of_templated_actions = []
@@ -124,23 +127,23 @@ class DatasetPiles:
self.pile_of_templated_actions = processed_pile_of_templated_actions
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_specific_actions = list(reader)
self.pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_status_requests = list(reader)
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
# service names are not translated
with open(f"piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_hallucinated_service_names = list(reader)