add mfu callback

This commit is contained in:
Alex O'Connell
2024-04-15 17:39:41 -04:00
parent b41742b9fb
commit 644a326c0f
2 changed files with 84 additions and 27 deletions

View File

@@ -5,11 +5,12 @@ import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, \
HfArgumentParser, GPTQConfig
from transformers.integrations.integration_utils import TensorBoardCallback
from trl import DPOTrainer
from datasets import load_dataset
from train_util import TrainingRunArguments, DataCollatorForSupervisedFineTuning, CustomSFTTrainer, \
UploadToS3Callback
UploadToS3Callback, MFUCallback
"""
Phi Modules:
@@ -86,7 +87,7 @@ accelerate launch --config_file fsdp_config.yaml train.py \
--learning_rate 1e-5 --batch_size 64 --epochs 1 \
--micro_batch_size 2 --gradient_checkpointing --group_by_length \
--ctx_size 2048 \
--save_steps 50 --save_total_limit 5 --eval_steps 100 --logging_steps 2
--save_steps 50 --save_total_limit 10 --eval_steps 100 --logging_steps 2
"""
"""
@@ -149,10 +150,12 @@ def find_max_vram(min_buffer_mib=800):
return max_memory
if torch.cuda.device_count() == 1:
model_kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(
training_run_args.base_model,
trust_remote_code=True,
# device_map="auto",
max_memory=find_max_vram(),
**model_kwargs
)
@@ -189,6 +192,8 @@ else:
# model.tie_weights()
original_model = model
peft_config = None
if training_run_args.use_lora:
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
print("Creating LoRA for model...")
@@ -239,13 +244,14 @@ training_args = TrainingArguments(
output_dir=model_dir,
num_train_epochs=training_run_args.epochs,
save_total_limit=training_run_args.save_total_limit,
report_to="tensorboard",
report_to='none',
learning_rate=training_run_args.learning_rate,
lr_scheduler_type=training_run_args.learning_rate_schedule,
warmup_ratio=training_run_args.learning_rate_warmup,
log_level="info",
bf16=training_run_args.bf16,
group_by_length=training_run_args.group_by_length,
include_num_input_tokens_seen=True,
**training_kwargs,
)
@@ -284,6 +290,17 @@ if training_run_args.sync_to_bucket:
save_total_limit=training_run_args.save_total_limit
))
if training_run_args.flops_baseline:
# A100 GPU bfloat16 peak flops is 312 TFLOPS (312e12)
# 4090 GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11)
# 3090 GPU bfloat16 peak flops is 71 TFLOPS (71e12)
training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline)))
# log to tensorboard (but after MFU)
training_callbacks.append(TensorBoardCallback())
if not training_run_args.dpo:
print("Tokenizing datasets...")
@@ -325,35 +342,29 @@ else:
max_prompt_length = 0
train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) })
test_dataset = datasets["test"]
test_dataset = None
if training_run_args.test_dataset:
test_dataset = datasets["test"]
max_prompt_length = max(train_dataset["prompt_len"])
trainer = DPOTrainer(
model,
ref_model=original_model,
peft_config=peft_config,
args=training_args,
beta=training_run_args.beta,
train_dataset=datasets["train"],
eval_dataset=datasets["test"],
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
max_length=training_run_args.ctx_size,
max_prompt_length=max_prompt_length,
generate_during_eval=True,
truncation_mode="keep_start",
callbacks=training_callbacks,
)
# pre-allocate cuda buffers by running a forwards and backwards pass with the largest possible example length
# the trainer dumps the cuda buffers before we start... need to figure out how to disable that
# if training_run_args.pre_allocate_cuda_buffers:
# print("Allocating CUDA buffers...")
# inputs = tokenizer([""] * training_args.per_device_train_batch_size, return_tensors="pt", max_length=longest_example, padding="max_length", truncation=True)
# inputs = {k: v.to(model.device) for k, v in inputs.items()}
# inputs["labels"] = inputs["input_ids"]
# outputs = model(**inputs)
# loss = outputs.loss if isinstance(outputs, dict) else outputs[0]
# loss.backward()
# model.zero_grad()
try:
checkpoint = training_run_args.resume_from_checkpoint
if checkpoint:
@@ -361,7 +372,7 @@ try:
else:
trainer.train()
if training_run_args.train_dataset:
if training_run_args.test_dataset:
trainer.evaluate_all()
if training_run_args.use_lora and training_run_args.lora_merge:
@@ -377,7 +388,7 @@ try:
tokenizer.save_pretrained(model_dir)
except Exception as ex:
if len(torch.cuda.device_count()) > 1:
if torch.cuda.device_count() > 1:
raise ex # this doesn't play nice with FSDP so don't even try
print("Something bad happened! Try and save it?")

View File

@@ -1,15 +1,16 @@
import copy
import time
import torch
import os
import random
import shutil
from torch.utils.data import SequentialSampler, Subset, RandomSampler
from transformers import TrainerCallback, AutoTokenizer, Trainer
from transformers import TrainerCallback, AutoTokenizer, Trainer, AutoModelForCausalLM, \
TrainerControl, TrainerState
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
from calflops import calculate_flops
import boto3
import os
import shutil
@dataclass
class TrainingRunArguments:
@@ -25,6 +26,7 @@ class TrainingRunArguments:
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
learning_rate_warmup: float = field(default=0.0, metadata={"help": "The starting learning rate (speed at which the model trains)"})
weight_decay: float = field(default=0.1, metadata={"help": ""})
gradient_clip: float = field(default=1.0, metadata={"help": ""})
resume_from_checkpoint: str = field(default="", metadata={"help": "The name of the checkpoint to resume training from"})
@@ -58,6 +60,7 @@ class TrainingRunArguments:
gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing which saves quite a lot of VRAM"})
sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
class DataCollatorForSupervisedFineTuning(object):
"""Collate examples for supervised fine-tuning."""
@@ -194,15 +197,36 @@ class CustomSFTTrainer(Trainer):
Should speed up training by skipping the final fine tuning part that doesn't affect accuracy much
"""
return super().create_scheduler(int(num_training_steps * self.learning_rate_overshoot), optimizer=optimizer)
def floating_point_ops(self, inputs):
config = self.model.config
examples_length = len(inputs["input_ids"][0])
batch_size = len(inputs["input_ids"])
# mfu is approximated using thoughtput and param count
# the number of paramters is approximately the number of multiply-accumulates (MAC) in the network
# each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
# there are 3 passes of a NN (fwd, bwd, delta) - we multiply by 3 ie 2 * 3 * n_param
# this gets us FLOPs / token
flops_per_token = 2 * sum(p.numel() for p in self.model.parameters())
flops_per_seq = flops_per_token * examples_length
# there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
attn_flops_per_seq = config.num_hidden_layers * 2 * 2 * (config.hidden_size * (examples_length**2))
# there are 2 ops in bwd pass and 1 in fwd pass so we mult by 3
result = (3 * flops_per_seq + 3 * attn_flops_per_seq) * batch_size
return result
class UploadToS3Callback(TrainerCallback):
def __init__(self, s3_bucket, s3_prefix, save_total_limit=None):
import boto3
self.s3_client = boto3.client('s3')
self.s3_bucket = s3_bucket
self.s3_prefix = s3_prefix
self.save_total_limit = save_total_limit
def on_save(self, args, state, control, **kwargs):
def on_save(self, args, state: TrainerState, control: TrainerControl, **kwargs):
output_dir = kwargs['output_dir']
checkpoint = os.path.basename(output_dir)
@@ -241,4 +265,26 @@ class UploadToS3Callback(TrainerCallback):
resp = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=os.path.join(self.s3_prefix, checkpoint_name))
for obj in resp.get('Contents', []):
self.s3_client.delete_object(Bucket=self.s3_bucket, Key=obj['Key'])
print(f"Deleted s3://{self.s3_bucket}/{obj['Key']}")
print(f"Deleted s3://{self.s3_bucket}/{obj['Key']}")
class MFUCallback(TrainerCallback):
def __init__(self, peak_flops):
self.total_iterations = 0
self.start_time = time.time()
self.flops_promised = peak_flops
self.last_total_flos = 0
def on_log(self, args, state: TrainerState, control: TrainerControl, **kwargs):
if state.global_step == 0: # Avoid computation at the very beginning
return
current_time = time.time()
elapsed_time = current_time - self.start_time
# Calculate and log MFU
new_flops = state.total_flos - self.last_total_flos
kwargs['logs']['mfu'] = round(new_flops / elapsed_time / self.flops_promised, 4)
self.start_time = current_time
self.last_total_flos = state.total_flos