mostly working gemma implementation

This commit is contained in:
Alex O'Connell
2025-12-20 20:29:09 -05:00
parent 672a9de65c
commit 29d839eea8
8 changed files with 694 additions and 38 deletions

View File

@@ -67,6 +67,7 @@ if TYPE_CHECKING:
from llama_cpp import (
Llama as LlamaType,
LlamaGrammar as LlamaGrammarType,
LlamaDiskCache as LlamaDiskCacheType,
ChatCompletionRequestResponseFormat
)
else:
@@ -156,6 +157,7 @@ class LlamaCppClient(LocalLLMClient):
self.llama_cpp_module = importlib.import_module("llama_cpp")
Llama: type[LlamaType] = getattr(self.llama_cpp_module, "Llama")
LlamaDiskCache: type[LlamaDiskCacheType] = getattr(self.llama_cpp_module, "LlamaDiskCache")
_LOGGER.debug(f"Loading model '{model_path}'...")
model_settings = snapshot_settings(entity_options)
@@ -170,11 +172,11 @@ class LlamaCppClient(LocalLLMClient):
)
_LOGGER.debug("Model loaded")
# TODO: check about disk caching
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
# capacity_bytes=(512 * 10e8),
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
# ))
# FIXME: make cache size configurable (0 means disabled)
llm.set_cache(LlamaDiskCache(
capacity_bytes=int(512 * 10e8),
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
))
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
@callback
@@ -393,6 +395,7 @@ class LlamaCppClient(LocalLLMClient):
max_tokens=1,
grammar=grammar,
stream=False,
stop=["<end_of_turn>", "<end_function_call>"]
)
self.last_cache_prime = time.time()
@@ -464,6 +467,7 @@ class LlamaCppClient(LocalLLMClient):
grammar=grammar,
stream=True,
response_format=response_format,
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
)
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:

View File

@@ -261,12 +261,15 @@ class LocalLLMClient:
tool_content += content
if think_prefix in potential_block and not in_thinking:
_LOGGER.debug("Entering thinking block")
in_thinking = True
last_5_tokens.clear()
elif think_suffix in potential_block and in_thinking:
_LOGGER.debug("Exiting thinking block")
in_thinking = False
content = content.replace(think_suffix, "").strip()
elif tool_prefix in potential_block and not in_tool_call:
_LOGGER.debug("Entering tool call block")
in_tool_call = True
last_5_tokens.clear()
elif tool_suffix in potential_block and in_tool_call:
@@ -307,6 +310,20 @@ class LocalLLMClient:
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
yield result
if in_tool_call and tool_content:
# flush any unclosed tool calls because using the tool_suffix as a stop token can
# cause the tool_suffix to be omitted when the model streams output
tool_block = tool_content.strip().removeprefix(tool_prefix)
_LOGGER.debug("Raw tool block extracted at end: %s", tool_block)
tool_call, to_say = parse_raw_tool_call(tool_block, agent_id)
if tool_call:
_LOGGER.debug("Tool call parsed at end: %s", tool_call)
yield TextGenerationResult(
response=to_say,
response_streamed=True,
tool_calls=[tool_call]
)
async def _async_parse_completion(
self,
llm_api: llm.APIInstance | None,

View File

@@ -356,7 +356,12 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content], use
elif message.role == "tool_result":
messages.append({
"role": "tool",
"content": json.dumps(message.tool_result),
# FIXME: what is the correct format for content here? gemma expects name and result
# "content": json.dumps(message.tool_result),
"content": {
"name": message.tool_name,
"response": { "result": message.tool_result },
},
"tool_call_id": message.tool_call_id
})
@@ -404,7 +409,26 @@ def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolI
if isinstance(raw_block, dict):
parsed_tool_call = raw_block
else:
parsed_tool_call: dict = json.loads(raw_block)
try:
parsed_tool_call: dict = json.loads(raw_block)
except json.JSONDecodeError:
# handle the "gemma" tool calling format
# call:HassTurnOn{name:<escape>light.living_room_rgbww<escape>}
gemma_match = re.finditer(r"call:(?P<name>\w+){(?P<args>.+)}", raw_block)
for match in gemma_match:
tool_name = match.group("name")
raw_args = match.group("args")
args_dict = {}
for arg_match in re.finditer(r"(?P<key>\w+):<escape>(?P<value>.+?)<escape>", raw_args):
args_dict[arg_match.group("key")] = arg_match.group("value")
parsed_tool_call = {
"name": tool_name,
"arguments": args_dict
}
break # TODO: how do we properly handle multiple tool calls in one response?
else:
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted JSON")
# try to validate either format
is_services_tool_call = False

