mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
more training updates
This commit is contained in:
@@ -8,6 +8,11 @@ from typing import Callable
|
||||
from tqdm import tqdm
|
||||
import webcolors
|
||||
|
||||
# ensure we can import from the data/ directory
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from device_types import *
|
||||
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
|
||||
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException
|
||||
|
||||
61
scripts/chatml_template.j2
Normal file
61
scripts/chatml_template.j2
Normal file
@@ -0,0 +1,61 @@
|
||||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- messages[0].content + '\n\n' }}
|
||||
{%- endif %}
|
||||
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '\nNo tools were provided. If the user requests you interact with a device, tell them you are unable to do so.' + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if message.content is string %}
|
||||
{%- set content = message.content %}
|
||||
{%- else %}
|
||||
{%- set content = '' %}
|
||||
{%- endif %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{%- if tool_call.arguments is string %}
|
||||
{{- tool_call.arguments }}
|
||||
{%- else %}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{%- endif %}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
33
train.ipynb
33
train.ipynb
@@ -9,7 +9,7 @@
|
||||
"source": [
|
||||
"%%bash\n",
|
||||
"git config --global --add safe.directory /home/jovyan/workspace\n",
|
||||
"git checkout -b <your branch>"
|
||||
"git checkout <your branch>"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -56,8 +56,11 @@
|
||||
"source": [
|
||||
"%pip install -r data/requirements.txt\n",
|
||||
"from data.generate_data import main as generate_data\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"generate_data([\"--train\", \"--test\", \"--large\", \"--language\", \"english\", \"german\", \"french\", \"spanish\", \"polish\"])"
|
||||
"os.chdir(\"./data\")\n",
|
||||
"generate_data([\"--train\", \"--test\", \"--large\", \"--language\", \"english\", \"german\", \"french\", \"spanish\", \"polish\"])\n",
|
||||
"os.chdir(\"..\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -90,8 +93,8 @@
|
||||
" run_name=get_next_run_name(\"Home-Llama-3.2-1B\"),\n",
|
||||
" base_model=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
|
||||
" bf16=True,\n",
|
||||
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
|
||||
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
|
||||
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
|
||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||
" ctx_size=2048,\n",
|
||||
@@ -120,8 +123,8 @@
|
||||
" run_name=get_next_run_name(\"Home-Qwen-3-1.7B\"),\n",
|
||||
" base_model=\"Qwen/Qwen3-1.7B\",\n",
|
||||
" bf16=True,\n",
|
||||
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
|
||||
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
|
||||
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
|
||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||
" ctx_size=2048,\n",
|
||||
@@ -149,8 +152,8 @@
|
||||
" run_name=get_next_run_name(\"Home-Qwen-2.5-0.6B\"),\n",
|
||||
" base_model=\"Qwen/Qwen2.5-0.6B-Instruct\",\n",
|
||||
" bf16=True,\n",
|
||||
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
|
||||
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
|
||||
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
|
||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||
" ctx_size=2048,\n",
|
||||
@@ -170,8 +173,8 @@
|
||||
" run_name=get_next_run_name(\"Home-Qwen-2.5-1.5B\"),\n",
|
||||
" base_model=\"Qwen/Qwen2.5-1.5B-Instruct\",\n",
|
||||
" bf16=True,\n",
|
||||
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
|
||||
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
|
||||
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
|
||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||
" ctx_size=2048,\n",
|
||||
@@ -199,8 +202,8 @@
|
||||
" run_name=get_next_run_name(\"Home-Gemma-3-1B\"),\n",
|
||||
" base_model=\"google/gemma-3-1b-it\",\n",
|
||||
" bf16=True,\n",
|
||||
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
|
||||
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
|
||||
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
|
||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||
" ctx_size=2048,\n",
|
||||
@@ -213,7 +216,7 @@
|
||||
"id": "21865d91",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gemma 3 270m"
|
||||
"## Gemma 3 270m"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -228,8 +231,8 @@
|
||||
" run_name=get_next_run_name(\"Home-Gemma-3-270m\"),\n",
|
||||
" base_model=\"google/gemma-3-270m\",\n",
|
||||
" bf16=True,\n",
|
||||
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
|
||||
" train_dataset=\"data/output/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/output/home_assistant_test.jsonl\",\n",
|
||||
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
|
||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||
" ctx_size=8192,\n",
|
||||
|
||||
19
train.py
19
train.py
@@ -71,6 +71,7 @@ class TrainingRunArguments:
|
||||
# token options
|
||||
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
|
||||
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
|
||||
add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"})
|
||||
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
|
||||
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||
@@ -448,15 +449,14 @@ def do_training_run(training_run_args: TrainingRunArguments):
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if training_run_args.add_tool_calling_tokens:
|
||||
tokenizer.add_special_tokens({
|
||||
'additional_special_tokens': ['<tool_call>', '</tool_call>']
|
||||
})
|
||||
|
||||
if training_run_args.add_chatml_prompt_template:
|
||||
tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '<|im_start|>assistant\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
with open("scripts/chatml_template.j2", "r") as f:
|
||||
tokenizer.chat_template = f.read()
|
||||
|
||||
# resize embeddings if added tokens require it
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
@@ -474,8 +474,7 @@ def do_training_run(training_run_args: TrainingRunArguments):
|
||||
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
|
||||
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
task_type=TaskType. inference_mode=False,
|
||||
r=training_run_args.lora_rank,
|
||||
lora_alpha=training_run_args.lora_alpha,
|
||||
lora_dropout=training_run_args.lora_dropout,
|
||||
|
||||
Reference in New Issue
Block a user