mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
finish implementing alternate dataset generation mode
This commit is contained in:
File diff suppressed because it is too large
Load Diff
82
train.py
82
train.py
@@ -27,8 +27,8 @@ class TrainingRunArguments:
|
|||||||
run_name: str = field(metadata={"help": "The folder to save the output model under"})
|
run_name: str = field(metadata={"help": "The folder to save the output model under"})
|
||||||
base_model: str = field(metadata={"help": "The base model to load for fine-tuning"})
|
base_model: str = field(metadata={"help": "The base model to load for fine-tuning"})
|
||||||
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
|
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
|
||||||
test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
|
test_dataset: Optional[str] = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
|
||||||
dataset_processing_threads: int = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"})
|
dataset_processing_threads: Optional[int] = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"})
|
||||||
ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"})
|
ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"})
|
||||||
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"})
|
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"})
|
||||||
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
|
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
|
||||||
@@ -59,8 +59,8 @@ class TrainingRunArguments:
|
|||||||
lora_rank: int = field(default=4, metadata={"help": "Rank which determines LoRA matrix size. Rank typically starts at 8 but can go up to 256. Higher ranks can store more information but increase the computational and memory cost of LoRA."})
|
lora_rank: int = field(default=4, metadata={"help": "Rank which determines LoRA matrix size. Rank typically starts at 8 but can go up to 256. Higher ranks can store more information but increase the computational and memory cost of LoRA."})
|
||||||
lora_alpha: int = field(default=32, metadata={"help": "Alpha a scaling factor for updates. Alpha directly impacts the adapters contribution and is often set to 1x or 2x the rank value."})
|
lora_alpha: int = field(default=32, metadata={"help": "Alpha a scaling factor for updates. Alpha directly impacts the adapters contribution and is often set to 1x or 2x the rank value."})
|
||||||
lora_dropout: float = field(default=0.05)
|
lora_dropout: float = field(default=0.05)
|
||||||
lora_modules: str = field(default=None, metadata={"help": "Target modules: LoRA can be applied to various model components, including attention mechanisms (Q, K, V matrices), output projections, feed-forward blocks, and linear output layers. While initially focused on attention mechanisms, extending LoRA to other components has shown benefits. However, adapting more modules increases the number of trainable parameters and memory needs."})
|
lora_modules: Optional[str] = field(default=None, metadata={"help": "Target modules: LoRA can be applied to various model components, including attention mechanisms (Q, K, V matrices), output projections, feed-forward blocks, and linear output layers. While initially focused on attention mechanisms, extending LoRA to other components has shown benefits. However, adapting more modules increases the number of trainable parameters and memory needs."})
|
||||||
lora_modules_to_save: str = field(default=None, metadata={"help": "Additional modules to save"})
|
lora_modules_to_save: Optional[str] = field(default=None, metadata={"help": "Additional modules to save"})
|
||||||
lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"})
|
lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"})
|
||||||
|
|
||||||
# dpo config
|
# dpo config
|
||||||
@@ -72,13 +72,13 @@ class TrainingRunArguments:
|
|||||||
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
|
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
|
||||||
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
|
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
|
||||||
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
|
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
|
||||||
prefix_ids: str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||||
suffix_ids: str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||||
|
|
||||||
# custom trainer tweaks
|
# custom trainer tweaks
|
||||||
sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
|
sync_to_bucket: Optional[str] = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
|
||||||
bucket_save_limit: int = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"})
|
bucket_save_limit: Optional[int] = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"})
|
||||||
flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
|
flops_baseline: Optional[str] = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
|
||||||
|
|
||||||
|
|
||||||
class UploadToS3Callback(TrainerCallback):
|
class UploadToS3Callback(TrainerCallback):
|
||||||
@@ -283,25 +283,37 @@ class DataCollatorForSupervisedFineTuning(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_raw_example(batch, tokenizer=None, training_run_args=None):
|
def tokenize_example(batch, tokenizer=None, training_run_args=None):
|
||||||
return tokenizer(
|
|
||||||
text=batch["text"],
|
|
||||||
max_length=training_run_args.ctx_size,
|
|
||||||
truncation=True,
|
|
||||||
add_special_tokens=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def tokenize_sharegpt_example(batch, tokenizer=None, training_run_args=None):
|
|
||||||
# TODO: figure out how to properly batch this
|
# TODO: figure out how to properly batch this
|
||||||
result = []
|
result = []
|
||||||
for example in batch["conversations"]:
|
|
||||||
conversation = [ { "role": x["from"], "content": x["value"] } for x in example ]
|
# Get tools array if present (same for all examples in batch)
|
||||||
|
tools = batch.get("tools", [None] * len(batch["conversations"]))
|
||||||
|
|
||||||
|
for idx, example in enumerate(batch["conversations"]):
|
||||||
|
conversation = []
|
||||||
|
|
||||||
|
for message in example:
|
||||||
|
# Pass content directly - let chat template handle block formatting
|
||||||
|
conversation.append({
|
||||||
|
"role": message["role"],
|
||||||
|
"content": message["content"]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Prepare kwargs for apply_chat_template
|
||||||
|
template_kwargs = {
|
||||||
|
"conversation": conversation,
|
||||||
|
"max_length": training_run_args.ctx_size,
|
||||||
|
"truncation": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add tools if present in this example
|
||||||
|
example_tools = tools[idx] if idx < len(tools) else None
|
||||||
|
if example_tools:
|
||||||
|
template_kwargs["tools"] = example_tools
|
||||||
|
|
||||||
result.append(
|
result.append(
|
||||||
tokenizer.apply_chat_template(
|
tokenizer.apply_chat_template(**template_kwargs)
|
||||||
conversation=conversation,
|
|
||||||
max_length=training_run_args.ctx_size,
|
|
||||||
truncation=True,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"input_ids": result}
|
return {"input_ids": result}
|
||||||
@@ -546,22 +558,22 @@ def do_training_run(training_run_args: TrainingRunArguments):
|
|||||||
# prepare the dataset
|
# prepare the dataset
|
||||||
ddp_print("Tokenizing datasets...")
|
ddp_print("Tokenizing datasets...")
|
||||||
|
|
||||||
if "text" in datasets["train"].column_names:
|
if "conversations" not in datasets["train"].column_names:
|
||||||
tokenize_function = tokenize_raw_example
|
raise Exception("Dataset must contain 'conversations' column in ShareGPT format")
|
||||||
columns_to_remove = ["text"]
|
|
||||||
elif "conversations" in datasets["train"].column_names:
|
columns_to_remove = ["conversations"]
|
||||||
tokenize_function = tokenize_sharegpt_example
|
|
||||||
columns_to_remove = ["conversations"]
|
# Remove tools column if present
|
||||||
else:
|
if "tools" in datasets["train"].column_names:
|
||||||
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)")
|
columns_to_remove.append("tools")
|
||||||
|
|
||||||
tokenized_test_dataset = None
|
tokenized_test_dataset = None
|
||||||
num_proc = None
|
num_proc = None
|
||||||
if training_run_args.dataset_processing_threads:
|
if training_run_args.dataset_processing_threads:
|
||||||
num_proc = training_run_args.dataset_processing_threads // MULTI_GPU_WORLD_SIZE
|
num_proc = training_run_args.dataset_processing_threads // MULTI_GPU_WORLD_SIZE
|
||||||
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
|
tokenized_train_dataset = datasets["train"].map(tokenize_example, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
|
||||||
if training_run_args.test_dataset:
|
if training_run_args.test_dataset:
|
||||||
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
|
tokenized_test_dataset = datasets["test"].map(tokenize_example, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
|
||||||
|
|
||||||
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
|
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
|
||||||
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)
|
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)
|
||||||
|
|||||||
Reference in New Issue
Block a user