allow setting prompt more generically

This commit is contained in:
Alex O'Connell
2025-11-30 15:59:20 -05:00
parent 9f51dd0e94
commit d352d88fd2
2 changed files with 20 additions and 9 deletions

View File

@@ -20,7 +20,18 @@
"outputs": [],
"source": [
"%pip install -r requirements.txt\n",
"\n",
"import os\n",
"os.environ[\"HF_HOME\"] = \"/home/jovyan/workspace/models\"\n",
"os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ae5f449",
"metadata": {},
"outputs": [],
"source": [
"import os, re\n",
"from train import TrainingRunArguments, do_training_run\n",
"\n",
@@ -33,10 +44,7 @@
" if match:\n",
" max_rev = max(max_rev, int(match.group(1)))\n",
"\n",
" return f\"{model}-rev{max_rev + 1}\"\n",
"\n",
"os.environ[\"HF_HOME\"] = \"/workspace/\"\n",
"os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\""
" return f\"{model}-rev{max_rev + 1}\""
]
},
{
@@ -237,6 +245,8 @@
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=8192,\n",
" save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n",
" prompt_template_file=\"scripts/chatml_template.j2\",\n",
" add_chatml_tokens=True, add_pad_token=True, add_tool_calling_tokens=True\n",
"))"
]
}

View File

@@ -72,7 +72,7 @@ class TrainingRunArguments:
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"})
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
prompt_template_file: Optional[str] = field(default=None, metadata={"help": "If set, the prompt template will be set as the model's Jinja2 template"})
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
@@ -454,8 +454,8 @@ def do_training_run(training_run_args: TrainingRunArguments):
'additional_special_tokens': ['<tool_call>', '</tool_call>']
})
if training_run_args.add_chatml_prompt_template:
with open("scripts/chatml_template.j2", "r") as f:
if training_run_args.prompt_template_file:
with open(training_run_args.prompt_template_file, "r") as f:
tokenizer.chat_template = f.read()
# resize embeddings if added tokens require it
@@ -474,7 +474,8 @@ def do_training_run(training_run_args: TrainingRunArguments):
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
peft_config = LoraConfig(
task_type=TaskType. inference_mode=False,
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=training_run_args.lora_rank,
lora_alpha=training_run_args.lora_alpha,
lora_dropout=training_run_args.lora_dropout,