mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
allow setting prompt more generically
This commit is contained in:
20
train.ipynb
20
train.ipynb
@@ -20,7 +20,18 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install -r requirements.txt\n",
|
"%pip install -r requirements.txt\n",
|
||||||
"\n",
|
"import os\n",
|
||||||
|
"os.environ[\"HF_HOME\"] = \"/home/jovyan/workspace/models\"\n",
|
||||||
|
"os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2ae5f449",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
"import os, re\n",
|
"import os, re\n",
|
||||||
"from train import TrainingRunArguments, do_training_run\n",
|
"from train import TrainingRunArguments, do_training_run\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -33,10 +44,7 @@
|
|||||||
" if match:\n",
|
" if match:\n",
|
||||||
" max_rev = max(max_rev, int(match.group(1)))\n",
|
" max_rev = max(max_rev, int(match.group(1)))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return f\"{model}-rev{max_rev + 1}\"\n",
|
" return f\"{model}-rev{max_rev + 1}\""
|
||||||
"\n",
|
|
||||||
"os.environ[\"HF_HOME\"] = \"/workspace/\"\n",
|
|
||||||
"os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\""
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -237,6 +245,8 @@
|
|||||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||||
" ctx_size=8192,\n",
|
" ctx_size=8192,\n",
|
||||||
" save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n",
|
" save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n",
|
||||||
|
" prompt_template_file=\"scripts/chatml_template.j2\",\n",
|
||||||
|
" add_chatml_tokens=True, add_pad_token=True, add_tool_calling_tokens=True\n",
|
||||||
"))"
|
"))"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
9
train.py
9
train.py
@@ -72,7 +72,7 @@ class TrainingRunArguments:
|
|||||||
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
|
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
|
||||||
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
|
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
|
||||||
add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"})
|
add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"})
|
||||||
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
|
prompt_template_file: Optional[str] = field(default=None, metadata={"help": "If set, the prompt template will be set as the model's Jinja2 template"})
|
||||||
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||||
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||||
|
|
||||||
@@ -454,8 +454,8 @@ def do_training_run(training_run_args: TrainingRunArguments):
|
|||||||
'additional_special_tokens': ['<tool_call>', '</tool_call>']
|
'additional_special_tokens': ['<tool_call>', '</tool_call>']
|
||||||
})
|
})
|
||||||
|
|
||||||
if training_run_args.add_chatml_prompt_template:
|
if training_run_args.prompt_template_file:
|
||||||
with open("scripts/chatml_template.j2", "r") as f:
|
with open(training_run_args.prompt_template_file, "r") as f:
|
||||||
tokenizer.chat_template = f.read()
|
tokenizer.chat_template = f.read()
|
||||||
|
|
||||||
# resize embeddings if added tokens require it
|
# resize embeddings if added tokens require it
|
||||||
@@ -474,7 +474,8 @@ def do_training_run(training_run_args: TrainingRunArguments):
|
|||||||
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
|
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
|
||||||
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
|
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
|
||||||
peft_config = LoraConfig(
|
peft_config = LoraConfig(
|
||||||
task_type=TaskType. inference_mode=False,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
inference_mode=False,
|
||||||
r=training_run_args.lora_rank,
|
r=training_run_args.lora_rank,
|
||||||
lora_alpha=training_run_args.lora_alpha,
|
lora_alpha=training_run_args.lora_alpha,
|
||||||
lora_dropout=training_run_args.lora_dropout,
|
lora_dropout=training_run_args.lora_dropout,
|
||||||
|
|||||||
Reference in New Issue
Block a user