diff --git a/data/generate_data.py b/data/generate_data.py index b306085..97bf051 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -8,6 +8,11 @@ from typing import Callable from tqdm import tqdm import webcolors +# ensure we can import from the data/ directory +import os +import sys +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + from device_types import * from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException diff --git a/scripts/chatml_template.j2 b/scripts/chatml_template.j2 new file mode 100644 index 0000000..e09aaf3 --- /dev/null +++ b/scripts/chatml_template.j2 @@ -0,0 +1,61 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '\nNo tools were provided. If the user requests you interact with a device, tell them you are unable to do so.' + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/train.ipynb b/train.ipynb index be0e49c..0c0a50c 100644 --- a/train.ipynb +++ b/train.ipynb @@ -9,7 +9,7 @@ "source": [ "%%bash\n", "git config --global --add safe.directory /home/jovyan/workspace\n", - "git checkout -b " + "git checkout " ] }, { @@ -56,8 +56,11 @@ "source": [ "%pip install -r data/requirements.txt\n", "from data.generate_data import main as generate_data\n", + "import os\n", "\n", - "generate_data([\"--train\", \"--test\", \"--large\", \"--language\", \"english\", \"german\", \"french\", \"spanish\", \"polish\"])" + "os.chdir(\"./data\")\n", + "generate_data([\"--train\", \"--test\", \"--large\", \"--language\", \"english\", \"german\", \"french\", \"spanish\", \"polish\"])\n", + "os.chdir(\"..\")" ] }, { @@ -90,8 +93,8 @@ " run_name=get_next_run_name(\"Home-Llama-3.2-1B\"),\n", " base_model=\"meta-llama/Llama-3.2-1B-Instruct\",\n", " bf16=True,\n", - " train_dataset=\"data/home_assistant_train.jsonl\",\n", - " test_dataset=\"data/home_assistant_test.jsonl\",\n", + " train_dataset=\"data/output/home_assistant_train.jsonl\",\n", + " test_dataset=\"data/output/home_assistant_test.jsonl\",\n", " learning_rate=2e-5, learning_rate_warmup=0.03, \n", " batch_size=64, micro_batch_size=2, epochs=1,\n", " ctx_size=2048,\n", @@ -120,8 +123,8 @@ " run_name=get_next_run_name(\"Home-Qwen-3-1.7B\"),\n", " base_model=\"Qwen/Qwen3-1.7B\",\n", " bf16=True,\n", - " train_dataset=\"data/home_assistant_train.jsonl\",\n", - " test_dataset=\"data/home_assistant_test.jsonl\",\n", + " train_dataset=\"data/output/home_assistant_train.jsonl\",\n", + " test_dataset=\"data/output/home_assistant_test.jsonl\",\n", " learning_rate=2e-5, learning_rate_warmup=0.03, \n", " batch_size=64, micro_batch_size=2, epochs=1,\n", " ctx_size=2048,\n", @@ -149,8 +152,8 @@ " run_name=get_next_run_name(\"Home-Qwen-2.5-0.6B\"),\n", " base_model=\"Qwen/Qwen2.5-0.6B-Instruct\",\n", " bf16=True,\n", - " train_dataset=\"data/home_assistant_train.jsonl\",\n", - " test_dataset=\"data/home_assistant_test.jsonl\",\n", + " train_dataset=\"data/output/home_assistant_train.jsonl\",\n", + " test_dataset=\"data/output/home_assistant_test.jsonl\",\n", " learning_rate=2e-5, learning_rate_warmup=0.03, \n", " batch_size=64, micro_batch_size=2, epochs=1,\n", " ctx_size=2048,\n", @@ -170,8 +173,8 @@ " run_name=get_next_run_name(\"Home-Qwen-2.5-1.5B\"),\n", " base_model=\"Qwen/Qwen2.5-1.5B-Instruct\",\n", " bf16=True,\n", - " train_dataset=\"data/home_assistant_train.jsonl\",\n", - " test_dataset=\"data/home_assistant_test.jsonl\",\n", + " train_dataset=\"data/output/home_assistant_train.jsonl\",\n", + " test_dataset=\"data/output/home_assistant_test.jsonl\",\n", " learning_rate=2e-5, learning_rate_warmup=0.03, \n", " batch_size=64, micro_batch_size=2, epochs=1,\n", " ctx_size=2048,\n", @@ -199,8 +202,8 @@ " run_name=get_next_run_name(\"Home-Gemma-3-1B\"),\n", " base_model=\"google/gemma-3-1b-it\",\n", " bf16=True,\n", - " train_dataset=\"data/home_assistant_train.jsonl\",\n", - " test_dataset=\"data/home_assistant_test.jsonl\",\n", + " train_dataset=\"data/output/home_assistant_train.jsonl\",\n", + " test_dataset=\"data/output/home_assistant_test.jsonl\",\n", " learning_rate=2e-5, learning_rate_warmup=0.03, \n", " batch_size=64, micro_batch_size=2, epochs=1,\n", " ctx_size=2048,\n", @@ -213,7 +216,7 @@ "id": "21865d91", "metadata": {}, "source": [ - "# Gemma 3 270m" + "## Gemma 3 270m" ] }, { @@ -228,8 +231,8 @@ " run_name=get_next_run_name(\"Home-Gemma-3-270m\"),\n", " base_model=\"google/gemma-3-270m\",\n", " bf16=True,\n", - " train_dataset=\"data/home_assistant_train.jsonl\",\n", - " test_dataset=\"data/home_assistant_test.jsonl\",\n", + " train_dataset=\"data/output/home_assistant_train.jsonl\",\n", + " test_dataset=\"data/output/home_assistant_test.jsonl\",\n", " learning_rate=2e-5, learning_rate_warmup=0.03, \n", " batch_size=64, micro_batch_size=2, epochs=1,\n", " ctx_size=8192,\n", diff --git a/train.py b/train.py index 2adbba2..cc4e756 100644 --- a/train.py +++ b/train.py @@ -71,6 +71,7 @@ class TrainingRunArguments: # token options add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"}) add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"}) + add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"}) add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"}) prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) @@ -448,15 +449,14 @@ def do_training_run(training_run_args: TrainingRunArguments): model.config.bos_token_id = tokenizer.bos_token_id model.config.eos_token_id = tokenizer.eos_token_id + if training_run_args.add_tool_calling_tokens: + tokenizer.add_special_tokens({ + 'additional_special_tokens': ['', ''] + }) + if training_run_args.add_chatml_prompt_template: - tokenizer.chat_template = ( - "{% for message in messages %}" - "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{ '<|im_start|>assistant\n' }}" - "{% endif %}" - ) + with open("scripts/chatml_template.j2", "r") as f: + tokenizer.chat_template = f.read() # resize embeddings if added tokens require it embeddings_len = math.ceil(len(tokenizer) / 32) * 32 @@ -474,8 +474,7 @@ def do_training_run(training_run_args: TrainingRunArguments): target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, + task_type=TaskType. inference_mode=False, r=training_run_args.lora_rank, lora_alpha=training_run_args.lora_alpha, lora_dropout=training_run_args.lora_dropout,