From 2712f605a5a8dbef3b48913c1246afc627c8999a Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Mon, 10 Feb 2025 17:09:49 -0500 Subject: [PATCH] fix evaluate + add train notebook --- data/generate_home_assistant_data.py | 4 +- evaluate.py | 6 +- train.ipynb | 111 +++++ train.py | 579 ++++++++++++++------------- 4 files changed, 411 insertions(+), 289 deletions(-) create mode 100644 train.ipynb diff --git a/data/generate_home_assistant_data.py b/data/generate_home_assistant_data.py index 4d3dcc5..e601816 100644 --- a/data/generate_home_assistant_data.py +++ b/data/generate_home_assistant_data.py @@ -1151,7 +1151,7 @@ def load_dataset_piles(language): # TODO: answer questions about more than one thing in the state list at once # TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen" # TODO: add time, weather, and calendar/reminders (next 3 events?) -def main(): +def main(args=None): parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles") parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.") parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.") @@ -1171,7 +1171,7 @@ def main(): dataset_format_group.add_argument('--raw_corpus', action='store_const', const='raw', dest='format') dataset_format_group.add_argument('--sharegpt', action='store_const', const='sharegpt', dest='format') - args = parser.parse_args() + args = parser.parse_args(args=args) if not args.sample and not args.train and not args.test and not args.merge and not args.dpo: parser.print_usage() diff --git a/evaluate.py b/evaluate.py index ffb60e0..cecdb51 100644 --- a/evaluate.py +++ b/evaluate.py @@ -290,8 +290,8 @@ def load_model(model_name, is_lora, is_hf, load_in_8bit, checkpoint_name): top_k=40, top_p=1.0, repetition_penalty=1.15, - # eos_token_id=trained_model.config.eos_token_id, - eos_token_id=128009, + eos_token_id=trained_model.config.eos_token_id, + # eos_token_id=128009, pad_token_id=trained_model.config.pad_token_id if trained_model.config.pad_token_id else trained_model.config.eos_token_id, ) @@ -350,7 +350,7 @@ def main(): print(f"Evaluation already exists for {output_folder}. Skipping...") continue - trained_model, trained_tokenizer = load_model(args.model, args.lora, ckpt, False) + trained_model, trained_tokenizer = load_model(args.model, args.lora, False, False, ckpt) evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_size, False) diff --git a/train.ipynb b/train.ipynb new file mode 100644 index 0000000..2e934b8 --- /dev/null +++ b/train.ipynb @@ -0,0 +1,111 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "469a9a97-0f6b-475f-8aef-a796c1c5244f", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -r requirements.txt\n", + "\n", + "import os, re\n", + "from train import TrainingRunArguments, do_training_run\n", + "\n", + "def get_next_run_name(model):\n", + " pattern = re.compile(model + r\"-rev(\\d+)$\")\n", + " max_rev = 0\n", + "\n", + " for folder in os.listdir(\"models/\"):\n", + " match = pattern.search(folder)\n", + " if match:\n", + " max_rev = max(max_rev, int(match.group(1)))\n", + "\n", + " return f\"{model}-rev{max_rev + 1}\"\n", + "\n", + "os.environ[\"HF_HOME\"] = \"/workspace/\"" + ] + }, + { + "cell_type": "markdown", + "id": "ed0807bf", + "metadata": {}, + "source": [ + "## Generate Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aaafce74", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -r data/requirements.txt\n", + "from data.generate_home_assistant_data import main as generate_data\n", + "\n", + "generate_data([\"--train\", \"--test\", \"--large\", \"--sharegpt\", \"--language\", \"english\", \"german\", \"french\", \"spanish\"])" + ] + }, + { + "cell_type": "markdown", + "id": "ff011772", + "metadata": {}, + "source": [ + "## Llama 3.2 1B" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48839ce2-1939-4d7f-817c-97b047bafd42", + "metadata": {}, + "outputs": [], + "source": [ + "# python3 train.py \\\n", + "# --run_name Home-Llama-3.2-1B-rev1 \\\n", + "# --base_model meta-llama/Llama-3.2-1B-Instruct \\\n", + "# --bf16 \\\n", + "# --train_dataset data/home_assistant_train.jsonl \\\n", + "# --test_dataset data/home_assistant_test.jsonl \\\n", + "# --learning_rate 2e-5 --learning_rate_warmup 0.03 --batch_size 64 --epochs 1 \\\n", + "# --micro_batch_size 2 \\\n", + "# --ctx_size 2048 \\\n", + "# --save_steps 200 --save_total_limit 1 --eval_steps 200 --logging_steps 2\n", + "\n", + "do_training_run(TrainingRunArguments(\n", + " run_name=get_next_run_name(\"Home-Llama-3.2-1B\"),\n", + " base_model=\"meta-llama/Llama-3.2-1B-Instruct\",\n", + " bf16=True,\n", + " train_dataset=\"data/home_assistant_train.jsonl\",\n", + " test_dataset=\"data/home_assistant_test.jsonl\",\n", + " learning_rate=2e-5, learning_rate_warmup=0.03, \n", + " batch_size=64, micro_batch_size=2, epochs=1,\n", + " ctx_size=2048,\n", + " save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n", + "))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/train.py b/train.py index bcf1fb9..ec0f8bf 100644 --- a/train.py +++ b/train.py @@ -9,14 +9,14 @@ import time import shutil from torch.utils.data import SequentialSampler, Subset, RandomSampler from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, \ - PreTrainedTokenizerFast, HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback, BitsAndBytesConfig -from transformers.trainer_utils import EvalPrediction + HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback, BitsAndBytesConfig from transformers.integrations.integration_utils import TensorBoardCallback -from datasets import load_dataset, Dataset +from datasets import load_dataset from dataclasses import dataclass, field -from typing import Dict, Optional, Sequence, Sized, Iterator +from typing import Dict, Optional, Sequence +IS_DDP_ENABLED = "LOCAL_RANK" in os.environ MULTI_GPU_WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1")) MULTI_GPU_RANK = int(os.environ.get("RANK", "0")) IS_MULTI_GPU = os.environ.get("RANK") != None @@ -30,21 +30,23 @@ class TrainingRunArguments: train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"}) test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"}) ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"}) - bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp16"}) + bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"}) batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"}) micro_batch_size: int = field(default=2, metadata={"help": "The actual batch size that will fit into VRAM on this machine"}) epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"}) learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"}) learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"}) learning_rate_warmup: float = field(default=0.0, metadata={"help": "The starting learning rate (speed at which the model trains)"}) - weight_decay: float = field(default=0.1, metadata={"help": ""}) - gradient_clip: float = field(default=1.0, metadata={"help": ""}) + weight_decay: float = field(default=0.1, metadata={"help": "Weight Decay rate for regularization. Rate to reduce all neuron weights towards zero."}) + dropout: float = field(default=0.01, metadata={"help": "Dropout percent for regularization. Determines the fraction of neurons randomly deactivated during training."}) + gradient_clip: float = field(default=1.0, metadata={"help": "Maximum gradient norm for clipping to prevent exploding gradients during training."}) resume_from_checkpoint: str = field(default="", metadata={"help": "The name of the checkpoint to resume training from"}) eval_steps: int = field(default=200, metadata={"help": "The number of steps in between evaluations of the model; set to -1 to evaluate every epoch"}) save_steps: int = field(default=-1, metadata={"help": "The number of steps in between model checkpoints; set to -1 to save every epoch"}) save_total_limit: int = field(default=1, metadata={"help": "The number of recent checkpoints of the model to save (not including the final model)"}) logging_steps: int = field(default=5, metadata={"help": "Sets the number of steps in between log output for the training run"}) - group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding"}) + group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding. Runs from longest to shortest examples."}) + gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing to saves VRAM at the cost of re-computing activations during the backwards pass"}) pre_allocate_cuda_buffers: bool = field(default=True, metadata={"help": "If enabled, runs a forward and backward pass on the model before training to force pytorch to allocate the correct size CUDA buffers up front"}) # Quantization @@ -61,21 +63,22 @@ class TrainingRunArguments: lora_modules_to_save: str = field(default=None, metadata={"help": "Additional modules to save"}) lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"}) + # dpo config dpo: bool = field(default=False, metadata={"help": "If set, performs Direct Preference Optimization instead of Supervised Fine Tuning"}) beta: float = field(default=0.1, metadata={"help": "The implicit reward value used during DPO training"}) dpo_loss: str = field(default="sigmoid", metadata={"help": "The loss type to use during DPO training"}) + # token options add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"}) add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"}) add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"}) - gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing which saves quite a lot of VRAM"}) + prefix_ids: str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) + suffix_ids: str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) + # custom trainer tweaks sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"}) flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"}) - prefix_ids:str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognise response."}) - suffix_ids:str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognise response."}) - class UploadToS3Callback(TrainerCallback): def __init__(self, s3_bucket, s3_prefix, save_total_limit=None): @@ -146,153 +149,25 @@ class MFUCallback(TrainerCallback): self.start_time = current_time self.last_total_flos = state.total_flos - - -parser = HfArgumentParser([TrainingRunArguments]) -training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - -if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1: - raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq") - -if IS_MASTER_PROCESS: - print(f"Loading model '{training_run_args.base_model}'...") - -model_kwargs = {} -if training_run_args.load_in_8bit: - model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) -elif training_run_args.load_in_4bit: - model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) -elif training_run_args.load_as_gptq: - model_kwargs["quantization_config"] = GPTQConfig(bits=4, disable_exllama=True) -if training_run_args.bf16: - model_kwargs["torch_dtype"] = torch.bfloat16 -else: - model_kwargs["torch_dtype"] = torch.float16 - -# model_kwargs["resid_pdrop"] = 0.0 -# model_kwargs["revision"] = "accfee56d8988cae60915486310362db5831b1bd" -model_kwargs["use_cache"] = False +def ddp_print(*args, **kwargs): + if not IS_DDP_ENABLED or IS_MASTER_PROCESS: + print(*args, **kwargs) def find_max_vram(min_buffer_mib=800): max_memory = {} for i in range(torch.cuda.device_count()): - total_mem = (torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)) - suggestion = round((total_mem - 1000) / 1000) * 1000 - suggestion = min(suggestion, total_mem - min_buffer_mib) + gpu_properties = torch.cuda.get_device_properties(i) + total_memory_mib = (gpu_properties.total_memory / (1000 * 1000)) + suggestion = max(total_memory_mib - 1000, min_buffer_mib) - if IS_MASTER_PROCESS: - print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}") + ddp_print(f"GPU {i}: {gpu_properties.name}, Total Memory: {gpu_properties.total_memory / (1024**3):.2f} GB") + ddp_print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}") max_memory[i] = f'{suggestion}MiB' return max_memory -if "LOCAL_RANK" not in os.environ: - model_kwargs["device_map"] = "auto" - -model = AutoModelForCausalLM.from_pretrained( - training_run_args.base_model, - max_memory=find_max_vram(), - token=os.environ.get("HF_TOKEN"), - **model_kwargs -) -tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, token=os.environ.get("HF_TOKEN")) - -if training_run_args.add_pad_token: - tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) - model.config.pad_token_id = tokenizer.pad_token_id - -if training_run_args.add_chatml_tokens: - tokenizer.add_special_tokens({ - 'bos_token': '<|im_start|>', - 'eos_token': '<|im_end|>' - }) - - model.config.bos_token_id = tokenizer.bos_token_id - model.config.eos_token_id = tokenizer.eos_token_id - -if training_run_args.add_chatml_prompt_template: - tokenizer.chat_template = ( - "{% for message in messages %}" - "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{ '<|im_start|>assistant\n' }}" - "{% endif %}" - ) - -embeddings_len = math.ceil(len(tokenizer) / 32) * 32 -if model.get_input_embeddings().num_embeddings < embeddings_len: - model.resize_token_embeddings(embeddings_len) -else: - model.tie_weights() - -# model.tie_weights() - -original_model = model -peft_config = None -if training_run_args.use_lora: - from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training - if IS_MASTER_PROCESS: - print("Creating LoRA for model...") - target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None - modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None - peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=training_run_args.lora_rank, - lora_alpha=training_run_args.lora_alpha, - lora_dropout=training_run_args.lora_dropout, - target_modules=target_modules, - modules_to_save=modules_to_save, - ) - if training_run_args.load_in_8bit or training_run_args.load_in_4bit or training_run_args.load_as_gptq: - model = prepare_model_for_kbit_training( - model, use_gradient_checkpointing=training_run_args.gradient_checkpointing - ) - model = get_peft_model(model, peft_config) - model.enable_input_require_grads() - - model.print_trainable_parameters() - - -base_dir = "loras" if training_run_args.use_lora else "models" -model_dir = f"./{base_dir}/{training_run_args.run_name}" - -training_kwargs = {} - -if training_run_args.test_dataset: - training_kwargs.update({ - "per_device_eval_batch_size": training_run_args.micro_batch_size, - "eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"), - "eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None), - "bf16_full_eval": training_run_args.bf16, - }) - -training_args = TrainingArguments( - per_device_train_batch_size=training_run_args.micro_batch_size, - gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size, - gradient_checkpointing=training_run_args.gradient_checkpointing, - weight_decay=training_run_args.weight_decay, - max_grad_norm=training_run_args.gradient_clip, - save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"), - save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None), - save_safetensors=True, - logging_steps=training_run_args.logging_steps, - output_dir=model_dir, - num_train_epochs=training_run_args.epochs, - save_total_limit=training_run_args.save_total_limit, - report_to='none', - learning_rate=training_run_args.learning_rate, - lr_scheduler_type=training_run_args.learning_rate_schedule, - warmup_ratio=training_run_args.learning_rate_warmup, - log_level="info", - bf16=training_run_args.bf16, - group_by_length=training_run_args.group_by_length, - # include_num_input_tokens_seen=True, - **training_kwargs, -) class DataCollatorForSupervisedFineTuning(object): """Collate examples for supervised fine-tuning.""" @@ -414,14 +289,8 @@ class DataCollatorForSupervisedFineTuning(object): attention_mask=input_ids.ne(self.tokenizer.pad_token_id or self.tokenizer.eos_token_id), ) -if IS_MASTER_PROCESS: - print("Loading dataset...") -data_files = { "train": training_run_args.train_dataset } -if training_run_args.test_dataset: - data_files["test"] = training_run_args.test_dataset -datasets = load_dataset("json", data_files=data_files) -def tokenize_raw_example(batch): +def tokenize_raw_example(batch, tokenizer=None): return tokenizer( text=batch["text"], max_length=training_run_args.ctx_size, @@ -429,7 +298,7 @@ def tokenize_raw_example(batch): add_special_tokens=False, ) -def tokenize_sharegpt_example(batch): +def tokenize_sharegpt_example(batch, tokenizer=None): # TODO: figure out how to properly batch this result = [] for example in batch["conversations"]: @@ -444,7 +313,7 @@ def tokenize_sharegpt_example(batch): return {"input_ids": result} -def template_dpo_example(batch): +def template_dpo_example(batch, tokenizer=None): # TODO: figure out how to properly batch this result = [] for example in zip(batch["system"], batch["question"]): @@ -464,24 +333,6 @@ def template_dpo_example(batch): return {"prompt": result} -training_callbacks = [] -if training_run_args.sync_to_bucket: - training_callbacks.append(UploadToS3Callback( - s3_bucket=training_run_args.sync_to_bucket, - s3_prefix=training_run_args.run_name, - save_total_limit=training_run_args.save_total_limit - )) - -if training_run_args.flops_baseline: - # A100 GPU bfloat16 peak flops is 312 TFLOPS (312e12) - # 4090 GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11) - # 3090 GPU bfloat16 peak flops is 71 TFLOPS (71e12) - - training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline))) - - -# log to tensorboard (but after MFU) -training_callbacks.append(TensorBoardCallback()) class CustomSFTTrainer(Trainer): """Implement different training tweaks""" @@ -515,7 +366,7 @@ class CustomSFTTrainer(Trainer): def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): """ Saw this in the chinchilla paper. It says not to go over 25% overshoot - Should speed up training by skipping the final fine tuning part that doesn't affect accuracy much + Should improve training efficiency by skipping the final fine tuning part that doesn't affect accuracy much """ return super().create_scheduler(int(num_training_steps * self.learning_rate_overshoot), optimizer=optimizer) @@ -524,8 +375,8 @@ class CustomSFTTrainer(Trainer): examples_length = len(inputs["input_ids"][0]) batch_size = len(inputs["input_ids"]) - # mfu is approximated using thoughtput and param count - # the number of paramters is approximately the number of multiply-accumulates (MAC) in the network + # mfu is approximated using throughput and param count + # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param # there are 3 passes of a NN (fwd, bwd, delta) - we multiply by 3 ie 2 * 3 * n_param # this gets us FLOPs / token @@ -539,128 +390,288 @@ class CustomSFTTrainer(Trainer): result = (3 * flops_per_seq + 3 * attn_flops_per_seq) * batch_size return result -if not training_run_args.dpo: - if IS_MASTER_PROCESS: - print("Tokenizing datasets...") - if "text" in datasets["train"].column_names: - tokenize_function = tokenize_raw_example - columns_to_remove = ["text"] - elif "conversations" in datasets["train"].column_names: - tokenize_function = tokenize_sharegpt_example - columns_to_remove = ["conversations"] +def do_training_run(training_run_args: TrainingRunArguments): + # validate args + build model kwargs + if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1: + raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq") + + model_kwargs = {} + if training_run_args.load_in_8bit: + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif training_run_args.load_in_4bit: + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) + elif training_run_args.load_as_gptq: + model_kwargs["quantization_config"] = GPTQConfig(bits=4, disable_exllama=True) + + if training_run_args.bf16: + model_kwargs["torch_dtype"] = torch.bfloat16 + elif training_run_args.use_lora and "quantization_config" not in model_kwargs: + model_kwargs["torch_dtype"] = torch.float16 else: - raise Exception("Unknown dataset input format (not raw corpus or sharegpt)") + # auto detect 'best' format with fallback to fp32 + model_kwargs["torch_dtype"] = "auto" - tokenized_test_dataset = None - num_proc = DATASET_PROCESSING_THREADS // MULTI_GPU_WORLD_SIZE - tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove) - if training_run_args.test_dataset: - tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove) + model_kwargs["resid_pdrop"] = training_run_args.dropout + model_kwargs["use_cache"] = False - example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ] - tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths) - if IS_MASTER_PROCESS: - print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens") - - provided_prefix_ids = None - provided_suffix_ids = None - try: - if training_run_args.prefix_ids: - provided_prefix_ids = [ int(x) for x in training_run_args.prefix_ids.split(",") ] - if training_run_args.suffix_ids: - provided_suffix_ids = [ int(x) for x in training_run_args.suffix_ids.split(",") ] - except ValueError as ex: - print(f"Error parsing prefix_ids or suffix_ids: '{ex}'") - exit(-1) - - data_collator = DataCollatorForSupervisedFineTuning( - tokenizer=tokenizer, - prefix_ids=provided_prefix_ids, - suffix_ids=provided_suffix_ids, + if not IS_DDP_ENABLED: + model_kwargs["device_map"] = "auto" + + # load the model + ddp_print(f"Loading model '{training_run_args.base_model}'...") + + model = AutoModelForCausalLM.from_pretrained( + training_run_args.base_model, + max_memory=find_max_vram(), + token=os.environ.get("HF_TOKEN"), + **model_kwargs ) + tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, token=os.environ.get("HF_TOKEN")) - trainer = CustomSFTTrainer( - model=model, - args=training_args, - train_dataset=tokenized_train_dataset, - eval_dataset=tokenized_test_dataset, - data_collator=data_collator, - callbacks=training_callbacks, - ) -else: - from trl import DPOTrainer - max_prompt_length = 0 + # mess with tokens + prompt template + if training_run_args.add_pad_token: + tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) + model.config.pad_token_id = tokenizer.pad_token_id - train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) }) + if training_run_args.add_chatml_tokens: + tokenizer.add_special_tokens({ + 'bos_token': '<|im_start|>', + 'eos_token': '<|im_end|>' + }) - test_dataset = None - if training_run_args.test_dataset: - test_dataset = datasets["test"] + model.config.bos_token_id = tokenizer.bos_token_id + model.config.eos_token_id = tokenizer.eos_token_id - max_prompt_length = max(train_dataset["prompt_len"]) + if training_run_args.add_chatml_prompt_template: + tokenizer.chat_template = ( + "{% for message in messages %}" + "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" + ) - print("Templating DPO Examples...") - templated_test_dataset = None - templated_train_dataset = train_dataset.map(template_dpo_example, batched=True).remove_columns(["system", "question"]) - if training_run_args.test_dataset: - templated_test_dataset = datasets["test"].map(template_dpo_example, batched=True).remove_columns(["system", "question"]) - - # tokenizer.model_input_names = [ "chosen_input_ids" ] - - # group_by_length doesn't work here - # templated_train_dataset = templated_train_dataset.sort("prompt_len", reverse=True) - - training_args.length_column_name = "prompt_len" - model.enable_input_require_grads() - - trainer = DPOTrainer( - model, - ref_model=None, - # ref_model=original_model, - peft_config=peft_config, - args=training_args, - beta=training_run_args.beta, - loss_type=training_run_args.dpo_loss, - train_dataset=templated_train_dataset, - eval_dataset=templated_test_dataset, - tokenizer=tokenizer, - max_length=training_run_args.ctx_size, - max_prompt_length=max_prompt_length, - truncation_mode="keep_start", - callbacks=training_callbacks, - ) - -try: - checkpoint = training_run_args.resume_from_checkpoint - if checkpoint: - trainer.train(checkpoint) + # resize embeddings if added tokens require it + embeddings_len = math.ceil(len(tokenizer) / 32) * 32 + if model.get_input_embeddings().num_embeddings < embeddings_len: + model.resize_token_embeddings(embeddings_len) else: - trainer.train() + model.tie_weights() + + # create LoRA model if config says so + original_model = model + peft_config = None + if training_run_args.use_lora: + from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training + ddp_print("Creating LoRA for model...") + target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None + modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=training_run_args.lora_rank, + lora_alpha=training_run_args.lora_alpha, + lora_dropout=training_run_args.lora_dropout, + target_modules=target_modules, + modules_to_save=modules_to_save, + ) + if training_run_args.load_in_8bit or training_run_args.load_in_4bit or training_run_args.load_as_gptq: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=training_run_args.gradient_checkpointing + ) + model = get_peft_model(model, peft_config) + model.enable_input_require_grads() + + model.print_trainable_parameters() + + base_dir = "loras" if training_run_args.use_lora else "models" + model_dir = f"./{base_dir}/{training_run_args.run_name}" + + # set up HuggingFace Trainer args + training_kwargs = {} if training_run_args.test_dataset: - trainer.evaluate_all() + training_kwargs.update({ + "per_device_eval_batch_size": training_run_args.micro_batch_size, + "eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"), + "eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None), + "bf16_full_eval": training_run_args.bf16, + }) - if trainer.is_fsdp_enabled: - trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + training_args = TrainingArguments( + per_device_train_batch_size=training_run_args.micro_batch_size, + gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size, + gradient_checkpointing=training_run_args.gradient_checkpointing, + weight_decay=training_run_args.weight_decay, + max_grad_norm=training_run_args.gradient_clip, + save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"), + save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None), + save_safetensors=True, + logging_steps=training_run_args.logging_steps, + output_dir=model_dir, + num_train_epochs=training_run_args.epochs, + save_total_limit=training_run_args.save_total_limit, + report_to='none', + learning_rate=training_run_args.learning_rate, + lr_scheduler_type=training_run_args.learning_rate_schedule, + warmup_ratio=training_run_args.learning_rate_warmup, + log_level="info", + bf16=training_run_args.bf16, + group_by_length=training_run_args.group_by_length, + # include_num_input_tokens_seen=True, + **training_kwargs, + ) - if training_run_args.use_lora and training_run_args.lora_merge: - trainer.save_model() # save lora + # set up trainer callbacks + training_callbacks = [] + if training_run_args.sync_to_bucket: + training_callbacks.append(UploadToS3Callback( + s3_bucket=training_run_args.sync_to_bucket, + s3_prefix=training_run_args.run_name, + save_total_limit=training_run_args.save_total_limit + )) - merged_model = model.merge_and_unload(progressbar=True) - merged_model_dir = f"./models/{training_run_args.run_name}" - merged_model.save_pretrained(merged_model_dir, safe_serialization=True, max_shard_size="2GB") + if training_run_args.flops_baseline: + # A100 40/80GB GPU bfloat16 peak flops is 312 TFLOPS (312e12) + # 4090 24GB GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11) + # A40 48GB GPU bfloat16 peak flops is 149.7 TFLOPS (149.7e11) + # 3090 24GB GPU bfloat16 peak flops is 71 TFLOPS (71e12) + training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline))) + + # log to tensorboard (but after MFU) + training_callbacks.append(TensorBoardCallback()) + + if not training_run_args.dpo: + ddp_print("Loading dataset...") + data_files = { "train": training_run_args.train_dataset } + if training_run_args.test_dataset: + data_files["test"] = training_run_args.test_dataset + datasets = load_dataset("json", data_files=data_files) - tokenizer.save_pretrained(merged_model_dir) - else: - trainer.save_model() - tokenizer.save_pretrained(model_dir) + # prepare the dataset + ddp_print("Tokenizing datasets...") -except Exception as ex: - if trainer.is_fsdp_enabled: - raise ex # this doesn't play nice with FSDP so don't even try + if "text" in datasets["train"].column_names: + tokenize_function = tokenize_raw_example + columns_to_remove = ["text"] + elif "conversations" in datasets["train"].column_names: + tokenize_function = tokenize_sharegpt_example + columns_to_remove = ["conversations"] + else: + raise Exception("Unknown dataset input format (not raw corpus or sharegpt)") + + tokenized_test_dataset = None + num_proc = DATASET_PROCESSING_THREADS // MULTI_GPU_WORLD_SIZE + tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer}).remove_columns(columns_to_remove) + if training_run_args.test_dataset: + tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove) + + example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ] + tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths) + ddp_print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens") + + provided_prefix_ids = None + provided_suffix_ids = None + try: + if training_run_args.prefix_ids: + provided_prefix_ids = [ int(x) for x in training_run_args.prefix_ids.split(",") ] + if training_run_args.suffix_ids: + provided_suffix_ids = [ int(x) for x in training_run_args.suffix_ids.split(",") ] + except ValueError as ex: + print(f"Error parsing prefix_ids or suffix_ids: '{ex}'") + exit(-1) + + trainer = CustomSFTTrainer( + model=model, + args=training_args, + train_dataset=tokenized_train_dataset, + eval_dataset=tokenized_test_dataset, + data_collator=DataCollatorForSupervisedFineTuning( + tokenizer=tokenizer, + prefix_ids=provided_prefix_ids, + suffix_ids=provided_suffix_ids, + ), + callbacks=training_callbacks, + ) + else: + raise NotImplementedError("DPO Trainer doesn't work yet!") + # from trl import DPOTrainer + # max_prompt_length = 0 + + # train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) }) + + # test_dataset = None + # if training_run_args.test_dataset: + # test_dataset = datasets["test"] + + # max_prompt_length = max(train_dataset["prompt_len"]) + + # print("Templating DPO Examples...") + # templated_test_dataset = None + # templated_train_dataset = train_dataset.map(template_dpo_example, batched=True).remove_columns(["system", "question"]) + # if training_run_args.test_dataset: + # templated_test_dataset = datasets["test"].map(template_dpo_example, batched=True).remove_columns(["system", "question"]) + + # # tokenizer.model_input_names = [ "chosen_input_ids" ] + + # # group_by_length doesn't work here + # # templated_train_dataset = templated_train_dataset.sort("prompt_len", reverse=True) + + # training_args.length_column_name = "prompt_len" + # model.enable_input_require_grads() + + # trainer = DPOTrainer( + # model, + # ref_model=None, + # # ref_model=original_model, + # peft_config=peft_config, + # args=training_args, + # beta=training_run_args.beta, + # loss_type=training_run_args.dpo_loss, + # train_dataset=templated_train_dataset, + # eval_dataset=templated_test_dataset, + # tokenizer=tokenizer, + # max_length=training_run_args.ctx_size, + # max_prompt_length=max_prompt_length, + # truncation_mode="keep_start", + # callbacks=training_callbacks, + # ) + + try: + trainer.train(resume_from_checkpoint=training_run_args.resume_from_checkpoint) + + if training_run_args.test_dataset: + trainer.evaluate_all() + + if trainer.is_fsdp_enabled: + trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + + if training_run_args.use_lora and training_run_args.lora_merge: + trainer.save_model() # save lora + + merged_model = model.merge_and_unload(progressbar=True) + merged_model_dir = f"./models/{training_run_args.run_name}" + merged_model.save_pretrained(merged_model_dir, safe_serialization=True, max_shard_size="2GB") + + tokenizer.save_pretrained(merged_model_dir) + else: + trainer.save_model() + tokenizer.save_pretrained(model_dir) + + except Exception as ex: + if trainer.is_fsdp_enabled: + raise ex # this doesn't play nice with FSDP so don't even try + + if input("Something bad happened! Try and save it? (Y/n)").lower().startswith("y"): + trainer._save_checkpoint(model, None) + print("Saved Checkpoint!") + + raise ex - print("Something bad happened! Try and save it?") - import code, traceback - traceback.print_exc() - code.interact(local=locals()) \ No newline at end of file +if __name__ == "__main__": + parser = HfArgumentParser([TrainingRunArguments]) + training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + do_training_run(training_run_args)