start re-writing dataset generation to use the new HA Assist API

This commit is contained in:
Alex O'Connell
2025-11-26 19:08:46 -05:00
parent 00cc5a4b57
commit a16523f9e5

View File

@@ -15,7 +15,7 @@ from typing import Final, Any, Callable, Optional
from tqdm import tqdm
import webcolors
# #### STATES ####
# STATES
STATE_ON: Final = "on"
STATE_OFF: Final = "off"
STATE_ACTIVE: Final = "active"
@@ -41,6 +41,121 @@ STATE_CLEANING: Final = "cleaning"
STATE_DOCKED: Final = "docked"
STATE_RETURNING: Final = "returning"
# TOOLS
TOOL_TURN_ON = "HassTurnOn"
TOOL_TURN_OFF = "HassTurnOff"
TOOL_SET_POSITION = "HassSetPosition"
TOOL_LIGHT_SET = "HassLightSet"
TOOL_SET_VOLUME = "HassSetVolume" # not implemented yet
TOOL_MEDIA_UNPAUSE = "HassMediaUnpause"
TOOL_MEDIA_PAUSE = "HassMediaPause"
TOOL_MEDIA_NEXT = "HassMediaNext" # not implemented yet
TOOL_VACUUM_START = "HassVacuumStart"
TOOL_VACUUM_RETURN_TO_BASE = "HassVacuumReturnToBase"
TOOL_LIST_ADD_ITEM = "HassListAddItem"
TOOL_START_TIMER = "HassStartTimer" # ignored if timers unsupported
TOOL_CANCEL_TIMER = "HassCancelTimer" # ignored if timers unsupported
TOOL_PAUSE_TIMER = "HassPauseTimer" # ignored if timers unsupported
TOOL_UNPAUSE_TIMER = "HassUnpauseTimer" # ignored if timers unsupported
TOOL_SET_HUMIDITY = "HassHumidifierSetpoint"
TOOL_SET_HUMIDIFIER_MODE = "HassHumidifierMode"
# TOOLS
HASS_TOOLS = [
{
"name": TOOL_TURN_ON,
"description": "Turns on/opens a device or entity",
"parameters": {
"properties": { "name": "string", "area": "string", "floor": "string" },
"required": []
}
},
{
"name": TOOL_TURN_OFF,
"description": "Turns off/closes a device or entity",
"parameters": {
"properties": { "name": "string", "area": "string", "floor": "string" },
"required": []
}
},
{
"name": TOOL_SET_POSITION,
"description": "Sets the position of a device or entity",
"parameters": {
"properties": { "name": "string", "area": "string", "floor": "string", "position": "integer" },
"required": [ "position"
]
}
},
{
"name": TOOL_LIST_ADD_ITEM,
"description": "Add item to a todo list",
"parameters": {
"properties": { "item": "string", "name": "string" },
"required": []
}
},
{
"name": TOOL_SET_HUMIDITY,
"description": "Set desired humidity level",
"parameters": {
"properties": { "name": "string", "humidity": "integer" },
"required": [ "name", "humidity"
]
}
},
{
"name": TOOL_SET_HUMIDIFIER_MODE,
"description": "Set humidifier mode",
"parameters": {
"properties": { "name": "string", "mode": "string" },
"required": [ "name", "mode"
]
}
},
{
"name": TOOL_LIGHT_SET,
"description": "Sets the brightness or color of a light",
"parameters": {
"properties": { "name": "string", "area": "string", "floor": "string", "color": "string", "temperature": "integer", "brightness": "integer" },
"required": []
}
},
{
"name": TOOL_MEDIA_UNPAUSE,
"description": "Resumes a media player",
"parameters": {
"properties": { "name": "string", "area": "string", "floor": "string" },
"required": []
}
},
{
"name": TOOL_MEDIA_PAUSE,
"description": "Pauses a media player",
"parameters": {
"properties": { "name": "string", "area": "string", "floor": "string" },
"required": []
}
},
{
"name": TOOL_VACUUM_START,
"description": "Starts a vacuum",
"parameters": {
"properties": { "name": "string", "area": "string", "floor": "string" },
"required": []
}
},
{
"name": TOOL_VACUUM_RETURN_TO_BASE,
"description": "Returns a vacuum to base",
"parameters": {
"properties": { "name": "string", "area": "string", "floor": "string" },
"required": []
}
}
]
# define piles for global access
pile_of_durations = None
pile_of_media_names = None
@@ -83,96 +198,83 @@ def get_included_vars(response: str):
return ",".join(sorted(result))
RANDOM_PARAMETER_GENERATORS = {
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
"brightness": lambda: random.randint(0, 100),
"fan_mode": lambda: random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"]),
"temp_f": lambda: random.randint(60, 80),
"temp_c": lambda: random.randint(15, 25),
"humidity": lambda: random.randint(10, 90),
"preset_mode": lambda: random.choice(["home", "eco", "away", "auto"]),
"hvac_mode": lambda: random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"]),
"media": lambda: random.choice(pile_of_media_names),
"volume": lambda: round(random.random(), 2),
"duration": lambda: random.choice(list(pile_of_durations.keys())),
"remaining": lambda: f"{random.randint(0, 3):02}:{random.randint(0, 60)}:{random.randint(0, 60)}",
"todo": lambda: random.choice(pile_of_todo_items),
}
def generate_random_parameter(param_name):
param_generator = RANDOM_PARAMETER_GENERATORS.get(param_name)
if not param_generator:
raise Exception(f"Unknown param to generate random value for {param_name}")
return param_generator()
@dataclass
class DeviceType:
name: str
possible_states: list[(str, float)]
services: dict[str, list]
random_parameter_generator: Optional[dict[str, Callable]] = None
def get_all_services(self, extra_exposed_attributes):
result = []
for service in self.services.keys():
args = set(extra_exposed_attributes).intersection(self.services[service])
result.append(f"{self.name}.{service}({','.join(args)})")
return result
def get_random_parameter(self, param_name):
return self.random_parameter_generator[param_name]()
def get_random_state(self, extra_exposed_attributes=[]):
states = [ x[0] for x in self.possible_states ]
weights = [ x[1] for x in self.possible_states ]
return random.choices(states, weights=weights, k=1)[0]
# TODO: make services into a global "tools" concept since tools are not specific to device types
class LightDeviceType(DeviceType):
def __init__(self):
super().__init__("light",
possible_states=[
(STATE_ON, 0.5),
(STATE_OFF, 0.5)
],
services={
"turn_on": [ "rgb_color", "brightness" ],
"turn_off": [],
"toggle": []
},
random_parameter_generator={
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
"brightness": lambda: random.randint(0, 100),
}
]
)
def get_random_state(self, extra_exposed_attributes=[]):
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
random_rgb = self.get_random_parameter("rgb_color")
random_rgb = generate_random_parameter("rgb_color")
state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb)
if random.random() < 0.7 and "brightness" in extra_exposed_attributes:
state = state + ";" + str(self.get_random_parameter("brightness")) + "%"
state = state + ";" + str(generate_random_parameter("brightness")) + "%"
return state
class ClimateDeviceType(DeviceType):
def __init__(self):
super().__init__("climate", [], {
"turn_on": [],
"turn_off": [],
"toggle": [],
"set_temperature": ["temperature"],
"set_humidity": ["humidity"],
"set_fan_mode": ["fan_mode"],
"set_hvac_mode": ["hvac_mode"],
"set_preset_mode": ["preset_mode"]
},
random_parameter_generator={
"fan_mode": lambda: random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"]),
"temp_f": lambda: random.randint(60, 80),
"temp_c": lambda: random.randint(15, 25),
"humidity": lambda: random.randint(10, 90),
"preset_mode": lambda: random.choice(["home", "eco", "away", "auto"]),
"hvac_mode": lambda: random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"]),
})
super().__init__("climate", [])
def get_random_state(self, extra_exposed_attributes=[]):
"""state;fan_mode;temperature;humidity"""
state = self.get_random_parameter("hvac_mode")
state = generate_random_parameter("hvac_mode")
if "fan_mode" in extra_exposed_attributes:
state = state + ";" + self.get_random_parameter("fan_mode")
state = state + ";" + generate_random_parameter("fan_mode")
if "temperature" in extra_exposed_attributes:
if random.random() > 0.5:
state = state + ";" + str(self.get_random_parameter("temp_f")) + "F"
state = state + ";" + str(generate_random_parameter("temp_f")) + "F"
else:
state = state + ";" + str(self.get_random_parameter("temp_c")) + "C"
state = state + ";" + str(generate_random_parameter("temp_c")) + "C"
if "humidity" in extra_exposed_attributes:
state = state + ";" + str(self.get_random_parameter("humidity")) + "%"
state = state + ";" + str(generate_random_parameter("humidity")) + "%"
if random.random() < 0.8 and "preset_mode" in extra_exposed_attributes:
# if it is not "on a preset" then don't add the mode
state = state + ";" + self.get_random_parameter("preset_mode")
state = state + ";" + generate_random_parameter("preset_mode")
return state
@@ -186,33 +288,16 @@ class MediaPlayerDeviceType(DeviceType):
(STATE_PAUSED, 0.05),
(STATE_STANDBY, 0.05),
(STATE_BUFFERING, 0.01),
], {
"turn_on": [],
"turn_off": [],
"toggle": [],
"volume_up": [],
"volume_down": [],
"volume_mute": [],
"media_play_pause": [],
"media_play": [],
"media_pause": [],
"media_stop": [],
"media_next_track": [],
"media_previous_track": []
},
random_parameter_generator={
"media": lambda: random.choice(pile_of_media_names),
"volume": lambda: round(random.random(), 2),
})
])
def get_random_state(self, extra_exposed_attributes=[]):
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
state = state + ";" + self.get_random_parameter("media")
state = state + ";" + generate_random_parameter("media")
if "volume_level" in extra_exposed_attributes and state != STATE_OFF:
state = state + ";vol=" + str(self.get_random_parameter("volume"))
state = state + ";vol=" + str(generate_random_parameter("volume"))
return state
SUPPORTED_DEVICES = {
@@ -308,26 +393,11 @@ SUPPORTED_DEVICES = {
(STATE_IDLE, 0.2),
(STATE_ACTIVE, 0.6),
(STATE_PAUSED, 0.1),
],
services={
"start": ["duration"],
"pause": [],
"cancel": [],
},
random_parameter_generator={
"duration": lambda: random.choice(list(pile_of_durations.keys())),
"remaining": lambda: f"{random.randint(0, 3):02}:{random.randint(0, 60)}:{random.randint(0, 60)}"
}
]
),
"todo": DeviceType(
name="todo",
possible_states=[ (f"{i}", (1/32)) for i in range(32) ],
services={
"add_item": ["item"],
},
random_parameter_generator={
"todo": lambda: random.choice(pile_of_todo_items),
}
),
}
@@ -500,7 +570,7 @@ def generate_static_example(action: dict, persona: str, max_devices: int = 32):
return {
"states": device_list,
"available_services": list(available_services),
"available_tools": list(available_services),
"question": question.lower(),
"answers": [ response ],
"service_calls": [ { "service": service_name, "target_device": target_device } ]
@@ -660,7 +730,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
return {
"states": device_list,
"available_services": list(available_services),
"available_tools": list(available_services),
"question": question.lower(),
"answers": [ sentence.lower() for sentence in answer ],
"service_calls": service_calls
@@ -749,7 +819,7 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32,
result = {
"states": device_list,
"available_services": list(available_services),
"available_tools": list(available_services),
"question": question.lower(),
"answers": [ answer.lower() ],
"service_calls": []
@@ -786,7 +856,7 @@ def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int =
random_device = None
# random service should probably be "related"
available_services = [ x[:-2] for x in example["available_services"] if call["service"] not in x ]
available_services = [ x[:-2] for x in example["available_tools"] if call["service"] not in x ]
hallucinated_services = [ x["hallucinated_service"] for x in pile_of_hallucinated_service_names if x["real_service"] == call["service"].split(".")[1]]
random_service = random.choice(available_services + hallucinated_services)
random_argument = None # based on the service, add arguments that might be there like rgb, temperature, etc
@@ -823,7 +893,7 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i
device_name = target_device["device_name"]
device_type = device_name.split(".")[0]
random_device_services = [ x for x in example["available_services"] if x.split(".")[0] == device_type ]
random_device_services = [ x for x in example["available_tools"] if x.split(".")[0] == device_type ]
if len(random_device_services) == 0:
raise NoServicesAvailableException()
@@ -835,45 +905,12 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
pass
def format_example_raw_chatml(example, persona, language, use_system_role):
"""Don't use this one anymore"""
sys_prompt = pile_of_system_prompts[persona]
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
question = example["question"]
answers = " ".join(example["answers"])
if use_system_role:
system_block = "\n".join([ "<|im_start|>system", sys_prompt, services_block, states_block ]) + "<|im_end|>"
user_block = "\n".join([ "<|im_start|>user", question]) + "<|im_end|>"
else:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
system_block = ""
user_block = "\n".join([ "<|im_start|>user", sys_prompt, services_block, states_block, user_instruction_words, question]) + "<|im_end|>"
assistant_block = "<|im_start|>assistant\n" + answers
if len(example["service_calls"]) > 0:
json_calls = [ json.dumps(x) for x in example["service_calls"] ]
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
assistant_block = assistant_block + code_block
assistant_block = assistant_block + "<|im_end|>"
example_lines = [system_block, user_block, assistant_block]
result = "\n".join(example_lines)
if "<device_name" in result:
print("bad templating")
# replace aliases with their actual values
result = result.replace("blinds.", "cover.")
result = result.replace("garage_door.", "cover.")
return { "text": result }
def format_example_sharegpt(example, persona, language, use_system_role):
sys_prompt = pile_of_system_prompts[persona]
random_datetime = generate_random_datetime()
translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])
time_block = f"{CURRENT_DATE_PROMPT[language]} {translate_datetime}"
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_tools"]))
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
question = example["question"]
answers = " ".join(example["answers"])
@@ -887,18 +924,18 @@ def format_example_sharegpt(example, persona, language, use_system_role):
# replace aliases with their actual values
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
tools_block = tools_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
if use_system_role:
conversation = [
{ "from": "system", "value": "\n".join([ sys_prompt, time_block, services_block, states_block ])},
{ "from": "system", "value": "\n".join([ sys_prompt, time_block, tools_block, states_block ])},
{ "from": "user", "value": question },
{ "from": "assistant", "value": assistant_block },
]
else:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
conversation = [
{ "from": "user", "value": "\n".join([ sys_prompt, time_block, services_block, states_block, user_instruction_words, question ]) },
{ "from": "user", "value": "\n".join([ sys_prompt, time_block, tools_block, states_block, user_instruction_words, question ]) },
{ "from": "assistant", "value": assistant_block },
]
@@ -909,7 +946,7 @@ def format_example_dpo(example, persona, language):
example = example["accepted"]
sys_prompt = pile_of_system_prompts[persona]
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_tools"]))
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
question = example["question"]
@@ -1045,7 +1082,7 @@ def format_alpaca(example, format_func: Callable):
text = format_func(example={
"states": device_list,
"available_services": list(available_services),
"available_tools": list(available_services),
"question": question,
"answers": [ answer ],
"service_calls": []
@@ -1181,10 +1218,7 @@ def main(args=None):
print("Train size was provided but not generating the training set!")
exit(-1)
if not args.format or args.format == "raw":
format_func = format_example_raw_chatml
elif args.format == "sharegpt":
format_func = format_example_sharegpt
format_func = format_example_sharegpt
use_system_role = not args.no_system_role