#!/usr/bin/env python3 import math import copy import torch import os import random import time import traceback from torch.utils.data import SequentialSampler, Subset, RandomSampler from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, \ HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback, BitsAndBytesConfig from transformers.integrations.integration_utils import TensorBoardCallback from datasets import load_dataset from dataclasses import dataclass, field from typing import Dict, Optional, Sequence IS_DDP_ENABLED = "LOCAL_RANK" in os.environ MULTI_GPU_WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1")) MULTI_GPU_RANK = int(os.environ.get("RANK", "0")) IS_MULTI_GPU = os.environ.get("RANK") != None IS_MASTER_PROCESS = MULTI_GPU_RANK == 0 @dataclass class TrainingRunArguments: run_name: str = field(metadata={"help": "The folder to save the output model under"}) base_model: str = field(metadata={"help": "The base model to load for fine-tuning"}) train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"}) test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"}) dataset_processing_threads: int = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"}) ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"}) bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"}) batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"}) micro_batch_size: int = field(default=2, metadata={"help": "The actual batch size that will fit into VRAM on this machine"}) epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"}) learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"}) learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"}) learning_rate_warmup: float = field(default=0.0, metadata={"help": "The starting learning rate (speed at which the model trains)"}) weight_decay: float = field(default=0.1, metadata={"help": "Weight Decay rate for regularization. Rate to reduce all neuron weights towards zero."}) # dropout: float = field(default=0.01, metadata={"help": "Dropout percent for regularization. Determines the fraction of neurons randomly deactivated during training."}) gradient_clip: float = field(default=1.0, metadata={"help": "Maximum gradient norm for clipping to prevent exploding gradients during training."}) resume_from_checkpoint: str = field(default="", metadata={"help": "The name of the checkpoint to resume training from"}) eval_steps: int = field(default=200, metadata={"help": "The number of steps in between evaluations of the model; set to -1 to evaluate every epoch"}) save_steps: int = field(default=-1, metadata={"help": "The number of steps in between model checkpoints; set to -1 to save every epoch"}) save_total_limit: int = field(default=1, metadata={"help": "The number of recent checkpoints of the model to save (not including the final model)"}) logging_steps: int = field(default=5, metadata={"help": "Sets the number of steps in between log output for the training run"}) group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding. Runs from longest to shortest examples."}) gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing to saves VRAM at the cost of re-computing activations during the backwards pass"}) pre_allocate_cuda_buffers: bool = field(default=True, metadata={"help": "If enabled, runs a forward and backward pass on the model before training to force pytorch to allocate the correct size CUDA buffers up front"}) # Quantization load_in_8bit: bool = field(default=False, metadata={"help": "Set to load the base model in 8-bit mode using bitsandbytes"}) load_in_4bit: bool = field(default=False, metadata={"help": "Set to load the base model in 4-bit mode using bitsandbytes"}) load_as_gptq: bool = field(default=False, metadata={"help": "Set to load the base model as a GPTQ using AutoGPTQ"}) # lora config use_lora: bool = field(default=False, metadata={"help": "If set, then the trained model will be a LoRA"}) lora_rank: int = field(default=4, metadata={"help": "Rank which determines LoRA matrix size. Rank typically starts at 8 but can go up to 256. Higher ranks can store more information but increase the computational and memory cost of LoRA."}) lora_alpha: int = field(default=32, metadata={"help": "Alpha a scaling factor for updates. Alpha directly impacts the adapters contribution and is often set to 1x or 2x the rank value."}) lora_dropout: float = field(default=0.05) lora_modules: str = field(default=None, metadata={"help": "Target modules: LoRA can be applied to various model components, including attention mechanisms (Q, K, V matrices), output projections, feed-forward blocks, and linear output layers. While initially focused on attention mechanisms, extending LoRA to other components has shown benefits. However, adapting more modules increases the number of trainable parameters and memory needs."}) lora_modules_to_save: str = field(default=None, metadata={"help": "Additional modules to save"}) lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"}) # dpo config dpo: bool = field(default=False, metadata={"help": "If set, performs Direct Preference Optimization instead of Supervised Fine Tuning"}) beta: float = field(default=0.1, metadata={"help": "The implicit reward value used during DPO training"}) dpo_loss: str = field(default="sigmoid", metadata={"help": "The loss type to use during DPO training"}) # token options add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"}) add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"}) add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"}) prefix_ids: str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) suffix_ids: str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) # custom trainer tweaks sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"}) bucket_save_limit: int = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"}) flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"}) class UploadToS3Callback(TrainerCallback): def __init__(self, s3_bucket, s3_prefix, save_total_limit=None): import boto3 self.s3_client = boto3.client('s3') self.s3_bucket = s3_bucket self.s3_prefix = s3_prefix self.save_total_limit = save_total_limit def on_save(self, args, state, control, **kwargs): # Upload current checkpoint checkpoint = f"checkpoint-{state.global_step}" output_dir = f"{args.output_dir}/{checkpoint}" for root, dirs, files in os.walk(output_dir): for file in files: local_path = os.path.join(root, file) s3_path = os.path.join(self.s3_prefix, checkpoint, os.path.relpath(local_path, start=output_dir)) self.s3_client.upload_file(local_path, self.s3_bucket, s3_path) print(f"Uploaded {local_path} to s3://{self.s3_bucket}/{s3_path}") # Delete prior checkpoints from S3 if self.save_total_limit: s3_checkpoints = self.list_s3_checkpoints() if len(s3_checkpoints) > self.save_total_limit: sorted_checkpoints = sorted(s3_checkpoints) to_delete = sorted_checkpoints[:-self.save_total_limit] for checkpoint in to_delete: self.delete_checkpoint_from_s3(checkpoint) def list_s3_checkpoints(self): paginator = self.s3_client.get_paginator('list_objects_v2') page_iterator = paginator.paginate(Bucket=self.s3_bucket, Prefix=self.s3_prefix + '/', Delimiter='/') return [prefix.get('Prefix').rstrip('/').split('/')[-1] for page in page_iterator for prefix in page.get('CommonPrefixes', [])] def delete_checkpoint_from_s3(self, checkpoint_name): resp = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=os.path.join(self.s3_prefix, checkpoint_name)) for obj in resp.get('Contents', []): self.s3_client.delete_object(Bucket=self.s3_bucket, Key=obj['Key']) print(f"Deleted s3://{self.s3_bucket}/{obj['Key']}") class MFUCallback(TrainerCallback): def __init__(self, peak_flops): self.total_iterations = 0 self.start_time = time.time() self.flops_promised = peak_flops self.last_total_flos = 0 def on_log(self, args, state, control, **kwargs): if state.global_step == 0: # Avoid computation at the very beginning return current_time = time.time() elapsed_time = current_time - self.start_time # Calculate and log MFU new_flops = state.total_flos - self.last_total_flos kwargs['logs']['mfu'] = round(new_flops / elapsed_time / self.flops_promised, 4) self.start_time = current_time self.last_total_flos = state.total_flos def ddp_print(*args, **kwargs): if not IS_DDP_ENABLED or IS_MASTER_PROCESS: print(*args, **kwargs) def find_max_vram(min_buffer_mib=800): max_memory = {} for i in range(torch.cuda.device_count()): gpu_properties = torch.cuda.get_device_properties(i) total_memory_mib = (gpu_properties.total_memory / (1000 * 1000)) suggestion = max(total_memory_mib - 1000, min_buffer_mib) ddp_print(f"GPU {i}: {gpu_properties.name}, Total Memory: {gpu_properties.total_memory / (1024**3):.2f} GB") ddp_print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}") max_memory[i] = f'{suggestion}MiB' return max_memory class DataCollatorForSupervisedFineTuning(object): """Collate examples for supervised fine-tuning.""" tokenizer: AutoTokenizer prompt_split: str response_prefix: str response_suffix: str prefix_ids: list[int] suffix_ids: list[int] def __init__(self, *, tokenizer: AutoTokenizer, prefix_ids: Optional[list[int]] = None, suffix_ids: Optional[list[int]] = None): self.tokenizer = tokenizer if not prefix_ids and not suffix_ids: assistant_prompt = tokenizer.apply_chat_template( conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%") self.response_prefix = assistant_prompt[0] self.response_suffix = assistant_prompt[1] # check for inserted system prompt and remove it if tokenizer.eos_token in self.response_prefix: self.response_prefix = self.response_prefix.split(tokenizer.eos_token)[1].lstrip() # some chat templates ALWAYS add the bos token if tokenizer.bos_token in self.response_prefix: self.response_prefix = self.response_prefix.replace(tokenizer.bos_token, "") if prefix_ids: self.prefix_ids = prefix_ids else: self.prefix_ids = self.tokenizer(self.response_prefix, add_special_tokens=False)["input_ids"] if suffix_ids: self.suffix_ids = suffix_ids else: self.suffix_ids = self.tokenizer(self.response_suffix, add_special_tokens=False)["input_ids"] def _find_mask_ranges(self, input_ids): """ Returns a mask that blocks out everything but the response from the assistant The mask does NOT include the response_prefix but DOES include the response_suffix. The resulting behavior is the model uses the prefix as a prompt and the suffix as the end of text token """ ranges = [] i = 0 while i < len(input_ids): try: # Find the start index of the prefix start_idx = input_ids.index(self.prefix_ids[0], i) except ValueError: break # Check if the entire prefix is present if input_ids[start_idx:start_idx + len(self.prefix_ids)] == self.prefix_ids: end_prefix_idx = start_idx + len(self.prefix_ids) start_response_idx = end_prefix_idx + 1 # Find the start index of the suffix try: # Find the start index of the suffix suffix_start_idx = input_ids.index(self.suffix_ids[0], end_prefix_idx) except ValueError: ranges.append((start_response_idx, len(input_ids))) break # Check if the entire suffix is present if input_ids[suffix_start_idx:suffix_start_idx + len(self.suffix_ids)] == self.suffix_ids: ranges.append((start_response_idx, suffix_start_idx)) i = suffix_start_idx + len(self.suffix_ids) else: i = suffix_start_idx + 1 else: i = start_idx + 1 inverse_ranges = [] current = 0 for start, end in sorted(ranges): if start > current: inverse_ranges.append((current, start - 1)) current = max(current, end + 1) if current < len(input_ids): inverse_ranges.append((current, len(input_ids) - 1)) return inverse_ranges def _pad(self, examples, pad_value): longest = max([len(ex) for ex in examples]) result = [] for example in examples: cur_len = len(example) result.append(example + [pad_value] * (longest - cur_len)) return result def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids = [instance["input_ids"] for instance in instances] labels = copy.deepcopy(input_ids) for label in labels: mask_ranges = self._find_mask_ranges(label) for start, end in mask_ranges: if end - start == len(label) - 1: print("warning! example had no assistant response in it!") print(input_ids) label[start:end] = [-100] * (end - start) input_ids = torch.LongTensor(self._pad(input_ids, self.tokenizer.pad_token_id or self.tokenizer.eos_token_id)) labels = torch.LongTensor(self._pad(labels, -100)) return dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id or self.tokenizer.eos_token_id), ) def tokenize_raw_example(batch, tokenizer=None, training_run_args=None): return tokenizer( text=batch["text"], max_length=training_run_args.ctx_size, truncation=True, add_special_tokens=False, ) def tokenize_sharegpt_example(batch, tokenizer=None, training_run_args=None): # TODO: figure out how to properly batch this result = [] for example in batch["conversations"]: conversation = [ { "role": x["from"], "content": x["value"] } for x in example ] result.append( tokenizer.apply_chat_template( conversation=conversation, max_length=training_run_args.ctx_size, truncation=True, ) ) return {"input_ids": result} def template_dpo_example(batch, tokenizer=None, training_run_args=None): # TODO: figure out how to properly batch this result = [] for example in zip(batch["system"], batch["question"]): conversation = [ { "role": "system", "content": example[0] }, { "role": "user", "content": example[1] }, ] result.append( tokenizer.apply_chat_template( conversation=conversation, max_length=training_run_args.ctx_size, truncation=True, tokenize=False, add_generation_prompt=True ) ) return {"prompt": result} class CustomSFTTrainer(Trainer): """Implement different training tweaks""" def __init__(self, random_eval_sample_pct=0.1, learning_rate_overshoot=1.15, *args, **kwargs): super().__init__(*args, **kwargs) self.random_eval_sample_pct = random_eval_sample_pct self.evaluate_full_dataset = False self.learning_rate_overshoot = learning_rate_overshoot def evaluate_all(self): self.evaluate_full_dataset = True super().evaluate() self.evaluate_full_dataset = False # Randomly sample the eval dataset def _get_eval_sampler(self, eval_dataset): if self.evaluate_full_dataset: return SequentialSampler(eval_dataset) else: num_samples = int(self.random_eval_sample_pct * len(eval_dataset)) random_indices = random.sample(range(len(eval_dataset)), num_samples) subset_eval_dataset = Subset(eval_dataset, random_indices) return SequentialSampler(subset_eval_dataset) def _get_train_sampler(self): if self.args.group_by_length: return super()._get_train_sampler() return RandomSampler(self.train_dataset, generator=torch.Generator(device='cpu')) def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): """ Saw this in the chinchilla paper. It says not to go over 25% overshoot Should improve training efficiency by skipping the final fine tuning part that doesn't affect accuracy much """ return super().create_scheduler(int(num_training_steps * self.learning_rate_overshoot), optimizer=optimizer) def floating_point_ops(self, inputs): config = self.model.config examples_length = len(inputs["input_ids"][0]) batch_size = len(inputs["input_ids"]) # mfu is approximated using throughput and param count # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param # there are 3 passes of a NN (fwd, bwd, delta) - we multiply by 3 ie 2 * 3 * n_param # this gets us FLOPs / token flops_per_token = 2 * sum(p.numel() for p in self.model.parameters()) flops_per_seq = flops_per_token * examples_length # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2) attn_flops_per_seq = config.num_hidden_layers * 2 * 2 * (config.hidden_size * (examples_length**2)) # there are 2 ops in bwd pass and 1 in fwd pass so we mult by 3 result = (3 * flops_per_seq + 3 * attn_flops_per_seq) * batch_size return result def do_training_run(training_run_args: TrainingRunArguments): # validate args + build model kwargs if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1: raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq") model_kwargs = {} if training_run_args.load_in_8bit: model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) elif training_run_args.load_in_4bit: model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) elif training_run_args.load_as_gptq: model_kwargs["quantization_config"] = GPTQConfig(bits=4, disable_exllama=True) if training_run_args.bf16: model_kwargs["torch_dtype"] = torch.bfloat16 elif training_run_args.use_lora and "quantization_config" not in model_kwargs: model_kwargs["torch_dtype"] = torch.float16 else: # auto detect 'best' format with fallback to fp32 model_kwargs["torch_dtype"] = "auto" # model_kwargs["resid_pdrop"] = training_run_args.dropout model_kwargs["use_cache"] = False if not IS_DDP_ENABLED: model_kwargs["device_map"] = "auto" # load the model ddp_print(f"Loading model '{training_run_args.base_model}'...") model = AutoModelForCausalLM.from_pretrained( training_run_args.base_model, max_memory=find_max_vram(), token=os.environ.get("HF_TOKEN"), **model_kwargs ) tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, token=os.environ.get("HF_TOKEN")) # mess with tokens + prompt template if training_run_args.add_pad_token: tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) model.config.pad_token_id = tokenizer.pad_token_id if training_run_args.add_chatml_tokens: tokenizer.add_special_tokens({ 'bos_token': '<|im_start|>', 'eos_token': '<|im_end|>' }) model.config.bos_token_id = tokenizer.bos_token_id model.config.eos_token_id = tokenizer.eos_token_id if training_run_args.add_chatml_prompt_template: tokenizer.chat_template = ( "{% for message in messages %}" "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" "{% endfor %}" "{% if add_generation_prompt %}" "{{ '<|im_start|>assistant\n' }}" "{% endif %}" ) # resize embeddings if added tokens require it embeddings_len = math.ceil(len(tokenizer) / 32) * 32 if model.get_input_embeddings().num_embeddings < embeddings_len: model.resize_token_embeddings(embeddings_len) else: model.tie_weights() # create LoRA model if config says so original_model = model peft_config = None if training_run_args.use_lora: from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training ddp_print("Creating LoRA for model...") target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=training_run_args.lora_rank, lora_alpha=training_run_args.lora_alpha, lora_dropout=training_run_args.lora_dropout, target_modules=target_modules, modules_to_save=modules_to_save, ) if training_run_args.load_in_8bit or training_run_args.load_in_4bit or training_run_args.load_as_gptq: model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=training_run_args.gradient_checkpointing ) model = get_peft_model(model, peft_config) model.enable_input_require_grads() model.print_trainable_parameters() base_dir = "loras" if training_run_args.use_lora else "models" model_dir = f"./{base_dir}/{training_run_args.run_name}" # set up HuggingFace Trainer args training_kwargs = {} if training_run_args.test_dataset: training_kwargs.update({ "per_device_eval_batch_size": training_run_args.micro_batch_size, "eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"), "eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None), "bf16_full_eval": training_run_args.bf16, }) training_args = TrainingArguments( per_device_train_batch_size=training_run_args.micro_batch_size, gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size, gradient_checkpointing=training_run_args.gradient_checkpointing, weight_decay=training_run_args.weight_decay, max_grad_norm=training_run_args.gradient_clip, save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"), save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None), save_safetensors=True, logging_steps=training_run_args.logging_steps, output_dir=model_dir, num_train_epochs=training_run_args.epochs, save_total_limit=training_run_args.save_total_limit, report_to='none', learning_rate=training_run_args.learning_rate, lr_scheduler_type=training_run_args.learning_rate_schedule, warmup_ratio=training_run_args.learning_rate_warmup, log_level="info", bf16=training_run_args.bf16, group_by_length=training_run_args.group_by_length, # include_num_input_tokens_seen=True, **training_kwargs, ) # set up trainer callbacks training_callbacks = [] if training_run_args.sync_to_bucket: training_callbacks.append(UploadToS3Callback( s3_bucket=training_run_args.sync_to_bucket, s3_prefix=training_run_args.run_name, save_total_limit=training_run_args.bucket_save_limit if training_run_args.bucket_save_limit else training_run_args.save_total_limit )) if training_run_args.flops_baseline: # A100 40/80GB GPU bfloat16 peak flops is 312 TFLOPS (312e12) # 4090 24GB GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11) # A40 48GB GPU bfloat16 peak flops is 149.7 TFLOPS (149.7e11) # 3090 24GB GPU bfloat16 peak flops is 71 TFLOPS (71e12) training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline))) # log to tensorboard (but after MFU) training_callbacks.append(TensorBoardCallback()) if not training_run_args.dpo: ddp_print("Loading dataset...") data_files = { "train": training_run_args.train_dataset } if training_run_args.test_dataset: data_files["test"] = training_run_args.test_dataset datasets = load_dataset("json", data_files=data_files) # prepare the dataset ddp_print("Tokenizing datasets...") if "text" in datasets["train"].column_names: tokenize_function = tokenize_raw_example columns_to_remove = ["text"] elif "conversations" in datasets["train"].column_names: tokenize_function = tokenize_sharegpt_example columns_to_remove = ["conversations"] else: raise Exception("Unknown dataset input format (not raw corpus or sharegpt)") tokenized_test_dataset = None num_proc = None if training_run_args.dataset_processing_threads: num_proc = training_run_args.dataset_processing_threads // MULTI_GPU_WORLD_SIZE tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove) if training_run_args.test_dataset: tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove) example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ] tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths) ddp_print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens") provided_prefix_ids = None provided_suffix_ids = None try: if training_run_args.prefix_ids: provided_prefix_ids = [ int(x) for x in training_run_args.prefix_ids.split(",") ] if training_run_args.suffix_ids: provided_suffix_ids = [ int(x) for x in training_run_args.suffix_ids.split(",") ] except ValueError as ex: print(f"Error parsing prefix_ids or suffix_ids: '{ex}'") exit(-1) trainer = CustomSFTTrainer( model=model, args=training_args, train_dataset=tokenized_train_dataset, eval_dataset=tokenized_test_dataset, data_collator=DataCollatorForSupervisedFineTuning( tokenizer=tokenizer, prefix_ids=provided_prefix_ids, suffix_ids=provided_suffix_ids, ), callbacks=training_callbacks, ) else: raise NotImplementedError("DPO Trainer doesn't work yet!") # from trl import DPOTrainer # max_prompt_length = 0 # train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) }) # test_dataset = None # if training_run_args.test_dataset: # test_dataset = datasets["test"] # max_prompt_length = max(train_dataset["prompt_len"]) # print("Templating DPO Examples...") # templated_test_dataset = None # templated_train_dataset = train_dataset.map(template_dpo_example, batched=True).remove_columns(["system", "question"]) # if training_run_args.test_dataset: # templated_test_dataset = datasets["test"].map(template_dpo_example, batched=True).remove_columns(["system", "question"]) # # tokenizer.model_input_names = [ "chosen_input_ids" ] # # group_by_length doesn't work here # # templated_train_dataset = templated_train_dataset.sort("prompt_len", reverse=True) # training_args.length_column_name = "prompt_len" # model.enable_input_require_grads() # trainer = DPOTrainer( # model, # ref_model=None, # # ref_model=original_model, # peft_config=peft_config, # args=training_args, # beta=training_run_args.beta, # loss_type=training_run_args.dpo_loss, # train_dataset=templated_train_dataset, # eval_dataset=templated_test_dataset, # tokenizer=tokenizer, # max_length=training_run_args.ctx_size, # max_prompt_length=max_prompt_length, # truncation_mode="keep_start", # callbacks=training_callbacks, # ) try: trainer.train(resume_from_checkpoint=training_run_args.resume_from_checkpoint if training_run_args.resume_from_checkpoint else None) if training_run_args.test_dataset: trainer.evaluate_all() if trainer.is_fsdp_enabled: trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") if training_run_args.use_lora and training_run_args.lora_merge: trainer.save_model() # save lora merged_model = model.merge_and_unload(progressbar=True) merged_model_dir = f"./models/{training_run_args.run_name}" merged_model.save_pretrained(merged_model_dir, safe_serialization=True, max_shard_size="2GB") tokenizer.save_pretrained(merged_model_dir) else: trainer.save_model() tokenizer.save_pretrained(model_dir) if training_run_args.sync_to_bucket: import boto3 s3_client = boto3.client('s3') for root, dirs, files in os.walk(model_dir): for file in files: local_path = os.path.join(root, file) s3_path = os.path.join(training_run_args.run_name, os.path.relpath(local_path, start=".")) s3_client.upload_file(local_path, training_run_args.sync_to_bucket, s3_path) print(f"Uploaded {local_path} to s3://{training_run_args.sync_to_bucket}/{s3_path}") except Exception as ex: if trainer.is_fsdp_enabled: raise ex # this doesn't play nice with FSDP so don't even try traceback.print_exc() if input("Something bad happened! Try and save it? (Y/n) ").lower().startswith("y"): trainer._save_checkpoint(model, None) print("Saved Checkpoint!") exit(-1) if __name__ == "__main__": parser = HfArgumentParser([TrainingRunArguments]) training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) do_training_run(training_run_args)