finish implementing alternate dataset generation mode

This commit is contained in:
Alex O'Connell
2025-11-26 22:01:08 -05:00
parent 07507ee5f5
commit 14640bd14b
2 changed files with 517 additions and 371 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -27,8 +27,8 @@ class TrainingRunArguments:
run_name: str = field(metadata={"help": "The folder to save the output model under"}) run_name: str = field(metadata={"help": "The folder to save the output model under"})
base_model: str = field(metadata={"help": "The base model to load for fine-tuning"}) base_model: str = field(metadata={"help": "The base model to load for fine-tuning"})
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"}) train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"}) test_dataset: Optional[str] = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
dataset_processing_threads: int = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"}) dataset_processing_threads: Optional[int] = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"})
ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"}) ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"})
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"}) bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"})
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"}) batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
@@ -59,8 +59,8 @@ class TrainingRunArguments:
lora_rank: int = field(default=4, metadata={"help": "Rank which determines LoRA matrix size. Rank typically starts at 8 but can go up to 256. Higher ranks can store more information but increase the computational and memory cost of LoRA."}) lora_rank: int = field(default=4, metadata={"help": "Rank which determines LoRA matrix size. Rank typically starts at 8 but can go up to 256. Higher ranks can store more information but increase the computational and memory cost of LoRA."})
lora_alpha: int = field(default=32, metadata={"help": "Alpha a scaling factor for updates. Alpha directly impacts the adapters contribution and is often set to 1x or 2x the rank value."}) lora_alpha: int = field(default=32, metadata={"help": "Alpha a scaling factor for updates. Alpha directly impacts the adapters contribution and is often set to 1x or 2x the rank value."})
lora_dropout: float = field(default=0.05) lora_dropout: float = field(default=0.05)
lora_modules: str = field(default=None, metadata={"help": "Target modules: LoRA can be applied to various model components, including attention mechanisms (Q, K, V matrices), output projections, feed-forward blocks, and linear output layers. While initially focused on attention mechanisms, extending LoRA to other components has shown benefits. However, adapting more modules increases the number of trainable parameters and memory needs."}) lora_modules: Optional[str] = field(default=None, metadata={"help": "Target modules: LoRA can be applied to various model components, including attention mechanisms (Q, K, V matrices), output projections, feed-forward blocks, and linear output layers. While initially focused on attention mechanisms, extending LoRA to other components has shown benefits. However, adapting more modules increases the number of trainable parameters and memory needs."})
lora_modules_to_save: str = field(default=None, metadata={"help": "Additional modules to save"}) lora_modules_to_save: Optional[str] = field(default=None, metadata={"help": "Additional modules to save"})
lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"}) lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"})
# dpo config # dpo config
@@ -72,13 +72,13 @@ class TrainingRunArguments:
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"}) add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"}) add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"}) add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
prefix_ids: str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
suffix_ids: str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."}) suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
# custom trainer tweaks # custom trainer tweaks
sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"}) sync_to_bucket: Optional[str] = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
bucket_save_limit: int = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"}) bucket_save_limit: Optional[int] = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"})
flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"}) flops_baseline: Optional[str] = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
class UploadToS3Callback(TrainerCallback): class UploadToS3Callback(TrainerCallback):
@@ -283,25 +283,37 @@ class DataCollatorForSupervisedFineTuning(object):
) )
def tokenize_raw_example(batch, tokenizer=None, training_run_args=None): def tokenize_example(batch, tokenizer=None, training_run_args=None):
return tokenizer(
text=batch["text"],
max_length=training_run_args.ctx_size,
truncation=True,
add_special_tokens=False,
)
def tokenize_sharegpt_example(batch, tokenizer=None, training_run_args=None):
# TODO: figure out how to properly batch this # TODO: figure out how to properly batch this
result = [] result = []
for example in batch["conversations"]:
conversation = [ { "role": x["from"], "content": x["value"] } for x in example ] # Get tools array if present (same for all examples in batch)
tools = batch.get("tools", [None] * len(batch["conversations"]))
for idx, example in enumerate(batch["conversations"]):
conversation = []
for message in example:
# Pass content directly - let chat template handle block formatting
conversation.append({
"role": message["role"],
"content": message["content"]
})
# Prepare kwargs for apply_chat_template
template_kwargs = {
"conversation": conversation,
"max_length": training_run_args.ctx_size,
"truncation": True,
}
# Add tools if present in this example
example_tools = tools[idx] if idx < len(tools) else None
if example_tools:
template_kwargs["tools"] = example_tools
result.append( result.append(
tokenizer.apply_chat_template( tokenizer.apply_chat_template(**template_kwargs)
conversation=conversation,
max_length=training_run_args.ctx_size,
truncation=True,
)
) )
return {"input_ids": result} return {"input_ids": result}
@@ -546,22 +558,22 @@ def do_training_run(training_run_args: TrainingRunArguments):
# prepare the dataset # prepare the dataset
ddp_print("Tokenizing datasets...") ddp_print("Tokenizing datasets...")
if "text" in datasets["train"].column_names: if "conversations" not in datasets["train"].column_names:
tokenize_function = tokenize_raw_example raise Exception("Dataset must contain 'conversations' column in ShareGPT format")
columns_to_remove = ["text"]
elif "conversations" in datasets["train"].column_names: columns_to_remove = ["conversations"]
tokenize_function = tokenize_sharegpt_example
columns_to_remove = ["conversations"] # Remove tools column if present
else: if "tools" in datasets["train"].column_names:
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)") columns_to_remove.append("tools")
tokenized_test_dataset = None tokenized_test_dataset = None
num_proc = None num_proc = None
if training_run_args.dataset_processing_threads: if training_run_args.dataset_processing_threads:
num_proc = training_run_args.dataset_processing_threads // MULTI_GPU_WORLD_SIZE num_proc = training_run_args.dataset_processing_threads // MULTI_GPU_WORLD_SIZE
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove) tokenized_train_dataset = datasets["train"].map(tokenize_example, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
if training_run_args.test_dataset: if training_run_args.test_dataset:
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove) tokenized_test_dataset = datasets["test"].map(tokenize_example, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ] example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths) tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)