mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-07 21:04:08 -05:00
mostly working gemma implementation
This commit is contained in:
@@ -67,6 +67,7 @@ if TYPE_CHECKING:
|
||||
from llama_cpp import (
|
||||
Llama as LlamaType,
|
||||
LlamaGrammar as LlamaGrammarType,
|
||||
LlamaDiskCache as LlamaDiskCacheType,
|
||||
ChatCompletionRequestResponseFormat
|
||||
)
|
||||
else:
|
||||
@@ -156,6 +157,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
||||
|
||||
Llama: type[LlamaType] = getattr(self.llama_cpp_module, "Llama")
|
||||
LlamaDiskCache: type[LlamaDiskCacheType] = getattr(self.llama_cpp_module, "LlamaDiskCache")
|
||||
|
||||
_LOGGER.debug(f"Loading model '{model_path}'...")
|
||||
model_settings = snapshot_settings(entity_options)
|
||||
@@ -170,11 +172,11 @@ class LlamaCppClient(LocalLLMClient):
|
||||
)
|
||||
_LOGGER.debug("Model loaded")
|
||||
|
||||
# TODO: check about disk caching
|
||||
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
|
||||
# capacity_bytes=(512 * 10e8),
|
||||
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
# ))
|
||||
# FIXME: make cache size configurable (0 means disabled)
|
||||
llm.set_cache(LlamaDiskCache(
|
||||
capacity_bytes=int(512 * 10e8),
|
||||
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
))
|
||||
|
||||
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
||||
@callback
|
||||
@@ -393,6 +395,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
max_tokens=1,
|
||||
grammar=grammar,
|
||||
stream=False,
|
||||
stop=["<end_of_turn>", "<end_function_call>"]
|
||||
)
|
||||
|
||||
self.last_cache_prime = time.time()
|
||||
@@ -464,6 +467,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
grammar=grammar,
|
||||
stream=True,
|
||||
response_format=response_format,
|
||||
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
|
||||
)
|
||||
|
||||
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
||||
|
||||
@@ -261,12 +261,15 @@ class LocalLLMClient:
|
||||
tool_content += content
|
||||
|
||||
if think_prefix in potential_block and not in_thinking:
|
||||
_LOGGER.debug("Entering thinking block")
|
||||
in_thinking = True
|
||||
last_5_tokens.clear()
|
||||
elif think_suffix in potential_block and in_thinking:
|
||||
_LOGGER.debug("Exiting thinking block")
|
||||
in_thinking = False
|
||||
content = content.replace(think_suffix, "").strip()
|
||||
elif tool_prefix in potential_block and not in_tool_call:
|
||||
_LOGGER.debug("Entering tool call block")
|
||||
in_tool_call = True
|
||||
last_5_tokens.clear()
|
||||
elif tool_suffix in potential_block and in_tool_call:
|
||||
@@ -307,6 +310,20 @@ class LocalLLMClient:
|
||||
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
|
||||
yield result
|
||||
|
||||
if in_tool_call and tool_content:
|
||||
# flush any unclosed tool calls because using the tool_suffix as a stop token can
|
||||
# cause the tool_suffix to be omitted when the model streams output
|
||||
tool_block = tool_content.strip().removeprefix(tool_prefix)
|
||||
_LOGGER.debug("Raw tool block extracted at end: %s", tool_block)
|
||||
tool_call, to_say = parse_raw_tool_call(tool_block, agent_id)
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed at end: %s", tool_call)
|
||||
yield TextGenerationResult(
|
||||
response=to_say,
|
||||
response_streamed=True,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
|
||||
async def _async_parse_completion(
|
||||
self,
|
||||
llm_api: llm.APIInstance | None,
|
||||
|
||||
@@ -356,7 +356,12 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content], use
|
||||
elif message.role == "tool_result":
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(message.tool_result),
|
||||
# FIXME: what is the correct format for content here? gemma expects name and result
|
||||
# "content": json.dumps(message.tool_result),
|
||||
"content": {
|
||||
"name": message.tool_name,
|
||||
"response": { "result": message.tool_result },
|
||||
},
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
|
||||
@@ -404,7 +409,26 @@ def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolI
|
||||
if isinstance(raw_block, dict):
|
||||
parsed_tool_call = raw_block
|
||||
else:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
try:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
except json.JSONDecodeError:
|
||||
# handle the "gemma" tool calling format
|
||||
# call:HassTurnOn{name:<escape>light.living_room_rgbww<escape>}
|
||||
gemma_match = re.finditer(r"call:(?P<name>\w+){(?P<args>.+)}", raw_block)
|
||||
for match in gemma_match:
|
||||
tool_name = match.group("name")
|
||||
raw_args = match.group("args")
|
||||
args_dict = {}
|
||||
for arg_match in re.finditer(r"(?P<key>\w+):<escape>(?P<value>.+?)<escape>", raw_args):
|
||||
args_dict[arg_match.group("key")] = arg_match.group("value")
|
||||
|
||||
parsed_tool_call = {
|
||||
"name": tool_name,
|
||||
"arguments": args_dict
|
||||
}
|
||||
break # TODO: how do we properly handle multiple tool calls in one response?
|
||||
else:
|
||||
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted JSON")
|
||||
|
||||
# try to validate either format
|
||||
is_services_tool_call = False
|
||||
|
||||
@@ -140,6 +140,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
tool_args = {}
|
||||
|
||||
question = question.replace("<device_name>", target_device)
|
||||
response_starting = response_starting.replace("<device_name>", target_device)
|
||||
answer_list = replace_answer(answer_list, "<device_name>", target_device)
|
||||
|
||||
if "climate" in service_action:
|
||||
@@ -520,7 +521,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
else:
|
||||
return result
|
||||
|
||||
def format_example_sharegpt(example, persona, language, use_system_role, use_service_names, tool_response_format):
|
||||
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
|
||||
piles = get_dataset_piles(language)
|
||||
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
||||
question = example["question"]
|
||||
@@ -546,6 +547,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
|
||||
"tool_result": "Success"
|
||||
})
|
||||
|
||||
if append_user_instruction_prompt:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
sys_prompt = "\n".join([ sys_prompt, user_instruction_words ])
|
||||
|
||||
if use_system_role:
|
||||
conversation = [
|
||||
{
|
||||
@@ -558,11 +563,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
|
||||
}
|
||||
]
|
||||
else:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, user_instruction_words, question ]) }]
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -605,6 +609,7 @@ def generate_sft_file(
|
||||
seed: int,
|
||||
format_func: Callable,
|
||||
use_system_role: bool,
|
||||
append_user_instruction_prompt: bool,
|
||||
use_service_names: bool,
|
||||
personas: list[str],
|
||||
language: str,
|
||||
@@ -622,10 +627,10 @@ def generate_sft_file(
|
||||
def run_factor_times(func, examples, data, persona, factor, language):
|
||||
if factor >= 1:
|
||||
for i in range(factor):
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||
else:
|
||||
if random.random() < factor:
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||
|
||||
generated_examples = []
|
||||
|
||||
@@ -652,7 +657,8 @@ def generate_sft_file(
|
||||
for missing in sorted(missing_responses):
|
||||
print(missing)
|
||||
|
||||
with open(f"output/{filename}.jsonl", "w") as f:
|
||||
cwd = os.path.dirname(__file__)
|
||||
with open(f"{cwd}/output/{filename}.jsonl", "w") as f:
|
||||
for item in generated_examples:
|
||||
json_record = json.dumps(item)
|
||||
f.write(json_record + '\n')
|
||||
@@ -676,11 +682,13 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset
|
||||
|
||||
def merge_languages(filename_prefix: str, languages: list):
|
||||
all_examples = []
|
||||
cwd = os.path.dirname(__file__)
|
||||
|
||||
for language in languages:
|
||||
with open(f"output/{filename_prefix}_{language}.jsonl") as f:
|
||||
with open(f"{cwd}/output/{filename_prefix}_{language}.jsonl") as f:
|
||||
all_examples.extend(f.readlines())
|
||||
|
||||
with open(f"output/{filename_prefix}.jsonl", "w") as f:
|
||||
with open(f"{cwd}/output/{filename_prefix}.jsonl", "w") as f:
|
||||
f.writelines(all_examples)
|
||||
|
||||
|
||||
@@ -696,9 +704,12 @@ def main(args=None):
|
||||
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
|
||||
parser.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. It will be combined with the user role")
|
||||
parser.add_argument("--tool-response-format", default="text", choices=["text", "functiongemma"], help="Format to use for tool responses.")
|
||||
|
||||
role_tweaks = parser.add_mutually_exclusive_group()
|
||||
role_tweaks.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. The house context will be combined with the user role")
|
||||
role_tweaks.add_argument("--merged-system-role", action="store_true", help="Set this flag to still emit a system role, but assume it will be merged by the chat template into the user role.")
|
||||
|
||||
train_size_group = parser.add_mutually_exclusive_group()
|
||||
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
|
||||
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
|
||||
@@ -721,6 +732,7 @@ def main(args=None):
|
||||
format_func = format_example_sharegpt
|
||||
|
||||
use_system_role = not args.no_system_role
|
||||
append_user_instruction_prompt = args.merged_system_role or not args.no_system_role
|
||||
use_service_names = args.use_service_names
|
||||
tool_response_format = args.tool_response_format
|
||||
|
||||
@@ -730,21 +742,20 @@ def main(args=None):
|
||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||
|
||||
if args.sample:
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
if args.train:
|
||||
if args.size == "small":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
elif args.size == "medium":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
elif args.size == "large":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
elif args.size == "xl":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
else:
|
||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||
if args.test:
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
if len(args.language) > 1:
|
||||
if args.sample:
|
||||
merge_languages("sample", args.language)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import random
|
||||
import re
|
||||
import os
|
||||
import csv
|
||||
import pandas
|
||||
from datetime import datetime, timedelta
|
||||
@@ -84,23 +85,25 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
|
||||
class DatasetPiles:
|
||||
def __init__(self, supported_devices, language="english"):
|
||||
self.language = language
|
||||
|
||||
cwd = os.path.dirname(__file__)
|
||||
|
||||
with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
||||
self.and_words = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
|
||||
|
||||
# media names are not translated
|
||||
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||
self.pile_of_media_names = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
self.stacks_of_device_names = { x: [] for x in supported_devices }
|
||||
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_device_names = list(reader)
|
||||
for device_dict in pile_of_device_names:
|
||||
@@ -110,7 +113,7 @@ class DatasetPiles:
|
||||
except KeyError as ex:
|
||||
print(ex)
|
||||
|
||||
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_templated_actions = list(reader)
|
||||
processed_pile_of_templated_actions = []
|
||||
@@ -124,23 +127,23 @@ class DatasetPiles:
|
||||
|
||||
self.pile_of_templated_actions = processed_pile_of_templated_actions
|
||||
|
||||
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_specific_actions = list(reader)
|
||||
|
||||
self.pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
|
||||
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
|
||||
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
|
||||
|
||||
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_status_requests = list(reader)
|
||||
|
||||
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
|
||||
|
||||
# service names are not translated
|
||||
with open(f"piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
||||
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_hallucinated_service_names = list(reader)
|
||||
|
||||
|
||||
290
train/chat_templates/functiongemma_autotools.j2
Normal file
290
train/chat_templates/functiongemma_autotools.j2
Normal file
@@ -0,0 +1,290 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{- key }}:{description:<escape>{{ value['description'] }}<escape>
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
,enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<escape>{{- item -}}<escape>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'].items() -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<escape>{{- req_item -}}<escape>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
,type:<escape>{{ value['type'] | upper }}<escape>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{% macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}
|
||||
{description:<escape>{{- tool_data['function']['description'] -}}<escape>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<escape>{{- item -}}<escape>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<escape>{{- params['type'] | upper -}}<escape>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{% macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<escape>' + argument + '<escape>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{%- if argument -%}
|
||||
{{- 'true' -}}
|
||||
{%- else -%}
|
||||
{{- 'false' -}}
|
||||
{%- endif -%}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument.items() -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<escape>' + key + '<escape>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is iterable -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{{ bos_token }}
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{#- extract system prompt for merging with user role -#}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- set system_message_content = '' %}
|
||||
{%- if messages[0]['role'] == 'system' or messages[0]['role'] == 'developer' -%}
|
||||
{%- set system_message_content = messages[0]['content'] -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{#- 'static' system prompt. -#}
|
||||
{%- if tools -%}
|
||||
{{- '<start_of_turn>developer\nYou are a model that can do function calling with the following functions' -}}
|
||||
{%- for tool in tools %}
|
||||
{{- '<start_function_declaration>' -}}
|
||||
{{- format_function_declaration(tool) | trim }}
|
||||
{{- '<end_function_declaration>' -}}
|
||||
{%- endfor %}
|
||||
{{- '<end_of_turn>\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<start_of_turn>developer\nNo tools have been provided. Only respond with answers that do not require tool usage.<end_of_turn>\n' -}}
|
||||
{%- endif -%}
|
||||
{#- Loop through messages. -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if (message['role'] == 'assistant') -%}
|
||||
{#- Rename "assistant" to "model". -#}
|
||||
{%- set role = "model" -%}
|
||||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{%- if role != 'tool' -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<start_of_turn>' + role + '\n' }}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- if loop.first and system_message_content -%}
|
||||
{%- if system_message_content is string -%}
|
||||
{{ system_message_content | trim }}
|
||||
{%- elif system_message_content is iterable -%}
|
||||
{%- for item in system_message_content -%}
|
||||
{%- if item['type'] == 'image' -%}
|
||||
{{ raise_exception("Invalid content type 'image' in system message") }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type in system message") }}
|
||||
{%- endif -%}
|
||||
{{- '\n' -}}
|
||||
{%- endif -%}
|
||||
{#- User/Assistant Messages -#}
|
||||
{%- if 'content' in message and message['content'] is not none -%}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'image' -%}
|
||||
{{ '<start_of_image>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type in user/assistant message") }}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_message_type = 'content' -%}
|
||||
{%- endif -%}
|
||||
{%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls'] is iterable -%}
|
||||
{#- Tool Calls -#}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{% set function = tool_call['function'] %}
|
||||
{{- '<start_function_call>call:' + function['name'] + '{' -}}
|
||||
{%- if 'arguments' in function -%}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{# This handles string-JSON, just in case #}
|
||||
{{ function['arguments'] }}
|
||||
{%- endif %}
|
||||
{%- endif -%}
|
||||
{{- '}<end_function_call>' -}}
|
||||
{%- endfor -%}
|
||||
{%- if loop.last -%}
|
||||
{{ '<start_function_response>' }}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{#- Tool Responses -#}
|
||||
{%- if 'content' in message and message['content'] -%}
|
||||
{%- if message['content'] is mapping -%}
|
||||
{%- if 'name' in message['content'] and 'response' in message['content'] -%}
|
||||
{{ '<start_function_response>response:' + message['content']['name'] | trim + '{' }}
|
||||
{%- set response_ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in message['content']['response'] | dictsort -%}
|
||||
{%- if response_ns.found_first %},{% endif -%}
|
||||
{%- set response_ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}<end_function_response>' -}}
|
||||
{%- elif 'name' in message -%}
|
||||
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
|
||||
{%- set response_ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in message['content'].items() -%}
|
||||
{%- if response_ns.found_first %},{% endif -%}
|
||||
{%- set response_ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}<end_function_response>' -}}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is string -%}
|
||||
{%- if 'name' in message -%}
|
||||
{{ '<start_function_response>response:' + message['name'] | trim + '{value:' + format_argument(message['content'], escape_keys=False) + '}<end_function_response>' }}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid tool response: 'name' must be provided.") }}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item is mapping -%}
|
||||
{%- if 'name' in item and 'response' in item -%}
|
||||
{{ '<start_function_response>response:' + item['name'] | trim + '{' }}
|
||||
{%- set response_ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in item['response'].items() -%}
|
||||
{%- if response_ns.found_first %},{% endif -%}
|
||||
{%- set response_ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}<end_function_response>' -}}
|
||||
{%- elif 'name' in message -%}
|
||||
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
|
||||
{%- set response_ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in item.items() -%}
|
||||
{%- if response_ns.found_first %},{% endif -%}
|
||||
{%- set response_ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}<end_function_response>' -}}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid tool response message: multiple responses must all be mappings") }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type in tool message: must be mapping, iterable of mappings, or string.") }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.prev_message_type not in ['tool_call', 'tool_response'] -%}
|
||||
{{ '<end_of_turn>\n' }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<start_of_turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -4,6 +4,306 @@ model_type: Gemma3ForCausalLM
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: jinja
|
||||
chat_template_jinja: |
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{- key }}:{description:<escape>{{ value['description'] }}<escape>
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
,enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<escape>{{- item -}}<escape>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'].items() -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<escape>{{- req_item -}}<escape>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
,type:<escape>{{ value['type'] | upper }}<escape>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{% macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}
|
||||
{description:<escape>{{- tool_data['function']['description'] -}}<escape>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<escape>{{- item -}}<escape>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<escape>{{- params['type'] | upper -}}<escape>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{% macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<escape>' + argument + '<escape>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{%- if argument -%}
|
||||
{{- 'true' -}}
|
||||
{%- else -%}
|
||||
{{- 'false' -}}
|
||||
{%- endif -%}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument.items() -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<escape>' + key + '<escape>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is iterable -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{{ bos_token }}
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{#- extract system prompt for merging with user role -#}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- set system_message_content = '' %}
|
||||
{%- if messages[0]['role'] == 'system' or messages[0]['role'] == 'developer' -%}
|
||||
{%- set system_message_content = messages[0]['content'] -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{#- 'static' system prompt. -#}
|
||||
{%- if tools -%}
|
||||
{{- '<start_of_turn>developer\nYou are a model that can do function calling with the following functions' -}}
|
||||
{%- for tool in tools %}
|
||||
{{- '<start_function_declaration>' -}}
|
||||
{{- format_function_declaration(tool) | trim }}
|
||||
{{- '<end_function_declaration>' -}}
|
||||
{%- endfor %}
|
||||
{{- '<end_of_turn>\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<start_of_turn>developer\nNo tools have been provided. Only respond with answers that do not require tool usage.<end_of_turn>\n' -}}
|
||||
{%- endif -%}
|
||||
{#- Loop through messages. -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if (message['role'] == 'assistant') -%}
|
||||
{#- Rename "assistant" to "model". -#}
|
||||
{%- set role = "model" -%}
|
||||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{%- if role != 'tool' -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<start_of_turn>' + role + '\n' }}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- if loop.first and system_message_content -%}
|
||||
{%- if system_message_content is string -%}
|
||||
{{ system_message_content | trim }}
|
||||
{%- elif system_message_content is iterable -%}
|
||||
{%- for item in system_message_content -%}
|
||||
{%- if item['type'] == 'image' -%}
|
||||
{{ raise_exception("Invalid content type 'image' in system message") }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type in system message") }}
|
||||
{%- endif -%}
|
||||
{{- '\n' -}}
|
||||
{%- endif -%}
|
||||
{#- User/Assistant Messages -#}
|
||||
{%- if 'content' in message and message['content'] is not none -%}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'image' -%}
|
||||
{{ '<start_of_image>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type in user/assistant message") }}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_message_type = 'content' -%}
|
||||
{%- endif -%}
|
||||
{%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls'] is iterable -%}
|
||||
{#- Tool Calls -#}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{% set function = tool_call['function'] %}
|
||||
{{- '<start_function_call>call:' + function['name'] + '{' -}}
|
||||
{%- if 'arguments' in function -%}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{# This handles string-JSON, just in case #}
|
||||
{{ function['arguments'] }}
|
||||
{%- endif %}
|
||||
{%- endif -%}
|
||||
{{- '}<end_function_call>' -}}
|
||||
{%- endfor -%}
|
||||
{%- if loop.last -%}
|
||||
{{ '<start_function_response>' }}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{#- Tool Responses -#}
|
||||
{%- if 'content' in message and message['content'] -%}
|
||||
{%- if message['content'] is mapping -%}
|
||||
{%- if 'name' in message['content'] and 'response' in message['content'] -%}
|
||||
{{ '<start_function_response>response:' + message['content']['name'] | trim + '{' }}
|
||||
{%- set response_ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in message['content']['response'] | dictsort -%}
|
||||
{%- if response_ns.found_first %},{% endif -%}
|
||||
{%- set response_ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}<end_function_response>' -}}
|
||||
{%- elif 'name' in message -%}
|
||||
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
|
||||
{%- set response_ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in message['content'].items() -%}
|
||||
{%- if response_ns.found_first %},{% endif -%}
|
||||
{%- set response_ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}<end_function_response>' -}}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is string -%}
|
||||
{%- if 'name' in message -%}
|
||||
{{ '<start_function_response>response:' + message['name'] | trim + '{value:' + format_argument(message['content'], escape_keys=False) + '}<end_function_response>' }}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid tool response: 'name' must be provided.") }}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item is mapping -%}
|
||||
{%- if 'name' in item and 'response' in item -%}
|
||||
{{ '<start_function_response>response:' + item['name'] | trim + '{' }}
|
||||
{%- set response_ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in item['response'].items() -%}
|
||||
{%- if response_ns.found_first %},{% endif -%}
|
||||
{%- set response_ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}<end_function_response>' -}}
|
||||
{%- elif 'name' in message -%}
|
||||
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
|
||||
{%- set response_ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in item.items() -%}
|
||||
{%- if response_ns.found_first %},{% endif -%}
|
||||
{%- set response_ns.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}<end_function_response>' -}}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid tool response message: multiple responses must all be mappings") }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type in tool message: must be mapping, iterable of mappings, or string.") }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.prev_message_type not in ['tool_call', 'tool_response'] -%}
|
||||
{{ '<end_of_turn>\n' }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<start_of_turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
- <end_function_call>
|
||||
eos_token: <end_of_turn>
|
||||
|
||||
datasets:
|
||||
- path: /workspace/data/datasets/sample.jsonl
|
||||
ds_type: json
|
||||
@@ -12,10 +312,10 @@ datasets:
|
||||
- assistant
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: /workspace/data/training-runs/Home-Gemma3-270m
|
||||
output_dir: /workspace/data/training-runs/Home-FunctionGemma-270m
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
sequence_len: 5130
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
|
||||
use_tensorboard: true
|
||||
|
||||
@@ -18,10 +18,15 @@ spec:
|
||||
command:
|
||||
- axolotl
|
||||
- train
|
||||
- /workspace/configs/gemma3-270m.yml
|
||||
- /workspace/configs/functiongemma-270m.yml
|
||||
env:
|
||||
- name: AXOLOTL_DO_NOT_TRACK
|
||||
value: "1"
|
||||
- name: HF_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: hf-token
|
||||
key: token
|
||||
volumeMounts:
|
||||
- name: training-runs
|
||||
mountPath: /workspace/data/training-runs
|
||||
@@ -34,9 +39,11 @@ spec:
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 2
|
||||
initContainers:
|
||||
- name: tensorboard
|
||||
image: python:3.11-slim
|
||||
imagePullPolicy: IfNotPresent
|
||||
restartPolicy: Always # mark as sidecar
|
||||
command:
|
||||
- bash
|
||||
- -c
|
||||
|
||||
Reference in New Issue
Block a user