Files
home-llm/data/generate_home_assistant_data.py
2024-08-04 19:19:10 +00:00

1166 lines
48 KiB
Python

import argparse
import json
import csv
import pandas
import numpy as np
import random
import re
import copy
from dataclasses import dataclass
from datetime import datetime, timedelta
from datasets import load_dataset, concatenate_datasets
from difflib import SequenceMatcher
from typing import Final, Any, Callable, Optional
from tqdm import tqdm
import webcolors
# #### STATES ####
STATE_ON: Final = "on"
STATE_OFF: Final = "off"
STATE_ACTIVE: Final = "active"
STATE_UNKNOWN: Final = "unknown"
STATE_OPEN: Final = "open"
STATE_OPENING: Final = "opening"
STATE_CLOSED: Final = "closed"
STATE_CLOSING: Final = "closing"
STATE_BUFFERING: Final = "buffering"
STATE_PLAYING: Final = "playing"
STATE_PAUSED: Final = "paused"
STATE_IDLE: Final = "idle"
STATE_STANDBY: Final = "standby"
STATE_LOCKED: Final = "locked"
STATE_UNLOCKED: Final = "unlocked"
STATE_LOCKING: Final = "locking"
STATE_UNLOCKING: Final = "unlocking"
STATE_JAMMED: Final = "jammed"
STATE_UNAVAILABLE: Final = "unavailable"
STATE_OK: Final = "ok"
STATE_PROBLEM: Final = "problem"
STATE_CLEANING: Final = "cleaning"
STATE_DOCKED: Final = "docked"
STATE_RETURNING: Final = "returning"
# define piles for global access
pile_of_durations = None
pile_of_media_names = None
pile_of_todo_items = None
stacks_of_device_names = None
pile_of_templated_actions = None
pile_of_specific_actions = None
pile_of_responses = None
pile_of_status_requests = None
pile_of_system_prompts = None
pile_of_hallucinated_service_names = None
and_word = None
def closest_color(requested_color):
min_colors = {}
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
rd = (r_c - requested_color[0]) ** 2
gd = (g_c - requested_color[1]) ** 2
bd = (b_c - requested_color[2]) ** 2
min_colors[(rd + gd + bd)] = name
return min_colors[min(min_colors.keys())]
def generate_random_date():
start_date = datetime(2022, 1, 1)
end_date = datetime(2030, 12, 31)
delta = end_date - start_date
random_days = random.randint(0, delta.days)
random_seconds = random.randint(0, 24 * 60 * 60)
random_date_time = start_date + timedelta(days=random_days, seconds=random_seconds)
return random_date_time
var_pattern = re.compile("<(.*?)>")
def get_included_vars(response: str):
result = []
for var in var_pattern.findall(response):
if var == "device_name":
continue
result.append(var)
return ",".join(sorted(result))
@dataclass
class DeviceType:
name: str
possible_states: list[(str, float)]
services: dict[str, list]
random_parameter_generator: Optional[dict[str, Callable]] = None
def get_all_services(self, extra_exposed_attributes):
result = []
for service in self.services.keys():
args = set(extra_exposed_attributes).intersection(self.services[service])
result.append(f"{self.name}.{service}({','.join(args)})")
return result
def get_random_parameter(self, param_name):
return self.random_parameter_generator[param_name]()
def get_random_state(self, extra_exposed_attributes=[]):
states = [ x[0] for x in self.possible_states ]
weights = [ x[1] for x in self.possible_states ]
return random.choices(states, weights=weights, k=1)[0]
class LightDeviceType(DeviceType):
def __init__(self):
super().__init__("light",
possible_states=[
(STATE_ON, 0.5),
(STATE_OFF, 0.5)
],
services={
"turn_on": [ "rgb_color", "brightness" ],
"turn_off": [],
"toggle": []
},
random_parameter_generator={
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
"brightness": lambda: random.randint(0, 100),
}
)
def get_random_state(self, extra_exposed_attributes=[]):
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
random_rgb = self.get_random_parameter("rgb_color")
state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb)
if random.random() < 0.7 and "brightness" in extra_exposed_attributes:
state = state + ";" + str(self.get_random_parameter("brightness")) + "%"
return state
class ClimateDeviceType(DeviceType):
def __init__(self):
super().__init__("climate", [], {
"turn_on": [],
"turn_off": [],
"toggle": [],
"set_temperature": ["temperature"],
"set_humidity": ["humidity"],
"set_fan_mode": ["fan_mode"],
"set_hvac_mode": ["hvac_mode"],
"set_preset_mode": ["preset_mode"]
},
random_parameter_generator={
"fan_mode": lambda: random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"]),
"temp_f": lambda: random.randint(60, 80),
"temp_c": lambda: random.randint(15, 25),
"humidity": lambda: random.randint(10, 90),
"preset_mode": lambda: random.choice(["home", "eco", "away", "auto"]),
"hvac_mode": lambda: random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"]),
})
def get_random_state(self, extra_exposed_attributes=[]):
"""state;fan_mode;temperature;humidity"""
state = self.get_random_parameter("hvac_mode")
if "fan_mode" in extra_exposed_attributes:
state = state + ";" + self.get_random_parameter("fan_mode")
if "temperature" in extra_exposed_attributes:
if random.random() > 0.5:
state = state + ";" + str(self.get_random_parameter("temp_f")) + "F"
else:
state = state + ";" + str(self.get_random_parameter("temp_c")) + "C"
if "humidity" in extra_exposed_attributes:
state = state + ";" + str(self.get_random_parameter("humidity")) + "%"
if random.random() < 0.8 and "preset_mode" in extra_exposed_attributes:
# if it is not "on a preset" then don't add the mode
state = state + ";" + self.get_random_parameter("preset_mode")
return state
class MediaPlayerDeviceType(DeviceType):
def __init__(self):
super().__init__("media_player", [
(STATE_ON, 0.15),
(STATE_OFF, 0.54),
(STATE_IDLE, 0.1),
(STATE_PLAYING, 0.1),
(STATE_PAUSED, 0.05),
(STATE_STANDBY, 0.05),
(STATE_BUFFERING, 0.01),
], {
"turn_on": [],
"turn_off": [],
"toggle": [],
"volume_up": [],
"volume_down": [],
"volume_mute": [],
"media_play_pause": [],
"media_play": [],
"media_pause": [],
"media_stop": [],
"media_next_track": [],
"media_previous_track": []
},
random_parameter_generator={
"media": lambda: random.choice(pile_of_media_names),
"volume": lambda: round(random.random(), 2),
})
def get_random_state(self, extra_exposed_attributes=[]):
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
state = state + ";" + self.get_random_parameter("media")
if "volume_level" in extra_exposed_attributes and state != STATE_OFF:
state = state + ";vol=" + str(self.get_random_parameter("volume"))
return state
SUPPORTED_DEVICES = {
"light": LightDeviceType(),
"switch": DeviceType(
name="switch",
possible_states=[
(STATE_ON, 0.5),
(STATE_OFF, 0.5)
],
services={
"turn_on": [],
"turn_off": [],
"toggle": []
},
),
"fan": DeviceType(
name="fan",
possible_states=[
(STATE_ON, 0.5),
(STATE_OFF, 0.5)
],
services={
"turn_on": [],
"turn_off": [],
"toggle": [],
"increase_speed": [],
"decrease_speed": [],
},
),
"garage_door": DeviceType(
name="garage_door",
possible_states=[
(STATE_OPEN, 0.49),
(STATE_CLOSED, 0.49),
(STATE_OPENING, 0.01),
(STATE_CLOSING, 0.01)
],
services={
"open_cover": [],
"close_cover": [],
"stop_cover": [],
"toggle": [],
},
),
"blinds": DeviceType(
name="blinds",
possible_states=[
(STATE_OPEN, 0.49),
(STATE_CLOSED, 0.49),
(STATE_OPENING, 0.01),
(STATE_CLOSING, 0.01)
],
services={
"open_cover": [],
"close_cover": [],
"stop_cover": [],
"toggle": [],
},
),
"lock": DeviceType(
name="lock",
possible_states=[
(STATE_LOCKED, 0.5),
(STATE_UNLOCKED, 0.5),
],
services={
"lock": [],
"unlock": [],
},
),
"media_player": MediaPlayerDeviceType(),
"climate": ClimateDeviceType(),
"vacuum": DeviceType(
name="vacuum",
possible_states=[
(STATE_CLEANING, 0.2),
(STATE_DOCKED, 0.6),
(STATE_RETURNING, 0.1),
(STATE_IDLE, 0.05),
(STATE_PAUSED, 0.05),
],
services={
"start": [],
"pause": [],
"stop": [],
"return_to_base": [],
},
),
"timer": DeviceType(
name="timer",
possible_states=[
(STATE_IDLE, 0.2),
(STATE_ACTIVE, 0.6),
(STATE_PAUSED, 0.1),
],
services={
"start": ["duration"],
"pause": [],
"cancel": [],
},
random_parameter_generator={
"duration": lambda: random.choice(list(pile_of_durations.keys())),
"remaining": lambda: f"{random.randint(0, 3):02}:{random.randint(0, 60)}:{random.randint(0, 60)}"
}
),
"todo": DeviceType(
name="todo",
possible_states=[ (f"{i}", (1/32)) for i in range(32) ],
services={
"add_item": ["item"],
},
random_parameter_generator={
"todo": lambda: random.choice(pile_of_todo_items),
}
),
}
class NoResponseAvailableException(Exception):
pass
class NoServicesAvailableException(Exception):
pass
def get_random_response(*, service: str, persona: str, question_template: str, short: bool) -> str:
required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var]))
possible_results = pile_of_responses.loc[(pile_of_responses['service']==service) &
(pile_of_responses['persona']==persona) &
(pile_of_responses['short']==(1 if short else 0)) &
(pile_of_responses['contains_vars']==",".join(sorted(required_vars)))
]
if len(possible_results) == 0:
raise NoResponseAvailableException(f"No responses matched the provided filters: {persona}, {service}, {required_vars}, {short}")
return possible_results.sample()["response"].values[0]
def format_device_line(*, device_name: str, friendly_name: str, state: str):
return (f"{device_name} '{friendly_name}' = {state}")
# generate a random list of devices for the context
def random_device_list(max_devices: int, avoid_device_names: list[str]):
num_devices = random.randint(2, max_devices)
local_device_names = { k: v[:] for k,v in stacks_of_device_names.items() }
avoid_climate = False
for avoid_device in avoid_device_names:
avoid_type = avoid_device.split(".")[0]
filtered_possible_devices = []
for possible_device in local_device_names[avoid_type]:
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
if similarity_ratio < 0.4:
filtered_possible_devices.append(possible_device)
local_device_names[avoid_type] = filtered_possible_devices
if avoid_type == "climate":
avoid_climate = True
possible_choices = []
for device_type in local_device_names.keys():
possible_choices.extend(local_device_names[device_type])
device_types = set()
device_list = []
device_lines = []
# TODO: randomly pick attributes for this list
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
while len(device_list) < num_devices:
choice = random.choice(possible_choices)
if choice["device_name"] in device_list:
continue
try:
device_name = choice["device_name"]
device_type = device_name.split(".")[0]
friendly_name = choice["description"]
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
if avoid_climate and device_type == "climate":
continue
state = SUPPORTED_DEVICES[device_type].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
device_lines.append(format_device_line(
device_name=device_name,
friendly_name=friendly_name,
state=state
))
device_list.append(device_name)
device_types.add(device_type)
except Exception as ex:
print(f"bad device name: {choice}")
print(repr(ex))
return device_lines, list(device_types), list(extra_exposed_attributes)
def generate_static_example(action: dict, persona: str, max_devices: int = 32):
question = action["phrase"]
service_name = action["service_name"]
device_type = service_name.split(".")[0]
target_device = f"{device_type}.{action['device_name']}"
friendly_name = target_device.split(".")[1].replace("_", " ").title()
device_list, device_types, extra_exposed_attributes = random_device_list(
max_devices=max_devices, avoid_device_names=[target_device])
# insert our target device somewhere random in the list
index = random.randint(0, len(device_list))
state = SUPPORTED_DEVICES[device_type].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
device_list.insert(index, format_device_line(
device_name=target_device,
friendly_name=friendly_name,
state=state
))
# gather a list of all available services
available_services = []
for x in set(device_types + [device_type]):
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
response = get_random_response(
service=action["service_name"],
persona=persona,
question_template="",
short=False
).lower()
response = response.replace("<device_name>", friendly_name)
return {
"states": device_list,
"available_services": list(available_services),
"question": question.lower(),
"answers": [ response ],
"service_calls": [ { "service": service_name, "target_device": target_device } ]
}
def generate_templated_example(template: dict, persona: str, max_devices: int = 32):
template_device_types: list[str] = template["device_type"].split("|")
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
question_template: str = template["phrase"]
# choose a random device for this template
chosen_devices = []
for device_type in template_device_types:
device_dict = random.choice(stacks_of_device_names[device_type])
device_dict["type"] = device_type
chosen_devices.append(device_dict)
device_list, device_types, extra_exposed_attributes = random_device_list(
max_devices=max_devices, avoid_device_names=[d["device_name"] for d in chosen_devices])
# insert our target device somewhere random in the list
for device_dict in chosen_devices:
index = random.randint(0, len(device_list))
if "<brightness>" in question_template and "brightness" not in extra_exposed_attributes:
extra_exposed_attributes.append("brightness")
if "<color>" in question_template and "rgb_color" not in extra_exposed_attributes:
extra_exposed_attributes.append("rgb_color")
if ("<temp_f>" in question_template or "<temp_c>" in question_template) \
and "temperature" not in extra_exposed_attributes:
extra_exposed_attributes.append("temperature")
if "<humidity>" in question_template and "humidity" not in extra_exposed_attributes:
extra_exposed_attributes.append("humidity")
if "<fan_mode>" in question_template and "fan_mode" not in extra_exposed_attributes:
extra_exposed_attributes.append("fan_mode")
if "<duration>" in question_template and "duration" not in extra_exposed_attributes:
extra_exposed_attributes.append("duration")
state = SUPPORTED_DEVICES[device_dict["type"]].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
device_name = device_dict["device_name"]
friendly_name = device_dict["description"]
device_list.insert(index, format_device_line(
device_name=device_name,
friendly_name=friendly_name,
state=state
))
# gather a list of all available services with arguments
available_services = []
for x in set(device_types + template_device_types):
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
# pick an appropriate response and generate the question
if len(template_device_types) == 1:
answer_template = get_random_response(
service=service_names[0],
persona=persona,
question_template=question_template,
short=False
)
question = question_template.replace("<device_name>", chosen_devices[0]["description"])
answer = answer_template.replace("<device_name>", chosen_devices[0]["description"])
else:
question = question_template
answers = []
for i in range(len(template_device_types)):
question = question.replace(f"<device_name{(i + 1)}>", chosen_devices[i]["description"])
answer = get_random_response(
service=service_names[i],
persona=persona,
question_template=question_template,
short=True
)
answers.append(answer.replace(f"<device_name>", chosen_devices[i]["description"]))
answer = f" {and_word} ".join(answers)
# generate the list of service calls and answers
service_calls = []
for device_dict, service in zip(chosen_devices, service_names):
service_calls.append({ "service": service, "target_device": device_dict["device_name"] })
if any(["climate" in service for service in service_names ]):
climate_device_type = SUPPORTED_DEVICES["climate"]
if "<hvac_mode>" in question:
hvac_mode = climate_device_type.get_random_parameter("hvac_mode")
question = question.replace("<hvac_mode>", hvac_mode)
answer = answer.replace("<hvac_mode>", hvac_mode)
service_calls = [ { **call, "hvac_mode": hvac_mode} for call in service_calls ]
if "<fan_mode>" in question:
fan_mode = climate_device_type.get_random_parameter("fan_mode")
question = question.replace("<fan_mode>", fan_mode)
answer = answer.replace("<fan_mode>", fan_mode)
service_calls = [ { **call, "fan_mode": fan_mode} for call in service_calls ]
if "<temp_f>" in question:
temp_f = climate_device_type.get_random_parameter("temp_f")
question = question.replace("<temp_f>", str(temp_f))
answer = answer.replace("<temp_f>", str(temp_f))
service_calls = [ { **call, "temperature": temp_f} for call in service_calls ]
if "<temp_c>" in question:
temp_c = climate_device_type.get_random_parameter("temp_c")
question = question.replace("<temp_c>", str(temp_c))
answer = answer.replace("<temp_c>", str(temp_c))
service_calls = [ { **call, "temperature": temp_c} for call in service_calls ]
if "<humidity>" in question:
humidity = climate_device_type.get_random_parameter("humidity")
question = question.replace("<humidity>", str(humidity))
answer = answer.replace("<humidity>", str(humidity))
service_calls = [ { **call, "humidity": humidity} for call in service_calls ]
if any(["light" in service for service in service_names ]):
light_device_type = SUPPORTED_DEVICES["light"]
if "<brightness>" in question:
brightness = light_device_type.get_random_parameter("brightness")
question = question.replace("<brightness>", str(brightness))
answer = answer.replace("<brightness>", str(brightness))
service_calls = [ { **call, "brightness": round(brightness / 100, 2) } for call in service_calls ]
if "<color>" in question:
random_rgb = light_device_type.get_random_parameter("rgb_color")
random_rgb_name = closest_color(random_rgb)
actual_random_rgb = webcolors.name_to_rgb(random_rgb_name)
actual_random_rgb = (actual_random_rgb.red, actual_random_rgb.green, actual_random_rgb.blue)
question = question.replace("<color>", str(random_rgb_name))
answer = answer.replace("<color>", str(random_rgb_name))
service_calls = [ { **call, "rgb_color": str(actual_random_rgb) } for call in service_calls ]
if any(["timer" in service for service in service_names ]):
timer_device_type = SUPPORTED_DEVICES["timer"]
if "<duration>" in question:
duration = timer_device_type.get_random_parameter("duration")
duration_name = pile_of_durations[duration]
question = question.replace("<duration>", duration_name)
answer = answer.replace("<duration>", duration_name)
service_calls = [ { **call, "duration": str(duration) } for call in service_calls ]
if any(["todo" in service for service in service_names ]):
todo_device_type = SUPPORTED_DEVICES["todo"]
if "<todo>" in question:
todo = todo_device_type.get_random_parameter("todo")
question = question.replace("<todo>", todo)
answer = answer.replace("<todo>", todo)
service_calls = [ { **call, "item": todo } for call in service_calls ]
return {
"states": device_list,
"available_services": list(available_services),
"question": question.lower(),
"answers": [ answer.lower() ],
"service_calls": service_calls
}
def generate_status_request(template: dict, persona: str, max_devices: int = 32, return_target_device: bool = False):
device_type: str = template["device_type"]
state_name: str = template["state"]
question_template: str = template["phrase"]
answer_template: str = template["assistant_response"]
# choose a random device for this template
chosen_device = random.choice(stacks_of_device_names[device_type])
# build a random list of devices
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
# generate the question
question = question_template.replace("<device_name>", chosen_device["description"])
answer = answer_template.replace("<device_name>", chosen_device["description"])
# insert other templated variables
if device_type == "climate":
climate_device_type = SUPPORTED_DEVICES["climate"]
temp_f = climate_device_type.get_random_parameter("temp_f")
answer = answer.replace("<temp_f>", str(temp_f))
state_name = state_name.replace("<temp_f>", str(temp_f))
temp_c = climate_device_type.get_random_parameter("temp_c")
answer = answer.replace("<temp_c>", str(temp_c))
state_name = state_name.replace("<temp_c>", str(temp_f))
humidity = climate_device_type.get_random_parameter("humidity")
answer = answer.replace("<humidity>", str(humidity))
state_name = state_name.replace("<humidity>", str(temp_f))
if device_type == "light":
light_device_type = SUPPORTED_DEVICES["light"]
brightness = light_device_type.get_random_parameter("brightness")
answer = answer.replace("<brightness>", str(brightness))
state_name = state_name.replace("<brightness>", str(brightness))
random_rgb = light_device_type.get_random_parameter("rgb_color")
random_rgb_name = closest_color(random_rgb)
actual_random_rgb = webcolors.name_to_rgb(random_rgb_name)
actual_random_rgb = (actual_random_rgb.red, actual_random_rgb.green, actual_random_rgb.blue)
state_name = state_name.replace("<color>", str(random_rgb_name) + " " + str(actual_random_rgb))
answer = answer.replace("<color>", str(random_rgb_name))
if device_type == "media_player":
media_player_device_type = SUPPORTED_DEVICES["media_player"]
volume = media_player_device_type.get_random_parameter("volume")
random_media = media_player_device_type.get_random_parameter("media")
answer = answer.replace("<volume>", str(volume) + "%")
state_name = state_name.replace("<volume>", str(volume) + "%")
answer = answer.replace("<media>", random_media)
state_name = state_name.replace("<media>", random_media)
if device_type == "timer":
timer_device_type = SUPPORTED_DEVICES["timer"]
duration = timer_device_type.get_random_parameter("duration")
duration_name = pile_of_durations[duration]
remaining = timer_device_type.get_random_parameter("remaining")
answer = answer.replace("<duration>", duration_name)
state_name = state_name.replace("<duration>", duration)
answer = answer.replace("<remaining>", remaining)
state_name = state_name.replace("<remaining>", remaining)
# insert our target device somewhere random in the list
index = random.randint(0, len(device_list))
device_list.insert(index, format_device_line(
device_name=chosen_device["device_name"],
friendly_name=chosen_device["description"],
state=state_name
))
# gather a list of all available services
available_services = []
for x in set(device_types + [device_type]):
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
result = {
"states": device_list,
"available_services": list(available_services),
"question": question.lower(),
"answers": [ answer.lower() ],
"service_calls": []
}
if return_target_device:
return result, chosen_device
else:
return result
def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int = 32):
"""Generates examples of the model passing incorrect service call arguments"""
while True:
example = generate_templated_example(template, persona, max_devices)
rejected_example = copy.deepcopy(example)
call_idx = random.randint(0, len(example["service_calls"]) - 1)
call = example["service_calls"][call_idx]
target_device_type = call["target_device"].split(".")[0]
potential_devices = [ x for x in example["states"] if x.split(".")[0] == target_device_type]
random_device = random.choice(potential_devices).split(" ")[0]
for device in stacks_of_device_names[call["target_device"].split(".")[0]]:
similarity_ratio = SequenceMatcher(None, call["target_device"], device["device_name"]).ratio()
if similarity_ratio > 0.7 and device["device_name"] != call["target_device"]:
potential_devices.append(device["device_name"])
if len(potential_devices) > 1:
while random_device == call["target_device"]:
random_device = random.choice(potential_devices).split(" ")[0]
else:
random_device = None
# random service should probably be "related"
available_services = [ x[:-2] for x in example["available_services"] if call["service"] not in x ]
hallucinated_services = [ x["hallucinated_service"] for x in pile_of_hallucinated_service_names if x["real_service"] == call["service"].split(".")[1]]
random_service = random.choice(available_services + hallucinated_services)
random_argument = None # based on the service, add arguments that might be there like rgb, temperature, etc
update_choices = []
if random_device:
update_choices.append({ "target_device": random_device })
if random_service:
update_choices.append({ "service": random_service })
if random_argument:
update_choices.append(random_argument)
update_dict = random.choice(update_choices)
# need to replace the response text with what the incorrect response would have been
if len(rejected_example["service_calls"]) == 1:
pass
rejected_example["service_calls"][call_idx].update(update_dict)
return { "accepted": example, "rejected": rejected_example }
def generate_dpo_no_service_call(template: dict, persona: str, max_devices: int = 32):
"""Generates examples of the model saying 'i'll do that for you' and generating no service calls"""
example = generate_templated_example(template, persona, max_devices)
rejected_example = copy.deepcopy(example)
rejected_example["service_calls"] = []
return { "accepted": example, "rejected": rejected_example }
def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: int = 32):
"""Generates examples of the model adding random service calls to the end of status requests"""
example, target_device = generate_status_request(template, persona, max_devices, return_target_device=True)
rejected_example = copy.deepcopy(example)
device_name = target_device["device_name"]
device_type = device_name.split(".")[0]
random_device_services = [ x for x in example["available_services"] if x.split(".")[0] == device_type ]
if len(random_device_services) == 0:
raise NoServicesAvailableException()
rejected_example["service_calls"] = [{ "service": random.choice(random_device_services), "target_device": device_name }]
return { "accepted": example, "rejected": rejected_example }
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
pass
def format_example_raw_chatml(example, persona):
"""Don't use this one anymore"""
sys_prompt = pile_of_system_prompts[persona]
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
states_block = "Devices:\n" + "\n".join(example["states"])
question = example["question"]
answers = " ".join(example["answers"])
system_block = "\n".join([ "<|im_start|>system", sys_prompt, services_block, states_block ]) + "<|im_end|>"
user_block = "\n".join([ "<|im_start|>user", question]) + "<|im_end|>"
assistant_block = "<|im_start|>assistant\n" + answers
if len(example["service_calls"]) > 0:
json_calls = [ json.dumps(x) for x in example["service_calls"] ]
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
assistant_block = assistant_block + code_block
assistant_block = assistant_block + "<|im_end|>"
example_lines = [system_block, user_block, assistant_block]
result = "\n".join(example_lines)
if "<device_name" in result:
print("bad templating")
# replace aliases with their actual values
result = result.replace("blinds.", "cover.")
result = result.replace("garage_door.", "cover.")
return { "text": result }
def format_example_sharegpt(example, persona):
sys_prompt = pile_of_system_prompts[persona]
time_block = "The current time and date is " + generate_random_date().strftime("%I:%M %p on %A %B %d, %Y")
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
states_block = "Devices:\n" + "\n".join(example["states"])
question = example["question"]
answers = " ".join(example["answers"])
assistant_block = answers
if len(example["service_calls"]) > 0:
json_calls = [ json.dumps(x) for x in example["service_calls"] ]
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
assistant_block = assistant_block + code_block
# replace aliases with their actual values
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
conversation = [
{ "from": "system", "value": "\n".join([ sys_prompt, time_block, services_block, states_block ])},
{ "from": "user", "value": question },
{ "from": "assistant", "value": assistant_block },
]
return { "conversations": conversation }
def format_example_dpo(example, persona):
rejected_example = example["rejected"]
example = example["accepted"]
sys_prompt = pile_of_system_prompts[persona]
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
states_block = "Devices:\n" + "\n".join(example["states"])
question = example["question"]
assistant_block = " ".join(example["answers"])
if len(example["service_calls"]) > 0:
json_calls = [ json.dumps(x) for x in example["service_calls"] ]
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
assistant_block = assistant_block + code_block
rejected_assistant_block = " ".join(rejected_example["answers"])
if len(rejected_example["service_calls"]) > 0:
json_calls = [ json.dumps(x) for x in rejected_example["service_calls"] ]
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
rejected_assistant_block = rejected_assistant_block + code_block
# replace aliases with their actual values
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
rejected_assistant_block = rejected_assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
return {
"system": "\n".join([ sys_prompt, services_block, states_block ]),
"question": question,
"chosen": assistant_block,
"rejected": rejected_assistant_block,
}
def generate_sft_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, static_factor: int, template_factor: int, status_request_factor: int):
random.seed(seed)
np.random.seed(seed)
print("Generating...")
def run_factor_times(func, examples, data, persona, factor):
if factor >= 1:
for i in range(factor):
examples.append(format_func(func(data, persona), persona))
else:
if random.random() < factor:
examples.append(format_func(func(data, persona), persona))
generated_examples = []
missing_responses = set()
for person in personas:
for action in tqdm(pile_of_specific_actions):
try:
run_factor_times(generate_static_example, generated_examples, action, person, static_factor)
except NoResponseAvailableException as ex:
missing_responses.add(str(ex))
for templated_action in tqdm(pile_of_templated_actions):
try:
run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor)
except NoResponseAvailableException as ex:
missing_responses.add(str(ex))
for status_request in tqdm(pile_of_status_requests):
run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor)
print(f"Generated {len(generated_examples)} examples. Saving...")
for missing in sorted(missing_responses):
print(missing)
with open(f"{filename}.jsonl", "w") as f:
for item in generated_examples:
json_record = json.dumps(item)
f.write(json_record + '\n')
print("Done!")
def generate_dpo_file(filename: str, seed: int, format_func: Callable, personas: list[str], *, wrong_argument_factor: int, no_argument_factor: int, extra_service_call_factor: int, incorrect_persona_factor: int):
random.seed(seed)
np.random.seed(seed)
print("Generating...")
def run_factor_times(func, examples, data, persona, factor):
if factor >= 1:
for i in range(factor):
examples.append(format_func(func(data, persona), persona))
else:
if random.random() < factor:
examples.append(format_func(func(data, persona), persona))
generated_examples = []
missing_responses = set()
for person in personas:
for templated_action in tqdm(pile_of_templated_actions):
try:
run_factor_times(generate_dpo_wrong_argument, generated_examples, templated_action, person, wrong_argument_factor)
run_factor_times(generate_dpo_no_service_call, generated_examples, templated_action, person, no_argument_factor)
# run_factor_times(generate_dpo_incorrect_persona, generated_examples, templated_action, person, incorrect_persona_factor)
except NoResponseAvailableException as ex:
missing_responses.add(str(ex))
for status_request in tqdm(pile_of_status_requests):
try:
run_factor_times(generate_dpo_extra_service_call, generated_examples, status_request, "assistant", extra_service_call_factor)
except NoServicesAvailableException as ex:
pass # TODO: warn here?
print(f"Generated {len(generated_examples)} DPO examples. Saving...")
for missing in sorted(missing_responses):
print(missing)
with open(f"{filename}.jsonl", "w") as f:
for item in generated_examples:
json_record = json.dumps(item)
f.write(json_record + '\n')
print("Done!")
def format_alpaca(example, format_func: Callable):
question = example["instruction"]
if "input" in example and example["input"]:
question = question = "\n" + example["input"]
answer = example["output"]
device_list, device_types, extra_exposed_attributes = random_device_list(
max_devices=32, avoid_device_names=[])
available_services = []
for x in device_types:
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
text = format_func(example={
"states": device_list,
"available_services": list(available_services),
"question": question,
"answers": [ answer ],
"service_calls": []
})
result = {
"text": text
}
return result
def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset_column_names, format_func):
alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1)
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" })
random.seed(seed)
np.random.seed(seed)
alpaca_dataset = alpaca_dataset.map(format_function).remove_columns(dataset_column_names)
combined_dataset_train = concatenate_datasets([home_assistant_dataset["train"], alpaca_dataset["train"]]).shuffle(seed=42)
combined_dataset_test = concatenate_datasets([home_assistant_dataset["test"], alpaca_dataset["test"]]).shuffle(seed=42)
combined_dataset_train.to_json(f"home_assistant_{output_name}_merged_train.jsonl")
combined_dataset_test.to_json(f"home_assistant_{output_name}_merged_test.jsonl")
def merge_languages(filename_prefix: str, languages: list):
all_examples = []
for language in languages:
with open(f"{filename_prefix}_{language}.jsonl") as f:
all_examples.extend(f.readlines())
with open(f"{filename_prefix}.jsonl", "w") as f:
f.writelines(all_examples)
def load_dataset_piles(language):
global pile_of_durations, pile_of_media_names, pile_of_todo_items, stacks_of_device_names, \
pile_of_templated_actions, pile_of_specific_actions, pile_of_responses, pile_of_status_requests, \
pile_of_system_prompts, pile_of_hallucinated_service_names, and_word
# TODO: make this properly dynamic
if language == "english":
and_word = "and"
elif language == "german":
and_word = "und"
elif language == "french":
and_word = "et"
elif language == "spanish":
and_word = "y"
elif language == "polish":
and_word = "i"
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_durations = { x["duration"]: x["name"] for x in reader }
# media names are not translated
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
pile_of_media_names = [ x.strip() for x in f.readlines() ]
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
pile_of_todo_items = [ x.strip() for x in f.readlines() ]
stacks_of_device_names = { x: [] for x in SUPPORTED_DEVICES.keys() }
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_device_names = list(reader)
for device_dict in pile_of_device_names:
try:
device_type = device_dict["device_name"].split(".")[0]
stacks_of_device_names[device_type].append(device_dict)
except KeyError as ex:
print(ex)
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_templated_actions = list(reader)
processed_pile_of_templated_actions = []
for action in pile_of_templated_actions:
try:
multiplier = int(action["multiplier"])
except Exception:
raise Exception(f"line has a bad multiplier: {action}")
for x in range(multiplier):
processed_pile_of_templated_actions.append(action)
pile_of_templated_actions = processed_pile_of_templated_actions
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_specific_actions = list(reader)
pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
pile_of_responses["contains_vars"] = pile_of_responses["response"].apply(get_included_vars)
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_status_requests = list(reader)
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
# service names are not translated
with open(f"piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f)
pile_of_hallucinated_service_names = list(reader)
# TODO: add examples for ambiguous requests. asking a clarifying question
# TODO: support rejection when asking to do a service that isn't exposed
# TODO: make more randomized names for devices (random words or people's names)
# TODO: answer questions about more than one thing in the state list at once
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
# TODO: add time, weather, and calendar/reminders (next 3 events?)
def main():
parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles")
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--dpo", action="store_true", help="Set this flag to enable generation of the DPO dataset.")
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
train_size_group = parser.add_mutually_exclusive_group()
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
train_size_group.add_argument('--large', action='store_const', const='large', dest='size')
train_size_group.add_argument('--xl', action='store_const', const='xl', dest='size')
dataset_format_group = parser.add_mutually_exclusive_group()
dataset_format_group.add_argument('--raw_corpus', action='store_const', const='raw', dest='format')
dataset_format_group.add_argument('--sharegpt', action='store_const', const='sharegpt', dest='format')
args = parser.parse_args()
if not args.sample and not args.train and not args.test and not args.merge and not args.dpo:
parser.print_usage()
exit(-1)
if args.size and not args.train:
print("Train size was provided but not generating the training set!")
exit(-1)
if not args.format or args.format == "raw":
format_func = format_example_raw_chatml
elif args.format == "sharegpt":
format_func = format_example_sharegpt
for language in args.language:
load_dataset_piles(language)
personas = list(pile_of_system_prompts.keys())
suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample:
generate_sft_file(f"sample{suffix}", 42, format_func, personas, static_factor=1, template_factor=1, status_request_factor=1)
if args.train:
if args.size == "small":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=1, template_factor=10, status_request_factor=8)
elif args.size == "medium":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=15, status_request_factor=12)
elif args.size == "large":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=5, template_factor=20, status_request_factor=15)
elif args.size == "xl":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, personas, static_factor=7, template_factor=25, status_request_factor=18)
else:
raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test:
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, personas, static_factor=0.25, template_factor=1, status_request_factor=2)
if len(args.language) > 1:
if args.sample:
merge_languages("sample", args.language)
if args.train:
merge_languages("home_assistant_train", args.language)
if args.test:
merge_languages("home_assistant_test", args.language)
if args.dpo:
generate_dpo_file(f"home_assistant_dpo", 42, format_example_dpo, personas, wrong_argument_factor=1, no_argument_factor=1, extra_service_call_factor=1, incorrect_persona_factor=1)
if args.merge == "alpaca":
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"], format_func)
elif args.merge == "wizardlm70k":
merge_with_dataset("WizardLM/WizardLM_evol_instruct_70k", 42, "wizardlm70k", format_alpaca, ["output", "instruction"], format_func)
if __name__ == "__main__":
main()