training fixes, default values + other fixes

This commit is contained in:
Alex O'Connell
2024-04-21 23:40:28 -04:00
parent 4ed5b08323
commit adae87addd
8 changed files with 40 additions and 11 deletions

View File

@@ -193,6 +193,13 @@ DEFAULT_OPTIONS = types.MappingProxyType(
)
OPTIONS_OVERRIDES = {
"home-3b-v4": {
CONF_PROMPT: DEFAULT_PROMPT_BASE,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX,
CONF_USE_GBNF_GRAMMAR: True,
},
"home-3b-v3": {
CONF_PROMPT: DEFAULT_PROMPT_BASE,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR,
@@ -207,6 +214,12 @@ OPTIONS_OVERRIDES = {
CONF_USE_GBNF_GRAMMAR: True,
},
"home-3b-v1": {
CONF_PROMPT: DEFAULT_PROMPT_BASE,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX,
},
"home-1b-v3": {
CONF_PROMPT: DEFAULT_PROMPT_BASE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX,
@@ -224,10 +237,14 @@ OPTIONS_OVERRIDES = {
"mistral": {
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_NO_SYSTEM_PROMPT_EXTRAS,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_MISTRAL,
CONF_MIN_P: 0.1,
CONF_TYPICAL_P: 0.9,
},
"mixtral": {
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_NO_SYSTEM_PROMPT_EXTRAS,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_MISTRAL,
CONF_MIN_P: 0.1,
CONF_TYPICAL_P: 0.9,
},
"llama-3": {
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_EXTRAS,

View File

@@ -92,7 +92,7 @@ def install_llama_cpp_python(config_dir: str):
instruction_extensions_suffix = ""
if platform_suffix == "amd64" or platform_suffix == "i386":
with open("/proc/cpuinfo") as f:
cpu_features = [ line for line in f.readlines() if line.startswith("Features")][0]
cpu_features = [ line for line in f.readlines() if line.startswith("Features") or line.startswith("flags")][0]
if "avx512f" in cpu_features and "avx512bw" in cpu_features:
instruction_extensions_suffix = "-avx512"
elif "avx" not in cpu_features:

View File

@@ -1111,6 +1111,11 @@ def main():
if not args.sample and not args.train and not args.test and not args.merge and not args.dpo:
parser.print_usage()
exit(-1)
if args.size and not args.train:
print("Train size was provided but not generating the training set!")
exit(-1)
if not args.format or args.format == "raw":
format_func = format_example_raw_chatml

View File

@@ -205,7 +205,6 @@ media_player,off,"Can you check the volume level of <device_name>?","<device_nam
media_player,idle,"I need the current volume of <device_name>, please.","The current volume of <device_name> is <volume>."
media_player,playing,"Is the volume too loud on <device_name>?","The volume on <device_name> is currently set to <volume>."
media_player,paused,"What was the volume set to before <device_name> was paused?","Before being paused, <device_name> had a volume level of <volume>."
media_player,playing,"Could you increase <device_name>'s volume a bit?","I've increased the volume on <device_name>. It's now at <volume>."
media_player,playing,"Is the volume muted on <device_name>?","The volume on <device_name> is not muted. It's at <volume>."
media_player,on,"What's playing on <device_name> right now?","<media> is currently playing on <device_name>."
media_player,off,"Is there anything playing on <device_name>?","Nothing is playing on <device_name> as it is currently off."
1 device_type state phrase assistant_response
205 media_player idle I need the current volume of <device_name>, please. The current volume of <device_name> is <volume>.
206 media_player playing Is the volume too loud on <device_name>? The volume on <device_name> is currently set to <volume>.
207 media_player paused What was the volume set to before <device_name> was paused? Before being paused, <device_name> had a volume level of <volume>.
media_player playing Could you increase <device_name>'s volume a bit? I've increased the volume on <device_name>. It's now at <volume>.
208 media_player playing Is the volume muted on <device_name>? The volume on <device_name> is not muted. It's at <volume>.
209 media_player on What's playing on <device_name> right now? <media> is currently playing on <device_name>.
210 media_player off Is there anything playing on <device_name>? Nothing is playing on <device_name> as it is currently off.

View File

@@ -36,4 +36,4 @@
- 800:
- 900:
- 1000:
- Final:
- Final: 0.9817813765182186

View File

@@ -79,6 +79,7 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
found_responses = service_call_regex.findall(response.strip())
if len(expected_service_calls) == 0:
total_answers = total_answers + 1
if len(found_responses) == 0:
correct_answers = correct_answers + 1
continue

View File

@@ -10,7 +10,8 @@ if [[ ! -d "./models/$MODEL_NAME" ]]; then
fi
echo "Converting to GGUF..."
$LLAMA_CPP/convert-hf-to-gguf.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
$LLAMA_CPP/convert.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
# $LLAMA_CPP/convert-hf-to-gguf.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
DESIRED_QUANTS=("Q8_0" "Q5_K_M" "Q4_K_M" "Q3_K_M" "Q2_K")
for QUANT in "${DESIRED_QUANTS[@]}"

View File

@@ -133,13 +133,13 @@ python3 train.py \
"""
python3 train.py \
--run_name tinyhome-rev2 \
--run_name tinyhome-rev3 \
--base_model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--bf16 \
--train_dataset data/home_assistant_train.jsonl \
--test_dataset data/home_assistant_test.jsonl \
--learning_rate 2e-5 --batch_size 32 \
--micro_batch_size 4 --gradient_checkpointing --group_by_length \
--micro_batch_size 8 --gradient_checkpointing --group_by_length \
--ctx_size 2048 --save_steps 100 --save_total_limit 10
"""
@@ -166,7 +166,6 @@ class TrainingRunArguments:
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp16"})
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
micro_batch_size: int = field(default=2, metadata={"help": "The actual batch size that will fit into VRAM on this machine"})
eval_batch_size: int = field(default=1, metadata={"help": "The batch size for generation used while performing evaluation"})
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
@@ -391,7 +390,7 @@ training_kwargs = {}
if training_run_args.test_dataset:
training_kwargs.update({
"per_device_eval_batch_size": training_run_args.eval_batch_size,
"per_device_eval_batch_size": training_run_args.batch_size,
"evaluation_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
"bf16_full_eval": training_run_args.bf16,
@@ -431,15 +430,22 @@ class DataCollatorForSupervisedFineTuning(object):
prefix_ids: list[int]
suffix_ids: list[int]
def __init__(self, *, tokenizer: AutoTokenizer):
def __init__(self, *, tokenizer: AutoTokenizer, prefix_ids = None, suffix_ids = None):
self.tokenizer = tokenizer
assistant_prompt = tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")
self.response_prefix = assistant_prompt[0]
self.response_suffix = assistant_prompt[1]
self.prefix_ids = self.tokenizer(self.response_prefix, add_special_tokens=False)["input_ids"]
self.suffix_ids = self.tokenizer(self.response_suffix, add_special_tokens=False)["input_ids"]
if prefix_ids:
self.prefix_ids = prefix_ids
else:
self.prefix_ids = self.tokenizer(self.response_prefix, add_special_tokens=False)["input_ids"]
if suffix_ids:
self.suffix_ids = suffix_ids
else:
self.suffix_ids = self.tokenizer(self.response_suffix, add_special_tokens=False)["input_ids"]
def _find_mask_ranges(self, input_ids):
"""