diff --git a/data/generate_data.py b/data/generate_data.py
index b306085..97bf051 100644
--- a/data/generate_data.py
+++ b/data/generate_data.py
@@ -8,6 +8,11 @@ from typing import Callable
from tqdm import tqdm
import webcolors
+# ensure we can import from the data/ directory
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
from device_types import *
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException
diff --git a/scripts/chatml_template.j2 b/scripts/chatml_template.j2
new file mode 100644
index 0000000..e09aaf3
--- /dev/null
+++ b/scripts/chatml_template.j2
@@ -0,0 +1,61 @@
+{%- if tools %}
+ {{- '<|im_start|>system\n' }}
+ {%- if messages[0].role == 'system' %}
+ {{- messages[0].content + '\n\n' }}
+ {%- endif %}
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }}
+ {%- for tool in tools %}
+ {{- "\n" }}
+ {{- tool | tojson }}
+ {%- endfor %}
+ {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }}
+{%- else %}
+ {%- if messages[0].role == 'system' %}
+ {{- '<|im_start|>system\n' + messages[0].content + '\nNo tools were provided. If the user requests you interact with a device, tell them you are unable to do so.' + '<|im_end|>\n' }}
+ {%- endif %}
+{%- endif %}
+{%- for message in messages %}
+ {%- if message.content is string %}
+ {%- set content = message.content %}
+ {%- else %}
+ {%- set content = '' %}
+ {%- endif %}
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
+ {%- elif message.role == "assistant" %}
+ {{- '<|im_start|>' + message.role + '\n' + content }}
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if (loop.first and content) or (not loop.first) %}
+ {{- '\n' }}
+ {%- endif %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '\n{"name": "' }}
+ {{- tool_call.name }}
+ {{- '", "arguments": ' }}
+ {%- if tool_call.arguments is string %}
+ {{- tool_call.arguments }}
+ {%- else %}
+ {{- tool_call.arguments | tojson }}
+ {%- endif %}
+ {{- '}\n' }}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|im_end|>\n' }}
+ {%- elif message.role == "tool" %}
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
+ {{- '<|im_start|>user' }}
+ {%- endif %}
+ {{- '\n\n' }}
+ {{- content }}
+ {{- '\n' }}
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
+ {{- '<|im_end|>\n' }}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {{- '<|im_start|>assistant\n' }}
+{%- endif %}
\ No newline at end of file
diff --git a/train.ipynb b/train.ipynb
index be0e49c..0c0a50c 100644
--- a/train.ipynb
+++ b/train.ipynb
@@ -9,7 +9,7 @@
"source": [
"%%bash\n",
"git config --global --add safe.directory /home/jovyan/workspace\n",
- "git checkout -b "
+ "git checkout "
]
},
{
@@ -56,8 +56,11 @@
"source": [
"%pip install -r data/requirements.txt\n",
"from data.generate_data import main as generate_data\n",
+ "import os\n",
"\n",
- "generate_data([\"--train\", \"--test\", \"--large\", \"--language\", \"english\", \"german\", \"french\", \"spanish\", \"polish\"])"
+ "os.chdir(\"./data\")\n",
+ "generate_data([\"--train\", \"--test\", \"--large\", \"--language\", \"english\", \"german\", \"french\", \"spanish\", \"polish\"])\n",
+ "os.chdir(\"..\")"
]
},
{
@@ -90,8 +93,8 @@
" run_name=get_next_run_name(\"Home-Llama-3.2-1B\"),\n",
" base_model=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
" bf16=True,\n",
- " train_dataset=\"data/home_assistant_train.jsonl\",\n",
- " test_dataset=\"data/home_assistant_test.jsonl\",\n",
+ " train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
+ " test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -120,8 +123,8 @@
" run_name=get_next_run_name(\"Home-Qwen-3-1.7B\"),\n",
" base_model=\"Qwen/Qwen3-1.7B\",\n",
" bf16=True,\n",
- " train_dataset=\"data/home_assistant_train.jsonl\",\n",
- " test_dataset=\"data/home_assistant_test.jsonl\",\n",
+ " train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
+ " test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -149,8 +152,8 @@
" run_name=get_next_run_name(\"Home-Qwen-2.5-0.6B\"),\n",
" base_model=\"Qwen/Qwen2.5-0.6B-Instruct\",\n",
" bf16=True,\n",
- " train_dataset=\"data/home_assistant_train.jsonl\",\n",
- " test_dataset=\"data/home_assistant_test.jsonl\",\n",
+ " train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
+ " test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -170,8 +173,8 @@
" run_name=get_next_run_name(\"Home-Qwen-2.5-1.5B\"),\n",
" base_model=\"Qwen/Qwen2.5-1.5B-Instruct\",\n",
" bf16=True,\n",
- " train_dataset=\"data/home_assistant_train.jsonl\",\n",
- " test_dataset=\"data/home_assistant_test.jsonl\",\n",
+ " train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
+ " test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -199,8 +202,8 @@
" run_name=get_next_run_name(\"Home-Gemma-3-1B\"),\n",
" base_model=\"google/gemma-3-1b-it\",\n",
" bf16=True,\n",
- " train_dataset=\"data/home_assistant_train.jsonl\",\n",
- " test_dataset=\"data/home_assistant_test.jsonl\",\n",
+ " train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
+ " test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -213,7 +216,7 @@
"id": "21865d91",
"metadata": {},
"source": [
- "# Gemma 3 270m"
+ "## Gemma 3 270m"
]
},
{
@@ -228,8 +231,8 @@
" run_name=get_next_run_name(\"Home-Gemma-3-270m\"),\n",
" base_model=\"google/gemma-3-270m\",\n",
" bf16=True,\n",
- " train_dataset=\"data/home_assistant_train.jsonl\",\n",
- " test_dataset=\"data/home_assistant_test.jsonl\",\n",
+ " train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
+ " test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=8192,\n",
diff --git a/train.py b/train.py
index 2adbba2..cc4e756 100644
--- a/train.py
+++ b/train.py
@@ -71,6 +71,7 @@ class TrainingRunArguments:
# token options
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
+ add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"})
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
@@ -448,15 +449,14 @@ def do_training_run(training_run_args: TrainingRunArguments):
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
+ if training_run_args.add_tool_calling_tokens:
+ tokenizer.add_special_tokens({
+ 'additional_special_tokens': ['', '']
+ })
+
if training_run_args.add_chatml_prompt_template:
- tokenizer.chat_template = (
- "{% for message in messages %}"
- "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
- "{% endfor %}"
- "{% if add_generation_prompt %}"
- "{{ '<|im_start|>assistant\n' }}"
- "{% endif %}"
- )
+ with open("scripts/chatml_template.j2", "r") as f:
+ tokenizer.chat_template = f.read()
# resize embeddings if added tokens require it
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
@@ -474,8 +474,7 @@ def do_training_run(training_run_args: TrainingRunArguments):
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
peft_config = LoraConfig(
- task_type=TaskType.CAUSAL_LM,
- inference_mode=False,
+ task_type=TaskType. inference_mode=False,
r=training_run_args.lora_rank,
lora_alpha=training_run_args.lora_alpha,
lora_dropout=training_run_args.lora_dropout,