mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
start re-writing dataset generation to use the new HA Assist API
This commit is contained in:
@@ -15,7 +15,7 @@ from typing import Final, Any, Callable, Optional
|
||||
from tqdm import tqdm
|
||||
import webcolors
|
||||
|
||||
# #### STATES ####
|
||||
# STATES
|
||||
STATE_ON: Final = "on"
|
||||
STATE_OFF: Final = "off"
|
||||
STATE_ACTIVE: Final = "active"
|
||||
@@ -41,6 +41,121 @@ STATE_CLEANING: Final = "cleaning"
|
||||
STATE_DOCKED: Final = "docked"
|
||||
STATE_RETURNING: Final = "returning"
|
||||
|
||||
# TOOLS
|
||||
TOOL_TURN_ON = "HassTurnOn"
|
||||
TOOL_TURN_OFF = "HassTurnOff"
|
||||
TOOL_SET_POSITION = "HassSetPosition"
|
||||
TOOL_LIGHT_SET = "HassLightSet"
|
||||
TOOL_SET_VOLUME = "HassSetVolume" # not implemented yet
|
||||
TOOL_MEDIA_UNPAUSE = "HassMediaUnpause"
|
||||
TOOL_MEDIA_PAUSE = "HassMediaPause"
|
||||
TOOL_MEDIA_NEXT = "HassMediaNext" # not implemented yet
|
||||
TOOL_VACUUM_START = "HassVacuumStart"
|
||||
TOOL_VACUUM_RETURN_TO_BASE = "HassVacuumReturnToBase"
|
||||
TOOL_LIST_ADD_ITEM = "HassListAddItem"
|
||||
TOOL_START_TIMER = "HassStartTimer" # ignored if timers unsupported
|
||||
TOOL_CANCEL_TIMER = "HassCancelTimer" # ignored if timers unsupported
|
||||
TOOL_PAUSE_TIMER = "HassPauseTimer" # ignored if timers unsupported
|
||||
TOOL_UNPAUSE_TIMER = "HassUnpauseTimer" # ignored if timers unsupported
|
||||
TOOL_SET_HUMIDITY = "HassHumidifierSetpoint"
|
||||
TOOL_SET_HUMIDIFIER_MODE = "HassHumidifierMode"
|
||||
|
||||
# TOOLS
|
||||
HASS_TOOLS = [
|
||||
{
|
||||
"name": TOOL_TURN_ON,
|
||||
"description": "Turns on/opens a device or entity",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_TURN_OFF,
|
||||
"description": "Turns off/closes a device or entity",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_SET_POSITION,
|
||||
"description": "Sets the position of a device or entity",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "area": "string", "floor": "string", "position": "integer" },
|
||||
"required": [ "position"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_LIST_ADD_ITEM,
|
||||
"description": "Add item to a todo list",
|
||||
"parameters": {
|
||||
"properties": { "item": "string", "name": "string" },
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_SET_HUMIDITY,
|
||||
"description": "Set desired humidity level",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "humidity": "integer" },
|
||||
"required": [ "name", "humidity"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_SET_HUMIDIFIER_MODE,
|
||||
"description": "Set humidifier mode",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "mode": "string" },
|
||||
"required": [ "name", "mode"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_LIGHT_SET,
|
||||
"description": "Sets the brightness or color of a light",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "area": "string", "floor": "string", "color": "string", "temperature": "integer", "brightness": "integer" },
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_MEDIA_UNPAUSE,
|
||||
"description": "Resumes a media player",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_MEDIA_PAUSE,
|
||||
"description": "Pauses a media player",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_VACUUM_START,
|
||||
"description": "Starts a vacuum",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_VACUUM_RETURN_TO_BASE,
|
||||
"description": "Returns a vacuum to base",
|
||||
"parameters": {
|
||||
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# define piles for global access
|
||||
pile_of_durations = None
|
||||
pile_of_media_names = None
|
||||
@@ -83,96 +198,83 @@ def get_included_vars(response: str):
|
||||
|
||||
return ",".join(sorted(result))
|
||||
|
||||
RANDOM_PARAMETER_GENERATORS = {
|
||||
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
|
||||
"brightness": lambda: random.randint(0, 100),
|
||||
"fan_mode": lambda: random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"]),
|
||||
"temp_f": lambda: random.randint(60, 80),
|
||||
"temp_c": lambda: random.randint(15, 25),
|
||||
"humidity": lambda: random.randint(10, 90),
|
||||
"preset_mode": lambda: random.choice(["home", "eco", "away", "auto"]),
|
||||
"hvac_mode": lambda: random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"]),
|
||||
"media": lambda: random.choice(pile_of_media_names),
|
||||
"volume": lambda: round(random.random(), 2),
|
||||
"duration": lambda: random.choice(list(pile_of_durations.keys())),
|
||||
"remaining": lambda: f"{random.randint(0, 3):02}:{random.randint(0, 60)}:{random.randint(0, 60)}",
|
||||
"todo": lambda: random.choice(pile_of_todo_items),
|
||||
}
|
||||
|
||||
def generate_random_parameter(param_name):
|
||||
param_generator = RANDOM_PARAMETER_GENERATORS.get(param_name)
|
||||
|
||||
if not param_generator:
|
||||
raise Exception(f"Unknown param to generate random value for {param_name}")
|
||||
|
||||
return param_generator()
|
||||
|
||||
@dataclass
|
||||
class DeviceType:
|
||||
name: str
|
||||
possible_states: list[(str, float)]
|
||||
services: dict[str, list]
|
||||
random_parameter_generator: Optional[dict[str, Callable]] = None
|
||||
|
||||
def get_all_services(self, extra_exposed_attributes):
|
||||
result = []
|
||||
for service in self.services.keys():
|
||||
args = set(extra_exposed_attributes).intersection(self.services[service])
|
||||
result.append(f"{self.name}.{service}({','.join(args)})")
|
||||
return result
|
||||
|
||||
def get_random_parameter(self, param_name):
|
||||
return self.random_parameter_generator[param_name]()
|
||||
|
||||
def get_random_state(self, extra_exposed_attributes=[]):
|
||||
states = [ x[0] for x in self.possible_states ]
|
||||
weights = [ x[1] for x in self.possible_states ]
|
||||
return random.choices(states, weights=weights, k=1)[0]
|
||||
|
||||
|
||||
# TODO: make services into a global "tools" concept since tools are not specific to device types
|
||||
class LightDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("light",
|
||||
possible_states=[
|
||||
(STATE_ON, 0.5),
|
||||
(STATE_OFF, 0.5)
|
||||
],
|
||||
services={
|
||||
"turn_on": [ "rgb_color", "brightness" ],
|
||||
"turn_off": [],
|
||||
"toggle": []
|
||||
},
|
||||
random_parameter_generator={
|
||||
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
|
||||
"brightness": lambda: random.randint(0, 100),
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def get_random_state(self, extra_exposed_attributes=[]):
|
||||
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
|
||||
random_rgb = self.get_random_parameter("rgb_color")
|
||||
random_rgb = generate_random_parameter("rgb_color")
|
||||
state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb)
|
||||
|
||||
if random.random() < 0.7 and "brightness" in extra_exposed_attributes:
|
||||
state = state + ";" + str(self.get_random_parameter("brightness")) + "%"
|
||||
state = state + ";" + str(generate_random_parameter("brightness")) + "%"
|
||||
|
||||
return state
|
||||
|
||||
class ClimateDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("climate", [], {
|
||||
"turn_on": [],
|
||||
"turn_off": [],
|
||||
"toggle": [],
|
||||
"set_temperature": ["temperature"],
|
||||
"set_humidity": ["humidity"],
|
||||
"set_fan_mode": ["fan_mode"],
|
||||
"set_hvac_mode": ["hvac_mode"],
|
||||
"set_preset_mode": ["preset_mode"]
|
||||
},
|
||||
random_parameter_generator={
|
||||
"fan_mode": lambda: random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"]),
|
||||
"temp_f": lambda: random.randint(60, 80),
|
||||
"temp_c": lambda: random.randint(15, 25),
|
||||
"humidity": lambda: random.randint(10, 90),
|
||||
"preset_mode": lambda: random.choice(["home", "eco", "away", "auto"]),
|
||||
"hvac_mode": lambda: random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"]),
|
||||
})
|
||||
super().__init__("climate", [])
|
||||
|
||||
def get_random_state(self, extra_exposed_attributes=[]):
|
||||
"""state;fan_mode;temperature;humidity"""
|
||||
state = self.get_random_parameter("hvac_mode")
|
||||
state = generate_random_parameter("hvac_mode")
|
||||
|
||||
if "fan_mode" in extra_exposed_attributes:
|
||||
state = state + ";" + self.get_random_parameter("fan_mode")
|
||||
state = state + ";" + generate_random_parameter("fan_mode")
|
||||
if "temperature" in extra_exposed_attributes:
|
||||
if random.random() > 0.5:
|
||||
state = state + ";" + str(self.get_random_parameter("temp_f")) + "F"
|
||||
state = state + ";" + str(generate_random_parameter("temp_f")) + "F"
|
||||
else:
|
||||
state = state + ";" + str(self.get_random_parameter("temp_c")) + "C"
|
||||
state = state + ";" + str(generate_random_parameter("temp_c")) + "C"
|
||||
if "humidity" in extra_exposed_attributes:
|
||||
state = state + ";" + str(self.get_random_parameter("humidity")) + "%"
|
||||
state = state + ";" + str(generate_random_parameter("humidity")) + "%"
|
||||
|
||||
if random.random() < 0.8 and "preset_mode" in extra_exposed_attributes:
|
||||
# if it is not "on a preset" then don't add the mode
|
||||
state = state + ";" + self.get_random_parameter("preset_mode")
|
||||
state = state + ";" + generate_random_parameter("preset_mode")
|
||||
|
||||
return state
|
||||
|
||||
@@ -186,33 +288,16 @@ class MediaPlayerDeviceType(DeviceType):
|
||||
(STATE_PAUSED, 0.05),
|
||||
(STATE_STANDBY, 0.05),
|
||||
(STATE_BUFFERING, 0.01),
|
||||
], {
|
||||
"turn_on": [],
|
||||
"turn_off": [],
|
||||
"toggle": [],
|
||||
"volume_up": [],
|
||||
"volume_down": [],
|
||||
"volume_mute": [],
|
||||
"media_play_pause": [],
|
||||
"media_play": [],
|
||||
"media_pause": [],
|
||||
"media_stop": [],
|
||||
"media_next_track": [],
|
||||
"media_previous_track": []
|
||||
},
|
||||
random_parameter_generator={
|
||||
"media": lambda: random.choice(pile_of_media_names),
|
||||
"volume": lambda: round(random.random(), 2),
|
||||
})
|
||||
])
|
||||
|
||||
def get_random_state(self, extra_exposed_attributes=[]):
|
||||
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
|
||||
state = state + ";" + self.get_random_parameter("media")
|
||||
state = state + ";" + generate_random_parameter("media")
|
||||
|
||||
if "volume_level" in extra_exposed_attributes and state != STATE_OFF:
|
||||
state = state + ";vol=" + str(self.get_random_parameter("volume"))
|
||||
state = state + ";vol=" + str(generate_random_parameter("volume"))
|
||||
return state
|
||||
|
||||
SUPPORTED_DEVICES = {
|
||||
@@ -308,26 +393,11 @@ SUPPORTED_DEVICES = {
|
||||
(STATE_IDLE, 0.2),
|
||||
(STATE_ACTIVE, 0.6),
|
||||
(STATE_PAUSED, 0.1),
|
||||
],
|
||||
services={
|
||||
"start": ["duration"],
|
||||
"pause": [],
|
||||
"cancel": [],
|
||||
},
|
||||
random_parameter_generator={
|
||||
"duration": lambda: random.choice(list(pile_of_durations.keys())),
|
||||
"remaining": lambda: f"{random.randint(0, 3):02}:{random.randint(0, 60)}:{random.randint(0, 60)}"
|
||||
}
|
||||
]
|
||||
),
|
||||
"todo": DeviceType(
|
||||
name="todo",
|
||||
possible_states=[ (f"{i}", (1/32)) for i in range(32) ],
|
||||
services={
|
||||
"add_item": ["item"],
|
||||
},
|
||||
random_parameter_generator={
|
||||
"todo": lambda: random.choice(pile_of_todo_items),
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
@@ -500,7 +570,7 @@ def generate_static_example(action: dict, persona: str, max_devices: int = 32):
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_services": list(available_services),
|
||||
"available_tools": list(available_services),
|
||||
"question": question.lower(),
|
||||
"answers": [ response ],
|
||||
"service_calls": [ { "service": service_name, "target_device": target_device } ]
|
||||
@@ -660,7 +730,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_services": list(available_services),
|
||||
"available_tools": list(available_services),
|
||||
"question": question.lower(),
|
||||
"answers": [ sentence.lower() for sentence in answer ],
|
||||
"service_calls": service_calls
|
||||
@@ -749,7 +819,7 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32,
|
||||
|
||||
result = {
|
||||
"states": device_list,
|
||||
"available_services": list(available_services),
|
||||
"available_tools": list(available_services),
|
||||
"question": question.lower(),
|
||||
"answers": [ answer.lower() ],
|
||||
"service_calls": []
|
||||
@@ -786,7 +856,7 @@ def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int =
|
||||
random_device = None
|
||||
|
||||
# random service should probably be "related"
|
||||
available_services = [ x[:-2] for x in example["available_services"] if call["service"] not in x ]
|
||||
available_services = [ x[:-2] for x in example["available_tools"] if call["service"] not in x ]
|
||||
hallucinated_services = [ x["hallucinated_service"] for x in pile_of_hallucinated_service_names if x["real_service"] == call["service"].split(".")[1]]
|
||||
random_service = random.choice(available_services + hallucinated_services)
|
||||
random_argument = None # based on the service, add arguments that might be there like rgb, temperature, etc
|
||||
@@ -823,7 +893,7 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i
|
||||
|
||||
device_name = target_device["device_name"]
|
||||
device_type = device_name.split(".")[0]
|
||||
random_device_services = [ x for x in example["available_services"] if x.split(".")[0] == device_type ]
|
||||
random_device_services = [ x for x in example["available_tools"] if x.split(".")[0] == device_type ]
|
||||
|
||||
if len(random_device_services) == 0:
|
||||
raise NoServicesAvailableException()
|
||||
@@ -835,45 +905,12 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i
|
||||
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
|
||||
pass
|
||||
|
||||
def format_example_raw_chatml(example, persona, language, use_system_role):
|
||||
"""Don't use this one anymore"""
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||
question = example["question"]
|
||||
answers = " ".join(example["answers"])
|
||||
|
||||
if use_system_role:
|
||||
system_block = "\n".join([ "<|im_start|>system", sys_prompt, services_block, states_block ]) + "<|im_end|>"
|
||||
user_block = "\n".join([ "<|im_start|>user", question]) + "<|im_end|>"
|
||||
else:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
system_block = ""
|
||||
user_block = "\n".join([ "<|im_start|>user", sys_prompt, services_block, states_block, user_instruction_words, question]) + "<|im_end|>"
|
||||
|
||||
assistant_block = "<|im_start|>assistant\n" + answers
|
||||
if len(example["service_calls"]) > 0:
|
||||
json_calls = [ json.dumps(x) for x in example["service_calls"] ]
|
||||
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
|
||||
assistant_block = assistant_block + code_block
|
||||
assistant_block = assistant_block + "<|im_end|>"
|
||||
|
||||
example_lines = [system_block, user_block, assistant_block]
|
||||
result = "\n".join(example_lines)
|
||||
if "<device_name" in result:
|
||||
print("bad templating")
|
||||
|
||||
# replace aliases with their actual values
|
||||
result = result.replace("blinds.", "cover.")
|
||||
result = result.replace("garage_door.", "cover.")
|
||||
return { "text": result }
|
||||
|
||||
def format_example_sharegpt(example, persona, language, use_system_role):
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
random_datetime = generate_random_datetime()
|
||||
translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])
|
||||
time_block = f"{CURRENT_DATE_PROMPT[language]} {translate_datetime}"
|
||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_tools"]))
|
||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||
question = example["question"]
|
||||
answers = " ".join(example["answers"])
|
||||
@@ -887,18 +924,18 @@ def format_example_sharegpt(example, persona, language, use_system_role):
|
||||
# replace aliases with their actual values
|
||||
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
tools_block = tools_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
|
||||
if use_system_role:
|
||||
conversation = [
|
||||
{ "from": "system", "value": "\n".join([ sys_prompt, time_block, services_block, states_block ])},
|
||||
{ "from": "system", "value": "\n".join([ sys_prompt, time_block, tools_block, states_block ])},
|
||||
{ "from": "user", "value": question },
|
||||
{ "from": "assistant", "value": assistant_block },
|
||||
]
|
||||
else:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
conversation = [
|
||||
{ "from": "user", "value": "\n".join([ sys_prompt, time_block, services_block, states_block, user_instruction_words, question ]) },
|
||||
{ "from": "user", "value": "\n".join([ sys_prompt, time_block, tools_block, states_block, user_instruction_words, question ]) },
|
||||
{ "from": "assistant", "value": assistant_block },
|
||||
]
|
||||
|
||||
@@ -909,7 +946,7 @@ def format_example_dpo(example, persona, language):
|
||||
example = example["accepted"]
|
||||
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_tools"]))
|
||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||
question = example["question"]
|
||||
|
||||
@@ -1045,7 +1082,7 @@ def format_alpaca(example, format_func: Callable):
|
||||
|
||||
text = format_func(example={
|
||||
"states": device_list,
|
||||
"available_services": list(available_services),
|
||||
"available_tools": list(available_services),
|
||||
"question": question,
|
||||
"answers": [ answer ],
|
||||
"service_calls": []
|
||||
@@ -1181,10 +1218,7 @@ def main(args=None):
|
||||
print("Train size was provided but not generating the training set!")
|
||||
exit(-1)
|
||||
|
||||
if not args.format or args.format == "raw":
|
||||
format_func = format_example_raw_chatml
|
||||
elif args.format == "sharegpt":
|
||||
format_func = format_example_sharegpt
|
||||
format_func = format_example_sharegpt
|
||||
|
||||
use_system_role = not args.no_system_role
|
||||
|
||||
|
||||
Reference in New Issue
Block a user