mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
allow setting prompt more generically
This commit is contained in:
20
train.ipynb
20
train.ipynb
@@ -20,7 +20,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -r requirements.txt\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"HF_HOME\"] = \"/home/jovyan/workspace/models\"\n",
|
||||
"os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2ae5f449",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os, re\n",
|
||||
"from train import TrainingRunArguments, do_training_run\n",
|
||||
"\n",
|
||||
@@ -33,10 +44,7 @@
|
||||
" if match:\n",
|
||||
" max_rev = max(max_rev, int(match.group(1)))\n",
|
||||
"\n",
|
||||
" return f\"{model}-rev{max_rev + 1}\"\n",
|
||||
"\n",
|
||||
"os.environ[\"HF_HOME\"] = \"/workspace/\"\n",
|
||||
"os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\""
|
||||
" return f\"{model}-rev{max_rev + 1}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -237,6 +245,8 @@
|
||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||
" ctx_size=8192,\n",
|
||||
" save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n",
|
||||
" prompt_template_file=\"scripts/chatml_template.j2\",\n",
|
||||
" add_chatml_tokens=True, add_pad_token=True, add_tool_calling_tokens=True\n",
|
||||
"))"
|
||||
]
|
||||
}
|
||||
|
||||
9
train.py
9
train.py
@@ -72,7 +72,7 @@ class TrainingRunArguments:
|
||||
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
|
||||
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
|
||||
add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"})
|
||||
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
|
||||
prompt_template_file: Optional[str] = field(default=None, metadata={"help": "If set, the prompt template will be set as the model's Jinja2 template"})
|
||||
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||
|
||||
@@ -454,8 +454,8 @@ def do_training_run(training_run_args: TrainingRunArguments):
|
||||
'additional_special_tokens': ['<tool_call>', '</tool_call>']
|
||||
})
|
||||
|
||||
if training_run_args.add_chatml_prompt_template:
|
||||
with open("scripts/chatml_template.j2", "r") as f:
|
||||
if training_run_args.prompt_template_file:
|
||||
with open(training_run_args.prompt_template_file, "r") as f:
|
||||
tokenizer.chat_template = f.read()
|
||||
|
||||
# resize embeddings if added tokens require it
|
||||
@@ -474,7 +474,8 @@ def do_training_run(training_run_args: TrainingRunArguments):
|
||||
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
|
||||
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType. inference_mode=False,
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=training_run_args.lora_rank,
|
||||
lora_alpha=training_run_args.lora_alpha,
|
||||
lora_dropout=training_run_args.lora_dropout,
|
||||
|
||||
Reference in New Issue
Block a user