mostly working gemma implementation

This commit is contained in:
Alex O'Connell
2025-12-20 20:29:09 -05:00
parent 672a9de65c
commit 29d839eea8
8 changed files with 694 additions and 38 deletions

View File

@@ -140,6 +140,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
tool_args = {}
question = question.replace("<device_name>", target_device)
response_starting = response_starting.replace("<device_name>", target_device)
answer_list = replace_answer(answer_list, "<device_name>", target_device)
if "climate" in service_action:
@@ -520,7 +521,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
else:
return result
def format_example_sharegpt(example, persona, language, use_system_role, use_service_names, tool_response_format):
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
piles = get_dataset_piles(language)
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
question = example["question"]
@@ -546,6 +547,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
"tool_result": "Success"
})
if append_user_instruction_prompt:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
sys_prompt = "\n".join([ sys_prompt, user_instruction_words ])
if use_system_role:
conversation = [
{
@@ -558,11 +563,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
}
]
else:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
conversation = [
{
"role": "user",
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, user_instruction_words, question ]) }]
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
}
]
@@ -605,6 +609,7 @@ def generate_sft_file(
seed: int,
format_func: Callable,
use_system_role: bool,
append_user_instruction_prompt: bool,
use_service_names: bool,
personas: list[str],
language: str,
@@ -622,10 +627,10 @@ def generate_sft_file(
def run_factor_times(func, examples, data, persona, factor, language):
if factor >= 1:
for i in range(factor):
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
else:
if random.random() < factor:
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
generated_examples = []
@@ -652,7 +657,8 @@ def generate_sft_file(
for missing in sorted(missing_responses):
print(missing)
with open(f"output/{filename}.jsonl", "w") as f:
cwd = os.path.dirname(__file__)
with open(f"{cwd}/output/{filename}.jsonl", "w") as f:
for item in generated_examples:
json_record = json.dumps(item)
f.write(json_record + '\n')
@@ -676,11 +682,13 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset
def merge_languages(filename_prefix: str, languages: list):
all_examples = []
cwd = os.path.dirname(__file__)
for language in languages:
with open(f"output/{filename_prefix}_{language}.jsonl") as f:
with open(f"{cwd}/output/{filename_prefix}_{language}.jsonl") as f:
all_examples.extend(f.readlines())
with open(f"output/{filename_prefix}.jsonl", "w") as f:
with open(f"{cwd}/output/{filename_prefix}.jsonl", "w") as f:
f.writelines(all_examples)
@@ -696,9 +704,12 @@ def main(args=None):
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
parser.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. It will be combined with the user role")
parser.add_argument("--tool-response-format", default="text", choices=["text", "functiongemma"], help="Format to use for tool responses.")
role_tweaks = parser.add_mutually_exclusive_group()
role_tweaks.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. The house context will be combined with the user role")
role_tweaks.add_argument("--merged-system-role", action="store_true", help="Set this flag to still emit a system role, but assume it will be merged by the chat template into the user role.")
train_size_group = parser.add_mutually_exclusive_group()
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
@@ -721,6 +732,7 @@ def main(args=None):
format_func = format_example_sharegpt
use_system_role = not args.no_system_role
append_user_instruction_prompt = args.merged_system_role or not args.no_system_role
use_service_names = args.use_service_names
tool_response_format = args.tool_response_format
@@ -730,21 +742,20 @@ def main(args=None):
suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample:
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
if args.train:
if args.size == "small":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
elif args.size == "medium":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
elif args.size == "large":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
elif args.size == "xl":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
else:
raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test:
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
if len(args.language) > 1:
if args.sample:
merge_languages("sample", args.language)