add load in 8bit + enable lora

This commit is contained in:
Alex O'Connell
2023-10-27 00:23:13 -04:00
parent 110b988527
commit 82b4ea4c65
2 changed files with 47 additions and 25 deletions

View File

@@ -1,3 +1,4 @@
transformers
tensorboard
peft
peft
bitsandbytes

View File

@@ -27,6 +27,7 @@ class TrainingRunArguments:
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
load_in_8bit: bool = field(default=False, metadata={"help": "Set to load the base model in 8-bit mode using bitsandbytes"})
use_lora: bool = field(default=False, metadata={"help": "If set, then the trained model will be a LoRA"})
lora_rank: int = field(default=4)
@@ -36,38 +37,58 @@ class TrainingRunArguments:
parser = HfArgumentParser([TrainingRunArguments])
training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if training_run_args.load_in_8bit and training_run_args.bf16:
raise Exception("Cannot use load_in_8bit and bf16 flags at the same time!")
# TODO: write a proper evaluation script
print(f"Loading model '{training_run_args.base_model}'...")
model_dtype = torch.bfloat16 if training_run_args.bf16 else torch.float16
model = AutoModelForCausalLM.from_pretrained(training_run_args.base_model, trust_remote_code=True, torch_dtype=model_dtype)
model_kwargs = {}
if training_run_args.load_in_8bit:
model_kwargs["load_int_8bit"] = True
elif training_run_args.bf16:
model_kwargs["torch_dtype"] = torch.bfloat16
else:
model_kwargs["torch_dtype"] = torch.float16
model = AutoModelForCausalLM.from_pretrained(
training_run_args.base_model,
trust_remote_code=True,
**model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
training_callbacks = None
if training_run_args.use_lora:
raise NotImplementedError("Need to fix the callback thing still")
# if training_run_args.use_lora:
# from peft import LoraConfig, TaskType, get_peft_model
# print("Creating LoRA for model...")
# class SavePeftModelCallback(transformers.TrainerCallback):
# def on_save(self, args, state, control, **kwargs):
# checkpoint_folder_name = f"{transformers.trainer_utils.PREFIX_CHECKPOINT_DIR}-{state.global_step}"
# checkpoint_folder = os.path.join(args.output_dir, checkpoint_folder_name)
# peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
# kwargs["model"].save_pretrained(peft_model_path)
# return control
# training_callbacks = [SavePeftModelCallback]
# peft_config = LoraConfig(
# task_type=TaskType.CAUSAL_LM,
# inference_mode=False,
# r=training_run_args.lora_rank,
# lora_alpha=training_run_args.alpha,
# lora_dropout=training_run_args.droput,
# target_modules=None,
# )
# model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()
# raise NotImplementedError("Need to fix the callback thing still")
if training_run_args.use_lora:
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
print("Creating LoRA for model...")
class SavePeftModelCallback(transformers.TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_folder_name = f"{transformers.trainer_utils.PREFIX_CHECKPOINT_DIR}-{state.global_step}"
checkpoint_folder = os.path.join(args.output_dir, checkpoint_folder_name)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
return control
training_callbacks = [SavePeftModelCallback]
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=training_run_args.lora_rank,
lora_alpha=training_run_args.alpha,
lora_dropout=training_run_args.droput,
target_modules=None,
)
if training_run_args.load_in_8bit:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=False
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
def tokenize_function(example):
result = tokenizer(example['text'] + tokenizer.eos_token,