more training updates

This commit is contained in:
Alex O'Connell
2025-11-30 15:31:58 -05:00
parent 1a5445e68a
commit 9f51dd0e94
4 changed files with 93 additions and 25 deletions

View File

@@ -8,6 +8,11 @@ from typing import Callable
from tqdm import tqdm
import webcolors
# ensure we can import from the data/ directory
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from device_types import *
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException

View File

@@ -0,0 +1,61 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '\nNo tools were provided. If the user requests you interact with a device, tell them you are unable to do so.' + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}

View File

@@ -9,7 +9,7 @@
"source": [
"%%bash\n",
"git config --global --add safe.directory /home/jovyan/workspace\n",
"git checkout -b <your branch>"
"git checkout <your branch>"
]
},
{
@@ -56,8 +56,11 @@
"source": [
"%pip install -r data/requirements.txt\n",
"from data.generate_data import main as generate_data\n",
"import os\n",
"\n",
"generate_data([\"--train\", \"--test\", \"--large\", \"--language\", \"english\", \"german\", \"french\", \"spanish\", \"polish\"])"
"os.chdir(\"./data\")\n",
"generate_data([\"--train\", \"--test\", \"--large\", \"--language\", \"english\", \"german\", \"french\", \"spanish\", \"polish\"])\n",
"os.chdir(\"..\")"
]
},
{
@@ -90,8 +93,8 @@
" run_name=get_next_run_name(\"Home-Llama-3.2-1B\"),\n",
" base_model=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
" bf16=True,\n",
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -120,8 +123,8 @@
" run_name=get_next_run_name(\"Home-Qwen-3-1.7B\"),\n",
" base_model=\"Qwen/Qwen3-1.7B\",\n",
" bf16=True,\n",
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -149,8 +152,8 @@
" run_name=get_next_run_name(\"Home-Qwen-2.5-0.6B\"),\n",
" base_model=\"Qwen/Qwen2.5-0.6B-Instruct\",\n",
" bf16=True,\n",
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -170,8 +173,8 @@
" run_name=get_next_run_name(\"Home-Qwen-2.5-1.5B\"),\n",
" base_model=\"Qwen/Qwen2.5-1.5B-Instruct\",\n",
" bf16=True,\n",
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -199,8 +202,8 @@
" run_name=get_next_run_name(\"Home-Gemma-3-1B\"),\n",
" base_model=\"google/gemma-3-1b-it\",\n",
" bf16=True,\n",
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
@@ -213,7 +216,7 @@
"id": "21865d91",
"metadata": {},
"source": [
"# Gemma 3 270m"
"## Gemma 3 270m"
]
},
{
@@ -228,8 +231,8 @@
" run_name=get_next_run_name(\"Home-Gemma-3-270m\"),\n",
" base_model=\"google/gemma-3-270m\",\n",
" bf16=True,\n",
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=8192,\n",

View File

@@ -71,6 +71,7 @@ class TrainingRunArguments:
# token options
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"})
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
@@ -448,15 +449,14 @@ def do_training_run(training_run_args: TrainingRunArguments):
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
if training_run_args.add_tool_calling_tokens:
tokenizer.add_special_tokens({
'additional_special_tokens': ['<tool_call>', '</tool_call>']
})
if training_run_args.add_chatml_prompt_template:
tokenizer.chat_template = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
with open("scripts/chatml_template.j2", "r") as f:
tokenizer.chat_template = f.read()
# resize embeddings if added tokens require it
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
@@ -474,8 +474,7 @@ def do_training_run(training_run_args: TrainingRunArguments):
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
task_type=TaskType. inference_mode=False,
r=training_run_args.lora_rank,
lora_alpha=training_run_args.lora_alpha,
lora_dropout=training_run_args.lora_dropout,