mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
mostly working gemma implementation
This commit is contained in:
@@ -67,6 +67,7 @@ if TYPE_CHECKING:
|
|||||||
from llama_cpp import (
|
from llama_cpp import (
|
||||||
Llama as LlamaType,
|
Llama as LlamaType,
|
||||||
LlamaGrammar as LlamaGrammarType,
|
LlamaGrammar as LlamaGrammarType,
|
||||||
|
LlamaDiskCache as LlamaDiskCacheType,
|
||||||
ChatCompletionRequestResponseFormat
|
ChatCompletionRequestResponseFormat
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -156,6 +157,7 @@ class LlamaCppClient(LocalLLMClient):
|
|||||||
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
||||||
|
|
||||||
Llama: type[LlamaType] = getattr(self.llama_cpp_module, "Llama")
|
Llama: type[LlamaType] = getattr(self.llama_cpp_module, "Llama")
|
||||||
|
LlamaDiskCache: type[LlamaDiskCacheType] = getattr(self.llama_cpp_module, "LlamaDiskCache")
|
||||||
|
|
||||||
_LOGGER.debug(f"Loading model '{model_path}'...")
|
_LOGGER.debug(f"Loading model '{model_path}'...")
|
||||||
model_settings = snapshot_settings(entity_options)
|
model_settings = snapshot_settings(entity_options)
|
||||||
@@ -170,11 +172,11 @@ class LlamaCppClient(LocalLLMClient):
|
|||||||
)
|
)
|
||||||
_LOGGER.debug("Model loaded")
|
_LOGGER.debug("Model loaded")
|
||||||
|
|
||||||
# TODO: check about disk caching
|
# FIXME: make cache size configurable (0 means disabled)
|
||||||
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
|
llm.set_cache(LlamaDiskCache(
|
||||||
# capacity_bytes=(512 * 10e8),
|
capacity_bytes=int(512 * 10e8),
|
||||||
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||||
# ))
|
))
|
||||||
|
|
||||||
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
||||||
@callback
|
@callback
|
||||||
@@ -393,6 +395,7 @@ class LlamaCppClient(LocalLLMClient):
|
|||||||
max_tokens=1,
|
max_tokens=1,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
stop=["<end_of_turn>", "<end_function_call>"]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.last_cache_prime = time.time()
|
self.last_cache_prime = time.time()
|
||||||
@@ -464,6 +467,7 @@ class LlamaCppClient(LocalLLMClient):
|
|||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
stream=True,
|
stream=True,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
|
||||||
)
|
)
|
||||||
|
|
||||||
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
||||||
|
|||||||
@@ -261,12 +261,15 @@ class LocalLLMClient:
|
|||||||
tool_content += content
|
tool_content += content
|
||||||
|
|
||||||
if think_prefix in potential_block and not in_thinking:
|
if think_prefix in potential_block and not in_thinking:
|
||||||
|
_LOGGER.debug("Entering thinking block")
|
||||||
in_thinking = True
|
in_thinking = True
|
||||||
last_5_tokens.clear()
|
last_5_tokens.clear()
|
||||||
elif think_suffix in potential_block and in_thinking:
|
elif think_suffix in potential_block and in_thinking:
|
||||||
|
_LOGGER.debug("Exiting thinking block")
|
||||||
in_thinking = False
|
in_thinking = False
|
||||||
content = content.replace(think_suffix, "").strip()
|
content = content.replace(think_suffix, "").strip()
|
||||||
elif tool_prefix in potential_block and not in_tool_call:
|
elif tool_prefix in potential_block and not in_tool_call:
|
||||||
|
_LOGGER.debug("Entering tool call block")
|
||||||
in_tool_call = True
|
in_tool_call = True
|
||||||
last_5_tokens.clear()
|
last_5_tokens.clear()
|
||||||
elif tool_suffix in potential_block and in_tool_call:
|
elif tool_suffix in potential_block and in_tool_call:
|
||||||
@@ -307,6 +310,20 @@ class LocalLLMClient:
|
|||||||
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
|
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
if in_tool_call and tool_content:
|
||||||
|
# flush any unclosed tool calls because using the tool_suffix as a stop token can
|
||||||
|
# cause the tool_suffix to be omitted when the model streams output
|
||||||
|
tool_block = tool_content.strip().removeprefix(tool_prefix)
|
||||||
|
_LOGGER.debug("Raw tool block extracted at end: %s", tool_block)
|
||||||
|
tool_call, to_say = parse_raw_tool_call(tool_block, agent_id)
|
||||||
|
if tool_call:
|
||||||
|
_LOGGER.debug("Tool call parsed at end: %s", tool_call)
|
||||||
|
yield TextGenerationResult(
|
||||||
|
response=to_say,
|
||||||
|
response_streamed=True,
|
||||||
|
tool_calls=[tool_call]
|
||||||
|
)
|
||||||
|
|
||||||
async def _async_parse_completion(
|
async def _async_parse_completion(
|
||||||
self,
|
self,
|
||||||
llm_api: llm.APIInstance | None,
|
llm_api: llm.APIInstance | None,
|
||||||
|
|||||||
@@ -356,7 +356,12 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content], use
|
|||||||
elif message.role == "tool_result":
|
elif message.role == "tool_result":
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": json.dumps(message.tool_result),
|
# FIXME: what is the correct format for content here? gemma expects name and result
|
||||||
|
# "content": json.dumps(message.tool_result),
|
||||||
|
"content": {
|
||||||
|
"name": message.tool_name,
|
||||||
|
"response": { "result": message.tool_result },
|
||||||
|
},
|
||||||
"tool_call_id": message.tool_call_id
|
"tool_call_id": message.tool_call_id
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -404,7 +409,26 @@ def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolI
|
|||||||
if isinstance(raw_block, dict):
|
if isinstance(raw_block, dict):
|
||||||
parsed_tool_call = raw_block
|
parsed_tool_call = raw_block
|
||||||
else:
|
else:
|
||||||
parsed_tool_call: dict = json.loads(raw_block)
|
try:
|
||||||
|
parsed_tool_call: dict = json.loads(raw_block)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# handle the "gemma" tool calling format
|
||||||
|
# call:HassTurnOn{name:<escape>light.living_room_rgbww<escape>}
|
||||||
|
gemma_match = re.finditer(r"call:(?P<name>\w+){(?P<args>.+)}", raw_block)
|
||||||
|
for match in gemma_match:
|
||||||
|
tool_name = match.group("name")
|
||||||
|
raw_args = match.group("args")
|
||||||
|
args_dict = {}
|
||||||
|
for arg_match in re.finditer(r"(?P<key>\w+):<escape>(?P<value>.+?)<escape>", raw_args):
|
||||||
|
args_dict[arg_match.group("key")] = arg_match.group("value")
|
||||||
|
|
||||||
|
parsed_tool_call = {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": args_dict
|
||||||
|
}
|
||||||
|
break # TODO: how do we properly handle multiple tool calls in one response?
|
||||||
|
else:
|
||||||
|
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted JSON")
|
||||||
|
|
||||||
# try to validate either format
|
# try to validate either format
|
||||||
is_services_tool_call = False
|
is_services_tool_call = False
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
|||||||
tool_args = {}
|
tool_args = {}
|
||||||
|
|
||||||
question = question.replace("<device_name>", target_device)
|
question = question.replace("<device_name>", target_device)
|
||||||
|
response_starting = response_starting.replace("<device_name>", target_device)
|
||||||
answer_list = replace_answer(answer_list, "<device_name>", target_device)
|
answer_list = replace_answer(answer_list, "<device_name>", target_device)
|
||||||
|
|
||||||
if "climate" in service_action:
|
if "climate" in service_action:
|
||||||
@@ -520,7 +521,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
|||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def format_example_sharegpt(example, persona, language, use_system_role, use_service_names, tool_response_format):
|
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
|
||||||
piles = get_dataset_piles(language)
|
piles = get_dataset_piles(language)
|
||||||
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
||||||
question = example["question"]
|
question = example["question"]
|
||||||
@@ -546,6 +547,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
|
|||||||
"tool_result": "Success"
|
"tool_result": "Success"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if append_user_instruction_prompt:
|
||||||
|
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||||
|
sys_prompt = "\n".join([ sys_prompt, user_instruction_words ])
|
||||||
|
|
||||||
if use_system_role:
|
if use_system_role:
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
@@ -558,11 +563,10 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, user_instruction_words, question ]) }]
|
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -605,6 +609,7 @@ def generate_sft_file(
|
|||||||
seed: int,
|
seed: int,
|
||||||
format_func: Callable,
|
format_func: Callable,
|
||||||
use_system_role: bool,
|
use_system_role: bool,
|
||||||
|
append_user_instruction_prompt: bool,
|
||||||
use_service_names: bool,
|
use_service_names: bool,
|
||||||
personas: list[str],
|
personas: list[str],
|
||||||
language: str,
|
language: str,
|
||||||
@@ -622,10 +627,10 @@ def generate_sft_file(
|
|||||||
def run_factor_times(func, examples, data, persona, factor, language):
|
def run_factor_times(func, examples, data, persona, factor, language):
|
||||||
if factor >= 1:
|
if factor >= 1:
|
||||||
for i in range(factor):
|
for i in range(factor):
|
||||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
|
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||||
else:
|
else:
|
||||||
if random.random() < factor:
|
if random.random() < factor:
|
||||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names, tool_response_format))
|
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||||
|
|
||||||
generated_examples = []
|
generated_examples = []
|
||||||
|
|
||||||
@@ -652,7 +657,8 @@ def generate_sft_file(
|
|||||||
for missing in sorted(missing_responses):
|
for missing in sorted(missing_responses):
|
||||||
print(missing)
|
print(missing)
|
||||||
|
|
||||||
with open(f"output/{filename}.jsonl", "w") as f:
|
cwd = os.path.dirname(__file__)
|
||||||
|
with open(f"{cwd}/output/{filename}.jsonl", "w") as f:
|
||||||
for item in generated_examples:
|
for item in generated_examples:
|
||||||
json_record = json.dumps(item)
|
json_record = json.dumps(item)
|
||||||
f.write(json_record + '\n')
|
f.write(json_record + '\n')
|
||||||
@@ -676,11 +682,13 @@ def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset
|
|||||||
|
|
||||||
def merge_languages(filename_prefix: str, languages: list):
|
def merge_languages(filename_prefix: str, languages: list):
|
||||||
all_examples = []
|
all_examples = []
|
||||||
|
cwd = os.path.dirname(__file__)
|
||||||
|
|
||||||
for language in languages:
|
for language in languages:
|
||||||
with open(f"output/{filename_prefix}_{language}.jsonl") as f:
|
with open(f"{cwd}/output/{filename_prefix}_{language}.jsonl") as f:
|
||||||
all_examples.extend(f.readlines())
|
all_examples.extend(f.readlines())
|
||||||
|
|
||||||
with open(f"output/{filename_prefix}.jsonl", "w") as f:
|
with open(f"{cwd}/output/{filename_prefix}.jsonl", "w") as f:
|
||||||
f.writelines(all_examples)
|
f.writelines(all_examples)
|
||||||
|
|
||||||
|
|
||||||
@@ -696,9 +704,12 @@ def main(args=None):
|
|||||||
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||||
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||||
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
|
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
|
||||||
parser.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. It will be combined with the user role")
|
|
||||||
parser.add_argument("--tool-response-format", default="text", choices=["text", "functiongemma"], help="Format to use for tool responses.")
|
parser.add_argument("--tool-response-format", default="text", choices=["text", "functiongemma"], help="Format to use for tool responses.")
|
||||||
|
|
||||||
|
role_tweaks = parser.add_mutually_exclusive_group()
|
||||||
|
role_tweaks.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. The house context will be combined with the user role")
|
||||||
|
role_tweaks.add_argument("--merged-system-role", action="store_true", help="Set this flag to still emit a system role, but assume it will be merged by the chat template into the user role.")
|
||||||
|
|
||||||
train_size_group = parser.add_mutually_exclusive_group()
|
train_size_group = parser.add_mutually_exclusive_group()
|
||||||
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
|
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
|
||||||
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
|
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
|
||||||
@@ -721,6 +732,7 @@ def main(args=None):
|
|||||||
format_func = format_example_sharegpt
|
format_func = format_example_sharegpt
|
||||||
|
|
||||||
use_system_role = not args.no_system_role
|
use_system_role = not args.no_system_role
|
||||||
|
append_user_instruction_prompt = args.merged_system_role or not args.no_system_role
|
||||||
use_service_names = args.use_service_names
|
use_service_names = args.use_service_names
|
||||||
tool_response_format = args.tool_response_format
|
tool_response_format = args.tool_response_format
|
||||||
|
|
||||||
@@ -730,21 +742,20 @@ def main(args=None):
|
|||||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||||
|
|
||||||
if args.sample:
|
if args.sample:
|
||||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
|
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
|
||||||
if args.train:
|
if args.train:
|
||||||
if args.size == "small":
|
if args.size == "small":
|
||||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
|
||||||
elif args.size == "medium":
|
elif args.size == "medium":
|
||||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
|
||||||
elif args.size == "large":
|
elif args.size == "large":
|
||||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
|
||||||
elif args.size == "xl":
|
elif args.size == "xl":
|
||||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||||
if args.test:
|
if args.test:
|
||||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
|
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||||
|
|
||||||
if len(args.language) > 1:
|
if len(args.language) > 1:
|
||||||
if args.sample:
|
if args.sample:
|
||||||
merge_languages("sample", args.language)
|
merge_languages("sample", args.language)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
import csv
|
import csv
|
||||||
import pandas
|
import pandas
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@@ -84,23 +85,25 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
|
|||||||
class DatasetPiles:
|
class DatasetPiles:
|
||||||
def __init__(self, supported_devices, language="english"):
|
def __init__(self, supported_devices, language="english"):
|
||||||
self.language = language
|
self.language = language
|
||||||
|
|
||||||
|
cwd = os.path.dirname(__file__)
|
||||||
|
|
||||||
with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
||||||
self.and_words = [ x.strip() for x in f.readlines() ]
|
self.and_words = [ x.strip() for x in f.readlines() ]
|
||||||
|
|
||||||
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
|
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
|
||||||
|
|
||||||
# media names are not translated
|
# media names are not translated
|
||||||
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||||
self.pile_of_media_names = [ x.strip() for x in f.readlines() ]
|
self.pile_of_media_names = [ x.strip() for x in f.readlines() ]
|
||||||
|
|
||||||
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||||
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
||||||
|
|
||||||
self.stacks_of_device_names = { x: [] for x in supported_devices }
|
self.stacks_of_device_names = { x: [] for x in supported_devices }
|
||||||
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
pile_of_device_names = list(reader)
|
pile_of_device_names = list(reader)
|
||||||
for device_dict in pile_of_device_names:
|
for device_dict in pile_of_device_names:
|
||||||
@@ -110,7 +113,7 @@ class DatasetPiles:
|
|||||||
except KeyError as ex:
|
except KeyError as ex:
|
||||||
print(ex)
|
print(ex)
|
||||||
|
|
||||||
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
pile_of_templated_actions = list(reader)
|
pile_of_templated_actions = list(reader)
|
||||||
processed_pile_of_templated_actions = []
|
processed_pile_of_templated_actions = []
|
||||||
@@ -124,23 +127,23 @@ class DatasetPiles:
|
|||||||
|
|
||||||
self.pile_of_templated_actions = processed_pile_of_templated_actions
|
self.pile_of_templated_actions = processed_pile_of_templated_actions
|
||||||
|
|
||||||
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_specific_actions = list(reader)
|
self.pile_of_specific_actions = list(reader)
|
||||||
|
|
||||||
self.pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
|
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
|
||||||
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
|
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
|
||||||
|
|
||||||
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_status_requests = list(reader)
|
self.pile_of_status_requests = list(reader)
|
||||||
|
|
||||||
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
|
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
|
||||||
|
|
||||||
# service names are not translated
|
# service names are not translated
|
||||||
with open(f"piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_hallucinated_service_names = list(reader)
|
self.pile_of_hallucinated_service_names = list(reader)
|
||||||
|
|
||||||
|
|||||||
290
train/chat_templates/functiongemma_autotools.j2
Normal file
290
train/chat_templates/functiongemma_autotools.j2
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
{%- macro format_parameters(properties, required) -%}
|
||||||
|
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||||
|
{%- set ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in properties | dictsort -%}
|
||||||
|
{%- if key not in standard_keys -%}
|
||||||
|
{%- if ns.found_first %},{% endif -%}
|
||||||
|
{%- set ns.found_first = true -%}
|
||||||
|
{{- key }}:{description:<escape>{{ value['description'] }}<escape>
|
||||||
|
{%- if value['type'] | upper == 'STRING' -%}
|
||||||
|
{%- if value['enum'] -%}
|
||||||
|
,enum:{{ format_argument(value['enum']) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||||
|
,properties:{
|
||||||
|
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||||
|
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||||
|
{%- elif value is mapping -%}
|
||||||
|
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||||
|
{%- endif -%}
|
||||||
|
}
|
||||||
|
{%- if value['required'] -%}
|
||||||
|
,required:[
|
||||||
|
{%- for item in value['required'] | default([]) -%}
|
||||||
|
<escape>{{- item -}}<escape>
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
]
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||||
|
{%- if value['items'] is mapping and value['items'] -%}
|
||||||
|
,items:{
|
||||||
|
{%- set ns_items = namespace(found_first=false) -%}
|
||||||
|
{%- for item_key, item_value in value['items'].items() -%}
|
||||||
|
{%- if item_value is not none -%}
|
||||||
|
{%- if ns_items.found_first %},{% endif -%}
|
||||||
|
{%- set ns_items.found_first = true -%}
|
||||||
|
{%- if item_key == 'properties' -%}
|
||||||
|
properties:{
|
||||||
|
{%- if item_value is mapping -%}
|
||||||
|
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||||
|
{%- endif -%}
|
||||||
|
}
|
||||||
|
{%- elif item_key == 'required' -%}
|
||||||
|
required:[
|
||||||
|
{%- for req_item in item_value -%}
|
||||||
|
<escape>{{- req_item -}}<escape>
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
]
|
||||||
|
{%- elif item_key == 'type' -%}
|
||||||
|
{%- if item_value is string -%}
|
||||||
|
type:{{ format_argument(item_value | upper) }}
|
||||||
|
{%- else -%}
|
||||||
|
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ item_key }}:{{ format_argument(item_value) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
,type:<escape>{{ value['type'] | upper }}<escape>}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{% macro format_function_declaration(tool_data) -%}
|
||||||
|
declaration:{{- tool_data['function']['name'] -}}
|
||||||
|
{description:<escape>{{- tool_data['function']['description'] -}}<escape>
|
||||||
|
{%- set params = tool_data['function']['parameters'] -%}
|
||||||
|
{%- if params -%}
|
||||||
|
,parameters:{
|
||||||
|
{%- if params['properties'] -%}
|
||||||
|
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if params['required'] -%}
|
||||||
|
required:[
|
||||||
|
{%- for item in params['required'] -%}
|
||||||
|
<escape>{{- item -}}<escape>
|
||||||
|
{{- ',' if not loop.last -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
],
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if params['type'] -%}
|
||||||
|
type:<escape>{{- params['type'] | upper -}}<escape>}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{% macro format_argument(argument, escape_keys=True) -%}
|
||||||
|
{%- if argument is string -%}
|
||||||
|
{{- '<escape>' + argument + '<escape>' -}}
|
||||||
|
{%- elif argument is boolean -%}
|
||||||
|
{%- if argument -%}
|
||||||
|
{{- 'true' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- 'false' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif argument is mapping -%}
|
||||||
|
{{- '{' -}}
|
||||||
|
{%- set ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in argument.items() -%}
|
||||||
|
{%- if ns.found_first %},{% endif -%}
|
||||||
|
{%- set ns.found_first = true -%}
|
||||||
|
{%- if escape_keys -%}
|
||||||
|
{{- '<escape>' + key + '<escape>' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- key -}}
|
||||||
|
{%- endif -%}
|
||||||
|
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}' -}}
|
||||||
|
{%- elif argument is iterable -%}
|
||||||
|
{{- '[' -}}
|
||||||
|
{%- for item in argument -%}
|
||||||
|
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- ']' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- argument -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{{ bos_token }}
|
||||||
|
{%- set ns = namespace(prev_message_type=None) -%}
|
||||||
|
{#- extract system prompt for merging with user role -#}
|
||||||
|
{%- set loop_messages = messages -%}
|
||||||
|
{%- set system_message_content = '' %}
|
||||||
|
{%- if messages[0]['role'] == 'system' or messages[0]['role'] == 'developer' -%}
|
||||||
|
{%- set system_message_content = messages[0]['content'] -%}
|
||||||
|
{%- set loop_messages = messages[1:] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{#- 'static' system prompt. -#}
|
||||||
|
{%- if tools -%}
|
||||||
|
{{- '<start_of_turn>developer\nYou are a model that can do function calling with the following functions' -}}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- '<start_function_declaration>' -}}
|
||||||
|
{{- format_function_declaration(tool) | trim }}
|
||||||
|
{{- '<end_function_declaration>' -}}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- '<end_of_turn>\n' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- '<start_of_turn>developer\nNo tools have been provided. Only respond with answers that do not require tool usage.<end_of_turn>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{#- Loop through messages. -#}
|
||||||
|
{%- for message in loop_messages -%}
|
||||||
|
{%- if (message['role'] == 'assistant') -%}
|
||||||
|
{#- Rename "assistant" to "model". -#}
|
||||||
|
{%- set role = "model" -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set role = message['role'] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if role != 'tool' -%}
|
||||||
|
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||||
|
{{- '<start_of_turn>' + role + '\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns.prev_message_type = None -%}
|
||||||
|
{%- if loop.first and system_message_content -%}
|
||||||
|
{%- if system_message_content is string -%}
|
||||||
|
{{ system_message_content | trim }}
|
||||||
|
{%- elif system_message_content is iterable -%}
|
||||||
|
{%- for item in system_message_content -%}
|
||||||
|
{%- if item['type'] == 'image' -%}
|
||||||
|
{{ raise_exception("Invalid content type 'image' in system message") }}
|
||||||
|
{%- elif item['type'] == 'text' -%}
|
||||||
|
{{ item['text'] | trim }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid content type in system message") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{#- User/Assistant Messages -#}
|
||||||
|
{%- if 'content' in message and message['content'] is not none -%}
|
||||||
|
{%- if message['content'] is string -%}
|
||||||
|
{{ message['content'] | trim }}
|
||||||
|
{%- elif message['content'] is iterable -%}
|
||||||
|
{%- for item in message['content'] -%}
|
||||||
|
{%- if item['type'] == 'image' -%}
|
||||||
|
{{ '<start_of_image>' }}
|
||||||
|
{%- elif item['type'] == 'text' -%}
|
||||||
|
{{ item['text'] | trim }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid content type in user/assistant message") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns.prev_message_type = 'content' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls'] is iterable -%}
|
||||||
|
{#- Tool Calls -#}
|
||||||
|
{%- for tool_call in message['tool_calls'] -%}
|
||||||
|
{% set function = tool_call['function'] %}
|
||||||
|
{{- '<start_function_call>call:' + function['name'] + '{' -}}
|
||||||
|
{%- if 'arguments' in function -%}
|
||||||
|
{%- if function['arguments'] is mapping -%}
|
||||||
|
{%- set ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in function['arguments'] | dictsort -%}
|
||||||
|
{%- if ns.found_first %},{% endif -%}
|
||||||
|
{%- set ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- elif function['arguments'] is string -%}
|
||||||
|
{# This handles string-JSON, just in case #}
|
||||||
|
{{ function['arguments'] }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '}<end_function_call>' -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if loop.last -%}
|
||||||
|
{{ '<start_function_response>' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{#- Tool Responses -#}
|
||||||
|
{%- if 'content' in message and message['content'] -%}
|
||||||
|
{%- if message['content'] is mapping -%}
|
||||||
|
{%- if 'name' in message['content'] and 'response' in message['content'] -%}
|
||||||
|
{{ '<start_function_response>response:' + message['content']['name'] | trim + '{' }}
|
||||||
|
{%- set response_ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in message['content']['response'] | dictsort -%}
|
||||||
|
{%- if response_ns.found_first %},{% endif -%}
|
||||||
|
{%- set response_ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}<end_function_response>' -}}
|
||||||
|
{%- elif 'name' in message -%}
|
||||||
|
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
|
||||||
|
{%- set response_ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in message['content'].items() -%}
|
||||||
|
{%- if response_ns.found_first %},{% endif -%}
|
||||||
|
{%- set response_ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}<end_function_response>' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['content'] is string -%}
|
||||||
|
{%- if 'name' in message -%}
|
||||||
|
{{ '<start_function_response>response:' + message['name'] | trim + '{value:' + format_argument(message['content'], escape_keys=False) + '}<end_function_response>' }}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid tool response: 'name' must be provided.") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['content'] is iterable -%}
|
||||||
|
{%- for item in message['content'] -%}
|
||||||
|
{%- if item is mapping -%}
|
||||||
|
{%- if 'name' in item and 'response' in item -%}
|
||||||
|
{{ '<start_function_response>response:' + item['name'] | trim + '{' }}
|
||||||
|
{%- set response_ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in item['response'].items() -%}
|
||||||
|
{%- if response_ns.found_first %},{% endif -%}
|
||||||
|
{%- set response_ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}<end_function_response>' -}}
|
||||||
|
{%- elif 'name' in message -%}
|
||||||
|
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
|
||||||
|
{%- set response_ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in item.items() -%}
|
||||||
|
{%- if response_ns.found_first %},{% endif -%}
|
||||||
|
{%- set response_ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}<end_function_response>' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid tool response message: multiple responses must all be mappings") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid content type in tool message: must be mapping, iterable of mappings, or string.") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if ns.prev_message_type not in ['tool_call', 'tool_response'] -%}
|
||||||
|
{{ '<end_of_turn>\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||||
|
{{- '<start_of_turn>model\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
@@ -4,6 +4,306 @@ model_type: Gemma3ForCausalLM
|
|||||||
# gemma3 doesn't seem to play nice with ddp
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
ddp_find_unused_parameters: true
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
|
chat_template: jinja
|
||||||
|
chat_template_jinja: |
|
||||||
|
{%- macro format_parameters(properties, required) -%}
|
||||||
|
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||||
|
{%- set ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in properties | dictsort -%}
|
||||||
|
{%- if key not in standard_keys -%}
|
||||||
|
{%- if ns.found_first %},{% endif -%}
|
||||||
|
{%- set ns.found_first = true -%}
|
||||||
|
{{- key }}:{description:<escape>{{ value['description'] }}<escape>
|
||||||
|
{%- if value['type'] | upper == 'STRING' -%}
|
||||||
|
{%- if value['enum'] -%}
|
||||||
|
,enum:{{ format_argument(value['enum']) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||||
|
,properties:{
|
||||||
|
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||||
|
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||||
|
{%- elif value is mapping -%}
|
||||||
|
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||||
|
{%- endif -%}
|
||||||
|
}
|
||||||
|
{%- if value['required'] -%}
|
||||||
|
,required:[
|
||||||
|
{%- for item in value['required'] | default([]) -%}
|
||||||
|
<escape>{{- item -}}<escape>
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
]
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||||
|
{%- if value['items'] is mapping and value['items'] -%}
|
||||||
|
,items:{
|
||||||
|
{%- set ns_items = namespace(found_first=false) -%}
|
||||||
|
{%- for item_key, item_value in value['items'].items() -%}
|
||||||
|
{%- if item_value is not none -%}
|
||||||
|
{%- if ns_items.found_first %},{% endif -%}
|
||||||
|
{%- set ns_items.found_first = true -%}
|
||||||
|
{%- if item_key == 'properties' -%}
|
||||||
|
properties:{
|
||||||
|
{%- if item_value is mapping -%}
|
||||||
|
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||||
|
{%- endif -%}
|
||||||
|
}
|
||||||
|
{%- elif item_key == 'required' -%}
|
||||||
|
required:[
|
||||||
|
{%- for req_item in item_value -%}
|
||||||
|
<escape>{{- req_item -}}<escape>
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
]
|
||||||
|
{%- elif item_key == 'type' -%}
|
||||||
|
{%- if item_value is string -%}
|
||||||
|
type:{{ format_argument(item_value | upper) }}
|
||||||
|
{%- else -%}
|
||||||
|
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ item_key }}:{{ format_argument(item_value) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
,type:<escape>{{ value['type'] | upper }}<escape>}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{% macro format_function_declaration(tool_data) -%}
|
||||||
|
declaration:{{- tool_data['function']['name'] -}}
|
||||||
|
{description:<escape>{{- tool_data['function']['description'] -}}<escape>
|
||||||
|
{%- set params = tool_data['function']['parameters'] -%}
|
||||||
|
{%- if params -%}
|
||||||
|
,parameters:{
|
||||||
|
{%- if params['properties'] -%}
|
||||||
|
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if params['required'] -%}
|
||||||
|
required:[
|
||||||
|
{%- for item in params['required'] -%}
|
||||||
|
<escape>{{- item -}}<escape>
|
||||||
|
{{- ',' if not loop.last -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
],
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if params['type'] -%}
|
||||||
|
type:<escape>{{- params['type'] | upper -}}<escape>}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{% macro format_argument(argument, escape_keys=True) -%}
|
||||||
|
{%- if argument is string -%}
|
||||||
|
{{- '<escape>' + argument + '<escape>' -}}
|
||||||
|
{%- elif argument is boolean -%}
|
||||||
|
{%- if argument -%}
|
||||||
|
{{- 'true' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- 'false' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif argument is mapping -%}
|
||||||
|
{{- '{' -}}
|
||||||
|
{%- set ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in argument.items() -%}
|
||||||
|
{%- if ns.found_first %},{% endif -%}
|
||||||
|
{%- set ns.found_first = true -%}
|
||||||
|
{%- if escape_keys -%}
|
||||||
|
{{- '<escape>' + key + '<escape>' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- key -}}
|
||||||
|
{%- endif -%}
|
||||||
|
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}' -}}
|
||||||
|
{%- elif argument is iterable -%}
|
||||||
|
{{- '[' -}}
|
||||||
|
{%- for item in argument -%}
|
||||||
|
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- ']' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- argument -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{{ bos_token }}
|
||||||
|
{%- set ns = namespace(prev_message_type=None) -%}
|
||||||
|
{#- extract system prompt for merging with user role -#}
|
||||||
|
{%- set loop_messages = messages -%}
|
||||||
|
{%- set system_message_content = '' %}
|
||||||
|
{%- if messages[0]['role'] == 'system' or messages[0]['role'] == 'developer' -%}
|
||||||
|
{%- set system_message_content = messages[0]['content'] -%}
|
||||||
|
{%- set loop_messages = messages[1:] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{#- 'static' system prompt. -#}
|
||||||
|
{%- if tools -%}
|
||||||
|
{{- '<start_of_turn>developer\nYou are a model that can do function calling with the following functions' -}}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- '<start_function_declaration>' -}}
|
||||||
|
{{- format_function_declaration(tool) | trim }}
|
||||||
|
{{- '<end_function_declaration>' -}}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- '<end_of_turn>\n' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- '<start_of_turn>developer\nNo tools have been provided. Only respond with answers that do not require tool usage.<end_of_turn>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{#- Loop through messages. -#}
|
||||||
|
{%- for message in loop_messages -%}
|
||||||
|
{%- if (message['role'] == 'assistant') -%}
|
||||||
|
{#- Rename "assistant" to "model". -#}
|
||||||
|
{%- set role = "model" -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set role = message['role'] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if role != 'tool' -%}
|
||||||
|
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||||
|
{{- '<start_of_turn>' + role + '\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns.prev_message_type = None -%}
|
||||||
|
{%- if loop.first and system_message_content -%}
|
||||||
|
{%- if system_message_content is string -%}
|
||||||
|
{{ system_message_content | trim }}
|
||||||
|
{%- elif system_message_content is iterable -%}
|
||||||
|
{%- for item in system_message_content -%}
|
||||||
|
{%- if item['type'] == 'image' -%}
|
||||||
|
{{ raise_exception("Invalid content type 'image' in system message") }}
|
||||||
|
{%- elif item['type'] == 'text' -%}
|
||||||
|
{{ item['text'] | trim }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid content type in system message") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{#- User/Assistant Messages -#}
|
||||||
|
{%- if 'content' in message and message['content'] is not none -%}
|
||||||
|
{%- if message['content'] is string -%}
|
||||||
|
{{ message['content'] | trim }}
|
||||||
|
{%- elif message['content'] is iterable -%}
|
||||||
|
{%- for item in message['content'] -%}
|
||||||
|
{%- if item['type'] == 'image' -%}
|
||||||
|
{{ '<start_of_image>' }}
|
||||||
|
{%- elif item['type'] == 'text' -%}
|
||||||
|
{{ item['text'] | trim }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid content type in user/assistant message") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns.prev_message_type = 'content' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls'] is iterable -%}
|
||||||
|
{#- Tool Calls -#}
|
||||||
|
{%- for tool_call in message['tool_calls'] -%}
|
||||||
|
{% set function = tool_call['function'] %}
|
||||||
|
{{- '<start_function_call>call:' + function['name'] + '{' -}}
|
||||||
|
{%- if 'arguments' in function -%}
|
||||||
|
{%- if function['arguments'] is mapping -%}
|
||||||
|
{%- set ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in function['arguments'] | dictsort -%}
|
||||||
|
{%- if ns.found_first %},{% endif -%}
|
||||||
|
{%- set ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- elif function['arguments'] is string -%}
|
||||||
|
{# This handles string-JSON, just in case #}
|
||||||
|
{{ function['arguments'] }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '}<end_function_call>' -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if loop.last -%}
|
||||||
|
{{ '<start_function_response>' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{#- Tool Responses -#}
|
||||||
|
{%- if 'content' in message and message['content'] -%}
|
||||||
|
{%- if message['content'] is mapping -%}
|
||||||
|
{%- if 'name' in message['content'] and 'response' in message['content'] -%}
|
||||||
|
{{ '<start_function_response>response:' + message['content']['name'] | trim + '{' }}
|
||||||
|
{%- set response_ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in message['content']['response'] | dictsort -%}
|
||||||
|
{%- if response_ns.found_first %},{% endif -%}
|
||||||
|
{%- set response_ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}<end_function_response>' -}}
|
||||||
|
{%- elif 'name' in message -%}
|
||||||
|
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
|
||||||
|
{%- set response_ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in message['content'].items() -%}
|
||||||
|
{%- if response_ns.found_first %},{% endif -%}
|
||||||
|
{%- set response_ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}<end_function_response>' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['content'] is string -%}
|
||||||
|
{%- if 'name' in message -%}
|
||||||
|
{{ '<start_function_response>response:' + message['name'] | trim + '{value:' + format_argument(message['content'], escape_keys=False) + '}<end_function_response>' }}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid tool response: 'name' must be provided.") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['content'] is iterable -%}
|
||||||
|
{%- for item in message['content'] -%}
|
||||||
|
{%- if item is mapping -%}
|
||||||
|
{%- if 'name' in item and 'response' in item -%}
|
||||||
|
{{ '<start_function_response>response:' + item['name'] | trim + '{' }}
|
||||||
|
{%- set response_ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in item['response'].items() -%}
|
||||||
|
{%- if response_ns.found_first %},{% endif -%}
|
||||||
|
{%- set response_ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}<end_function_response>' -}}
|
||||||
|
{%- elif 'name' in message -%}
|
||||||
|
{{ '<start_function_response>response:' + message['name'] | trim + '{' }}
|
||||||
|
{%- set response_ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in item.items() -%}
|
||||||
|
{%- if response_ns.found_first %},{% endif -%}
|
||||||
|
{%- set response_ns.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}<end_function_response>' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid tool response message: multiple responses must all be mappings") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid content type in tool message: must be mapping, iterable of mappings, or string.") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if ns.prev_message_type not in ['tool_call', 'tool_response'] -%}
|
||||||
|
{{ '<end_of_turn>\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||||
|
{{- '<start_of_turn>model\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
|
- <end_function_call>
|
||||||
|
eos_token: <end_of_turn>
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: /workspace/data/datasets/sample.jsonl
|
- path: /workspace/data/datasets/sample.jsonl
|
||||||
ds_type: json
|
ds_type: json
|
||||||
@@ -12,10 +312,10 @@ datasets:
|
|||||||
- assistant
|
- assistant
|
||||||
|
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: /workspace/data/training-runs/Home-Gemma3-270m
|
output_dir: /workspace/data/training-runs/Home-FunctionGemma-270m
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 5130
|
||||||
sample_packing: true
|
sample_packing: false
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
use_tensorboard: true
|
use_tensorboard: true
|
||||||
|
|||||||
@@ -18,10 +18,15 @@ spec:
|
|||||||
command:
|
command:
|
||||||
- axolotl
|
- axolotl
|
||||||
- train
|
- train
|
||||||
- /workspace/configs/gemma3-270m.yml
|
- /workspace/configs/functiongemma-270m.yml
|
||||||
env:
|
env:
|
||||||
- name: AXOLOTL_DO_NOT_TRACK
|
- name: AXOLOTL_DO_NOT_TRACK
|
||||||
value: "1"
|
value: "1"
|
||||||
|
- name: HF_TOKEN
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: hf-token
|
||||||
|
key: token
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
- name: training-runs
|
- name: training-runs
|
||||||
mountPath: /workspace/data/training-runs
|
mountPath: /workspace/data/training-runs
|
||||||
@@ -34,9 +39,11 @@ spec:
|
|||||||
resources:
|
resources:
|
||||||
limits:
|
limits:
|
||||||
nvidia.com/gpu: 2
|
nvidia.com/gpu: 2
|
||||||
|
initContainers:
|
||||||
- name: tensorboard
|
- name: tensorboard
|
||||||
image: python:3.11-slim
|
image: python:3.11-slim
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
|
restartPolicy: Always # mark as sidecar
|
||||||
command:
|
command:
|
||||||
- bash
|
- bash
|
||||||
- -c
|
- -c
|
||||||
|
|||||||
Reference in New Issue
Block a user