mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 14:18:00 -05:00
add mfu callback
This commit is contained in:
53
train.py
53
train.py
@@ -5,11 +5,12 @@ import torch
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, \
|
||||
HfArgumentParser, GPTQConfig
|
||||
from transformers.integrations.integration_utils import TensorBoardCallback
|
||||
from trl import DPOTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
from train_util import TrainingRunArguments, DataCollatorForSupervisedFineTuning, CustomSFTTrainer, \
|
||||
UploadToS3Callback
|
||||
UploadToS3Callback, MFUCallback
|
||||
|
||||
"""
|
||||
Phi Modules:
|
||||
@@ -86,7 +87,7 @@ accelerate launch --config_file fsdp_config.yaml train.py \
|
||||
--learning_rate 1e-5 --batch_size 64 --epochs 1 \
|
||||
--micro_batch_size 2 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 \
|
||||
--save_steps 50 --save_total_limit 5 --eval_steps 100 --logging_steps 2
|
||||
--save_steps 50 --save_total_limit 10 --eval_steps 100 --logging_steps 2
|
||||
"""
|
||||
|
||||
"""
|
||||
@@ -149,10 +150,12 @@ def find_max_vram(min_buffer_mib=800):
|
||||
|
||||
return max_memory
|
||||
|
||||
if torch.cuda.device_count() == 1:
|
||||
model_kwargs["device_map"] = "auto"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
training_run_args.base_model,
|
||||
trust_remote_code=True,
|
||||
# device_map="auto",
|
||||
max_memory=find_max_vram(),
|
||||
**model_kwargs
|
||||
)
|
||||
@@ -189,6 +192,8 @@ else:
|
||||
|
||||
# model.tie_weights()
|
||||
|
||||
original_model = model
|
||||
peft_config = None
|
||||
if training_run_args.use_lora:
|
||||
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
||||
print("Creating LoRA for model...")
|
||||
@@ -239,13 +244,14 @@ training_args = TrainingArguments(
|
||||
output_dir=model_dir,
|
||||
num_train_epochs=training_run_args.epochs,
|
||||
save_total_limit=training_run_args.save_total_limit,
|
||||
report_to="tensorboard",
|
||||
report_to='none',
|
||||
learning_rate=training_run_args.learning_rate,
|
||||
lr_scheduler_type=training_run_args.learning_rate_schedule,
|
||||
warmup_ratio=training_run_args.learning_rate_warmup,
|
||||
log_level="info",
|
||||
bf16=training_run_args.bf16,
|
||||
group_by_length=training_run_args.group_by_length,
|
||||
include_num_input_tokens_seen=True,
|
||||
**training_kwargs,
|
||||
)
|
||||
|
||||
@@ -284,6 +290,17 @@ if training_run_args.sync_to_bucket:
|
||||
save_total_limit=training_run_args.save_total_limit
|
||||
))
|
||||
|
||||
if training_run_args.flops_baseline:
|
||||
# A100 GPU bfloat16 peak flops is 312 TFLOPS (312e12)
|
||||
# 4090 GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11)
|
||||
# 3090 GPU bfloat16 peak flops is 71 TFLOPS (71e12)
|
||||
|
||||
training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline)))
|
||||
|
||||
|
||||
# log to tensorboard (but after MFU)
|
||||
training_callbacks.append(TensorBoardCallback())
|
||||
|
||||
if not training_run_args.dpo:
|
||||
print("Tokenizing datasets...")
|
||||
|
||||
@@ -325,35 +342,29 @@ else:
|
||||
max_prompt_length = 0
|
||||
|
||||
train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) })
|
||||
test_dataset = datasets["test"]
|
||||
|
||||
test_dataset = None
|
||||
if training_run_args.test_dataset:
|
||||
test_dataset = datasets["test"]
|
||||
|
||||
max_prompt_length = max(train_dataset["prompt_len"])
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=original_model,
|
||||
peft_config=peft_config,
|
||||
args=training_args,
|
||||
beta=training_run_args.beta,
|
||||
train_dataset=datasets["train"],
|
||||
eval_dataset=datasets["test"],
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=test_dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_run_args.ctx_size,
|
||||
max_prompt_length=max_prompt_length,
|
||||
generate_during_eval=True,
|
||||
truncation_mode="keep_start",
|
||||
callbacks=training_callbacks,
|
||||
)
|
||||
|
||||
# pre-allocate cuda buffers by running a forwards and backwards pass with the largest possible example length
|
||||
# the trainer dumps the cuda buffers before we start... need to figure out how to disable that
|
||||
# if training_run_args.pre_allocate_cuda_buffers:
|
||||
# print("Allocating CUDA buffers...")
|
||||
# inputs = tokenizer([""] * training_args.per_device_train_batch_size, return_tensors="pt", max_length=longest_example, padding="max_length", truncation=True)
|
||||
# inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
# inputs["labels"] = inputs["input_ids"]
|
||||
# outputs = model(**inputs)
|
||||
# loss = outputs.loss if isinstance(outputs, dict) else outputs[0]
|
||||
# loss.backward()
|
||||
# model.zero_grad()
|
||||
|
||||
try:
|
||||
checkpoint = training_run_args.resume_from_checkpoint
|
||||
if checkpoint:
|
||||
@@ -361,7 +372,7 @@ try:
|
||||
else:
|
||||
trainer.train()
|
||||
|
||||
if training_run_args.train_dataset:
|
||||
if training_run_args.test_dataset:
|
||||
trainer.evaluate_all()
|
||||
|
||||
if training_run_args.use_lora and training_run_args.lora_merge:
|
||||
@@ -377,7 +388,7 @@ try:
|
||||
tokenizer.save_pretrained(model_dir)
|
||||
|
||||
except Exception as ex:
|
||||
if len(torch.cuda.device_count()) > 1:
|
||||
if torch.cuda.device_count() > 1:
|
||||
raise ex # this doesn't play nice with FSDP so don't even try
|
||||
|
||||
print("Something bad happened! Try and save it?")
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import copy
|
||||
import time
|
||||
import torch
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from torch.utils.data import SequentialSampler, Subset, RandomSampler
|
||||
from transformers import TrainerCallback, AutoTokenizer, Trainer
|
||||
from transformers import TrainerCallback, AutoTokenizer, Trainer, AutoModelForCausalLM, \
|
||||
TrainerControl, TrainerState
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Sequence
|
||||
from calflops import calculate_flops
|
||||
|
||||
import boto3
|
||||
import os
|
||||
import shutil
|
||||
|
||||
@dataclass
|
||||
class TrainingRunArguments:
|
||||
@@ -25,6 +26,7 @@ class TrainingRunArguments:
|
||||
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
|
||||
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
|
||||
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
|
||||
learning_rate_warmup: float = field(default=0.0, metadata={"help": "The starting learning rate (speed at which the model trains)"})
|
||||
weight_decay: float = field(default=0.1, metadata={"help": ""})
|
||||
gradient_clip: float = field(default=1.0, metadata={"help": ""})
|
||||
resume_from_checkpoint: str = field(default="", metadata={"help": "The name of the checkpoint to resume training from"})
|
||||
@@ -58,6 +60,7 @@ class TrainingRunArguments:
|
||||
gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing which saves quite a lot of VRAM"})
|
||||
|
||||
sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
|
||||
flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
|
||||
|
||||
class DataCollatorForSupervisedFineTuning(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
@@ -194,15 +197,36 @@ class CustomSFTTrainer(Trainer):
|
||||
Should speed up training by skipping the final fine tuning part that doesn't affect accuracy much
|
||||
"""
|
||||
return super().create_scheduler(int(num_training_steps * self.learning_rate_overshoot), optimizer=optimizer)
|
||||
|
||||
def floating_point_ops(self, inputs):
|
||||
config = self.model.config
|
||||
examples_length = len(inputs["input_ids"][0])
|
||||
batch_size = len(inputs["input_ids"])
|
||||
|
||||
# mfu is approximated using thoughtput and param count
|
||||
# the number of paramters is approximately the number of multiply-accumulates (MAC) in the network
|
||||
# each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
|
||||
# there are 3 passes of a NN (fwd, bwd, delta) - we multiply by 3 ie 2 * 3 * n_param
|
||||
# this gets us FLOPs / token
|
||||
flops_per_token = 2 * sum(p.numel() for p in self.model.parameters())
|
||||
flops_per_seq = flops_per_token * examples_length
|
||||
|
||||
# there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
|
||||
attn_flops_per_seq = config.num_hidden_layers * 2 * 2 * (config.hidden_size * (examples_length**2))
|
||||
|
||||
# there are 2 ops in bwd pass and 1 in fwd pass so we mult by 3
|
||||
result = (3 * flops_per_seq + 3 * attn_flops_per_seq) * batch_size
|
||||
return result
|
||||
|
||||
class UploadToS3Callback(TrainerCallback):
|
||||
def __init__(self, s3_bucket, s3_prefix, save_total_limit=None):
|
||||
import boto3
|
||||
self.s3_client = boto3.client('s3')
|
||||
self.s3_bucket = s3_bucket
|
||||
self.s3_prefix = s3_prefix
|
||||
self.save_total_limit = save_total_limit
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
def on_save(self, args, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
output_dir = kwargs['output_dir']
|
||||
checkpoint = os.path.basename(output_dir)
|
||||
|
||||
@@ -241,4 +265,26 @@ class UploadToS3Callback(TrainerCallback):
|
||||
resp = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=os.path.join(self.s3_prefix, checkpoint_name))
|
||||
for obj in resp.get('Contents', []):
|
||||
self.s3_client.delete_object(Bucket=self.s3_bucket, Key=obj['Key'])
|
||||
print(f"Deleted s3://{self.s3_bucket}/{obj['Key']}")
|
||||
print(f"Deleted s3://{self.s3_bucket}/{obj['Key']}")
|
||||
|
||||
class MFUCallback(TrainerCallback):
|
||||
def __init__(self, peak_flops):
|
||||
self.total_iterations = 0
|
||||
self.start_time = time.time()
|
||||
self.flops_promised = peak_flops
|
||||
self.last_total_flos = 0
|
||||
|
||||
def on_log(self, args, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
if state.global_step == 0: # Avoid computation at the very beginning
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - self.start_time
|
||||
|
||||
# Calculate and log MFU
|
||||
new_flops = state.total_flos - self.last_total_flos
|
||||
kwargs['logs']['mfu'] = round(new_flops / elapsed_time / self.flops_promised, 4)
|
||||
|
||||
self.start_time = current_time
|
||||
self.last_total_flos = state.total_flos
|
||||
|
||||
Reference in New Issue
Block a user