mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
add load in 8bit + enable lora
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
transformers
|
||||
tensorboard
|
||||
peft
|
||||
peft
|
||||
bitsandbytes
|
||||
69
train.py
69
train.py
@@ -27,6 +27,7 @@ class TrainingRunArguments:
|
||||
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
|
||||
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
|
||||
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
|
||||
load_in_8bit: bool = field(default=False, metadata={"help": "Set to load the base model in 8-bit mode using bitsandbytes"})
|
||||
|
||||
use_lora: bool = field(default=False, metadata={"help": "If set, then the trained model will be a LoRA"})
|
||||
lora_rank: int = field(default=4)
|
||||
@@ -36,38 +37,58 @@ class TrainingRunArguments:
|
||||
parser = HfArgumentParser([TrainingRunArguments])
|
||||
training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
if training_run_args.load_in_8bit and training_run_args.bf16:
|
||||
raise Exception("Cannot use load_in_8bit and bf16 flags at the same time!")
|
||||
|
||||
# TODO: write a proper evaluation script
|
||||
|
||||
print(f"Loading model '{training_run_args.base_model}'...")
|
||||
model_dtype = torch.bfloat16 if training_run_args.bf16 else torch.float16
|
||||
model = AutoModelForCausalLM.from_pretrained(training_run_args.base_model, trust_remote_code=True, torch_dtype=model_dtype)
|
||||
|
||||
model_kwargs = {}
|
||||
if training_run_args.load_in_8bit:
|
||||
model_kwargs["load_int_8bit"] = True
|
||||
elif training_run_args.bf16:
|
||||
model_kwargs["torch_dtype"] = torch.bfloat16
|
||||
else:
|
||||
model_kwargs["torch_dtype"] = torch.float16
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
training_run_args.base_model,
|
||||
trust_remote_code=True,
|
||||
**model_kwargs
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, trust_remote_code=True)
|
||||
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
|
||||
|
||||
training_callbacks = None
|
||||
if training_run_args.use_lora:
|
||||
raise NotImplementedError("Need to fix the callback thing still")
|
||||
# if training_run_args.use_lora:
|
||||
# from peft import LoraConfig, TaskType, get_peft_model
|
||||
# print("Creating LoRA for model...")
|
||||
# class SavePeftModelCallback(transformers.TrainerCallback):
|
||||
# def on_save(self, args, state, control, **kwargs):
|
||||
# checkpoint_folder_name = f"{transformers.trainer_utils.PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
||||
# checkpoint_folder = os.path.join(args.output_dir, checkpoint_folder_name)
|
||||
# peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
# kwargs["model"].save_pretrained(peft_model_path)
|
||||
# return control
|
||||
# training_callbacks = [SavePeftModelCallback]
|
||||
# peft_config = LoraConfig(
|
||||
# task_type=TaskType.CAUSAL_LM,
|
||||
# inference_mode=False,
|
||||
# r=training_run_args.lora_rank,
|
||||
# lora_alpha=training_run_args.alpha,
|
||||
# lora_dropout=training_run_args.droput,
|
||||
# target_modules=None,
|
||||
# )
|
||||
# model = get_peft_model(model, peft_config)
|
||||
# model.print_trainable_parameters()
|
||||
# raise NotImplementedError("Need to fix the callback thing still")
|
||||
|
||||
if training_run_args.use_lora:
|
||||
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
||||
print("Creating LoRA for model...")
|
||||
class SavePeftModelCallback(transformers.TrainerCallback):
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
checkpoint_folder_name = f"{transformers.trainer_utils.PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
||||
checkpoint_folder = os.path.join(args.output_dir, checkpoint_folder_name)
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
return control
|
||||
training_callbacks = [SavePeftModelCallback]
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=training_run_args.lora_rank,
|
||||
lora_alpha=training_run_args.alpha,
|
||||
lora_dropout=training_run_args.droput,
|
||||
target_modules=None,
|
||||
)
|
||||
if training_run_args.load_in_8bit:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=False
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
def tokenize_function(example):
|
||||
result = tokenizer(example['text'] + tokenizer.eos_token,
|
||||
|
||||
Reference in New Issue
Block a user