View File

@@ -140,6 +140,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
tool_args = {}
question = question.replace("<device_name>", target_device)
response_starting = response_starting.replace("<device_name>", target_device)
answer_list = replace_answer(answer_list, "<device_name>", target_device)
if "climate" in service_action:
@@ -520,7 +521,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
else:
return result
def format_example_sharegpt(example, persona, language, use_system_role, use_service_names, tool_response_format):
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
piles = get_dataset_piles(language)
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
question = example["question"]
@@ -546,6 +547,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
"tool_result": "Success"
})
if append_user_instruction_prompt:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
sys_prompt = "\n".join([ sys_prompt, user_instruction_words ])
if use_system_role:
conversation = [
{
@@ -558,11 +563,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
}
]
else:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
conversation = [
{
"role": "user",
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, user_instruction_words, question ]) }]
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
}
]
@@ -605,6 +609,7 @@ def generate_sft_file(
seed: int,
format_func: Callable,
use_system_role: bool,
append_user_instruction_prompt: bool,
use_service_names: bool,
personas: list[str],
language: str,
@@ -622,10 +627,10 @@ def generate_sft_file(
def run_factor_times(func, examples, data, persona, factor, language):
if factor >= 1:
for i in range(factor):
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
else:
if random.random() < factor:
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
generated_examples = []
@@ -652,7 +657,8 @@ def generate_sft_file(
for missing in sorted(missing_responses):
print(missing)
with open(f"output/{filename}.jsonl", "w") as f:
cwd = os.path.dirname(__file__)
with open(f"{cwd}/output/{filename}.jsonl", "w") as f:
for item in generated_examples:
json_record = json.dumps(item)
f.write(json_record + '\n')
@@ -676,11 +682,13 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset
def merge_languages(filename_prefix: str, languages: list):
all_examples = []
cwd = os.path.dirname(__file__)
for language in languages:
with open(f"output/{filename_prefix}_{language}.jsonl") as f:
with open(f"{cwd}/output/{filename_prefix}_{language}.jsonl") as f:
all_examples.extend(f.readlines())
with open(f"output/{filename_prefix}.jsonl", "w") as f:
with open(f"{cwd}/output/{filename_prefix}.jsonl", "w") as f:
f.writelines(all_examples)
@@ -696,9 +704,12 @@ def main(args=None):
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
parser.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. It will be combined with the user role")
parser.add_argument("--tool-response-format", default="text", choices=["text", "functiongemma"], help="Format to use for tool responses.")
role_tweaks = parser.add_mutually_exclusive_group()
role_tweaks.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. The house context will be combined with the user role")
role_tweaks.add_argument("--merged-system-role", action="store_true", help="Set this flag to still emit a system role, but assume it will be merged by the chat template into the user role.")
train_size_group = parser.add_mutually_exclusive_group()
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
@@ -721,6 +732,7 @@ def main(args=None):
format_func = format_example_sharegpt
use_system_role = not args.no_system_role
append_user_instruction_prompt = args.merged_system_role or not args.no_system_role
use_service_names = args.use_service_names
tool_response_format = args.tool_response_format
@@ -730,21 +742,20 @@ def main(args=None):
suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample:
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
if args.train:
if args.size == "small":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
elif args.size == "medium":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
elif args.size == "large":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
elif args.size == "xl":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
else:
raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test:
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
if len(args.language) > 1:
if args.sample:
merge_languages("sample", args.language)

View File

@@ -1,5 +1,6 @@
import random
import re
import os
import csv
import pandas
from datetime import datetime, timedelta
@@ -84,23 +85,25 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
class DatasetPiles:
def __init__(self, supported_devices, language="english"):
self.language = language
cwd = os.path.dirname(__file__)
with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
self.and_words = [ x.strip() for x in f.readlines() ]
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
# media names are not translated
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
self.pile_of_media_names = [ x.strip() for x in f.readlines() ]
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
self.stacks_of_device_names = { x: [] for x in supported_devices }
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_device_names = list(reader)
for device_dict in pile_of_device_names:
@@ -110,7 +113,7 @@ class DatasetPiles:
except KeyError as ex:
print(ex)
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_templated_actions = list(reader)
processed_pile_of_templated_actions = []
@@ -124,23 +127,23 @@ class DatasetPiles:
self.pile_of_templated_actions = processed_pile_of_templated_actions
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_specific_actions = list(reader)
self.pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_status_requests = list(reader)
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
# service names are not translated
with open(f"piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
self.pile_of_hallucinated_service_names = list(reader)

View File

@@ -0,0 +1,290 @@
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{- key }}:{description:<escape>{{ value['description'] }}<escape>
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
,enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'OBJECT' -%}
,properties:{
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
{%- elif value is mapping -%}
{{- format_parameters(value, value['required'] | default([])) -}}
{%- endif -%}
}
{%- if value['required'] -%}
,required:[
{%- for item in value['required'] | default([]) -%}
<escape>{{- item -}}<escape>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
,items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'].items() -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<escape>{{- req_item -}}<escape>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
,type:<escape>{{ value['type'] | upper }}<escape>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{% macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}
{description:<escape>{{- tool_data['function']['description'] -}}<escape>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<escape>{{- item -}}<escape>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<escape>{{- params['type'] | upper -}}<escape>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{% macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<escape>' + argument + '<escape>' -}}
{%- elif argument is boolean -%}
{%- if argument -%}
{{- 'true' -}}
{%- else -%}
{{- 'false' -}}
{%- endif -%}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument.items() -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<escape>' + key + '<escape>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is iterable -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{{ bos_token }}
{%- set ns = namespace(prev_message_type=None) -%}
{#- extract system prompt for merging with user role -#}
{%- set loop_messages = messages -%}
{%- set system_message_content = '' %}
{%- if messages[0]['role'] == 'system' or messages[0]['role'] == 'developer' -%}
{%- set system_message_content = messages[0]['content'] -%}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{#- 'static' system prompt. -#}
{%- if tools -%}
{{- '<start_of_turn>developer\nYou are a model that can do function calling with the following functions' -}}
{%- for tool in tools %}
{{- '<start_function_declaration>' -}}
{{- format_function_declaration(tool) | trim }}
{{- '<end_function_declaration>' -}}
{%- endfor %}
{{- '<end_of_turn>\n' -}}
{%- else -%}
{{- '<start_of_turn>developer\nNo tools have been provided. Only respond with answers that do not require tool usage.<end_of_turn>\n' -}}
{%- endif -%}
{#- Loop through messages. -#}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'assistant') -%}
{#- Rename "assistant" to "model". -#}
{%- set role = "model" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{%- if role != 'tool' -%}
{%- if ns.prev_message_type != 'tool_response' -%}
{{- '<start_of_turn>' + role + '\n' }}
{%- endif -%}
{%- set ns.prev_message_type = None -%}
{%- if loop.first and system_message_content -%}
{%- if system_message_content is string -%}
{{ system_message_content | trim }}
{%- elif system_message_content is iterable -%}
{%- for item in system_message_content -%}
{%- if item['type'] == 'image' -%}
{{ raise_exception("Invalid content type 'image' in system message") }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type in system message") }}
{%- endif -%}
{{- '\n' -}}
{%- endif -%}
{#- User/Assistant Messages -#}
{%- if 'content' in message and message['content'] is not none -%}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type in user/assistant message") }}
{%- endif -%}
{%- set ns.prev_message_type = 'content' -%}
{%- endif -%}
{%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls'] is iterable -%}
{#- Tool Calls -#}
{%- for tool_call in message['tool_calls'] -%}
{% set function = tool_call['function'] %}
{{- '<start_function_call>call:' + function['name'] + '{' -}}
{%- if 'arguments' in function -%}
{%- if function['arguments'] is mapping -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{# This handles string-JSON, just in case #}
{{ function['arguments'] }}
{%- endif %}
{%- endif -%}
{{- '}<end_function_call>' -}}
{%- endfor -%}
{%- if loop.last -%}
{{ '<start_function_response>' }}
{%- endif -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- else -%}
{#- Tool Responses -#}
{%- if 'content' in message and message['content'] -%}
{%- if message['content'] is mapping -%}
{%- if 'name' in message['content'] and 'response' in message['content'] -%}
{{ '<start_function_response>response:' + message['content']['name'] | trim + '{' }}
{%- set response_ns = namespace(found_first=false) -%}
{%- for key, value in message['content']['response'] | dictsort -%}
{%- if response_ns.found_first %},{% endif -%}
{%- set response_ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{{- '}<end_function_response>' -}}
{%- elif 'name' in message -%}
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
{%- set response_ns = namespace(found_first=false) -%}
{%- for key, value in message['content'].items() -%}
{%- if response_ns.found_first %},{% endif -%}
{%- set response_ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{{- '}<end_function_response>' -}}
{%- else -%}
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
{%- endif -%}
{%- elif message['content'] is string -%}
{%- if 'name' in message -%}
{{ '<start_function_response>response:' + message['name'] | trim + '{value:' + format_argument(message['content'], escape_keys=False) + '}<end_function_response>' }}
{%- else -%}
{{ raise_exception("Invalid tool response: 'name' must be provided.") }}
{%- endif -%}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item is mapping -%}
{%- if 'name' in item and 'response' in item -%}
{{ '<start_function_response>response:' + item['name'] | trim + '{' }}
{%- set response_ns = namespace(found_first=false) -%}
{%- for key, value in item['response'].items() -%}
{%- if response_ns.found_first %},{% endif -%}
{%- set response_ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{{- '}<end_function_response>' -}}
{%- elif 'name' in message -%}
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
{%- set response_ns = namespace(found_first=false) -%}
{%- for key, value in item.items() -%}
{%- if response_ns.found_first %},{% endif -%}
{%- set response_ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{{- '}<end_function_response>' -}}
{%- else -%}
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
{%- endif -%}
{%- else -%}
{{ raise_exception("Invalid tool response message: multiple responses must all be mappings") }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type in tool message: must be mapping, iterable of mappings, or string.") }}
{%- endif -%}
{%- endif -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- if ns.prev_message_type not in ['tool_call', 'tool_response'] -%}
{{ '<end_of_turn>\n' }}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' -%}
{{- '<start_of_turn>model\n' -}}
{%- endif -%}
{%- endif -%}

View File

@@ -4,6 +4,306 @@ model_type: Gemma3ForCausalLM
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: jinja
chat_template_jinja: |
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{- key }}:{description:<escape>{{ value['description'] }}<escape>
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
,enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'OBJECT' -%}
,properties:{
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
{%- elif value is mapping -%}
{{- format_parameters(value, value['required'] | default([])) -}}
{%- endif -%}
}
{%- if value['required'] -%}
,required:[
{%- for item in value['required'] | default([]) -%}
<escape>{{- item -}}<escape>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
,items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'].items() -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<escape>{{- req_item -}}<escape>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
,type:<escape>{{ value['type'] | upper }}<escape>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{% macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}
{description:<escape>{{- tool_data['function']['description'] -}}<escape>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<escape>{{- item -}}<escape>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<escape>{{- params['type'] | upper -}}<escape>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{% macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<escape>' + argument + '<escape>' -}}
{%- elif argument is boolean -%}
{%- if argument -%}
{{- 'true' -}}
{%- else -%}
{{- 'false' -}}
{%- endif -%}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument.items() -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<escape>' + key + '<escape>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is iterable -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{{ bos_token }}
{%- set ns = namespace(prev_message_type=None) -%}
{#- extract system prompt for merging with user role -#}
{%- set loop_messages = messages -%}
{%- set system_message_content = '' %}
{%- if messages[0]['role'] == 'system' or messages[0]['role'] == 'developer' -%}
{%- set system_message_content = messages[0]['content'] -%}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{#- 'static' system prompt. -#}
{%- if tools -%}
{{- '<start_of_turn>developer\nYou are a model that can do function calling with the following functions' -}}
{%- for tool in tools %}
{{- '<start_function_declaration>' -}}
{{- format_function_declaration(tool) | trim }}
{{- '<end_function_declaration>' -}}
{%- endfor %}
{{- '<end_of_turn>\n' -}}
{%- else -%}
{{- '<start_of_turn>developer\nNo tools have been provided. Only respond with answers that do not require tool usage.<end_of_turn>\n' -}}
{%- endif -%}
{#- Loop through messages. -#}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'assistant') -%}
{#- Rename "assistant" to "model". -#}
{%- set role = "model" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{%- if role != 'tool' -%}
{%- if ns.prev_message_type != 'tool_response' -%}
{{- '<start_of_turn>' + role + '\n' }}
{%- endif -%}
{%- set ns.prev_message_type = None -%}
{%- if loop.first and system_message_content -%}
{%- if system_message_content is string -%}
{{ system_message_content | trim }}
{%- elif system_message_content is iterable -%}
{%- for item in system_message_content -%}
{%- if item['type'] == 'image' -%}
{{ raise_exception("Invalid content type 'image' in system message") }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type in system message") }}
{%- endif -%}
{{- '\n' -}}
{%- endif -%}
{#- User/Assistant Messages -#}
{%- if 'content' in message and message['content'] is not none -%}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type in user/assistant message") }}
{%- endif -%}
{%- set ns.prev_message_type = 'content' -%}
{%- endif -%}
{%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls'] is iterable -%}
{#- Tool Calls -#}
{%- for tool_call in message['tool_calls'] -%}
{% set function = tool_call['function'] %}
{{- '<start_function_call>call:' + function['name'] + '{' -}}
{%- if 'arguments' in function -%}
{%- if function['arguments'] is mapping -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{# This handles string-JSON, just in case #}
{{ function['arguments'] }}
{%- endif %}
{%- endif -%}
{{- '}<end_function_call>' -}}
{%- endfor -%}
{%- if loop.last -%}
{{ '<start_function_response>' }}
{%- endif -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- else -%}
{#- Tool Responses -#}
{%- if 'content' in message and message['content'] -%}
{%- if message['content'] is mapping -%}
{%- if 'name' in message['content'] and 'response' in message['content'] -%}
{{ '<start_function_response>response:' + message['content']['name'] | trim + '{' }}
{%- set response_ns = namespace(found_first=false) -%}
{%- for key, value in message['content']['response'] | dictsort -%}
{%- if response_ns.found_first %},{% endif -%}
{%- set response_ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{{- '}<end_function_response>' -}}
{%- elif 'name' in message -%}
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
{%- set response_ns = namespace(found_first=false) -%}
{%- for key, value in message['content'].items() -%}
{%- if response_ns.found_first %},{% endif -%}
{%- set response_ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{{- '}<end_function_response>' -}}
{%- else -%}
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
{%- endif -%}
{%- elif message['content'] is string -%}
{%- if 'name' in message -%}
{{ '<start_function_response>response:' + message['name'] | trim + '{value:' + format_argument(message['content'], escape_keys=False) + '}<end_function_response>' }}
{%- else -%}
{{ raise_exception("Invalid tool response: 'name' must be provided.") }}
{%- endif -%}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item is mapping -%}
{%- if 'name' in item and 'response' in item -%}
{{ '<start_function_response>response:' + item['name'] | trim + '{' }}
{%- set response_ns = namespace(found_first=false) -%}
{%- for key, value in item['response'].items() -%}
{%- if response_ns.found_first %},{% endif -%}
{%- set response_ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{{- '}<end_function_response>' -}}
{%- elif 'name' in message -%}
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
{%- set response_ns = namespace(found_first=false) -%}
{%- for key, value in item.items() -%}
{%- if response_ns.found_first %},{% endif -%}
{%- set response_ns.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{{- '}<end_function_response>' -}}
{%- else -%}
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
{%- endif -%}
{%- else -%}
{{ raise_exception("Invalid tool response message: multiple responses must all be mappings") }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type in tool message: must be mapping, iterable of mappings, or string.") }}
{%- endif -%}
{%- endif -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- if ns.prev_message_type not in ['tool_call', 'tool_response'] -%}
{{ '<end_of_turn>\n' }}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' -%}
{{- '<start_of_turn>model\n' -}}
{%- endif -%}
{%- endif -%}
special_tokens:
eot_tokens:
- <end_of_turn>
- <end_function_call>
eos_token: <end_of_turn>
datasets:
- path: /workspace/data/datasets/sample.jsonl
ds_type: json
@@ -12,10 +312,10 @@ datasets:
- assistant
val_set_size: 0.0
output_dir: /workspace/data/training-runs/Home-Gemma3-270m
output_dir: /workspace/data/training-runs/Home-FunctionGemma-270m
sequence_len: 4096
sample_packing: true
sequence_len: 5130
sample_packing: false
eval_sample_packing: false
use_tensorboard: true

View File

@@ -18,10 +18,15 @@ spec:
command:
- axolotl
- train
- /workspace/configs/gemma3-270m.yml
- /workspace/configs/functiongemma-270m.yml
env:
- name: AXOLOTL_DO_NOT_TRACK
value: "1"
- name: HF_TOKEN
valueFrom:
secretKeyRef:
name: hf-token
key: token
volumeMounts:
- name: training-runs
mountPath: /workspace/data/training-runs
@@ -34,9 +39,11 @@ spec:
resources:
limits:
nvidia.com/gpu: 2
initContainers:
- name: tensorboard
image: python:3.11-slim
imagePullPolicy: IfNotPresent
restartPolicy: Always # mark as sidecar
command:
- bash
- -c