mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
add training script fixes
This commit is contained in:
118
train.py
118
train.py
@@ -9,7 +9,7 @@ import time
|
||||
import shutil
|
||||
from torch.utils.data import SequentialSampler, Subset, RandomSampler
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, \
|
||||
PreTrainedTokenizerFast, HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback
|
||||
PreTrainedTokenizerFast, HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback, BitsAndBytesConfig
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.integrations.integration_utils import TensorBoardCallback
|
||||
from datasets import load_dataset, Dataset
|
||||
@@ -96,6 +96,20 @@ python3 train.py \
|
||||
--use_lora --lora_rank 32 --lora_alpha 64 --lora_modules up_proj,down_proj,q_proj,v_proj,o_proj
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name mistralhome-bielik-rev1 \
|
||||
--base_model speakleash/Bielik-7B-Instruct-v0.1 \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 1e-5 --learning_rate_warmup 0.03 --batch_size 64 --epochs 1 \
|
||||
--micro_batch_size 4 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 \
|
||||
--save_steps 50 --save_total_limit 20 --eval_steps 200 --logging_steps 1 \
|
||||
--use_lora --lora_rank 32 --lora_alpha 64 --lora_modules up_proj,down_proj,q_proj,v_proj,o_proj --load_in_4bit
|
||||
"""
|
||||
|
||||
"""
|
||||
accelerate launch --config_file fsdp_config.yaml train.py \
|
||||
--run_name stablehome-3b-rev10 \
|
||||
@@ -143,6 +157,42 @@ python3 train.py \
|
||||
--ctx_size 2048 --save_steps 100 --save_total_limit 10
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name tinyhome-qwen-rev3 \
|
||||
--base_model Qwen/Qwen2-0.5B-Instruct \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 2e-5 --batch_size 64 \
|
||||
--micro_batch_size 8 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 --save_steps 1000
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name home-phi3-mini-rev1 \
|
||||
--base_model microsoft/Phi-3-mini-4k-instruct \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 5e-6 --batch_size 32 \
|
||||
--micro_batch_size 8 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 --save_steps 100 --save_total_limit 10
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name tinyhome-polish-rev1 \
|
||||
--base_model eryk-mazus/polka-1.1b-chat \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 2e-5 --batch_size 32 \
|
||||
--micro_batch_size 8 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 --save_steps 100 --save_total_limit 10
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name tinyhome-rev2-dpo \
|
||||
@@ -156,6 +206,11 @@ python3 train.py \
|
||||
--save_steps 50 --save_total_limit 10 --eval_steps 100 --logging_steps 2
|
||||
"""
|
||||
|
||||
MULTI_GPU_WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
MULTI_GPU_RANK = int(os.environ.get("RANK", "0"))
|
||||
IS_MULTI_GPU = os.environ.get("RANK") != None
|
||||
IS_MASTER_PROCESS = MULTI_GPU_RANK == 0
|
||||
|
||||
@dataclass
|
||||
class TrainingRunArguments:
|
||||
run_name: str = field(metadata={"help": "The folder to save the output model under"})
|
||||
@@ -283,13 +338,14 @@ training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strin
|
||||
if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1:
|
||||
raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq")
|
||||
|
||||
print(f"Loading model '{training_run_args.base_model}'...")
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Loading model '{training_run_args.base_model}'...")
|
||||
|
||||
model_kwargs = {}
|
||||
if training_run_args.load_in_8bit:
|
||||
model_kwargs["load_in_8bit"] = True
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif training_run_args.load_in_4bit:
|
||||
model_kwargs["load_in_4bit"] = True
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
||||
elif training_run_args.load_as_gptq:
|
||||
model_kwargs["quantization_config"] = GPTQConfig(bits=4, disable_exllama=True)
|
||||
|
||||
@@ -310,7 +366,8 @@ def find_max_vram(min_buffer_mib=800):
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
suggestion = min(suggestion, total_mem - min_buffer_mib)
|
||||
|
||||
print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
|
||||
max_memory[i] = f'{suggestion}MiB'
|
||||
|
||||
return max_memory
|
||||
@@ -361,7 +418,8 @@ original_model = model
|
||||
peft_config = None
|
||||
if training_run_args.use_lora:
|
||||
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
||||
print("Creating LoRA for model...")
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Creating LoRA for model...")
|
||||
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
|
||||
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
|
||||
peft_config = LoraConfig(
|
||||
@@ -390,8 +448,8 @@ training_kwargs = {}
|
||||
|
||||
if training_run_args.test_dataset:
|
||||
training_kwargs.update({
|
||||
"per_device_eval_batch_size": training_run_args.batch_size,
|
||||
"evaluation_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
|
||||
"per_device_eval_batch_size": training_run_args.micro_batch_size,
|
||||
"eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
|
||||
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
|
||||
"bf16_full_eval": training_run_args.bf16,
|
||||
})
|
||||
@@ -433,9 +491,10 @@ class DataCollatorForSupervisedFineTuning(object):
|
||||
def __init__(self, *, tokenizer: AutoTokenizer, prefix_ids = None, suffix_ids = None):
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
assistant_prompt = tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")
|
||||
self.response_prefix = assistant_prompt[0]
|
||||
self.response_suffix = assistant_prompt[1]
|
||||
if not prefix_ids and not suffix_ids:
|
||||
assistant_prompt = tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")
|
||||
self.response_prefix = assistant_prompt[0]
|
||||
self.response_suffix = assistant_prompt[1]
|
||||
|
||||
if prefix_ids:
|
||||
self.prefix_ids = prefix_ids
|
||||
@@ -514,8 +573,9 @@ class DataCollatorForSupervisedFineTuning(object):
|
||||
for label in labels:
|
||||
mask_ranges = self._find_mask_ranges(label)
|
||||
for start, end in mask_ranges:
|
||||
if end - start == len(label):
|
||||
if end - start == len(label) - 1:
|
||||
print("warning! example had no assistant response in it!")
|
||||
print(input_ids)
|
||||
label[start:end] = [-100] * (end - start)
|
||||
|
||||
input_ids = torch.LongTensor(self._pad(input_ids, self.tokenizer.pad_token_id or self.tokenizer.eos_token_id))
|
||||
@@ -527,7 +587,8 @@ class DataCollatorForSupervisedFineTuning(object):
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id or self.tokenizer.eos_token_id),
|
||||
)
|
||||
|
||||
print("Loading dataset...")
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Loading dataset...")
|
||||
data_files = { "train": training_run_args.train_dataset }
|
||||
if training_run_args.test_dataset:
|
||||
data_files["test"] = training_run_args.test_dataset
|
||||
@@ -648,7 +709,8 @@ class CustomSFTTrainer(Trainer):
|
||||
return result
|
||||
|
||||
if not training_run_args.dpo:
|
||||
print("Tokenizing datasets...")
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Tokenizing datasets...")
|
||||
|
||||
if "text" in datasets["train"].column_names:
|
||||
tokenize_function = tokenize_raw_example
|
||||
@@ -660,20 +722,36 @@ if not training_run_args.dpo:
|
||||
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)")
|
||||
|
||||
tokenized_test_dataset = None
|
||||
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=os.cpu_count()).remove_columns(columns_to_remove)
|
||||
num_proc = os.cpu_count() // MULTI_GPU_WORLD_SIZE
|
||||
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove)
|
||||
if training_run_args.test_dataset:
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=os.cpu_count()).remove_columns(columns_to_remove)
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove)
|
||||
|
||||
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
|
||||
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)
|
||||
print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens")
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens")
|
||||
|
||||
|
||||
prefix_ids = None
|
||||
suffix_ids = None
|
||||
|
||||
# data_collator = DataCollatorForSupervisedFineTuning(tokenizer=tokenizer)
|
||||
# fix for tinyllama not detecting split properly
|
||||
# prefix_ids = [29966, 29989, 465, 22137, 29989, 29958, 13]
|
||||
# suffix_ids = [2]
|
||||
|
||||
# fix for qwen2 not detecting split properly
|
||||
# prefix_ids = [151644, 77091, 198]
|
||||
# suffix_ids = [151645, 198]
|
||||
|
||||
# fix for polka-1.1 not detecting split properly
|
||||
# prefix_ids = [43883, 20255, 13]
|
||||
# suffix_ids = [43882, 29871, 13]
|
||||
|
||||
data_collator = DataCollatorForSupervisedFineTuning(
|
||||
tokenizer=tokenizer,
|
||||
prefix_ids=[29966, 29989, 465, 22137, 29989, 29958, 13],
|
||||
suffix_ids=[2],
|
||||
prefix_ids=prefix_ids,
|
||||
suffix_ids=suffix_ids,
|
||||
)
|
||||
|
||||
trainer = CustomSFTTrainer(
|
||||
|
||||
Reference in New Issue
Block a user