diff --git a/train.ipynb b/train.ipynb index 0c0a50c..4878838 100644 --- a/train.ipynb +++ b/train.ipynb @@ -20,7 +20,18 @@ "outputs": [], "source": [ "%pip install -r requirements.txt\n", - "\n", + "import os\n", + "os.environ[\"HF_HOME\"] = \"/home/jovyan/workspace/models\"\n", + "os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ae5f449", + "metadata": {}, + "outputs": [], + "source": [ "import os, re\n", "from train import TrainingRunArguments, do_training_run\n", "\n", @@ -33,10 +44,7 @@ " if match:\n", " max_rev = max(max_rev, int(match.group(1)))\n", "\n", - " return f\"{model}-rev{max_rev + 1}\"\n", - "\n", - "os.environ[\"HF_HOME\"] = \"/workspace/\"\n", - "os.environ[\"HF_TOKEN\"] = \"your_huggingface_token_here\"" + " return f\"{model}-rev{max_rev + 1}\"" ] }, { @@ -237,6 +245,8 @@ " batch_size=64, micro_batch_size=2, epochs=1,\n", " ctx_size=8192,\n", " save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n", + " prompt_template_file=\"scripts/chatml_template.j2\",\n", + " add_chatml_tokens=True, add_pad_token=True, add_tool_calling_tokens=True\n", "))" ] } diff --git a/train.py b/train.py index cc4e756..e40810b 100644 --- a/train.py +++ b/train.py @@ -72,7 +72,7 @@ class TrainingRunArguments: add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"}) add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"}) add_tool_calling_tokens: bool = field(default=False, metadata={"help": "If set, tokens for tool calling will be added specifically"}) - add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"}) + prompt_template_file: Optional[str] = field(default=None, metadata={"help": "If set, the prompt template will be set as the model's Jinja2 template"}) prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) @@ -454,8 +454,8 @@ def do_training_run(training_run_args: TrainingRunArguments): 'additional_special_tokens': ['', ''] }) - if training_run_args.add_chatml_prompt_template: - with open("scripts/chatml_template.j2", "r") as f: + if training_run_args.prompt_template_file: + with open(training_run_args.prompt_template_file, "r") as f: tokenizer.chat_template = f.read() # resize embeddings if added tokens require it @@ -474,7 +474,8 @@ def do_training_run(training_run_args: TrainingRunArguments): target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None peft_config = LoraConfig( - task_type=TaskType. inference_mode=False, + task_type=TaskType.CAUSAL_LM, + inference_mode=False, r=training_run_args.lora_rank, lora_alpha=training_run_args.lora_alpha, lora_dropout=training_run_args.lora_dropout,