allow setting prompt more generically

This commit is contained in:
Alex O'Connell
2025-11-30 15:59:20 -05:00
parent 9f51dd0e94
commit d352d88fd2
2 changed files with 20 additions and 9 deletions

View File

@@ -20,7 +20,18 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"%pip install -r requirements.txt\n", "%pip install -r requirements.txt\n",
"\n", "import os\n",
"os.environ[\"HF_HOME\"] = \"/home/jovyan/workspace/models\"\n",
"os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ae5f449",
"metadata": {},
"outputs": [],
"source": [
"import os, re\n", "import os, re\n",
"from train import TrainingRunArguments, do_training_run\n", "from train import TrainingRunArguments, do_training_run\n",
"\n", "\n",
@@ -33,10 +44,7 @@
" if match:\n", " if match:\n",
" max_rev = max(max_rev, int(match.group(1)))\n", " max_rev = max(max_rev, int(match.group(1)))\n",
"\n", "\n",
" return f\"{model}-rev{max_rev + 1}\"\n", " return f\"{model}-rev{max_rev + 1}\""
"\n",
"os.environ[\"HF_HOME\"] = \"/workspace/\"\n",
"os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\""
] ]
}, },
{ {
@@ -237,6 +245,8 @@
" batch_size=64, micro_batch_size=2, epochs=1,\n", " batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=8192,\n", " ctx_size=8192,\n",
" save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n", " save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n",
" prompt_template_file=\"scripts/chatml_template.j2\",\n",
" add_chatml_tokens=True, add_pad_token=True, add_tool_calling_tokens=True\n",
"))" "))"
] ]
} }

View File

@@ -72,7 +72,7 @@ class TrainingRunArguments:
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"}) add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"}) add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"}) add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"})
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"}) prompt_template_file: Optional[str] = field(default=None, metadata={"help": "If set, the prompt template will be set as the model's Jinja2 template"})
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
@@ -454,8 +454,8 @@ def do_training_run(training_run_args: TrainingRunArguments):
'additional_special_tokens': ['<tool_call>', '</tool_call>'] 'additional_special_tokens': ['<tool_call>', '</tool_call>']
}) })
if training_run_args.add_chatml_prompt_template: if training_run_args.prompt_template_file:
with open("scripts/chatml_template.j2", "r") as f: with open(training_run_args.prompt_template_file, "r") as f:
tokenizer.chat_template = f.read() tokenizer.chat_template = f.read()
# resize embeddings if added tokens require it # resize embeddings if added tokens require it
@@ -474,7 +474,8 @@ def do_training_run(training_run_args: TrainingRunArguments):
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
peft_config = LoraConfig( peft_config = LoraConfig(
task_type=TaskType. inference_mode=False, task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=training_run_args.lora_rank, r=training_run_args.lora_rank,
lora_alpha=training_run_args.lora_alpha, lora_alpha=training_run_args.lora_alpha,
lora_dropout=training_run_args.lora_dropout, lora_dropout=training_run_args.lora_dropout,