mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
fix eval script + add notes from recent training runs
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#!/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import csv
|
||||
@@ -1011,6 +1012,19 @@ def merge_languages(filename_prefix: str, languages: list):
|
||||
f.writelines(all_examples)
|
||||
|
||||
|
||||
def shard_languages(filename_prefix: str, main_lang: str, secondary_lang: str, ratio: float):
|
||||
all_examples = []
|
||||
with open(f"{filename_prefix}_{main_lang}.jsonl") as f:
|
||||
all_examples.extend(f.readlines())
|
||||
|
||||
with open(f"{filename_prefix}_{secondary_lang}.jsonl") as f:
|
||||
lines = f.readlines()
|
||||
all_examples.extend(random.sample(lines, int(len(lines) * ratio)))
|
||||
|
||||
with open(f"{filename_prefix}.jsonl", "w") as f:
|
||||
f.writelines(all_examples)
|
||||
|
||||
|
||||
def load_dataset_piles(language):
|
||||
global pile_of_durations, pile_of_media_names, pile_of_todo_items, stacks_of_device_names, \
|
||||
pile_of_templated_actions, pile_of_specific_actions, pile_of_responses, pile_of_status_requests, \
|
||||
@@ -1096,6 +1110,7 @@ def main():
|
||||
parser.add_argument("--dpo", action="store_true", help="Set this flag to enable generation of the DPO dataset.")
|
||||
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
|
||||
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate")
|
||||
parser.add_argument("--shard", action="store_true", help="Shard the provided language with examples from another language at a 10:1 ratio. Only supports 2 languages, and the first one will be the majority language.")
|
||||
|
||||
train_size_group = parser.add_mutually_exclusive_group()
|
||||
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
|
||||
@@ -1116,6 +1131,10 @@ def main():
|
||||
if args.size and not args.train:
|
||||
print("Train size was provided but not generating the training set!")
|
||||
exit(-1)
|
||||
|
||||
if args.shard and len(args.language) != 2:
|
||||
print("Can only shard when 2 languages are provided!")
|
||||
exit(-1)
|
||||
|
||||
if not args.format or args.format == "raw":
|
||||
format_func = format_example_raw_chatml
|
||||
@@ -1123,6 +1142,7 @@ def main():
|
||||
format_func = format_example_sharegpt
|
||||
|
||||
for language in args.language:
|
||||
print(f"Handling {language}")
|
||||
load_dataset_piles(language)
|
||||
personas = list(pile_of_system_prompts.keys())
|
||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||
@@ -1145,11 +1165,20 @@ def main():
|
||||
|
||||
if len(args.language) > 1:
|
||||
if args.sample:
|
||||
merge_languages("sample", args.language)
|
||||
if args.shard:
|
||||
shard_languages("sample", args.language[0], args.language[1], 0.1)
|
||||
else:
|
||||
merge_languages("sample", args.language)
|
||||
if args.train:
|
||||
merge_languages("home_assistant_train", args.language)
|
||||
if args.shard:
|
||||
shard_languages("home_assistant_train", args.language[0], args.language[1], 0.1)
|
||||
else:
|
||||
merge_languages("home_assistant_train", args.language)
|
||||
if args.test:
|
||||
merge_languages("home_assistant_test", args.language)
|
||||
if args.shard:
|
||||
shard_languages("home_assistant_test", args.language[0], args.language[1], 0.1)
|
||||
else:
|
||||
merge_languages("home_assistant_test", args.language)
|
||||
|
||||
if args.dpo:
|
||||
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1)
|
||||
|
||||
38
docs/experiment-notes-qwen.md
Normal file
38
docs/experiment-notes-qwen.md
Normal file
@@ -0,0 +1,38 @@
|
||||
## Qwen/Qwen2-0.5B-Instruct
|
||||
|
||||
# tinyhome-qwen-rev1
|
||||
- full fine tune
|
||||
- epochs: 1
|
||||
- 2048 train ctx
|
||||
- batch size 32
|
||||
- learning rate 2e-5
|
||||
- weight decay 0.1
|
||||
- gradient clipping 1.0
|
||||
- dataset size: small
|
||||
+ evaluation results: NEEDS RE-TEST b/c OF BAD EVAL SCRIPT
|
||||
|
||||
# tinyhome-qwen-rev2
|
||||
- full fine tune
|
||||
- epochs: 1
|
||||
- 2048 train ctx
|
||||
- batch size 32
|
||||
- learning rate 2e-5
|
||||
- weight decay 0.1
|
||||
- gradient clipping 1.0
|
||||
- dataset size: medium
|
||||
+ evaluation results: NEEDS RE-TEST b/c OF BAD EVAL SCRIPT
|
||||
|
||||
# tinyhome-qwen-rev3
|
||||
- full fine tune
|
||||
- epochs: 1
|
||||
- 2048 train ctx
|
||||
- batch size 64
|
||||
- learning rate 2e-5
|
||||
- weight decay 0.1
|
||||
- gradient clipping 1.0
|
||||
- dataset size: small 4 language mix
|
||||
+ evaluation results:
|
||||
- english: 0.9842022116903634
|
||||
- german: 0.8992834394904459
|
||||
- french: 0.9307445956765412
|
||||
- spanish: 0.9406099518459069
|
||||
@@ -36,4 +36,15 @@
|
||||
- 800:
|
||||
- 900:
|
||||
- 1000:
|
||||
- Final: 0.9817813765182186
|
||||
- Final: 0.9817813765182186
|
||||
|
||||
|
||||
# tinyhome-polish-rev2
|
||||
- dataset size: small (polish sharded with english)
|
||||
- learning rate 2e-5
|
||||
+ evaluation results: NEEDS RE-TEST b/c OF BAD EVAL SCRIPT
|
||||
|
||||
# tinyhome-polish-rev2
|
||||
- dataset size: medium (polish sharded with english)
|
||||
- learning rate 2e-5
|
||||
+ evaluation results: 0.8115246098439376
|
||||
51
evaluate.py
51
evaluate.py
@@ -79,11 +79,17 @@ def generate(model, tokenizer, prompts):
|
||||
inputs = tokenize(tokenizer, prompts)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(**inputs)
|
||||
text = tokenizer.batch_decode(outputs)
|
||||
|
||||
text = []
|
||||
for batch_inputs, batch_outputs in zip(inputs["input_ids"], outputs):
|
||||
text.append(tokenizer.decode(batch_outputs[batch_inputs.shape[0]:]))
|
||||
|
||||
return text
|
||||
|
||||
def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_size, use_icl):
|
||||
split = trained_tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")[0].replace(trained_tokenizer.bos_token, "")
|
||||
split = trained_tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")[0].replace(trained_tokenizer.bos_token or "", "")
|
||||
# print(split)
|
||||
split = "<|im_start|> assistant"
|
||||
|
||||
print("Evaluating...")
|
||||
correct_answers = 0
|
||||
@@ -93,6 +99,7 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
|
||||
# pre-allocate cuda buffers
|
||||
inputs = trained_tokenizer([""] * batch_size, return_tensors="pt", max_length=CTX_SIZE, padding="max_length", truncation=True)
|
||||
inputs = {k: v.to(trained_model.device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = trained_model(**inputs)
|
||||
|
||||
@@ -138,8 +145,13 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
|
||||
output = generate(trained_model, trained_tokenizer, prompts)
|
||||
|
||||
for model_output, expected_response in zip(output, expected_responses):
|
||||
response = model_output.replace(trained_tokenizer.pad_token, "").replace(trained_tokenizer.eos_token, "").split(split)[1]
|
||||
|
||||
try:
|
||||
response = model_output.split(trained_tokenizer.eos_token)[0]
|
||||
except Exception as ex:
|
||||
print("model_output----------------------------------")
|
||||
print(model_output)
|
||||
raise ex
|
||||
|
||||
expected_service_calls = []
|
||||
|
||||
if use_icl:
|
||||
@@ -152,12 +164,12 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
|
||||
if len(line) == 0:
|
||||
continue
|
||||
expected_service_calls.append(json.loads(line))
|
||||
total_answers = total_answers + 1
|
||||
|
||||
total_answers = total_answers + 1
|
||||
|
||||
found_responses = regex_to_use.findall(response.strip())
|
||||
|
||||
if len(expected_service_calls) == 0:
|
||||
total_answers = total_answers + 1
|
||||
if len(found_responses) == 0:
|
||||
correct_answers = correct_answers + 1
|
||||
continue
|
||||
@@ -169,14 +181,21 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
|
||||
failed_examples.append({ "expected": expected_response, "actual": response, "no_response_found": True })
|
||||
continue
|
||||
|
||||
processed_tool_calls = 0
|
||||
color_mismatch = False
|
||||
failure_flags = {}
|
||||
for block in found_responses:
|
||||
if processed_tool_calls >= len(expected_service_calls):
|
||||
failure_flags.update({"extra_service_calls": True})
|
||||
break
|
||||
|
||||
for line in block.split("\n"):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
try:
|
||||
json_output = json.loads(line)
|
||||
except:
|
||||
failed_examples.append({ "expected": expected_response, "actual": response, "invalid_json": True })
|
||||
failure_flags.update({"invalid_json": True})
|
||||
continue
|
||||
|
||||
if use_icl:
|
||||
@@ -184,7 +203,8 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
|
||||
|
||||
if json_output in expected_service_calls:
|
||||
expected_service_calls.pop(expected_service_calls.index(json_output))
|
||||
correct_answers = correct_answers + 1
|
||||
|
||||
processed_tool_calls = processed_tool_calls + 1
|
||||
elif "rgb_color" in json_output:
|
||||
for sc in expected_service_calls:
|
||||
sc = { **sc }
|
||||
@@ -196,10 +216,19 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
|
||||
if sc == json_output_copy:
|
||||
correct_answers = correct_answers + 1
|
||||
color_mismatches = color_mismatches + 1
|
||||
processed_tool_calls = processed_tool_calls + 1
|
||||
else:
|
||||
failed_examples.append({ "expected": expected_response, "actual": response })
|
||||
failure_flags.update({"bad_service_call": True})
|
||||
else:
|
||||
failed_examples.append({ "expected": expected_response, "actual": response })
|
||||
failure_flags.update({"bad_service_call": True})
|
||||
|
||||
if len(failure_flags) == 0:
|
||||
correct_answers = correct_answers + 1
|
||||
else:
|
||||
failed_examples.append({ "expected": expected_response, "actual": response, **failure_flags })
|
||||
|
||||
if color_mismatch:
|
||||
color_mismatches = color_mismatches + 1
|
||||
|
||||
pbar.update(batch_size)
|
||||
pbar.set_description(f"Accuracy: {correct_answers/total_answers*100:.2f}% ({correct_answers}/{total_answers})")
|
||||
@@ -350,7 +379,7 @@ def main():
|
||||
print(f"Evaluation already exists for {output_folder}. Skipping...")
|
||||
continue
|
||||
|
||||
trained_model, trained_tokenizer = load_model(args.model, args.lora, ckpt, False)
|
||||
trained_model, trained_tokenizer = load_model(args.model, args.lora, False, False, ckpt)
|
||||
evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_size, False)
|
||||
|
||||
|
||||
|
||||
@@ -11,8 +11,7 @@ fi
|
||||
|
||||
echo "Converting to GGUF..."
|
||||
if [ ! -f "./models/$MODEL_NAME/$MODEL_NAME.f16.gguf" ]; then
|
||||
$LLAMA_CPP/convert.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
|
||||
# $LLAMA_CPP/convert-hf-to-gguf.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
|
||||
$LLAMA_CPP/convert-hf-to-gguf.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
|
||||
else
|
||||
echo "Converted model for already exists. Skipping..."
|
||||
fi
|
||||
|
||||
14
scripts/upload_to_hf.sh
Normal file
14
scripts/upload_to_hf.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_NAME=$1
|
||||
|
||||
pushd models/
|
||||
huggingface-cli upload $MODEL_NAME \
|
||||
--repo-type model \
|
||||
--commit-message "Upload model" \
|
||||
--include "*.gguf" "README.md"
|
||||
|
||||
# huggingface-cli upload $MODEL_NAME \
|
||||
# --repo-type model \
|
||||
# --commit-message "Upload safetensors" \
|
||||
# --include "*.safetensors" "config.json" "special_tokens_map.json" "tokenizer_config.json" "tokenizer.json" "tokenizer.model" "generation_config.json"
|
||||
105
train.py
105
train.py
@@ -96,6 +96,20 @@ python3 train.py \
|
||||
--use_lora --lora_rank 32 --lora_alpha 64 --lora_modules up_proj,down_proj,q_proj,v_proj,o_proj
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name mistralhome-bielik-rev1 \
|
||||
--base_model speakleash/Bielik-7B-Instruct-v0.1 \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 1e-5 --learning_rate_warmup 0.03 --batch_size 64 --epochs 1 \
|
||||
--micro_batch_size 2 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 \
|
||||
--save_steps 100 --save_total_limit 20 --eval_steps 200 --logging_steps 5 \
|
||||
--use_lora --lora_rank 32 --lora_alpha 64 --lora_modules up_proj,down_proj,q_proj,v_proj,o_proj
|
||||
"""
|
||||
|
||||
"""
|
||||
accelerate launch --config_file fsdp_config.yaml train.py \
|
||||
--run_name stablehome-3b-rev10 \
|
||||
@@ -143,6 +157,42 @@ python3 train.py \
|
||||
--ctx_size 2048 --save_steps 100 --save_total_limit 10
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name tinyhome-qwen-rev3 \
|
||||
--base_model Qwen/Qwen2-0.5B-Instruct \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 2e-5 --batch_size 64 \
|
||||
--micro_batch_size 8 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 --save_steps 1000
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name home-phi3-mini-rev1 \
|
||||
--base_model microsoft/Phi-3-mini-4k-instruct \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 5e-6 --batch_size 32 \
|
||||
--micro_batch_size 8 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 --save_steps 100 --save_total_limit 10
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name tinyhome-polish-rev1 \
|
||||
--base_model eryk-mazus/polka-1.1b-chat \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 2e-5 --batch_size 32 \
|
||||
--micro_batch_size 8 --gradient_checkpointing --group_by_length \
|
||||
--ctx_size 2048 --save_steps 100 --save_total_limit 10
|
||||
"""
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name tinyhome-rev2-dpo \
|
||||
@@ -156,6 +206,11 @@ python3 train.py \
|
||||
--save_steps 50 --save_total_limit 10 --eval_steps 100 --logging_steps 2
|
||||
"""
|
||||
|
||||
MULTI_GPU_WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
MULTI_GPU_RANK = int(os.environ.get("RANK", "0"))
|
||||
IS_MULTI_GPU = os.environ.get("RANK") != None
|
||||
IS_MASTER_PROCESS = MULTI_GPU_RANK == 0
|
||||
|
||||
@dataclass
|
||||
class TrainingRunArguments:
|
||||
run_name: str = field(metadata={"help": "The folder to save the output model under"})
|
||||
@@ -283,7 +338,8 @@ training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strin
|
||||
if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1:
|
||||
raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq")
|
||||
|
||||
print(f"Loading model '{training_run_args.base_model}'...")
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Loading model '{training_run_args.base_model}'...")
|
||||
|
||||
model_kwargs = {}
|
||||
if training_run_args.load_in_8bit:
|
||||
@@ -310,7 +366,8 @@ def find_max_vram(min_buffer_mib=800):
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
suggestion = min(suggestion, total_mem - min_buffer_mib)
|
||||
|
||||
print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
|
||||
max_memory[i] = f'{suggestion}MiB'
|
||||
|
||||
return max_memory
|
||||
@@ -361,7 +418,8 @@ original_model = model
|
||||
peft_config = None
|
||||
if training_run_args.use_lora:
|
||||
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
||||
print("Creating LoRA for model...")
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Creating LoRA for model...")
|
||||
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
|
||||
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
|
||||
peft_config = LoraConfig(
|
||||
@@ -390,8 +448,8 @@ training_kwargs = {}
|
||||
|
||||
if training_run_args.test_dataset:
|
||||
training_kwargs.update({
|
||||
"per_device_eval_batch_size": training_run_args.batch_size,
|
||||
"evaluation_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
|
||||
"per_device_eval_batch_size": training_run_args.micro_batch_size,
|
||||
"eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
|
||||
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
|
||||
"bf16_full_eval": training_run_args.bf16,
|
||||
})
|
||||
@@ -514,8 +572,9 @@ class DataCollatorForSupervisedFineTuning(object):
|
||||
for label in labels:
|
||||
mask_ranges = self._find_mask_ranges(label)
|
||||
for start, end in mask_ranges:
|
||||
if end - start == len(label):
|
||||
if end - start == len(label) - 1:
|
||||
print("warning! example had no assistant response in it!")
|
||||
print(input_ids)
|
||||
label[start:end] = [-100] * (end - start)
|
||||
|
||||
input_ids = torch.LongTensor(self._pad(input_ids, self.tokenizer.pad_token_id or self.tokenizer.eos_token_id))
|
||||
@@ -527,7 +586,8 @@ class DataCollatorForSupervisedFineTuning(object):
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id or self.tokenizer.eos_token_id),
|
||||
)
|
||||
|
||||
print("Loading dataset...")
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Loading dataset...")
|
||||
data_files = { "train": training_run_args.train_dataset }
|
||||
if training_run_args.test_dataset:
|
||||
data_files["test"] = training_run_args.test_dataset
|
||||
@@ -648,7 +708,8 @@ class CustomSFTTrainer(Trainer):
|
||||
return result
|
||||
|
||||
if not training_run_args.dpo:
|
||||
print("Tokenizing datasets...")
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Tokenizing datasets...")
|
||||
|
||||
if "text" in datasets["train"].column_names:
|
||||
tokenize_function = tokenize_raw_example
|
||||
@@ -660,20 +721,36 @@ if not training_run_args.dpo:
|
||||
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)")
|
||||
|
||||
tokenized_test_dataset = None
|
||||
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=os.cpu_count()).remove_columns(columns_to_remove)
|
||||
num_proc = os.cpu_count() // MULTI_GPU_WORLD_SIZE
|
||||
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove)
|
||||
if training_run_args.test_dataset:
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=os.cpu_count()).remove_columns(columns_to_remove)
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove)
|
||||
|
||||
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
|
||||
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)
|
||||
print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens")
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens")
|
||||
|
||||
|
||||
prefix_ids = None
|
||||
suffix_ids = None
|
||||
|
||||
# data_collator = DataCollatorForSupervisedFineTuning(tokenizer=tokenizer)
|
||||
# fix for tinyllama not detecting split properly
|
||||
# prefix_ids = [29966, 29989, 465, 22137, 29989, 29958, 13]
|
||||
# suffix_ids = [2]
|
||||
|
||||
# fix for qwen2 not detecting split properly
|
||||
prefix_ids = [151644, 77091, 198]
|
||||
suffix_ids = [151645, 198]
|
||||
|
||||
# fix for polka-1.1 not detecting split properly
|
||||
# prefix_ids = [43883, 20255, 13]
|
||||
# suffix_ids = [43882, 29871, 13]
|
||||
|
||||
data_collator = DataCollatorForSupervisedFineTuning(
|
||||
tokenizer=tokenizer,
|
||||
prefix_ids=[29966, 29989, 465, 22137, 29989, 29958, 13],
|
||||
suffix_ids=[2],
|
||||
prefix_ids=prefix_ids,
|
||||
suffix_ids=suffix_ids,
|
||||
)
|
||||
|
||||
trainer = CustomSFTTrainer(
|
||||
|
||||
Reference in New Issue
Block a user