finish implementing alternate dataset generation mode

This commit is contained in:
Alex O'Connell
2025-11-26 22:01:08 -05:00
parent 07507ee5f5
commit 14640bd14b
2 changed files with 517 additions and 371 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -27,8 +27,8 @@ class TrainingRunArguments:
run_name: str = field(metadata={"help": "The folder to save the output model under"})
base_model: str = field(metadata={"help": "The base model to load for fine-tuning"})
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
dataset_processing_threads: int = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"})
test_dataset: Optional[str] = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
dataset_processing_threads: Optional[int] = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"})
ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"})
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"})
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
@@ -59,8 +59,8 @@ class TrainingRunArguments:
lora_rank: int = field(default=4, metadata={"help": "Rank which determines LoRA matrix size. Rank typically starts at 8 but can go up to 256. Higher ranks can store more information but increase the computational and memory cost of LoRA."})
lora_alpha: int = field(default=32, metadata={"help": "Alpha a scaling factor for updates. Alpha directly impacts the adapters contribution and is often set to 1x or 2x the rank value."})
lora_dropout: float = field(default=0.05)
lora_modules: str = field(default=None, metadata={"help": "Target modules: LoRA can be applied to various model components, including attention mechanisms (Q, K, V matrices), output projections, feed-forward blocks, and linear output layers. While initially focused on attention mechanisms, extending LoRA to other components has shown benefits. However, adapting more modules increases the number of trainable parameters and memory needs."})
lora_modules_to_save: str = field(default=None, metadata={"help": "Additional modules to save"})
lora_modules: Optional[str] = field(default=None, metadata={"help": "Target modules: LoRA can be applied to various model components, including attention mechanisms (Q, K, V matrices), output projections, feed-forward blocks, and linear output layers. While initially focused on attention mechanisms, extending LoRA to other components has shown benefits. However, adapting more modules increases the number of trainable parameters and memory needs."})
lora_modules_to_save: Optional[str] = field(default=None, metadata={"help": "Additional modules to save"})
lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"})
# dpo config
@@ -72,13 +72,13 @@ class TrainingRunArguments:
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
prefix_ids: str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
suffix_ids: str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
prefix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
suffix_ids: Optional[str] = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
# custom trainer tweaks
sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
bucket_save_limit: int = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"})
flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
sync_to_bucket: Optional[str] = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
bucket_save_limit: Optional[int] = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"})
flops_baseline: Optional[str] = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
class UploadToS3Callback(TrainerCallback):
@@ -283,25 +283,37 @@ class DataCollatorForSupervisedFineTuning(object):
)
def tokenize_raw_example(batch, tokenizer=None, training_run_args=None):
return tokenizer(
text=batch["text"],
max_length=training_run_args.ctx_size,
truncation=True,
add_special_tokens=False,
)
def tokenize_sharegpt_example(batch, tokenizer=None, training_run_args=None):
def tokenize_example(batch, tokenizer=None, training_run_args=None):
# TODO: figure out how to properly batch this
result = []
for example in batch["conversations"]:
conversation = [ { "role": x["from"], "content": x["value"] } for x in example ]
# Get tools array if present (same for all examples in batch)
tools = batch.get("tools", [None] * len(batch["conversations"]))
for idx, example in enumerate(batch["conversations"]):
conversation = []
for message in example:
# Pass content directly - let chat template handle block formatting
conversation.append({
"role": message["role"],
"content": message["content"]
})
# Prepare kwargs for apply_chat_template
template_kwargs = {
"conversation": conversation,
"max_length": training_run_args.ctx_size,
"truncation": True,
}
# Add tools if present in this example
example_tools = tools[idx] if idx < len(tools) else None
if example_tools:
template_kwargs["tools"] = example_tools
result.append(
tokenizer.apply_chat_template(
conversation=conversation,
max_length=training_run_args.ctx_size,
truncation=True,
)
tokenizer.apply_chat_template(**template_kwargs)
)
return {"input_ids": result}
@@ -546,22 +558,22 @@ def do_training_run(training_run_args: TrainingRunArguments):
# prepare the dataset
ddp_print("Tokenizing datasets...")
if "text" in datasets["train"].column_names:
tokenize_function = tokenize_raw_example
columns_to_remove = ["text"]
elif "conversations" in datasets["train"].column_names:
tokenize_function = tokenize_sharegpt_example
columns_to_remove = ["conversations"]
else:
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)")
if "conversations" not in datasets["train"].column_names:
raise Exception("Dataset must contain 'conversations' column in ShareGPT format")
columns_to_remove = ["conversations"]
# Remove tools column if present
if "tools" in datasets["train"].column_names:
columns_to_remove.append("tools")
tokenized_test_dataset = None
num_proc = None
if training_run_args.dataset_processing_threads:
num_proc = training_run_args.dataset_processing_threads // MULTI_GPU_WORLD_SIZE
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
tokenized_train_dataset = datasets["train"].map(tokenize_example, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
if training_run_args.test_dataset:
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
tokenized_test_dataset = datasets["test"].map(tokenize_example, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)