Merge branch 'feature/proper-functioncalling-args' into develop

This commit is contained in:
Alex O'Connell
2024-01-27 15:22:07 -05:00
11 changed files with 550 additions and 183 deletions

2
.gitignore vendored
View File

@@ -7,3 +7,5 @@ data/*.json
*.pyc
main.log
.venv
*.xlsx
notes.txt

View File

@@ -6,12 +6,13 @@ The "Home" models are a fine tuning of the Phi model series from Microsoft. The
The latest models can be found on HuggingFace:
3B v2 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v2-GGUF
1B v1 (Based on Phi-1.5): https://huggingface.co/acon96/Home-1B-v1-GGUF
1B v2 (Based on Phi-1.5): https://huggingface.co/acon96/Home-1B-v2-GGUF
Make sure you have `llama-cpp-python>=0.2.29` in order to run these models.
Old Models:
3B v1 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v1-GGUF
3B v1 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v1-GGUF
1B v1 (Based on Phi-1.5): https://huggingface.co/acon96/Home-1B-v1-GGUF
The main difference between the 2 models (besides parameter count) is the training data. The 1B model is ONLY trained on the synthetic dataset provided in this project, while the 3B model is trained on a mixture of this synthetic dataset, and the cleaned Stanford Alpaca dataset.
@@ -23,11 +24,12 @@ Example "system" prompt:
```
<|im_start|>system
You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.
Services: light.turn_off, light.turn_on, fan.turn_on, fan.turn_off
Services: light.turn_off(), light.turn_on(brightness,rgb_color), fan.turn_on(), fan.turn_off()
Devices:
light.office 'Office Light' = on
light.office 'Office Light' = on;80%
fan.office 'Office fan' = off
light.kitchen 'Kitchen Light' = on<|im_end|>
light.kitchen 'Kitchen Light' = on;80%;red
light.bedroom 'Bedroom Light' = off<|im_end|>
```
For more about how the model is prompted see [./docs/Model Prompting.md]
@@ -69,11 +71,11 @@ python3 train.py \
--add_chatml_tokens \
--bf16 \
--train_dataset data/home_assistant_alpaca_merged_train.json \
--test_dataset data/home_assistant_alpaca_merged_test.json \
--learning_rate 1e-5 \
--save_steps 1000 \
--micro_batch_size 2 --gradient_checkpointing \
--ctx_size 2048 \
--group_by_length \
--use_lora --lora_rank 32 --lora_alpha 64 --lora_modules fc1,fc2,q_proj,v_proj,dense --lora_modules_to_save embed_tokens,lm_head --lora_merge
```
@@ -87,7 +89,6 @@ python3 train.py \
--add_chatml_tokens \
--bf16 \
--train_dataset data/home_assistant_train.json \
--test_dataset data/home_assistant_test.json \
--learning_rate 1e-5 \
--micro_batch_size 4 --gradient_checkpointing \
--ctx_size 2048

View File

@@ -11,6 +11,9 @@ import os
import json
import webcolors
import voluptuous as vol
from collections.abc import Iterable
import homeassistant.components.conversation as ha_conversation
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
@@ -169,6 +172,24 @@ def closest_color(requested_color):
min_colors[(rd + gd + bd)] = name
return min_colors[min(min_colors.keys())]
def flatten_schema(schema):
flattened = []
def _flatten(current_schema, prefix=''):
if isinstance(current_schema, vol.Schema):
if isinstance(current_schema.schema, vol.validators._WithSubValidators):
for subval in current_schema.schema.validators:
_flatten(subval, prefix)
else:
for key, val in current_schema.schema.items():
_flatten(val, prefix + str(key) + '/')
elif isinstance(current_schema, vol.validators._WithSubValidators):
for subval in current_schema.validators:
_flatten(subval, prefix)
elif callable(current_schema):
flattened.append(prefix[:-1] if prefix else prefix)
_flatten(schema)
return flattened
class LLaMAAgent(AbstractConversationAgent):
"""Local LLaMA conversation agent."""
@@ -216,6 +237,8 @@ class LLaMAAgent(AbstractConversationAgent):
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
refresh_system_prompt = self.entry.options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX)
extra_attributes_to_expose = self.entry.options \
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
try:
service_call_pattern = re.compile(service_call_regex)
@@ -298,24 +321,37 @@ class LLaMAAgent(AbstractConversationAgent):
service = json_output["service"]
entity = json_output["target_device"]
domain, service = tuple(service.split("."))
extra_arguments = { k: v for k, v in json_output.items() if k not in [ "service", "target_device" ] }
except Exception:
try:
service = line.split("(")[0]
entity = line.split("(")[1][:-1]
domain, service = tuple(service.split("."))
extra_arguments = {}
except Exception:
to_say += f" Failed to parse call from '{line}'!"
continue
# fix certain arguments
# make sure brightness is 0-255 and not a percentage
if "brightness" in extra_arguments and 0.0 < extra_arguments["brightness"] < 1.0:
extra_arguments["brightness"] = int(extra_arguments["brightness"] * 255)
# only acknowledge requests to exposed entities
if entity not in exposed_entities:
to_say += f" Can't find device '{entity}'!"
else:
# copy arguments to service call
service_data = {ATTR_ENTITY_ID: entity}
for attr in extra_attributes_to_expose:
if attr in extra_arguments.keys():
service_data[attr] = extra_arguments[attr]
try:
await self.hass.services.async_call(
domain,
service,
service_data={ATTR_ENTITY_ID: entity},
service_data=service_data,
blocking=True,
)
except Exception as err:
@@ -386,7 +422,7 @@ class LLaMAAgent(AbstractConversationAgent):
value = attributes[attribute_name]
if value is not None:
if attribute_name == "current_temperature":
if attribute_name == "temperature":
value = int(value)
if value > 50:
value = f"{value}F"
@@ -396,6 +432,10 @@ class LLaMAAgent(AbstractConversationAgent):
value = F"{closest_color(value)} {value}"
elif attribute_name == "volume_level":
value = f"vol={int(value*100)}"
elif attribute_name == "brightness":
value = f"{int(value/255*100)}%"
elif attribute_name == "humidity":
value = f"{value}%"
result = result + ";" + str(value)
return result
@@ -407,10 +447,10 @@ class LLaMAAgent(AbstractConversationAgent):
service_dict = self.hass.services.async_services()
all_services = []
for domain in domains:
# all_services.extend(service_dict.get(domain, {}).keys())
all_services.extend(
[f"{domain}.{name}" for name in service_dict.get(domain, {}).keys()]
)
for name, service in service_dict.get(domain, {}).items():
args = flatten_schema(service.schema)
args_to_expose = set(args).intersection(extra_attributes_to_expose)
all_services.append(f"{domain}.{name}({','.join(args_to_expose)})")
formatted_services = ", ".join(all_services)
return template.Template(prompt_template, self.hass).async_render(

View File

@@ -136,7 +136,7 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
}
)
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset=None, webui_chat_mode=None):
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset="", webui_chat_mode=""):
extra1, extra2 = ({}, {})
default_port = DEFAULT_PORT
@@ -144,7 +144,7 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ch
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_PRESET, default=webui_preset)] = str
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_CHAT_MODE, default=webui_chat_mode)] = SelectSelector(SelectSelectorConfig(
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
options=["", TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
@@ -485,7 +485,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
) -> FlowResult:
errors = {}
backend_type = self.model_config[CONF_BACKEND_TYPE]
schema = STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI)
schema = STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type)
if user_input:
try:
@@ -499,7 +499,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
if error_reason:
errors["base"] = error_reason
schema = STEP_REMOTE_SETUP_DATA_SCHEMA(
True,
backend_type,
host=user_input[CONF_HOST],
port=user_input[CONF_PORT],
chat_model=user_input[CONF_CHAT_MODEL],

View File

@@ -28,14 +28,14 @@ BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER = "llama_cpp_python_server"
BACKEND_TYPE_OLLAMA = "ollama"
DEFAULT_BACKEND_TYPE = BACKEND_TYPE_LLAMA_HF
CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization"
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = ["Q8_0", "Q5_K_M", "Q4_K_M", "Q3_K_M"]
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = ["F16", "Q8_0", "Q5_K_M", "Q4_K_M", "Q3_K_M"]
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION = "Q5_K_M"
CONF_DOWNLOADED_MODEL_FILE = "downloaded_model_file"
DEFAULT_DOWNLOADED_MODEL_FILE = ""
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = "5000"
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "current_temperature", "fan_mode", "media_title", "volume_level" ]
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level"]
GBNF_GRAMMAR_FILE = "output.gbnf"
CONF_PROMPT_TEMPLATE = "prompt_template"
PROMPT_TEMPLATE_CHATML = "chatml"

View File

@@ -45,9 +45,16 @@ def closest_color(requested_color):
class DeviceType:
name: str
possible_states: list[(str, float)]
services: list[str]
services: dict[str, list]
def get_random_state(self, **kwargs):
def get_all_services(self, extra_exposed_attributes):
result = []
for service in self.services.keys():
args = set(extra_exposed_attributes).intersection(self.services[service])
result.append(f"{self.name}.{service}({','.join(args)})")
return result
def get_random_state(self, extra_exposed_attributes=[]):
states = [ x[0] for x in self.possible_states ]
weights = [ x[1] for x in self.possible_states ]
return random.choices(states, weights=weights, k=1)[0]
@@ -59,48 +66,62 @@ class LightDeviceType(DeviceType):
(STATE_ON, 0.5),
(STATE_OFF, 0.5)
],
services=[
"turn_on",
"turn_off",
"toggle"
],
services={
"turn_on": [ "rgb_color", "brightness" ],
"turn_off": [],
"toggle": []
},
)
def get_random_state(self, extra_exposed_attributes=[]):
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
def get_random_state(self, force_rgb=False, force_brightness=False, **kwargs):
state = super().get_random_state()
if random.random() < 0.05 or force_rgb:
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
random_rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb)
if random.random() < 0.3 or force_brightness:
if random.random() < 0.7 and "brightness" in extra_exposed_attributes:
state = state + ";" + str(random.randint(0, 100)) + "%"
return state
class ClimateDeviceType(DeviceType):
def __init__(self):
super().__init__("climate", [], [
"turn_on",
"turn_off",
"toggle",
"set_temperature",
"set_humidity",
"set_fan_mode",
"set_hvac_mode",
"set_preset_mode"
])
super().__init__("climate", [], {
"turn_on": [],
"turn_off": [],
"toggle": [],
"set_temperature": ["temperature"],
"set_humidity": ["humidity"],
"set_fan_mode": ["fan_mode"],
"set_hvac_mode": ["hvac_mode"],
"set_preset_mode": ["preset_mode"]
})
def get_random_state(self, **kwargs):
hvac = random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"])
fan = random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"])
if random.random() > 0.5:
temp = str(random.randint(60, 80)) + "F"
else:
temp = str(random.randint(15, 25)) + "C"
return f"{hvac};{fan};{temp}"
def get_random_state(self, extra_exposed_attributes=[]):
"""state;fan_mode;temperature;humidity"""
state = random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"])
if "fan_mode" in extra_exposed_attributes:
state = state + ";" + random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"])
if "temperature" in extra_exposed_attributes:
if random.random() > 0.5:
state = state + ";" + str(random.randint(60, 80)) + "F"
else:
state = state + ";" + str(random.randint(15, 25)) + "C"
if "humidity" in extra_exposed_attributes:
state = state + ";" + str(random.randint(10, 90)) + "%"
if "preset_mode" in extra_exposed_attributes:
# if it is not "on a preset" then don't add the mode
random_mode = random.choice(["home", "eco", "away", "auto", None, None, None])
if random_mode:
state = state + ";" + random_mode
return state
with open("piles/pile_of_media_names.csv") as f:
pile_of_media_names = [ x.strip() for x in f.readlines() ]
class MediaPlayerDeviceType(DeviceType):
def __init__(self):
@@ -112,31 +133,29 @@ class MediaPlayerDeviceType(DeviceType):
(STATE_PAUSED, 0.05),
(STATE_STANDBY, 0.05),
(STATE_BUFFERING, 0.01),
], [
"turn_on",
"turn_off",
"toggle",
"volume_up",
"volume_down",
"volume_mute",
"media_play_pause",
"media_play",
"media_pause",
"media_stop",
"media_next_track",
"media_previous_track"
])
], {
"turn_on": [],
"turn_off": [],
"toggle": [],
"volume_up": [],
"volume_down": [],
"volume_mute": [],
"media_play_pause": [],
"media_play": [],
"media_pause": [],
"media_stop": [],
"media_next_track": [],
"media_previous_track": []
})
with open("piles/pile_of_media_names.csv") as f:
self.media_names = [ x.strip() for x in f.readlines() ]
def get_random_state(self, extra_exposed_attributes=[]):
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
def get_random_state(self, **kwargs):
state = super().get_random_state()
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
state = state + ";" + random.choice(pile_of_media_names)
if state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
state = state + ";" + random.choice(self.media_names)
if state != STATE_OFF:
if "volume_level" in extra_exposed_attributes and state != STATE_OFF:
state = state + ";vol=" + str(round(random.random(), 2))
return state
@@ -148,11 +167,11 @@ SUPPORTED_DEVICES = {
(STATE_ON, 0.5),
(STATE_OFF, 0.5)
],
services=[
"turn_on",
"turn_off",
"toggle"
],
services={
"turn_on": [],
"turn_off": [],
"toggle": []
},
),
"fan": DeviceType(
name="fan",
@@ -160,13 +179,13 @@ SUPPORTED_DEVICES = {
(STATE_ON, 0.5),
(STATE_OFF, 0.5)
],
services=[
"turn_on",
"turn_off",
"toggle",
"increase_speed",
"decrease_speed",
],
services={
"turn_on": [],
"turn_off": [],
"toggle": [],
"increase_speed": [],
"decrease_speed": [],
},
),
"garage_door": DeviceType(
name="garage_door",
@@ -176,12 +195,12 @@ SUPPORTED_DEVICES = {
(STATE_OPENING, 0.01),
(STATE_CLOSING, 0.01)
],
services=[
"open_cover",
"close_cover",
"stop_cover",
"toggle",
],
services={
"open_cover": [],
"close_cover": [],
"stop_cover": [],
"toggle": [],
},
),
"blinds": DeviceType(
name="blinds",
@@ -191,12 +210,12 @@ SUPPORTED_DEVICES = {
(STATE_OPENING, 0.01),
(STATE_CLOSING, 0.01)
],
services=[
"open_cover",
"close_cover",
"stop_cover",
"toggle",
],
services={
"open_cover": [],
"close_cover": [],
"stop_cover": [],
"toggle": [],
},
),
"lock": DeviceType(
name="lock",
@@ -204,10 +223,10 @@ SUPPORTED_DEVICES = {
(STATE_LOCKED, 0.5),
(STATE_UNLOCKED, 0.5),
],
services=[
"lock",
"unlock",
],
services={
"lock": [],
"unlock": [],
},
),
"media_player": MediaPlayerDeviceType(),
"climate": ClimateDeviceType()
@@ -284,6 +303,9 @@ def random_device_list(max_devices: int, avoid_device_names: list[str]):
device_types = set()
device_list = []
device_lines = []
# TODO: randomly pick attributes for this list
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level"]
while len(device_list) < num_devices:
choice = random.choice(possible_choices)
if choice["device_name"] in device_list:
@@ -298,7 +320,7 @@ def random_device_list(max_devices: int, avoid_device_names: list[str]):
if avoid_climate and device_type == "climate":
continue
state = SUPPORTED_DEVICES[device_type].get_random_state()
state = SUPPORTED_DEVICES[device_type].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
device_lines.append(format_device_line(
device_name=device_name,
friendly_name=friendly_name,
@@ -310,7 +332,7 @@ def random_device_list(max_devices: int, avoid_device_names: list[str]):
print(f"bad device name: {choice}")
print(repr(ex))
return device_lines, list(device_types)
return device_lines, list(device_types), list(extra_exposed_attributes)
def generate_static_example(action: dict, max_devices: int = 32):
question = action["english_phrase"]
@@ -319,11 +341,12 @@ def generate_static_example(action: dict, max_devices: int = 32):
service_name = f"{device_type}.{action['service_name']}"
friendly_name = target_device.split(".")[1].replace("_", " ")
device_list, device_types = random_device_list(max_devices=max_devices, avoid_device_names=[target_device])
device_list, device_types, extra_exposed_attributes = random_device_list(
max_devices=max_devices, avoid_device_names=[target_device])
# insert our target device somewhere random in the list
index = random.randint(0, len(device_list))
state = SUPPORTED_DEVICES[device_type].get_random_state()
state = SUPPORTED_DEVICES[device_type].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
device_list.insert(index, format_device_line(
device_name=target_device,
@@ -334,7 +357,7 @@ def generate_static_example(action: dict, max_devices: int = 32):
# gather a list of all available services
available_services = []
for x in set(device_types + [device_type]):
available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ])
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
return {
"states": device_list,
@@ -357,17 +380,23 @@ def generate_templated_example(template: dict, max_devices: int = 32):
device_dict["type"] = device_type
chosen_devices.append(device_dict)
device_list, device_types = random_device_list(max_devices=max_devices, avoid_device_names=[d["device_name"] for d in chosen_devices])
device_list, device_types, extra_exposed_attributes = random_device_list(
max_devices=max_devices, avoid_device_names=[d["device_name"] for d in chosen_devices])
# insert our target device somewhere random in the list
for device_dict in chosen_devices:
index = random.randint(0, len(device_list))
state_kwargs = {}
if "<brightness>" in question_template:
state_kwargs["force_brightness"] = True
if "<color>" in question_template:
state_kwargs["force_rgb"] = True
state = SUPPORTED_DEVICES[device_dict["type"]].get_random_state(**state_kwargs)
if "<brightness>" in question_template and "brightness" not in extra_exposed_attributes:
extra_exposed_attributes.append("brightness")
if "<color>" in question_template and "rgb_color" not in extra_exposed_attributes:
extra_exposed_attributes.append("rgb_color")
if ("<temp_f>" in question_template or "<temp_c>" in question_template) \
and "temperature" not in extra_exposed_attributes:
extra_exposed_attributes.append("temperature")
if "<humidity>" in question_template and "humidity" not in extra_exposed_attributes:
extra_exposed_attributes.append("humidity")
state = SUPPORTED_DEVICES[device_dict["type"]].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
device_name = device_dict["device_name"]
friendly_name = device_dict["description"]
@@ -377,10 +406,10 @@ def generate_templated_example(template: dict, max_devices: int = 32):
state=state
))
# gather a list of all available services
# gather a list of all available services with arguments
available_services = []
for x in set(device_types + template_device_types):
available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ])
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
# generate the question
if len(template_device_types) == 1:
@@ -422,7 +451,7 @@ def generate_templated_example(template: dict, max_devices: int = 32):
brightness = random.randint(0, 100)
question = question.replace("<brightness>", str(brightness))
answer = answer.replace("<brightness>", str(brightness))
service_calls = [ { **call, "brightness_pct": brightness} for call in service_calls ]
service_calls = [ { **call, "brightness": round(brightness / 100, 2) } for call in service_calls ]
if "<color>" in question:
random_rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
@@ -452,7 +481,7 @@ def generate_status_request(template: dict, max_devices: int = 32):
chosen_device = random.choice(stacks_of_device_names[device_type])
# build a random list of devices
device_list, device_types = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
# insert our target device somewhere random in the list
index = random.randint(0, len(device_list))
@@ -489,16 +518,20 @@ def generate_status_request(template: dict, max_devices: int = 32):
if device_type == "media_player":
volume = random.randint(0, 100)
random_media = random.choice(pile_of_media_names)
answer = answer.replace("<volume>", str(volume) + "%")
state_name = state_name.replace("<volume>", str(volume) + "%")
answer = answer.replace("<media>", random_media)
state_name = state_name.replace("<media>", random_media)
device_list.insert(index, f"{chosen_device['device_name']} = {state_name}")
# gather a list of all available services
available_services = []
for x in set(device_types + [device_type]):
available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ])
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
return {
"states": device_list,
@@ -512,8 +545,6 @@ def format_example(example):
sys_prompt = "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed or answer the following question with the information provided only."
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
states_block = "Devices:\n" + "\n".join(example["states"])
# question = "Request:\n" + example["question"]
# answers = "Response:\n" + " ".join(example["answers"])
question = example["question"]
answers = " ".join(example["answers"])
@@ -569,16 +600,17 @@ def generate_example_file(filename: str, seed: int, *, static_factor: int, templ
def format_alpaca(example):
question = example["instruction"]
if example["input"]:
if "input" in example and example["input"]:
question = question = "\n" + example["input"]
answer = example["output"]
device_list, device_types = random_device_list(max_devices=32, avoid_device_names=[])
device_list, device_types, extra_exposed_attributes = random_device_list(
max_devices=32, avoid_device_names=[])
available_services = []
for x in device_types:
available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ])
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
text = format_example(example={
"states": device_list,
@@ -594,13 +626,13 @@ def format_alpaca(example):
return result
def merge_with_dataset(dataset_name, seed, outupt_name, format_function):
def merge_with_dataset(dataset_name, seed, outupt_name, format_function, dataset_column_names):
alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1)
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.json", "test": "home_assistant_test.json" })
random.seed(seed)
alpaca_dataset = alpaca_dataset.map(format_function).remove_columns(["input", "output", "instruction"])
alpaca_dataset = alpaca_dataset.map(format_function).remove_columns(dataset_column_names)
combined_dataset_train = concatenate_datasets([home_assistant_dataset["train"], alpaca_dataset["train"]]).shuffle(seed=42)
combined_dataset_test = concatenate_datasets([home_assistant_dataset["test"], alpaca_dataset["test"]]).shuffle(seed=42)
@@ -618,20 +650,36 @@ def main():
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..")
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--merge-alpaca", action="store_true", help="Set this flag to merge the generated datasets with the alpaca-cleaned dataset.")
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
train_size_group = parser.add_mutually_exclusive_group()
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
train_size_group.add_argument('--large', action='store_const', const='large', dest='size')
train_size_group.add_argument('--xl', action='store_const', const='xl', dest='size')
args = parser.parse_args()
if args.sample:
generate_example_file("sample", 42, static_factor=1, template_factor=1, status_request_factor=1)
if args.train:
# TODO: add small, medium, large cli clags
generate_example_file("home_assistant_train", 42, static_factor=1, template_factor=10, status_request_factor=8)
# generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=15, status_request_factor=12)
# generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=20, status_request_factor=15)
if args.size == "small":
generate_example_file("home_assistant_train", 42, static_factor=1, template_factor=10, status_request_factor=8)
elif args.size == "medium":
generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=15, status_request_factor=12)
elif args.size == "large":
generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=20, status_request_factor=15)
elif args.size == "xl":
generate_example_file("home_assistant_train", 42, static_factor=7, template_factor=25, status_request_factor=18)
else:
raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test:
generate_example_file("home_assistant_test", 12345, static_factor=0.25, template_factor=3, status_request_factor=2)
if args.merge_alpaca:
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca)
if args.merge == "alpaca":
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"])
elif args.merge == "wizardlm70k":
merge_with_dataset("WizardLM/WizardLM_evol_instruct_70k", 42, "wizardlm70k", format_alpaca, ["output", "instruction"])
if __name__ == "__main__":
main()

View File

@@ -200,24 +200,24 @@ fan,on,How's the <device_name>,The <device_name> is on
fan,off,Status of <device_name>,The <device_name> is off
light,on,<device_name> condition,The <device_name> is on
light,off,What's up with <device_name>,The <device_name> is off
media_player,on;<volume>,"What's the volume on <device_name>?","The volume on <device_name> is at <volume_level>."
media_player,on;<volume>,"What's the volume on <device_name>?","The volume on <device_name> is at <volume>."
media_player,off,"Can you check the volume level of <device_name>?","<device_name> is off, so I can't check the volume level currently."
media_player,idle,"I need the current volume of <device_name>, please.","The current volume of <device_name> is <volume_level>."
media_player,playing,"Is the volume too loud on <device_name>?","The volume on <device_name> is currently set to <volume_level>."
media_player,paused,"What was the volume set to before <device_name> was paused?","Before being paused, <device_name> had a volume level of <volume_level>."
media_player,playing,"Could you increase <device_name>'s volume a bit?","I've increased the volume on <device_name>. It's now at <volume_level>."
media_player,playing,"Is the volume muted on <device_name>?","The volume on <device_name> is not muted. It's at <volume_level>."
media_player,on,"What's playing on <device_name> right now?","<media_name> is currently playing on <device_name>."
media_player,idle,"I need the current volume of <device_name>, please.","The current volume of <device_name> is <volume>."
media_player,playing,"Is the volume too loud on <device_name>?","The volume on <device_name> is currently set to <volume>."
media_player,paused,"What was the volume set to before <device_name> was paused?","Before being paused, <device_name> had a volume level of <volume>."
media_player,playing,"Could you increase <device_name>'s volume a bit?","I've increased the volume on <device_name>. It's now at <volume>."
media_player,playing,"Is the volume muted on <device_name>?","The volume on <device_name> is not muted. It's at <volume>."
media_player,on,"What's playing on <device_name> right now?","<media> is currently playing on <device_name>."
media_player,off,"Is there anything playing on <device_name>?","Nothing is playing on <device_name> as it is currently off."
media_player,idle,"Can you tell me what was last played on <device_name>?","The last media played on <device_name> was <media_name>."
media_player,playing,"What is playing on <device_name> at the moment?","The song <media_name> is playing on <device_name> right now."
media_player,paused,"What media was paused on <device_name>?","<media_name> was paused on <device_name>."
media_player,idle,"Can you tell me what was last played on <device_name>?","The last media played on <device_name> was <media>."
media_player,playing,"What is playing on <device_name> at the moment?","The song <media> is playing on <device_name> right now."
media_player,paused,"What media was paused on <device_name>?","<media> was paused on <device_name>."
media_player,on,"I hear something from <device_name>, what is it?","<device_name> is not currently playing anything."
media_player,playing,"I like this track on <device_name>, what is it?","You're listening to <media_name> on <device_name>."
media_player,paused,"Can you name the movie that's on <device_name>?","The movie currently paused on <device_name> is <media_name>."
media_player,standby,"What was the last show watched on <device_name>?","The last show watched on <device_name> was <media_name>."
media_player,buffering,"Which episode is on <device_name> right now?","The episode on <device_name> is <media_name>."
light,on;<color>;<brightness>,"What's the brightness of <device_name>?","The brightness of <device_name> is currently at <brightness>%."
media_player,playing,"I like this track on <device_name>, what is it?","You're listening to <media> on <device_name>."
media_player,paused,"Can you name the movie that's on <device_name>?","The movie currently paused on <device_name> is <media>."
media_player,standby,"What was the last show watched on <device_name>?","The last show watched on <device_name> was <media>."
media_player,buffering,"Which episode is on <device_name> right now?","The episode on <device_name> is <media>."
light,on;<color>;<brightness>,"What's the brightness of <device_name>?","The brightness of <device_name> is currently at <brightness>."
light,on;<color>;<brightness>,"Can you tell me the color of <device_name>?","<device_name> is currently showing <color>."
light,on;<color>;<brightness>,"Is <device_name> set to a bright or dim setting?","<device_name> is set to a <brightness> brightness level."
light,on;<color>;<brightness>,"What hue is <device_name> displaying?","<device_name> is displaying a <color> hue."
@@ -225,7 +225,7 @@ light,on;<color>;<brightness>,"How bright is <device_name> right now?","<device_
light,on;<color>;<brightness>,"I'd like to know the current color of <device_name>.","The current color of <device_name> is <color>."
light,on;<color>;<brightness>,"Is <device_name> more on the dim side or the bright side?","<device_name> is more on the <brightness> side."
light,on;<color>;<brightness>,"What color setting is <device_name> on?","<device_name> is set to <color> color setting."
light,on;<color>;<brightness>,"Could you check the brightness level of <device_name> for me?","Sure, the brightness level of <device_name> is <brightness>%."
light,on;<color>;<brightness>,"Could you check the brightness level of <device_name> for me?","Sure, the brightness level of <device_name> is <brightness>."
light,on;<color>;<brightness>,"Tell me, what's the current hue of <device_name>?","The current hue of <device_name> is <color>."
light,off,"What's the brightness of <device_name>?","The device <device_name> is turned off."
light,off,"Can you tell me the color of <device_name>?","<device_name> is turned off."
1 device_type state english_phrase assistant_response
200 fan off Status of <device_name> The <device_name> is off
201 light on <device_name> condition The <device_name> is on
202 light off What's up with <device_name> The <device_name> is off
203 media_player on;<volume> What's the volume on <device_name>? The volume on <device_name> is at <volume_level>. The volume on <device_name> is at <volume>.
204 media_player off Can you check the volume level of <device_name>? <device_name> is off, so I can't check the volume level currently.
205 media_player idle I need the current volume of <device_name>, please. The current volume of <device_name> is <volume_level>. The current volume of <device_name> is <volume>.
206 media_player playing Is the volume too loud on <device_name>? The volume on <device_name> is currently set to <volume_level>. The volume on <device_name> is currently set to <volume>.
207 media_player paused What was the volume set to before <device_name> was paused? Before being paused, <device_name> had a volume level of <volume_level>. Before being paused, <device_name> had a volume level of <volume>.
208 media_player playing Could you increase <device_name>'s volume a bit? I've increased the volume on <device_name>. It's now at <volume_level>. I've increased the volume on <device_name>. It's now at <volume>.
209 media_player playing Is the volume muted on <device_name>? The volume on <device_name> is not muted. It's at <volume_level>. The volume on <device_name> is not muted. It's at <volume>.
210 media_player on What's playing on <device_name> right now? <media_name> is currently playing on <device_name>. <media> is currently playing on <device_name>.
211 media_player off Is there anything playing on <device_name>? Nothing is playing on <device_name> as it is currently off.
212 media_player idle Can you tell me what was last played on <device_name>? The last media played on <device_name> was <media_name>. The last media played on <device_name> was <media>.
213 media_player playing What is playing on <device_name> at the moment? The song <media_name> is playing on <device_name> right now. The song <media> is playing on <device_name> right now.
214 media_player paused What media was paused on <device_name>? <media_name> was paused on <device_name>. <media> was paused on <device_name>.
215 media_player on I hear something from <device_name>, what is it? <device_name> is not currently playing anything.
216 media_player playing I like this track on <device_name>, what is it? You're listening to <media_name> on <device_name>. You're listening to <media> on <device_name>.
217 media_player paused Can you name the movie that's on <device_name>? The movie currently paused on <device_name> is <media_name>. The movie currently paused on <device_name> is <media>.
218 media_player standby What was the last show watched on <device_name>? The last show watched on <device_name> was <media_name>. The last show watched on <device_name> was <media>.
219 media_player buffering Which episode is on <device_name> right now? The episode on <device_name> is <media_name>. The episode on <device_name> is <media>.
220 light on;<color>;<brightness> What's the brightness of <device_name>? The brightness of <device_name> is currently at <brightness>%. The brightness of <device_name> is currently at <brightness>.
221 light on;<color>;<brightness> Can you tell me the color of <device_name>? <device_name> is currently showing <color>.
222 light on;<color>;<brightness> Is <device_name> set to a bright or dim setting? <device_name> is set to a <brightness> brightness level.
223 light on;<color>;<brightness> What hue is <device_name> displaying? <device_name> is displaying a <color> hue.
225 light on;<color>;<brightness> I'd like to know the current color of <device_name>. The current color of <device_name> is <color>.
226 light on;<color>;<brightness> Is <device_name> more on the dim side or the bright side? <device_name> is more on the <brightness> side.
227 light on;<color>;<brightness> What color setting is <device_name> on? <device_name> is set to <color> color setting.
228 light on;<color>;<brightness> Could you check the brightness level of <device_name> for me? Sure, the brightness level of <device_name> is <brightness>%. Sure, the brightness level of <device_name> is <brightness>.
229 light on;<color>;<brightness> Tell me, what's the current hue of <device_name>? The current hue of <device_name> is <color>.
230 light off What's the brightness of <device_name>? The device <device_name> is turned off.
231 light off Can you tell me the color of <device_name>? <device_name> is turned off.

View File

@@ -193,21 +193,31 @@ climate,set_hvac_mode,"Switch the system to cooling mode.","Switching to cooling
climate,set_hvac_mode,"Can we set the HVAC to heat?","Setting the HVAC to heat.",8
climate,set_hvac_mode,"Change the HVAC to automatic.","Changing HVAC to automatic mode.",8
light,turn_on,"Set the brightness of <device_name> to <brightness>%.","Setting the brightness of <device_name> to <brightness>%.",8
light,turn_on,"Can you make <device_name> <color>?","Turning <device_name> <color>.",8
light,turn_on,"Change the color of <device_name> to <color>.","Changing the color of <device_name> to <color>.",8
light,turn_on,"Dim <device_name> to <brightness> percent brightness.","Dimming <device_name> to <brightness>% brightness.",8
light,turn_on,"Please set <device_name> to a <color> color.","Setting <device_name> to a <color> color.",8
light,turn_on,"Brighten <device_name> to <brightness>.","Brightening <device_name> to <brightness>%.",8
light,turn_on,"Turn <device_name> <color>.","Turning <device_name> <color>.",8
light,turn_on,"I want <device_name> at a <color> setting.","Setting <device_name> to a <color>.",8
light,turn_on,"Set <device_name> to a <color> color.","Setting <device_name> to a <color> color.",8
light,turn_on,"Adjust <device_name> brightness to <brightness>.","Adjusting <device_name> brightness to <brightness>%.",8
light,turn_on,"Increase <device_name>'s brightness to <brightness>.","Increasing <device_name>'s brightness to <brightness>%.",8
light,turn_on,"Lower the brightness of <device_name> to <brightness>.","Lowering the brightness of <device_name> to <brightness>%.",8
light,turn_on,"Adjust <device_name> brightness to <brightness>.","Adjusting <device_name> brightness to <brightness>.",8
light,turn_on,"Increase <device_name>'s brightness to <brightness>.","Increasing <device_name>'s brightness to <brightness>.",8
light,turn_on,"Lower the brightness of <device_name> to <brightness>.","Lowering the brightness of <device_name> to <brightness>.",8
light,turn_on,"Can you set <device_name>'s brightness level to <brightness> percent?","Setting <device_name>'s brightness level to <brightness>%.",8
light,turn_on,"I'd like <device_name> at <brightness> percent brightness, please.","Setting <device_name> to <brightness>% brightness.",8
light,turn_on,"Can you make <device_name> <color>?","Turning <device_name> <color>.",8
light,turn_on,"Change the color of <device_name> to <color>.","Changing the color of <device_name> to <color>.",8
light,turn_on,"Change <device_name> to a <color> hue.","Changing <device_name> to a <color> hue.",8
light,turn_on,"Set <device_name> to be <color>.","Setting <device_name> to be <color>.",8
light,turn_on,"I want <device_name> to be <color>, please.","Setting <device_name> to be <color>.",8
light,turn_on,"Can you make <device_name> shine in <color>?","Making <device_name> shine in <color>.",8
light,turn_on,"Turn <device_name> to a <color> shade.","Turning <device_name> to a <color> shade.",8
light,turn_on,"Turn <device_name> to a <color> shade.","Turning <device_name> to a <color> shade.",8
light,turn_on,"Turn <device_name> <color>.","Turning <device_name> <color>.",8
light,turn_on,"I want <device_name> at a <color> setting.","Setting <device_name> to a <color>.",8
light,turn_on,"Set <device_name> to a <color> color.","Setting <device_name> to a <color> color.",8
light,turn_on,"Please set <device_name> to a <color> color.","Setting <device_name> to a <color> color.",8
light,turn_on,"Make <device_name> glow <color>.","Making <device_name> glow <color>.",8
light,turn_on,"Could you turn <device_name> to <color>?","Turning <device_name> to <color>.",8
light,turn_on,"Please change <device_name> to <color>.","Changing <device_name> to <color>.",8
light,turn_on,"Adjust <device_name> to <color> color.","Adjusting <device_name> to <color> color.",8
light,turn_on,"Switch <device_name> color to <color>.","Switching <device_name> color to <color>.",8
light,turn_on,"Can <device_name> be <color> now?","Setting <device_name> to be <color>.",8
light,turn_on,"Let's have <device_name> in <color>.","Setting <device_name> in <color>.",8
light,turn_on,"I'd like <device_name> to change to <color>.","Changing <device_name> to <color>.",8
light,turn_on,"Can <device_name> display a <color> light?","Making <device_name> display a <color> light.",8
light,turn_on,"Set <device_name> color to <color>, please.","Setting <device_name> color to <color>.",8
1 device_type service english_phrase assistant_response multiplier
193 climate set_hvac_mode Can we set the HVAC to heat? Setting the HVAC to heat. 8
194 climate set_hvac_mode Change the HVAC to automatic. Changing HVAC to automatic mode. 8
195 light turn_on Set the brightness of <device_name> to <brightness>%. Setting the brightness of <device_name> to <brightness>%. 8
light turn_on Can you make <device_name> <color>? Turning <device_name> <color>. 8
light turn_on Change the color of <device_name> to <color>. Changing the color of <device_name> to <color>. 8
196 light turn_on Dim <device_name> to <brightness> percent brightness. Dimming <device_name> to <brightness>% brightness. 8
light turn_on Please set <device_name> to a <color> color. Setting <device_name> to a <color> color. 8
197 light turn_on Brighten <device_name> to <brightness>. Brightening <device_name> to <brightness>%. 8
198 light turn_on Turn <device_name> <color>. Adjust <device_name> brightness to <brightness>. Turning <device_name> <color>. Adjusting <device_name> brightness to <brightness>. 8
199 light turn_on I want <device_name> at a <color> setting. Increase <device_name>'s brightness to <brightness>. Setting <device_name> to a <color>. Increasing <device_name>'s brightness to <brightness>. 8
200 light turn_on Set <device_name> to a <color> color. Lower the brightness of <device_name> to <brightness>. Setting <device_name> to a <color> color. Lowering the brightness of <device_name> to <brightness>. 8
light turn_on Adjust <device_name> brightness to <brightness>. Adjusting <device_name> brightness to <brightness>%. 8
light turn_on Increase <device_name>'s brightness to <brightness>. Increasing <device_name>'s brightness to <brightness>%. 8
light turn_on Lower the brightness of <device_name> to <brightness>. Lowering the brightness of <device_name> to <brightness>%. 8
201 light turn_on Can you set <device_name>'s brightness level to <brightness> percent? Setting <device_name>'s brightness level to <brightness>%. 8
202 light turn_on I'd like <device_name> at <brightness> percent brightness, please. Setting <device_name> to <brightness>% brightness. 8
203 light turn_on Can you make <device_name> <color>? Turning <device_name> <color>. 8
204 light turn_on Change the color of <device_name> to <color>. Changing the color of <device_name> to <color>. 8
205 light turn_on Change <device_name> to a <color> hue. Changing <device_name> to a <color> hue. 8
206 light turn_on Set <device_name> to be <color>. Setting <device_name> to be <color>. 8
207 light turn_on I want <device_name> to be <color>, please. Setting <device_name> to be <color>. 8
208 light turn_on Can you make <device_name> shine in <color>? Making <device_name> shine in <color>. 8
209 light turn_on Turn <device_name> to a <color> shade. Turning <device_name> to a <color> shade. 8
210 light turn_on Turn <device_name> <color>. Turning <device_name> <color>. 8
211 light turn_on I want <device_name> at a <color> setting. Setting <device_name> to a <color>. 8
212 light turn_on Set <device_name> to a <color> color. Setting <device_name> to a <color> color. 8
213 light turn_on Please set <device_name> to a <color> color. Setting <device_name> to a <color> color. 8
214 light turn_on Make <device_name> glow <color>. Making <device_name> glow <color>. 8
215 light turn_on Could you turn <device_name> to <color>? Turning <device_name> to <color>. 8
216 light turn_on Please change <device_name> to <color>. Changing <device_name> to <color>. 8
217 light turn_on Adjust <device_name> to <color> color. Adjusting <device_name> to <color> color. 8
218 light turn_on Switch <device_name> color to <color>. Switching <device_name> color to <color>. 8
219 light turn_on Can <device_name> be <color> now? Setting <device_name> to be <color>. 8
220 light turn_on Let's have <device_name> in <color>. Setting <device_name> in <color>. 8
221 light turn_on I'd like <device_name> to change to <color>. Changing <device_name> to <color>. 8
222 light turn_on Can <device_name> display a <color> light? Making <device_name> display a <color> light. 8
223 light turn_on Set <device_name> color to <color>, please. Setting <device_name> color to <color>. 8

View File

@@ -1,3 +1,4 @@
# home-llm experiements (phi1.5)
rev1 - original test
- 1 epoch
- train ctx 1900
@@ -217,7 +218,8 @@ rev 9 - reduced dataset size
------
home-1b-rev1
# Home 1B
## home-1b-rev1
- 1 epoch
- 2048 train ctx
- batch size 8
@@ -231,4 +233,98 @@ home-1b-rev1
+ it works OK with low temperatures
+ seems to handle the alpaca dataset not so well
home-1b-rev2
Eval results for existing models:
Home-1b-v1: 0.767816091954023
Home-3b-v2: 0.6908045977011494
## home-1b-rev5 series
- 1 epoch
- 2048 train ctx
- batch size 8
- learning rate 1e-5
- weight decay 0.1
- gradient clipping 1.0
- save model every 200 steps
home-1b-rev5
- dataset size: medium
- evaluation results:
- 200: 0.553448275862069
- 400: 0.7482758620689656 (+.19)
- 600: 0.8103448275862069 (+.06)
- 800: 0.8316091954022988 (+.02)
- 1000: 0.8396551724137931 (+.008)
- 1200: 0.8488505747126437 (+.009)
- Final (1467): 0.8494252873563218 (+.00005)
home-1b-rev5_1
- dataset size: small
- evaluation results:
- 200: 0.6057471264367816
- 400: 0.7494252873563219 (+.143)
- 600: 0.7683908045977011 (+.018)
- 800: 0.7729885057471264 (+.0046)
- Final (869): bad
home-1b-rev5_2
- dataset size: large
- evaluation results:
- 200: --
- 400: --
- 600: 0.8425287356321839
- 800: 0.8666666666666667
- 1000: 0.8770114942528736
- 1200: 0.8844827586206897
- 1400: 0.8879310344827587
- 1600: 0.8844827586206897
- Final (1848): 0.8833333333333333
home-3b-v3-rev1
- dataset size: large
- evaluation results: 0.9091954022988505
home-1b-rev6
- dataset size: large (fixed templates + function calling arguments; brightness is broken)
- evaluation results: 0.8254149971379507
home-1b-rev6_1
- dataset size: xl (fixed templates + function calling arguments; 0-255 brightness is broken)
- evaluation results:
- 400: 0.7240984544934173
- 800: 0.8311390955924441
- 1200: 0.8471665712650257
- 1600: 0.8597595878649112
- 2000: 0.8551803091013166
- Final (2322): 0.8586147681740126
home-1b-rev6_2 = Home-1B-v2-GGUF
- dataset size: large (change brightness back to percentages; increase color references by ~2x)
- evaluation results:
- 400: 0.7856064418721691
- 800: 0.864116759
- 1200: 0.882234524
- 1600: 0.885254152
- 2000: 0.8852541519879215
- Final (2048):
home-3b-v3-rev2 = Home-3B-v2-GGUF
- dataset size: xl + alpaca
- evaluation results: 0.8731756416708606
Home-3B-v2-GGUF:ha_only
- dataset size: large
- evaluation results: FAILED (again.....)
# Datasets
## SFT
Alpaca: https://huggingface.co/datasets/yahma/alpaca-cleaned
Alpaca (Translated): https://huggingface.co/datasets/saillab/taco-datasets
WizardLM 200k: https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k
WizardLM 70k: https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_70k
Huggingface Ultrachat 200k: https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
OpenOrca Slim Deduped (363k): https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
## DPO
Intel Orca DPO Pairs: https://huggingface.co/datasets/Intel/orca_dpo_pairs
Huggingface Ultrachat: https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized

129
evaluate.py Normal file
View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
import argparse, os, re, json
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from tqdm import tqdm
CTX_SIZE = 2048
def tokenize(tokenizer, prompt):
return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=CTX_SIZE)
def generate(model, tokenizer, prompt):
inputs = tokenize(tokenizer, prompt)
with torch.no_grad():
outputs = model.generate(**inputs)
text = tokenizer.batch_decode(outputs)
return text
def main():
parser = argparse.ArgumentParser(description="Evaluate the function calling for a model")
parser.add_argument("model")
parser.add_argument("--dataset_file", default="./data/home_assistant_test.json")
parser.add_argument("--split", default="<|im_start|>assistant")
parser.add_argument("--batch-size", default=8)
args = parser.parse_args()
model_folder = f"./models/{args.model}"
split = args.split
dataset = load_dataset("json", data_files={ "train": args.dataset_file })["train"]
# filter out examples that are status requests
dataset = dataset.filter(lambda example: "```homeassistant" in example["text"])
service_call_regex = re.compile(r"```homeassistant\n([\S \t\n]*?)```")
torch.set_default_device("cuda")
print(f"Loading model from {model_folder}...")
trained_model = AutoModelForCausalLM.from_pretrained(model_folder, trust_remote_code=True, torch_dtype=torch.bfloat16) #, code_revision="834565c23f9b28b96ccbeabe614dd906b6db551a")
trained_tokenizer = AutoTokenizer.from_pretrained(model_folder, trust_remote_code=True, padding_side='left')
trained_model.generation_config = GenerationConfig(
max_new_tokens=128,
use_cache=True,
do_sample=True,
temperature=0.1,
top_k=40,
top_p=1.0,
repetition_penalty=1.15,
eos_token_id=trained_model.config.eos_token_id,
pad_token_id=trained_model.config.pad_token_id,
)
print("Evaluating...")
batch_size = int(args.batch_size)
correct_answers = 0
total_answers = 0
color_mismatches = 0
failed_examples = []
with tqdm(total=len(dataset), desc="Accuracy") as pbar:
for batch_start in range(0, len(dataset), batch_size):
batch = dataset[batch_start:batch_start + batch_size]
prompts = [ example.split(split)[0] + split for example in batch["text"] ]
expected_responses = [ example.split(split)[1] for example in batch["text"] ]
output = generate(trained_model, trained_tokenizer, prompts)
for model_output, expected_response in zip(output, expected_responses):
response = model_output.replace(trained_tokenizer.pad_token, "").replace(trained_tokenizer.eos_token, "").split(split)[1]
expected_service_calls = []
for block in service_call_regex.findall(expected_response.strip()):
for line in block.split("\n"):
if len(line) == 0:
continue
expected_service_calls.append(json.loads(line))
total_answers = total_answers + 1
for block in service_call_regex.findall(response.strip()):
for line in block.split("\n"):
if len(line) == 0:
continue
try:
json_output = json.loads(line)
except:
failed_examples.append({ "expected": expected_response, "actual": response, "invalid_json": True })
continue
if json_output in expected_service_calls:
expected_service_calls.pop(expected_service_calls.index(json_output))
correct_answers = correct_answers + 1
elif "rgb_color" in json_output:
for sc in expected_service_calls:
sc = { **sc }
json_output_copy = { **json_output }
if not "rgb_color" in sc:
continue
del sc["rgb_color"]
del json_output_copy["rgb_color"]
if sc == json_output_copy:
correct_answers = correct_answers + 1
color_mismatches = color_mismatches + 1
else:
failed_examples.append({ "expected": expected_response, "actual": response })
else:
failed_examples.append({ "expected": expected_response, "actual": response })
pbar.update(batch_size)
pbar.set_description(f"Accuracy: {correct_answers/total_answers*100:.2f}% ({correct_answers}/{total_answers})")
accuracy = correct_answers/total_answers
print(f"Final Accuracy Rating: {accuracy*100:.2f}%")
print(f"Color Mismatches: {color_mismatches}")
with open(os.path.join(model_folder, "eval_results.json"), "w") as f:
json.dump({
"possible_answers": total_answers,
"correct_answers": correct_answers,
"accuracy": accuracy,
"color_mismatches": color_mismatches,
"failed_examples": failed_examples,
}, f, indent=4)
if __name__ == "__main__":
main()

View File

@@ -3,8 +3,10 @@
import math
import copy
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, \
PreTrainedTokenizerFast, HfArgumentParser, GPTQConfig, AutoConfig
from transformers.trainer_utils import EvalPrediction
from datasets import load_dataset
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
@@ -15,13 +17,12 @@ Phi Modules: fc1,fc2,q_proj,v_proj,k_proj,dense,embed_tokens,lm_head
"""
python3 train.py \
--run_name home-3b-v2-rev3 \
--run_name Home-3B-v2_ha-GGUF \
--base_model microsoft/phi-2 \
--add_pad_token \
--add_chatml_tokens \
--bf16 \
--train_dataset data/home_assistant_alpaca_merged_train.json \
--test_dataset data/home_assistant_alpaca_merged_test.json \
--train_dataset data/home_assistant_train.json \
--learning_rate 1e-5 \
--save_steps 1000 \
--micro_batch_size 2 --gradient_checkpointing \
@@ -31,7 +32,7 @@ python3 train.py \
"""
python3 train.py \
--run_name home-1b-rev4 \
--run_name home-1b-rev6 \
--base_model microsoft/phi-1_5 \
--add_pad_token \
--add_chatml_tokens \
@@ -40,7 +41,7 @@ python3 train.py \
--test_dataset data/home_assistant_test.json \
--learning_rate 1e-5 \
--micro_batch_size 4 --gradient_checkpointing \
--ctx_size 2048
--ctx_size 2048 --save_steps 200
"""
"""
@@ -56,21 +57,23 @@ python3 train.py \
@dataclass
class TrainingRunArguments:
run_name: str = field(metadata={"help": "The folder to save the output model under"})
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
test_dataset: str = field(metadata={"help": "The JSON file containing the evaluation dataset"})
base_model: str = field(metadata={"help": "The base model to load for fine-tuning"})
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"})
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp16"})
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
micro_batch_size: int = field(default=2, metadata={"help": "The actual batch size that will fit into VRAM on this machine"})
eval_batch_size: int = field(default=1, metadata={"help": "The batch size for generation used while performing evaluation"})
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
weight_decay: float = field(default=0.1, metadata={"help": ""})
gradient_clip: float = field(default=1.0, metadata={"help": ""})
resume_from_checkpoint: str = field(default="", metadata={"help": "The name of the checkpoint to resume training from"})
eval_steps: int = field(default=100, metadata={"help": "The number of steps in between evaluations of the model"})
eval_steps: int = field(default=200, metadata={"help": "The number of steps in between evaluations of the model; set to -1 to evaluate every epoch"})
save_steps: int = field(default=-1, metadata={"help": "The number of steps in between model checkpoints; set to -1 to save every epoch"})
save_total_limit: int = field(default=1, metadata={"help": "The number of recent checkpoints of the model to save (not including the final model)"})
group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding"})
# Quantization
@@ -99,8 +102,6 @@ training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strin
if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1:
raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq")
# TODO: write a proper evaluation script
print(f"Loading model '{training_run_args.base_model}'...")
model_kwargs = {}
@@ -118,6 +119,7 @@ else:
model_kwargs["torch_dtype"] = torch.float16
# model_kwargs["resid_pdrop"] = 0.0
# model_kwargs["revision"] = "accfee56d8988cae60915486310362db5831b1bd"
def find_max_vram(min_buffer_mib=800):
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
@@ -136,10 +138,11 @@ model = AutoModelForCausalLM.from_pretrained(
max_memory=find_max_vram(),
**model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, trust_remote_code=True, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, trust_remote_code=True)
if training_run_args.add_pad_token:
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
model.config.pad_token_id = tokenizer.pad_token_id
if training_run_args.add_chatml_tokens:
tokenizer.add_special_tokens({
@@ -150,6 +153,8 @@ if training_run_args.add_chatml_tokens:
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# TODO: add chatml template to tokenizer for auto-formatting
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
if model.get_input_embeddings().num_embeddings < embeddings_len:
model.resize_token_embeddings(embeddings_len)
@@ -183,31 +188,37 @@ if training_run_args.use_lora:
base_dir = "loras" if training_run_args.use_lora else "models"
model_dir = f"./{base_dir}/{training_run_args.run_name}"
# TODO: eval is broken (returning NaN for loss)
training_kwargs = {}
if training_run_args.test_dataset:
training_kwargs.update({
"per_device_eval_batch_size": training_run_args.eval_batch_size,
"evaluation_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
"bf16_full_eval": training_run_args.bf16,
})
training_args = TrainingArguments(
per_device_train_batch_size=training_run_args.micro_batch_size,
# per_device_eval_batch_size=training_run_args.micro_batch_size,
gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size,
gradient_checkpointing=training_run_args.gradient_checkpointing,
# weight_decay=training_run_args.weight_decay,
# max_grad_norm=training_run_args.gradient_clip,
# evaluation_strategy="steps",
# eval_steps=training_run_args.eval_steps,
weight_decay=training_run_args.weight_decay,
max_grad_norm=training_run_args.gradient_clip,
save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"),
save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None),
save_safetensors=True,
logging_steps=5,
output_dir=model_dir,
num_train_epochs=training_run_args.epochs,
save_total_limit=1,
# dataloader_pin_memory=False,
save_total_limit=training_run_args.save_total_limit,
report_to="tensorboard",
learning_rate=training_run_args.learning_rate,
lr_scheduler_type=training_run_args.learning_rate_schedule,
log_level="info",
bf16=training_run_args.bf16,
# bf16_full_eval=training_run_args.bf16,
group_by_length=training_run_args.group_by_length
group_by_length=training_run_args.group_by_length,
**training_kwargs,
# include_inputs_for_metrics=True,
)
class DataCollatorForSupervisedFineTuning(object):
@@ -315,7 +326,10 @@ class DataCollatorForSupervisedFineTuning(object):
)
print("Loading dataset...")
datasets = load_dataset("json", data_files={ "train": training_run_args.train_dataset, "test": training_run_args.test_dataset })
data_files = { "train": training_run_args.train_dataset }
if training_run_args.test_dataset:
data_files["test"] = training_run_args.test_dataset
datasets = load_dataset("json", data_files=data_files)
def tokenize(example):
return tokenizer(
@@ -326,18 +340,24 @@ def tokenize(example):
)
print("Tokenizing datasets...")
tokenized_test_dataset = None
tokenized_train_dataset = datasets["train"].map(tokenize, batched=True).remove_columns(["text"])
tokenized_test_dataset = datasets["test"].map(tokenize, batched=True).remove_columns(["text"])
if training_run_args.test_dataset:
tokenized_test_dataset = datasets["test"].map(tokenize, batched=True).remove_columns(["text"])
tokens_in_train_set = sum(len(example) for example in tokenized_train_dataset["input_ids"])
print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens")
data_collator = DataCollatorForSupervisedFineTuning(tokenizer=tokenizer)
import random
from torch.utils.data import SequentialSampler, Subset, RandomSampler
class RandomEvalSubsetTrainer(Trainer):
def __init__(self, random_eval_sample_pct=0.1, *args, **kwargs):
def __init__(self, random_eval_sample_pct=0.1, learning_rate_overshoot=1.15, *args, **kwargs):
super().__init__(*args, **kwargs)
self.random_eval_sample_pct = random_eval_sample_pct
self.evaluate_full_dataset = False
self.learning_rate_overshoot = learning_rate_overshoot
def evaluate_all(self):
self.evaluate_full_dataset = True
@@ -359,13 +379,34 @@ class RandomEvalSubsetTrainer(Trainer):
return super()._get_train_sampler()
return RandomSampler(self.train_dataset, generator=torch.Generator(device='cpu'))
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
"""
Saw this in the chinchilla paper. It says not to go over 25% overshoot
Should speed up training by skipping the final fine tuning part that doesn't affect accuracy much
"""
return super().create_scheduler(int(num_training_steps * self.learning_rate_overshoot), optimizer=optimizer)
def compute_metrics(pred: EvalPrediction):
inputs = pred.inputs
label_ids = pred.label_ids
logits = pred.predictions
return {"accuracy": 1.0 }
def preprocess_logits_for_metrics(logits, labels):
"""https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941/22"""
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids, labels
trainer = RandomEvalSubsetTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train_dataset,
# eval_dataset=tokenized_test_dataset,
eval_dataset=tokenized_test_dataset,
data_collator=data_collator,
# compute_metrics=compute_metrics,
# preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
tensorboard_process = None