start re-working training to use axlotl instead of the custom script

This commit is contained in:
Alex O'Connell
2025-11-30 22:29:08 -05:00
parent 04a5909214
commit 55f254149a
14 changed files with 280 additions and 1309 deletions

61
train/chatml_template.j2 Normal file
View File

@@ -0,0 +1,61 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '\nNo tools were provided. If the user requests you interact with a device, tell them you are unable to do so.' + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}

107
train/gemma3-270m.yml Normal file
View File

@@ -0,0 +1,107 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: jinja
chat_template_jinja: |
{{ bos_token }}
{%- if not tools or tools | length == 0 %}No tools were provided. If the user requests you interact with a device, tell them you are unable to do so.{% else %}
Tools:
{% for tool in tools %}
- {{ tool['name'] }}({{ tool['parameters']['properties'].keys() | join(', ') }}): {{ tool['description'] }}
{% endfor -%}
{%- endif -%}
{%- for message in messages -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "model" -%}
{%- elif message['role'] == 'system' -%}
{%- set role = "user" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '
' }}
{%- if role == "tool" -%}
{{ '<tool_result>' }}
{%- endif -%}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- if not loop.last -%}
{{ '</tool_result>\n<tool_result>' }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{%- if role == "tool" -%}
{{ '</tool_result>' }}
{%- endif -%}
{%- if message['tool_calls'] is defined and message['tool_calls'] | length > 0 %}
{%- for tool_call in message["tool_calls"] -%}
{{ '\n<tool_call>{"name": "' + tool_call['name'] + '", "arguments": ' + ('"' + tool_call['arguments'] + '"' if tool_call['arguments'] is string else tool_call['arguments'] | tojson) + '"}</tool_call>' }}
{%- endfor %}
{%- endif -%}
{{ '<end_of_turn>
' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<start_of_turn>model
'}}
{%- endif -%}
special_tokens:
eot_tokens:
- <end_of_turn>
eos_token: <end_of_turn>
additional_special_tokens:
- <tool_call>
- </tool_call>
- <tool_result>
- </tool_result>
datasets:
- path: /workspace/data/datasets/sample.jsonl
ds_type: json
type: chat_template
roles_to_train:
- assistant
val_set_size: 0.0
output_dir: /workspace/data/training-runs/Home-Gemma3-270m
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
use_tensorboard: true
# batch size = 16
gradient_accumulation_steps: 16
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0

53
train/gemma3_withtools.j2 Normal file
View File

@@ -0,0 +1,53 @@
{{ bos_token }}
{%- if not tools or tools | length == 0 %}No tools were provided. If the user requests you interact with a device, tell them you are unable to do so.{% else %}
Tools:
{% for tool in tools %}
- {{ tool['name'] }}({{ tool['parameters']['properties'].keys() | join(', ') }}): {{ tool['description'] }}
{% endfor -%}
{%- endif -%}
{%- for message in messages -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "model" -%}
{%- elif message['role'] == 'system' -%}
{%- set role = "user" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '
' }}
{%- if role == "tool" -%}
{{ '<tool_result>' }}
{%- endif -%}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- if not loop.last -%}
{{ '</tool_result>\n<tool_result>' }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{%- if role == "tool" -%}
{{ '</tool_result>' }}
{%- endif -%}
{%- if message['tool_calls'] is defined and message['tool_calls'] | length > 0 %}
{%- for tool_call in message["tool_calls"] -%}
{{ '\n<tool_call>{"name": "' + tool_call['name'] + '", "arguments": ' + ('"' + tool_call['arguments'] + '"' if tool_call['arguments'] is string else tool_call['arguments'] | tojson) + '"}</tool_call>' }}
{%- endfor %}
{%- endif -%}
{{ '<end_of_turn>
' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<start_of_turn>model
'}}
{%- endif -%}

8
train/run.sh Normal file
View File

@@ -0,0 +1,8 @@
docker run -d --rm \
--gpus all \
-p 8888:8888 \
-v /mnt/data/training-runs:/workspace/data/axolotl-artifacts \
-v /mnt/data/training-data:/workspace/data/datasets \
-v /mnt/data/training-configs:/workspace/configs \
-v /mnt/data/hf-cache:/workspace/data/huggingface-cache \
axolotlai/axolotl-cloud:main-py3.11-cu128-2.8.0

17
train/zephyr_legacy.txt Normal file
View File

@@ -0,0 +1,17 @@
{% for message in messages %}
{%- if message['role'] == 'user' or message['role'] == 'tool' -%}
<|user|> {{ message['content'] }}{{ eos_token }}
{%- elif message['role'] == 'system' -%}
<|system|> {{ message['content'] }}
Services:
{%- for tool in tools %} {{ tool['function']['name'] }}({% for param in tool['function']['parameters']['properties'].keys() if param != 'target_device' %}{{ param }}{% if not loop.last %},{% endif %}{% endfor -%}),{% if not loop.last -%}
{%- if tools | length == 0 %}No tools were provided. If the user requests you interact with a device, tell them you are unable to do so.{% endif %}
{%- endif -%}{%- endfor -%}
{{ eos_token }}
{%- elif message['role'] == 'assistant' -%}
<|assistant|> {{ message['content'] }}{{ eos_token }}
{%- endif -%}
{%- if loop.last and add_generation_prompt %}
<|assistant|>
{%- endif %}
{% endfor -%}

View File

@@ -0,0 +1,11 @@
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}
{{- if eq .Role "system" }}
Services:
{{- range $.Tools }} {{ .Function.Name }}({{- range $key, $value := .Function.Parameters.Properties }}{{- if ne $key "target_device" }}{{ $key }},{{- end }}{{- end }}),{{- end }}
{{- else if eq .Role "assistant" }}
{{ if .ToolCalls }}```homeassistant
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{ end }}```{{ end }}
{{- end }}<|endoftext|>
{{ end }}<|assistant|>