mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
refactor dataset generation code + add new synthesis script
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -3,8 +3,8 @@ loras/
|
||||
core/
|
||||
config/
|
||||
.DS_Store
|
||||
data/*.json
|
||||
data/*.jsonl
|
||||
data/**/*.json
|
||||
data/**/*.jsonl
|
||||
*.pyc
|
||||
main.log
|
||||
.venv
|
||||
|
||||
@@ -20,7 +20,7 @@ This dataset contains a list of requests and responses for a user interacting wi
|
||||
|
||||
This dataset is NOT distributed as a static file, but as a Python script. This is due to the multitude of different formats that are used in the LLM fine-tuning ecosystem. The goal is to be able to support using this dataset to fine-tune any desired model, and to support that, you need to be able to generate the dataset in the exact format that matches the model you want to fine-tune.
|
||||
|
||||
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/<language>/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/<language>/pile_of_templated_actions.csv` and `piles/<language>/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py).
|
||||
The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/<language>/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/<language>/pile_of_templated_actions.csv` and `piles/<language>/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_data.py](./generate_data.py).
|
||||
|
||||
## Prepare environment
|
||||
|
||||
@@ -36,16 +36,15 @@ pip3 install pandas==2.2.2 datasets==2.20.0 webcolors==1.13 babel==2.15.0
|
||||
|
||||
## Generating the dataset from piles
|
||||
|
||||
`python3 generate_home_assistant_data.py --train --test --large --sharegpt`
|
||||
`python3 generate_data.py --train --test --large --sharegpt`
|
||||
|
||||
Supported dataset splits are `--test`, `--train`, & `--sample`
|
||||
Arguments to set the train dataset size are `--small`, `--medium`, `--large`, & `--xl`.
|
||||
Supported formats are `--raw_corpus` (chatml formatted) & `--sharegpt`
|
||||
Languages can be enabled using `--language english german french spanish polish`
|
||||
|
||||
## Merging with other instruct-datasets for training
|
||||
|
||||
`python3 generate_home_assistant_data.py --merge <dataset>`
|
||||
`python3 generate_data.py --merge <dataset>`
|
||||
|
||||
Supported datasets right now are:
|
||||
- `alpaca`
|
||||
@@ -70,7 +69,7 @@ Generating the full dataset using the python script will print out a warning for
|
||||
|
||||
## Adding new Home Assistant functionality
|
||||
TODO
|
||||
<!-- In order to add new home assistant device types, you will need to add data to a handful of piles, as well as make small modifications to the `generate_home_assistant_data.py` script.
|
||||
<!-- In order to add new home assistant device types, you will need to add data to a handful of piles, as well as make small modifications to the `generate_data.py` script.
|
||||
1. Add 15-30 new device names with the new type to the `pile_of_device_names.csv`. This should be an entity_id and a 'friendly name'
|
||||
2. Add
|
||||
-->
|
||||
|
||||
224
data/device_types.py
Normal file
224
data/device_types.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, Callable, List
|
||||
|
||||
from tools import *
|
||||
from utils import closest_color, generate_random_parameter, get_dataset_piles
|
||||
|
||||
# STATES
|
||||
STATE_ON: Final = "on"
|
||||
STATE_OFF: Final = "off"
|
||||
STATE_ACTIVE: Final = "active"
|
||||
STATE_UNKNOWN: Final = "unknown"
|
||||
STATE_OPEN: Final = "open"
|
||||
STATE_OPENING: Final = "opening"
|
||||
STATE_CLOSED: Final = "closed"
|
||||
STATE_CLOSING: Final = "closing"
|
||||
STATE_BUFFERING: Final = "buffering"
|
||||
STATE_PLAYING: Final = "playing"
|
||||
STATE_PAUSED: Final = "paused"
|
||||
STATE_IDLE: Final = "idle"
|
||||
STATE_STANDBY: Final = "standby"
|
||||
STATE_LOCKED: Final = "locked"
|
||||
STATE_UNLOCKED: Final = "unlocked"
|
||||
STATE_LOCKING: Final = "locking"
|
||||
STATE_UNLOCKING: Final = "unlocking"
|
||||
STATE_JAMMED: Final = "jammed"
|
||||
STATE_UNAVAILABLE: Final = "unavailable"
|
||||
STATE_OK: Final = "ok"
|
||||
STATE_PROBLEM: Final = "problem"
|
||||
STATE_CLEANING: Final = "cleaning"
|
||||
STATE_DOCKED: Final = "docked"
|
||||
STATE_RETURNING: Final = "returning"
|
||||
|
||||
@dataclass
|
||||
class DeviceType:
|
||||
name: str
|
||||
possible_states: list[tuple[str, float]]
|
||||
|
||||
def get_random_state(self, language: str, extra_exposed_attributes=[]):
|
||||
states = [ x[0] for x in self.possible_states ]
|
||||
weights = [ x[1] for x in self.possible_states ]
|
||||
return random.choices(states, weights=weights, k=1)[0]
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
"""Return list of tool names available for this device type."""
|
||||
tools = [TOOL_TURN_ON, TOOL_TURN_OFF, TOOL_TOGGLE]
|
||||
return tools
|
||||
|
||||
def get_random_parameter(self, param_name: str, language: str):
|
||||
"""Generate a random parameter value."""
|
||||
return generate_random_parameter(param_name, get_dataset_piles(language))
|
||||
|
||||
|
||||
class LightDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("light",
|
||||
possible_states=[
|
||||
(STATE_ON, 0.5),
|
||||
(STATE_OFF, 0.5)
|
||||
]
|
||||
)
|
||||
|
||||
def get_random_state(self, language: str, extra_exposed_attributes=[]):
|
||||
state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
|
||||
random_rgb = generate_random_parameter("rgb_color", get_dataset_piles(language))
|
||||
state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb)
|
||||
|
||||
if random.random() < 0.7 and "brightness" in extra_exposed_attributes:
|
||||
state = state + ";" + str(generate_random_parameter("brightness", get_dataset_piles(language))) + "%"
|
||||
return state
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
"""Return list of tool names available for lights."""
|
||||
tools = [TOOL_TURN_ON, TOOL_TURN_OFF, TOOL_TOGGLE]
|
||||
if "brightness" in extra_exposed_attributes or "rgb_color" in extra_exposed_attributes:
|
||||
tools.append(TOOL_LIGHT_SET)
|
||||
return tools
|
||||
|
||||
class SwitchDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("switch", [
|
||||
(STATE_ON, 0.5),
|
||||
(STATE_OFF, 0.5)
|
||||
])
|
||||
|
||||
class FanDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("fan", [
|
||||
(STATE_ON, 0.5),
|
||||
(STATE_OFF, 0.5)
|
||||
])
|
||||
|
||||
class GarageDoorDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("garage_door", [
|
||||
(STATE_OPEN, 0.49),
|
||||
(STATE_CLOSED, 0.49),
|
||||
(STATE_OPENING, 0.01),
|
||||
(STATE_CLOSING, 0.01)
|
||||
])
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
tools = [TOOL_TURN_ON, TOOL_TURN_OFF, TOOL_TOGGLE]
|
||||
if "position" in extra_exposed_attributes:
|
||||
tools.append(TOOL_SET_POSITION)
|
||||
return tools
|
||||
|
||||
class BlindsDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("blinds", [
|
||||
(STATE_OPEN, 0.49),
|
||||
(STATE_CLOSED, 0.49),
|
||||
(STATE_OPENING, 0.01),
|
||||
(STATE_CLOSING, 0.01)
|
||||
])
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
tools = [TOOL_TURN_ON, TOOL_TURN_OFF, TOOL_TOGGLE]
|
||||
if "position" in extra_exposed_attributes:
|
||||
tools.append(TOOL_SET_POSITION)
|
||||
return tools
|
||||
|
||||
class LockDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("lock", [
|
||||
(STATE_LOCKED, 0.5),
|
||||
(STATE_UNLOCKED, 0.5),
|
||||
])
|
||||
|
||||
class VacuumDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("vacuum", [
|
||||
(STATE_CLEANING, 0.2),
|
||||
(STATE_DOCKED, 0.6),
|
||||
(STATE_RETURNING, 0.1),
|
||||
(STATE_IDLE, 0.05),
|
||||
(STATE_PAUSED, 0.05),
|
||||
])
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
return [TOOL_VACUUM_START, TOOL_VACUUM_RETURN_TO_BASE]
|
||||
|
||||
class TimerDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("timer", [
|
||||
(STATE_IDLE, 0.2),
|
||||
(STATE_ACTIVE, 0.6),
|
||||
(STATE_PAUSED, 0.1),
|
||||
])
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
tools = [TOOL_START_TIMER, TOOL_CANCEL_TIMER, TOOL_PAUSE_TIMER, TOOL_UNPAUSE_TIMER]
|
||||
if "duration" in extra_exposed_attributes:
|
||||
tools.extend([TOOL_INCREASE_TIMER, TOOL_DECREASE_TIMER, TOOL_TIMER_STATUS])
|
||||
return tools
|
||||
|
||||
class TodoDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("todo", [ (f"{i}", (1/32)) for i in range(32) ],)
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
return [TOOL_LIST_ADD_ITEM]
|
||||
|
||||
class ClimateDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("climate", [])
|
||||
|
||||
def get_random_state(self, language: str, extra_exposed_attributes=[]):
|
||||
"""state;fan_mode;temperature;humidity"""
|
||||
state = generate_random_parameter("hvac_mode", get_dataset_piles(language))
|
||||
|
||||
if "fan_mode" in extra_exposed_attributes:
|
||||
state = state + ";" + generate_random_parameter("fan_mode", get_dataset_piles(language))
|
||||
if "temperature" in extra_exposed_attributes:
|
||||
if random.random() > 0.5:
|
||||
state = state + ";" + str(generate_random_parameter("temp_f", get_dataset_piles(language))) + "F"
|
||||
else:
|
||||
state = state + ";" + str(generate_random_parameter("temp_c", get_dataset_piles(language))) + "C"
|
||||
if "humidity" in extra_exposed_attributes:
|
||||
state = state + ";" + str(generate_random_parameter("humidity", get_dataset_piles(language))) + "%"
|
||||
if random.random() < 0.8 and "preset_mode" in extra_exposed_attributes:
|
||||
# if it is not "on a preset" then don't add the mode
|
||||
state = state + ";" + generate_random_parameter("preset_mode", get_dataset_piles(language))
|
||||
|
||||
return state
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
"""Return list of tool names available for climate devices."""
|
||||
tools = [TOOL_TURN_ON, TOOL_TURN_OFF]
|
||||
if "temperature" in extra_exposed_attributes or "fan_mode" in extra_exposed_attributes:
|
||||
tools.append(TOOL_CLIMATE_SET_TEMPERATURE)
|
||||
if "humidity" in extra_exposed_attributes:
|
||||
tools.extend([TOOL_SET_HUMIDITY, TOOL_SET_HUMIDIFIER_MODE])
|
||||
return tools
|
||||
|
||||
class MediaPlayerDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("media_player", [
|
||||
(STATE_ON, 0.15),
|
||||
(STATE_OFF, 0.54),
|
||||
(STATE_IDLE, 0.1),
|
||||
(STATE_PLAYING, 0.1),
|
||||
(STATE_PAUSED, 0.05),
|
||||
(STATE_STANDBY, 0.05),
|
||||
(STATE_BUFFERING, 0.01),
|
||||
])
|
||||
|
||||
def get_random_state(self, language: str, extra_exposed_attributes=[]):
|
||||
state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
|
||||
state = state + ";" + generate_random_parameter("media", get_dataset_piles(language))
|
||||
if "volume_level" in extra_exposed_attributes and state != STATE_OFF:
|
||||
state = state + ";vol=" + str(generate_random_parameter("volume", get_dataset_piles(language))) + "%"
|
||||
return state
|
||||
|
||||
def get_all_tools(self, extra_exposed_attributes: List[str]):
|
||||
"""Return list of tool names available for media players."""
|
||||
tools = [TOOL_TURN_ON, TOOL_TURN_OFF, TOOL_MEDIA_PAUSE, TOOL_MEDIA_UNPAUSE, TOOL_MEDIA_NEXT]
|
||||
if "volume_level" in extra_exposed_attributes:
|
||||
tools.append(TOOL_SET_VOLUME)
|
||||
return tools
|
||||
647
data/generate_data.py
Normal file
647
data/generate_data.py
Normal file
@@ -0,0 +1,647 @@
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import random
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Callable
|
||||
from tqdm import tqdm
|
||||
import webcolors
|
||||
|
||||
from device_types import *
|
||||
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
|
||||
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException
|
||||
|
||||
SUPPORTED_DEVICES = {
|
||||
"light": LightDeviceType(),
|
||||
"switch": SwitchDeviceType(),
|
||||
"fan": FanDeviceType(),
|
||||
"garage_door": GarageDoorDeviceType(),
|
||||
"blinds": BlindsDeviceType(),
|
||||
"lock": LockDeviceType(),
|
||||
"media_player": MediaPlayerDeviceType(),
|
||||
"climate": ClimateDeviceType(),
|
||||
"vacuum": VacuumDeviceType(),
|
||||
"timer": TimerDeviceType(),
|
||||
"todo": TodoDeviceType(),
|
||||
}
|
||||
|
||||
def format_device_line(*, device_name: str, friendly_name: str, state: str):
|
||||
return (f"{device_name} '{friendly_name}' = {state}")
|
||||
|
||||
# generate a random list of devices for the context
|
||||
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
|
||||
num_devices = random.randint(2, max_devices)
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
|
||||
|
||||
avoid_climate = False
|
||||
for avoid_device in avoid_device_names:
|
||||
avoid_type = avoid_device.split(".")[0]
|
||||
|
||||
filtered_possible_devices = []
|
||||
for possible_device in local_device_names[avoid_type]:
|
||||
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
|
||||
|
||||
if similarity_ratio < 0.4:
|
||||
filtered_possible_devices.append(possible_device)
|
||||
local_device_names[avoid_type] = filtered_possible_devices
|
||||
|
||||
if avoid_type == "climate":
|
||||
avoid_climate = True
|
||||
|
||||
possible_choices = []
|
||||
for device_type in local_device_names.keys():
|
||||
possible_choices.extend(local_device_names[device_type])
|
||||
|
||||
|
||||
device_types = set()
|
||||
device_list = []
|
||||
device_lines = []
|
||||
# TODO: randomly pick attributes for this list
|
||||
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
|
||||
|
||||
while len(device_list) < num_devices:
|
||||
choice = random.choice(possible_choices)
|
||||
if choice["device_name"] in device_list:
|
||||
continue
|
||||
|
||||
try:
|
||||
device_name = choice["device_name"]
|
||||
device_type = device_name.split(".")[0]
|
||||
friendly_name = choice["description"]
|
||||
|
||||
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
|
||||
if avoid_climate and device_type == "climate":
|
||||
continue
|
||||
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
device_lines.append(format_device_line(
|
||||
device_name=device_name,
|
||||
friendly_name=friendly_name,
|
||||
state=state
|
||||
))
|
||||
device_list.append(device_name)
|
||||
device_types.add(device_type)
|
||||
except Exception as ex:
|
||||
print(f"bad device name: {choice}")
|
||||
print(repr(ex))
|
||||
|
||||
return device_lines, list(device_types), list(extra_exposed_attributes)
|
||||
|
||||
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 32, use_service_names: bool = False):
|
||||
question = action["phrase"]
|
||||
service_name = action["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
target_device = f"{device_type}.{action['device_name']}"
|
||||
friendly_name = target_device.split(".")[1].replace("_", " ").title()
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||
max_devices=max_devices, avoid_device_names=[target_device], language=language)
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
index = random.randint(0, len(device_list))
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
device_list.insert(index, format_device_line(
|
||||
device_name=target_device,
|
||||
friendly_name=friendly_name,
|
||||
state=state
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
for x in set(device_types + [device_type]):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
|
||||
# Map service name to tool name
|
||||
service_action = service_name.split(".")[1]
|
||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
||||
|
||||
response = get_random_response(
|
||||
piles.pile_of_responses,
|
||||
service=service_name,
|
||||
persona=persona,
|
||||
question_template="",
|
||||
short=False
|
||||
).lower()
|
||||
|
||||
response = response.replace("<device_name>", friendly_name)
|
||||
|
||||
if use_service_names:
|
||||
tool_call = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": {"entity_id": target_device}
|
||||
}
|
||||
else:
|
||||
tool_call = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": {"name": target_device}
|
||||
}
|
||||
|
||||
if "arguments" in action and action["arguments"]:
|
||||
try:
|
||||
import json
|
||||
args = json.loads(action["arguments"])
|
||||
tool_call["tool_args"].update(args)
|
||||
except Exception as e:
|
||||
print(f"Failed to parse arguments for {action}: {e}")
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
"answers": [ response ],
|
||||
"tool_calls": [ tool_call ]
|
||||
}
|
||||
|
||||
def replace_answer(list_of_answer, var, value):
|
||||
new_list = []
|
||||
for answer in list_of_answer:
|
||||
new_list.append(answer.replace(var, value))
|
||||
return new_list
|
||||
|
||||
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 32, use_service_names: bool = False):
|
||||
template_device_types: list[str] = template["device_type"].split("|")
|
||||
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
||||
question_template: str = template["phrase"]
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
# choose a random device for this template
|
||||
chosen_devices = []
|
||||
for device_type in template_device_types:
|
||||
device_dict = random.choice(piles.stacks_of_device_names[device_type])
|
||||
device_dict["type"] = device_type
|
||||
chosen_devices.append(device_dict)
|
||||
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||
max_devices=max_devices, avoid_device_names=[d["device_name"] for d in chosen_devices])
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
for device_dict in chosen_devices:
|
||||
index = random.randint(0, len(device_list))
|
||||
if "<brightness>" in question_template and "brightness" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("brightness")
|
||||
if "<color>" in question_template and "rgb_color" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("rgb_color")
|
||||
if ("<temp_f>" in question_template or "<temp_c>" in question_template) \
|
||||
and "temperature" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("temperature")
|
||||
if "<humidity>" in question_template and "humidity" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("humidity")
|
||||
if "<fan_mode>" in question_template and "fan_mode" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("fan_mode")
|
||||
if "<duration>" in question_template and "duration" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("duration")
|
||||
|
||||
state = SUPPORTED_DEVICES[device_dict["type"]].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
device_name = device_dict["device_name"]
|
||||
friendly_name = device_dict["description"]
|
||||
|
||||
device_list.insert(index, format_device_line(
|
||||
device_name=device_name,
|
||||
friendly_name=friendly_name,
|
||||
state=state
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
for x in set(device_types + template_device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
|
||||
# pick an appropriate response and generate the question
|
||||
if len(template_device_types) == 1:
|
||||
answer_template = get_random_response(
|
||||
piles.pile_of_responses,
|
||||
service=service_names[0],
|
||||
persona=persona,
|
||||
question_template=question_template,
|
||||
short=False
|
||||
)
|
||||
|
||||
question = question_template.replace("<device_name>", chosen_devices[0]["description"])
|
||||
answer_list = [ answer_template.replace("<device_name>", chosen_devices[0]["description"]) ]
|
||||
else:
|
||||
question = question_template
|
||||
answers = []
|
||||
for i in range(len(template_device_types)):
|
||||
question = question.replace(f"<device_name{(i + 1)}>", chosen_devices[i]["description"])
|
||||
answer_response = get_random_response(
|
||||
piles.pile_of_responses,
|
||||
service=service_names[i],
|
||||
persona=persona,
|
||||
question_template=question_template,
|
||||
short=True
|
||||
)
|
||||
answers.append(answer_response.replace(f"<device_name>", chosen_devices[i]["description"]))
|
||||
|
||||
answer_list = []
|
||||
for word in piles.and_words:
|
||||
answer_list.append(f" {word} ".join(answers))
|
||||
|
||||
# generate the list of tool calls
|
||||
tool_calls = []
|
||||
for device_dict, service in zip(chosen_devices, service_names):
|
||||
service_action = service.split(".")[1]
|
||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
||||
tool_call = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service,
|
||||
"tool_args": {"entity_id" if use_service_names else "name": device_dict["device_name"] if use_service_names else device_dict["description"]}
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if any(["climate" in service for service in service_names ]):
|
||||
if "<hvac_mode>" in question:
|
||||
hvac_mode = generate_random_parameter("hvac_mode", piles)
|
||||
question = question.replace("<hvac_mode>", hvac_mode)
|
||||
answer_list = replace_answer(answer_list, "<hvac_mode>", hvac_mode)
|
||||
# Add hvac_mode as temperature parameter for climate tool
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_CLIMATE_SET_TEMPERATURE:
|
||||
call["tool_args"]["hvac_mode"] = hvac_mode
|
||||
|
||||
if "<fan_mode>" in question:
|
||||
fan_mode = generate_random_parameter("fan_mode", piles)
|
||||
question = question.replace("<fan_mode>", fan_mode)
|
||||
answer_list = replace_answer(answer_list, "<fan_mode>", fan_mode)
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_CLIMATE_SET_TEMPERATURE:
|
||||
call["tool_args"]["fan_mode"] = fan_mode
|
||||
|
||||
if "<temp_f>" in question:
|
||||
temp_f = generate_random_parameter("temp_f", piles)
|
||||
question = question.replace("<temp_f>", str(temp_f))
|
||||
answer_list = replace_answer(answer_list, "<temp_f>", str(temp_f))
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_CLIMATE_SET_TEMPERATURE:
|
||||
call["tool_args"]["temperature"] = temp_f
|
||||
|
||||
if "<temp_c>" in question:
|
||||
temp_c = generate_random_parameter("temp_c", piles)
|
||||
question = question.replace("<temp_c>", str(temp_c))
|
||||
answer_list = replace_answer(answer_list, "<temp_c>", str(temp_c))
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_CLIMATE_SET_TEMPERATURE:
|
||||
call["tool_args"]["temperature"] = temp_c
|
||||
|
||||
if "<humidity>" in question:
|
||||
humidity = generate_random_parameter("humidity", piles)
|
||||
question = question.replace("<humidity>", str(humidity))
|
||||
answer_list = replace_answer(answer_list, "<humidity>", str(humidity))
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_SET_HUMIDITY:
|
||||
call["tool_args"]["humidity"] = humidity
|
||||
|
||||
if any(["light" in service for service in service_names ]):
|
||||
if "<brightness>" in question:
|
||||
brightness = generate_random_parameter("brightness", piles)
|
||||
question = question.replace("<brightness>", str(brightness))
|
||||
answer_list = replace_answer(answer_list, "<brightness>", str(brightness))
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_LIGHT_SET:
|
||||
call["tool_args"]["brightness"] = brightness
|
||||
|
||||
if "<color>" in question:
|
||||
random_rgb = generate_random_parameter("rgb_color", piles)
|
||||
random_rgb_name = closest_color(random_rgb)
|
||||
question = question.replace("<color>", str(random_rgb_name))
|
||||
answer_list = replace_answer(answer_list, "<color>", str(random_rgb_name))
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_LIGHT_SET:
|
||||
call["tool_args"]["color"] = random_rgb_name
|
||||
|
||||
if any(["timer" in service for service in service_names ]):
|
||||
if "<duration>" in question:
|
||||
duration = generate_random_parameter("duration", piles)
|
||||
duration_name = piles.pile_of_durations[duration]
|
||||
question = question.replace("<duration>", duration_name)
|
||||
answer_list = replace_answer(answer_list, "<duration>", duration_name)
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_START_TIMER:
|
||||
call["tool_args"]["duration"] = str(duration)
|
||||
|
||||
if any(["todo" in service for service in service_names ]):
|
||||
if "<todo>" in question:
|
||||
todo = generate_random_parameter("todo", piles)
|
||||
question = question.replace("<todo>", todo)
|
||||
answer_list = replace_answer(answer_list, "<todo>", todo)
|
||||
for call in tool_calls:
|
||||
if call["tool_name"] == TOOL_LIST_ADD_ITEM:
|
||||
call["tool_args"]["item"] = todo
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
"answers": [ sentence.lower() for sentence in answer_list ],
|
||||
"tool_calls": tool_calls
|
||||
}
|
||||
|
||||
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 32, return_target_device: bool = False, use_service_names: bool = False):
|
||||
device_type: str = template["device_type"]
|
||||
state_name: str = template["state"]
|
||||
question_template: str = template["phrase"]
|
||||
answer_template: str = template["assistant_response"]
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
# choose a random device for this template
|
||||
chosen_device = random.choice(piles.stacks_of_device_names[device_type])
|
||||
|
||||
# build a random list of devices
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
|
||||
|
||||
# generate the question
|
||||
question = question_template.replace("<device_name>", chosen_device["description"])
|
||||
answer = answer_template.replace("<device_name>", chosen_device["description"])
|
||||
|
||||
# insert other templated variables
|
||||
if device_type == "climate":
|
||||
climate_device_type = SUPPORTED_DEVICES["climate"]
|
||||
temp_f = climate_device_type.get_random_parameter("temp_f", language)
|
||||
answer = answer.replace("<temp_f>", str(temp_f))
|
||||
state_name = state_name.replace("<temp_f>", str(temp_f))
|
||||
|
||||
temp_c = climate_device_type.get_random_parameter("temp_c", language)
|
||||
answer = answer.replace("<temp_c>", str(temp_c))
|
||||
state_name = state_name.replace("<temp_c>", str(temp_c))
|
||||
|
||||
humidity = climate_device_type.get_random_parameter("humidity", language)
|
||||
answer = answer.replace("<humidity>", str(humidity))
|
||||
state_name = state_name.replace("<humidity>", str(humidity))
|
||||
|
||||
if device_type == "light":
|
||||
light_device_type = SUPPORTED_DEVICES["light"]
|
||||
|
||||
brightness = light_device_type.get_random_parameter("brightness", language)
|
||||
answer = answer.replace("<brightness>", str(brightness))
|
||||
state_name = state_name.replace("<brightness>", str(brightness))
|
||||
|
||||
random_rgb = light_device_type.get_random_parameter("rgb_color", language)
|
||||
random_rgb_name = closest_color(random_rgb)
|
||||
actual_random_rgb = webcolors.name_to_rgb(random_rgb_name)
|
||||
actual_random_rgb = (actual_random_rgb.red, actual_random_rgb.green, actual_random_rgb.blue)
|
||||
state_name = state_name.replace("<color>", str(random_rgb_name) + " " + str(actual_random_rgb))
|
||||
answer = answer.replace("<color>", str(random_rgb_name))
|
||||
|
||||
if device_type == "media_player":
|
||||
media_player_device_type = SUPPORTED_DEVICES["media_player"]
|
||||
volume = media_player_device_type.get_random_parameter("volume", language)
|
||||
random_media = media_player_device_type.get_random_parameter("media", language)
|
||||
|
||||
answer = answer.replace("<volume>", str(volume) + "%")
|
||||
state_name = state_name.replace("<volume>", str(volume) + "%")
|
||||
|
||||
answer = answer.replace("<media>", random_media)
|
||||
state_name = state_name.replace("<media>", random_media)
|
||||
|
||||
if device_type == "timer":
|
||||
timer_device_type = SUPPORTED_DEVICES["timer"]
|
||||
duration = timer_device_type.get_random_parameter("duration", language)
|
||||
duration_name = piles.pile_of_durations[duration]
|
||||
remaining = timer_device_type.get_random_parameter("remaining", language)
|
||||
|
||||
answer = answer.replace("<duration>", duration_name)
|
||||
state_name = state_name.replace("<duration>", duration)
|
||||
|
||||
answer = answer.replace("<remaining>", remaining)
|
||||
state_name = state_name.replace("<remaining>", remaining)
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
index = random.randint(0, len(device_list))
|
||||
device_list.insert(index, format_device_line(
|
||||
device_name=chosen_device["device_name"],
|
||||
friendly_name=chosen_device["description"],
|
||||
state=state_name
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
for x in set(device_types + [device_type]):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
|
||||
result = {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
"answers": [ answer.lower() ],
|
||||
"tool_calls": []
|
||||
}
|
||||
if return_target_device:
|
||||
return result, chosen_device
|
||||
else:
|
||||
return result
|
||||
|
||||
def format_example_sharegpt(example, persona, language, use_system_role, use_service_names):
|
||||
piles = get_dataset_piles(language)
|
||||
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
||||
question = example["question"]
|
||||
answers = " ".join(example["answers"])
|
||||
|
||||
# Build assistant message with content blocks
|
||||
assistant_content = []
|
||||
|
||||
# Add text response
|
||||
assistant_content.append({
|
||||
"type": "text",
|
||||
"text": answers
|
||||
})
|
||||
|
||||
# Add tool use blocks if there are tool calls
|
||||
if len(example["tool_calls"]) > 0:
|
||||
for tool_call in example["tool_calls"]:
|
||||
# Use service_name if in service mode, otherwise use tool_name
|
||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||
assistant_content.append({
|
||||
"type": "tool_use",
|
||||
"name": call_name,
|
||||
"parameters": tool_call["tool_args"]
|
||||
})
|
||||
|
||||
if use_system_role:
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": sys_prompt
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": assistant_content
|
||||
},
|
||||
]
|
||||
else:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "\n".join([ sys_prompt, user_instruction_words, question ])
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": assistant_content
|
||||
},
|
||||
]
|
||||
|
||||
return {
|
||||
"conversations": conversation,
|
||||
"tools": SERVICE_TOOLS if use_service_names else HASS_TOOLS
|
||||
}
|
||||
|
||||
def generate_sft_file(filename: str, seed: int, format_func: Callable, use_system_role: bool, use_service_names: bool, personas: list[str], language: str, *, static_factor: float, template_factor: int, status_request_factor: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
print("Generating...")
|
||||
|
||||
def run_factor_times(func, examples, data, persona, factor, language):
|
||||
if factor >= 1:
|
||||
for i in range(factor):
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names))
|
||||
else:
|
||||
if random.random() < factor:
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names))
|
||||
|
||||
generated_examples = []
|
||||
|
||||
missing_responses = set()
|
||||
|
||||
for person in personas:
|
||||
for action in tqdm(piles.pile_of_specific_actions):
|
||||
try:
|
||||
run_factor_times(generate_static_example, generated_examples, action, person, static_factor, language)
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
for templated_action in tqdm(piles.pile_of_templated_actions):
|
||||
try:
|
||||
run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor, language)
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
for status_request in tqdm(piles.pile_of_status_requests):
|
||||
run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor, language)
|
||||
|
||||
print(f"Generated {len(generated_examples)} examples. Saving...")
|
||||
|
||||
for missing in sorted(missing_responses):
|
||||
print(missing)
|
||||
|
||||
with open(f"output/{filename}.jsonl", "w") as f:
|
||||
for item in generated_examples:
|
||||
json_record = json.dumps(item)
|
||||
f.write(json_record + '\n')
|
||||
|
||||
print("Done!")
|
||||
|
||||
def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset_column_names, format_func):
|
||||
alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1)
|
||||
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" })
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
alpaca_dataset = alpaca_dataset.map(format_function).remove_columns(dataset_column_names)
|
||||
|
||||
combined_dataset_train = concatenate_datasets([home_assistant_dataset["train"], alpaca_dataset["train"]]).shuffle(seed=42)
|
||||
combined_dataset_test = concatenate_datasets([home_assistant_dataset["test"], alpaca_dataset["test"]]).shuffle(seed=42)
|
||||
|
||||
combined_dataset_train.to_json(f"home_assistant_{output_name}_merged_train.jsonl")
|
||||
combined_dataset_test.to_json(f"home_assistant_{output_name}_merged_test.jsonl")
|
||||
|
||||
def merge_languages(filename_prefix: str, languages: list):
|
||||
all_examples = []
|
||||
for language in languages:
|
||||
with open(f"output/{filename_prefix}_{language}.jsonl") as f:
|
||||
all_examples.extend(f.readlines())
|
||||
|
||||
with open(f"output/{filename_prefix}.jsonl", "w") as f:
|
||||
f.writelines(all_examples)
|
||||
|
||||
|
||||
# TODO: add examples for ambiguous requests. asking a clarifying question
|
||||
# TODO: support rejection when asking to do a service that isn't exposed
|
||||
# TODO: make more randomized names for devices (random words or people's names)
|
||||
# TODO: answer questions about more than one thing in the state list at once
|
||||
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
|
||||
# TODO: add time, weather, and calendar/reminders (next 3 events?)
|
||||
def main(args=None):
|
||||
parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles")
|
||||
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
|
||||
parser.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. It will be combined with the user role")
|
||||
|
||||
train_size_group = parser.add_mutually_exclusive_group()
|
||||
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
|
||||
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
|
||||
train_size_group.add_argument('--large', action='store_const', const='large', dest='size')
|
||||
train_size_group.add_argument('--xl', action='store_const', const='xl', dest='size')
|
||||
|
||||
parser.add_argument('--use-service-names', action='store_true',
|
||||
help='Use service names (e.g., light.turn_on) instead of intent tool names (e.g., HassTurnOn)')
|
||||
|
||||
args = parser.parse_args(args=args)
|
||||
|
||||
if not args.sample and not args.train and not args.test and not args.merge:
|
||||
parser.print_usage()
|
||||
exit(-1)
|
||||
|
||||
if args.size and not args.train:
|
||||
print("Train size was provided but not generating the training set!")
|
||||
exit(-1)
|
||||
|
||||
format_func = format_example_sharegpt
|
||||
|
||||
use_system_role = not args.no_system_role
|
||||
use_service_names = args.use_service_names
|
||||
|
||||
for language in args.language:
|
||||
piles = get_dataset_piles(language)
|
||||
personas = list(piles.pile_of_system_prompts.keys())
|
||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||
|
||||
if args.sample:
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
if args.train:
|
||||
if args.size == "small":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
elif args.size == "medium":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
elif args.size == "large":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
elif args.size == "xl":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
else:
|
||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||
if args.test:
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, use_service_names, personas, language, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
|
||||
if len(args.language) > 1:
|
||||
if args.sample:
|
||||
merge_languages("sample", args.language)
|
||||
if args.train:
|
||||
merge_languages("home_assistant_train", args.language)
|
||||
if args.test:
|
||||
merge_languages("home_assistant_test", args.language)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
0
data/output/.gitkeep
Normal file
0
data/output/.gitkeep
Normal file
65
data/prompting.py
Normal file
65
data/prompting.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import babel.dates
|
||||
|
||||
from utils import generate_random_datetime
|
||||
|
||||
CURRENT_DATE_PROMPT = {
|
||||
"english": "The current time and date is",
|
||||
"polish": "Aktualna godzina i data to",
|
||||
"german": "Die aktuelle Uhrzeit und das aktuelle Datum sind",
|
||||
"french": "L'heure et la date actuelles sont",
|
||||
"spanish": "La hora y fecha actuales son"
|
||||
}
|
||||
|
||||
DEVICES_PROMPT = {
|
||||
"english": "Devices",
|
||||
"polish": "Urządzenia",
|
||||
"german": "Ger\u00e4te",
|
||||
"french": "Appareils",
|
||||
"spanish": "Dispositivos"
|
||||
}
|
||||
|
||||
SERVICES_PROMPT = {
|
||||
"english": "Services",
|
||||
"polish": "Usługi",
|
||||
"german": "Dienste",
|
||||
"french": "Services",
|
||||
"spanish": "Servicios"
|
||||
}
|
||||
|
||||
BABEL_LOCALE = {
|
||||
"english": "en_US",
|
||||
"polish": "pl_PL",
|
||||
"german": "de_DE",
|
||||
"french": "fr_FR",
|
||||
"spanish": "es_ES"
|
||||
}
|
||||
|
||||
BABEL_FORMAT = {
|
||||
"english": "h:m a 'on' EEEE, MMMM d yyyy",
|
||||
"polish": "H:m 'w' EEEE, d MMMM yyyy",
|
||||
"german": "H:m EEEE, d MMMM yyyy",
|
||||
"french": "H:m EEEE, d MMMM yyyy",
|
||||
"spanish": "H:m EEEE, d 'de' MMMM 'de' yyyy"
|
||||
}
|
||||
|
||||
USER_INSTRUCTION_PROMPT = {
|
||||
"english": "User instruction",
|
||||
"german": "Benutzeranweisung",
|
||||
"french": "Instruction de l'utilisateur ",
|
||||
"spanish": "Instrucción del usuario",
|
||||
"polish": "Instrukcja użytkownika"
|
||||
}
|
||||
|
||||
|
||||
def generate_system_prompt(example: dict, persona: str, language: str, pile_of_system_prompts: dict) -> str:
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
random_datetime = generate_random_datetime()
|
||||
translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])
|
||||
time_block = f"{CURRENT_DATE_PROMPT[language]} {translate_datetime}"
|
||||
|
||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||
|
||||
# replace aliases with their actual values
|
||||
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||
|
||||
return "\n".join([sys_prompt, time_block, states_block])
|
||||
567
data/synthesize.py
Normal file
567
data/synthesize.py
Normal file
@@ -0,0 +1,567 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import csv
|
||||
import json
|
||||
import random
|
||||
import aiohttp
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
from utils import get_dataset_piles
|
||||
|
||||
LLM_ENDPOINT = "https://ai.cloud.alexoconnell.net/v1/chat/completions"
|
||||
|
||||
class SyntheticDataGenerator:
|
||||
def __init__(self, model_name: str, language: str, concurrency: int):
|
||||
self.language = language
|
||||
self.concurrency = concurrency
|
||||
self.model_name = model_name
|
||||
self.piles = get_dataset_piles(language)
|
||||
self.synthetic_devices = {} # device_type -> list of {device_name, description}
|
||||
|
||||
async def generate_device_names(self, session, device_type, count=10):
|
||||
"""
|
||||
Generates a list of new device names for a given type.
|
||||
"""
|
||||
system_prompt = "You are a creative assistant that generates realistic smart home device names."
|
||||
user_prompt = f"Generate {count} realistic and diverse friendly names for a smart home device of type '{device_type}' (e.g. 'Kitchen Light', 'Porch Fan', 'Master Bedroom Blinds').\n" \
|
||||
f"Output ONLY the names, one per line. Do not number them. Do not include the device type if it's not part of the natural name."
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"temperature": 1.2,
|
||||
"max_tokens": 200,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(LLM_ENDPOINT, json=payload) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
content = data['choices'][0]['message']['content'].strip()
|
||||
names = [line.strip() for line in content.split('\n') if line.strip()]
|
||||
|
||||
# Add to synthetic devices
|
||||
if device_type not in self.synthetic_devices:
|
||||
self.synthetic_devices[device_type] = []
|
||||
|
||||
new_devices = []
|
||||
for name in names:
|
||||
# Create a fake entity ID
|
||||
slug = name.lower().replace(" ", "_").replace("'", "")
|
||||
entity_id = f"{device_type}.{slug}"
|
||||
device_entry = {
|
||||
"device_name": entity_id,
|
||||
"description": name
|
||||
}
|
||||
self.synthetic_devices[device_type].append(device_entry)
|
||||
new_devices.append(device_entry)
|
||||
|
||||
print(f"Generated {len(names)} new names for {device_type}")
|
||||
return new_devices
|
||||
else:
|
||||
print(f"Failed to generate device names: {response.status}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Device generation failed: {e}")
|
||||
return []
|
||||
|
||||
async def generate_phrase(self, session, context):
|
||||
"""
|
||||
Generates a user phrase for a given context (device, service, args).
|
||||
"""
|
||||
task_type = context.get("type", "action")
|
||||
device_name = context["device_name"]
|
||||
friendly_name = context["friendly_name"]
|
||||
|
||||
system_prompt = "You are a helpful assistant that generates synthetic training data for a smart home voice assistant. " \
|
||||
"Your goal is to generate diverse, natural, and realistic user commands based on a specific action. " \
|
||||
"The commands should vary in complexity and phrasing."
|
||||
|
||||
if task_type == "action":
|
||||
service_name = context["service_name"]
|
||||
service_args = context["service_data"]
|
||||
|
||||
user_prompt = f"""
|
||||
Task: Generate a natural language voice command in {self.language} that a user would say to perform the following action.
|
||||
|
||||
Target Device: {friendly_name} (ID: {device_name})
|
||||
Action: {service_name}
|
||||
Arguments: {json.dumps(service_args)}
|
||||
|
||||
Instructions:
|
||||
1. The command must be in {self.language}.
|
||||
2. The command should be natural and conversational.
|
||||
3. Do not include the device ID (e.g., {device_name}) in the command, only refer to it by name or context.
|
||||
4. Include the necessary information to imply the arguments (e.g., if brightness is 50%, mention "50%" or "half brightness").
|
||||
5. Provide ONLY the command text. Do not add quotes or explanations.
|
||||
"""
|
||||
elif task_type == "status":
|
||||
attribute = context["attribute"]
|
||||
user_prompt = f"""
|
||||
Task: Generate a natural language question in {self.language} that a user would ask to check the status of a device.
|
||||
|
||||
Target Device: {friendly_name} (ID: {device_name})
|
||||
Attribute to check: {attribute}
|
||||
|
||||
Instructions:
|
||||
1. The question must be in {self.language}.
|
||||
2. The question should be natural and conversational.
|
||||
3. Do not include the device ID.
|
||||
4. Provide ONLY the question text. Do not add quotes or explanations.
|
||||
"""
|
||||
else:
|
||||
# Fallback for unknown task types
|
||||
user_prompt = "Generate a random smart home command."
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"temperature": 1.0, # High temperature for diversity
|
||||
"max_tokens": 60,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(LLM_ENDPOINT, json=payload) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
content = data['choices'][0]['message']['content'].strip()
|
||||
# Cleanup: remove leading/trailing quotes if present
|
||||
if content.startswith('"') and content.endswith('"'):
|
||||
content = content[1:-1]
|
||||
return content
|
||||
else:
|
||||
# print(f"Error from LLM: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
return None
|
||||
|
||||
def sample_context(self, request_type: str):
|
||||
"""
|
||||
Creates a random scenario: device, service, and arguments.
|
||||
"""
|
||||
# 1. Pick a device from the loaded piles OR synthetic devices
|
||||
device_types = list(self.piles.stacks_of_device_names.keys())
|
||||
if not device_types:
|
||||
return None
|
||||
|
||||
dt = random.choice(device_types)
|
||||
|
||||
# Mix real and synthetic devices
|
||||
devices = self.piles.stacks_of_device_names[dt]
|
||||
if dt in self.synthetic_devices:
|
||||
devices = devices + self.synthetic_devices[dt]
|
||||
|
||||
if not devices:
|
||||
return None
|
||||
|
||||
device = random.choice(devices)
|
||||
device_name = device["device_name"]
|
||||
friendly_name = device["description"]
|
||||
|
||||
# Decide between Action and Status
|
||||
if request_type == "status":
|
||||
# Status Request
|
||||
# Determine available attributes based on domain
|
||||
domain = device_name.split(".")[0]
|
||||
attributes = ["state"] # Default
|
||||
if domain == "light":
|
||||
attributes.extend(["brightness", "color"])
|
||||
elif domain == "climate":
|
||||
attributes.extend(["temperature", "humidity", "hvac_mode"])
|
||||
elif domain == "media_player":
|
||||
attributes.extend(["volume", "media_title", "state"])
|
||||
elif domain == "cover":
|
||||
attributes.extend(["position", "state"])
|
||||
elif domain == "fan":
|
||||
attributes.extend(["speed", "state"])
|
||||
|
||||
attribute = random.choice(attributes)
|
||||
return {
|
||||
"type": "status",
|
||||
"device_name": device_name,
|
||||
"friendly_name": friendly_name,
|
||||
"attribute": attribute
|
||||
}
|
||||
|
||||
elif request_type == "action":
|
||||
# Action
|
||||
# 2. Pick a service compatible with this device type
|
||||
domain = device_name.split(".")[0]
|
||||
|
||||
services = []
|
||||
if domain == "light":
|
||||
services = ["light.turn_on", "light.turn_off", "light.toggle"]
|
||||
elif domain == "switch":
|
||||
services = ["switch.turn_on", "switch.turn_off", "switch.toggle"]
|
||||
elif domain == "cover":
|
||||
services = ["cover.open_cover", "cover.close_cover", "cover.stop_cover", "cover.toggle"]
|
||||
elif domain == "blinds":
|
||||
services = ["blinds.open_cover", "blinds.close_cover", "blinds.stop_cover", "blinds.toggle"]
|
||||
elif domain == "garage_door":
|
||||
services = ["garage_door.open_cover", "garage_door.close_cover", "garage_door.stop_cover", "garage_door.toggle"]
|
||||
elif domain == "fan":
|
||||
services = ["fan.turn_on", "fan.turn_off", "fan.toggle", "fan.increase_speed", "fan.decrease_speed"]
|
||||
elif domain == "climate":
|
||||
services = ["climate.turn_on", "climate.turn_off", "climate.set_temperature"]
|
||||
elif domain == "media_player":
|
||||
services = ["media_player.turn_on", "media_player.turn_off", "media_player.media_play_pause", "media_player.volume_up", "media_player.volume_down"]
|
||||
elif domain == "lock":
|
||||
services = ["lock.lock", "lock.unlock"]
|
||||
elif domain == "vacuum":
|
||||
services = ["vacuum.start", "vacuum.return_to_base", "vacuum.stop"]
|
||||
|
||||
if not services:
|
||||
return None
|
||||
|
||||
service_name = random.choice(services)
|
||||
|
||||
# 3. Generate Arguments
|
||||
service_data = {}
|
||||
if service_name == "light.turn_on":
|
||||
if random.random() < 0.3:
|
||||
service_data["brightness_pct"] = random.randint(10, 100)
|
||||
if random.random() < 0.3:
|
||||
# Simple colors
|
||||
colors = ["red", "blue", "green", "yellow", "purple", "white", "warm white", "cool white"]
|
||||
service_data["color_name"] = random.choice(colors)
|
||||
elif service_name == "climate.set_temperature":
|
||||
service_data["temperature"] = random.randint(18, 28)
|
||||
|
||||
return {
|
||||
"type": "action",
|
||||
"device_name": device_name,
|
||||
"friendly_name": friendly_name,
|
||||
"service_name": service_name,
|
||||
"service_data": service_data
|
||||
}
|
||||
raise ValueError(f"Unknown request type {request_type}")
|
||||
|
||||
async def run(self, num_actions: int, num_status_requests: int, num_devices: int, output_file, persona_name=None, persona_description=None):
|
||||
print(f"Starting generation...")
|
||||
print(f"Language: {self.language}")
|
||||
|
||||
# Ensure output directory exists
|
||||
if output_file:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
if num_devices > 0:
|
||||
print("Generating synthetic device names...")
|
||||
device_types = list(self.piles.stacks_of_device_names.keys())
|
||||
gen_tasks = []
|
||||
for dt in device_types:
|
||||
gen_tasks.append(self.generate_device_names(session, dt, count=num_devices))
|
||||
|
||||
generated_lists = await asyncio.gather(*gen_tasks)
|
||||
|
||||
# Flatten list and write to CSV
|
||||
all_new_devices = [item for sublist in generated_lists if sublist for item in sublist]
|
||||
|
||||
if all_new_devices:
|
||||
csv_path = f"data/piles/{self.language}/pile_of_device_names.csv"
|
||||
try:
|
||||
with open(csv_path, "a", newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["device_name", "description"])
|
||||
for device in all_new_devices:
|
||||
writer.writerow(device)
|
||||
print(f"Appended {len(all_new_devices)} new devices to {csv_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to write new devices to CSV: {e}")
|
||||
|
||||
if num_actions > 0 or num_status_requests > 0:
|
||||
print(f"Generating {num_actions} actions and {num_status_requests} status requests...")
|
||||
print(f"Output file: {output_file}")
|
||||
tasks = {}
|
||||
results = []
|
||||
|
||||
pbar = tqdm(total=num_actions + num_status_requests, desc="Generating phrases")
|
||||
|
||||
while len(results) < num_actions + num_status_requests:
|
||||
# Fill up the task queue
|
||||
while len(tasks) < self.concurrency and (len(results) + len(tasks)) < num_actions + num_status_requests:
|
||||
context = self.sample_context("action" if len(results) < num_actions else "status")
|
||||
if not context:
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(self.generate_phrase(session, context))
|
||||
tasks[task] = context
|
||||
|
||||
if not tasks:
|
||||
break
|
||||
|
||||
# Wait for completed tasks
|
||||
done, pending = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in done:
|
||||
context = tasks.pop(task)
|
||||
try:
|
||||
phrase = await task
|
||||
if phrase:
|
||||
entry = context.copy()
|
||||
entry["phrase"] = phrase
|
||||
|
||||
if entry["type"] == "action":
|
||||
# Write to pile_of_specific_actions.csv
|
||||
csv_path = f"data/piles/{self.language}/pile_of_specific_actions.csv"
|
||||
|
||||
# Prepare row
|
||||
# device_name in CSV is the suffix (e.g. 'kitchen' from 'light.kitchen')
|
||||
# But wait, generate_data.py expects device_name to be the suffix ONLY if the domain matches the service domain?
|
||||
# Actually generate_data.py does: target_device = f"{device_type}.{action['device_name']}"
|
||||
# where device_type = service_name.split(".")[0]
|
||||
# So if service is light.turn_on, device_type is light.
|
||||
# If device is light.kitchen, action['device_name'] should be 'kitchen'.
|
||||
|
||||
full_device_name = entry["device_name"]
|
||||
service_name = entry["service_name"]
|
||||
service_domain = service_name.split(".")[0]
|
||||
device_domain = full_device_name.split(".")[0]
|
||||
|
||||
if service_domain != device_domain:
|
||||
# This might happen if we use a service from a different domain (e.g. homeassistant.turn_on)
|
||||
# But our sample_context ensures domain match (mostly).
|
||||
# For blinds/garage_door, we use blinds.open_cover etc.
|
||||
# So service_domain is blinds. device_domain is blinds.
|
||||
pass
|
||||
|
||||
device_suffix = full_device_name.split(".", 1)[1]
|
||||
|
||||
row = {
|
||||
"service_name": service_name,
|
||||
"device_name": device_suffix,
|
||||
"phrase": phrase,
|
||||
"arguments": json.dumps(entry["service_data"]) if entry["service_data"] else ""
|
||||
}
|
||||
|
||||
# Check if header needs update (only once)
|
||||
if not hasattr(self, "_action_header_updated"):
|
||||
self._action_header_updated = True
|
||||
# Read header
|
||||
with open(csv_path, "r", encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
all_rows = list(reader)
|
||||
current_fieldnames = reader.fieldnames if reader.fieldnames else []
|
||||
|
||||
fieldnames = list(current_fieldnames) + ["arguments"]
|
||||
with open(csv_path, "w", newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(all_rows)
|
||||
|
||||
with open(csv_path, "a", newline='', encoding='utf-8') as f:
|
||||
# We need to know fieldnames.
|
||||
# We can read them from file or assume standard + arguments.
|
||||
# Let's read them.
|
||||
with open(csv_path, "r", encoding='utf-8') as fr:
|
||||
reader = csv.reader(fr)
|
||||
header = next(reader)
|
||||
|
||||
writer = csv.DictWriter(f, fieldnames=header)
|
||||
writer.writerow(row)
|
||||
|
||||
elif entry["type"] == "status":
|
||||
# Write to pile_of_status_requests.csv
|
||||
# We need to templatize the phrase.
|
||||
# Replace friendly_name with <device_name>
|
||||
phrase_tmpl = phrase.replace(entry["friendly_name"], "<device_name>")
|
||||
# Also try case insensitive?
|
||||
phrase_tmpl = phrase_tmpl.replace(entry["friendly_name"].lower(), "<device_name>")
|
||||
|
||||
# If friendly name not found, maybe skip?
|
||||
if "<device_name>" not in phrase_tmpl:
|
||||
# Try to find partial match?
|
||||
# For now, just skip if we can't templatize.
|
||||
pass
|
||||
else:
|
||||
csv_path = f"data/piles/{self.language}/pile_of_status_requests.csv"
|
||||
# Columns: device_type,state,phrase,assistant_response
|
||||
# We don't have assistant_response.
|
||||
# We can generate a generic one?
|
||||
# Or ask LLM to generate it?
|
||||
# For now, let's skip status requests writing as we lack assistant_response.
|
||||
pass
|
||||
|
||||
results.append(entry)
|
||||
pbar.update(1)
|
||||
except Exception as e:
|
||||
print(f"Task error: {e}")
|
||||
|
||||
pbar.close()
|
||||
|
||||
if persona_name and persona_description:
|
||||
await self.generate_persona(session, persona_name, persona_description)
|
||||
|
||||
print("Generation complete.")
|
||||
|
||||
async def generate_persona(self, session, persona_name, persona_description):
|
||||
print(f"Generating new persona: {persona_name}...")
|
||||
|
||||
# 1. Generate System Prompt
|
||||
sys_prompt_instruction = (
|
||||
f"Generate a system prompt for an AI assistant named '{persona_name}' "
|
||||
f"who has the following personality: {persona_description}. "
|
||||
"The prompt should define the persona's character and instructions. "
|
||||
"It should start with 'You are ...'. "
|
||||
"Keep it under 50 words."
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an expert at creating AI system prompts."},
|
||||
{"role": "user", "content": sys_prompt_instruction}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
}
|
||||
|
||||
system_prompt_text = ""
|
||||
try:
|
||||
async with session.post(LLM_ENDPOINT, json=payload) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
system_prompt_text = data['choices'][0]['message']['content'].strip()
|
||||
if system_prompt_text.startswith('"') and system_prompt_text.endswith('"'):
|
||||
system_prompt_text = system_prompt_text[1:-1]
|
||||
print(f"Generated system prompt: {system_prompt_text}")
|
||||
else:
|
||||
print(f"Failed to generate system prompt: {response.status}")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"System prompt generation failed: {e}")
|
||||
return
|
||||
|
||||
# 2. Get list of services to generate responses for
|
||||
responses_csv_path = f"data/piles/{self.language}/pile_of_responses.csv"
|
||||
services = set()
|
||||
try:
|
||||
with open(responses_csv_path, "r", encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
services.add(row["service"])
|
||||
except Exception as e:
|
||||
print(f"Failed to read responses CSV: {e}")
|
||||
return
|
||||
|
||||
print(f"Found {len(services)} unique services to generate responses for.")
|
||||
|
||||
# 3. Generate responses for each service
|
||||
new_responses = []
|
||||
|
||||
async def generate_service_responses(svc):
|
||||
# We want normal and short responses
|
||||
prompt = (
|
||||
f"You are acting as '{persona_name}', described as: {persona_description}.\n"
|
||||
f"Generate 3 diverse responses confirming that you are performing the action: '{svc}'.\n"
|
||||
"Then generate 3 SHORT/CONCISE responses for the same action.\n"
|
||||
"Format the output as follows:\n"
|
||||
"NORMAL: <response 1>\n"
|
||||
"NORMAL: <response 2>\n"
|
||||
"NORMAL: <response 3>\n"
|
||||
"SHORT: <short response 1>\n"
|
||||
"SHORT: <short response 2>\n"
|
||||
"SHORT: <short response 3>\n"
|
||||
"Do not include any other text."
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": f"You are {persona_name}. {persona_description}"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 300,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(LLM_ENDPOINT, json=payload) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
content = data['choices'][0]['message']['content'].strip()
|
||||
lines = content.split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("NORMAL:"):
|
||||
text = line.replace("NORMAL:", "").strip()
|
||||
if text:
|
||||
new_responses.append({
|
||||
"service": svc,
|
||||
"response": text,
|
||||
"persona": persona_name,
|
||||
"short": 0
|
||||
})
|
||||
elif line.startswith("SHORT:"):
|
||||
text = line.replace("SHORT:", "").strip()
|
||||
if text:
|
||||
new_responses.append({
|
||||
"service": svc,
|
||||
"response": text,
|
||||
"persona": persona_name,
|
||||
"short": 1
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Failed to generate responses for {svc}: {e}")
|
||||
|
||||
# Run in batches
|
||||
tasks = []
|
||||
for svc in services:
|
||||
tasks.append(generate_service_responses(svc))
|
||||
if len(tasks) >= self.concurrency:
|
||||
await asyncio.gather(*tasks)
|
||||
tasks = []
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
print(f"Generated {len(new_responses)} responses.")
|
||||
|
||||
# 4. Write to files
|
||||
# Append system prompt
|
||||
sys_prompts_path = f"data/piles/{self.language}/pile_of_system_prompts.csv"
|
||||
try:
|
||||
with open(sys_prompts_path, "a", newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
# Check if we need to add a newline if file doesn't end with one?
|
||||
# csv module handles newlines usually.
|
||||
writer.writerow([persona_name, system_prompt_text])
|
||||
print(f"Appended system prompt to {sys_prompts_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to write system prompt: {e}")
|
||||
|
||||
# Append responses
|
||||
try:
|
||||
with open(responses_csv_path, "a", newline='', encoding='utf-8') as f:
|
||||
fieldnames = ["service", "response", "persona", "short"]
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
for resp in new_responses:
|
||||
writer.writerow(resp)
|
||||
print(f"Appended responses to {responses_csv_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to write responses: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate synthetic data using LLM")
|
||||
parser.add_argument("--actions", type=int, default=0, help="Number of actions to generate")
|
||||
parser.add_argument("--status", type=int, default=0, help="Number of status requests to generate")
|
||||
parser.add_argument("--devices", type=int, default=0, help="Number of new devices to generate")
|
||||
parser.add_argument("--concurrency", type=int, default=8, help="Number of concurrent requests")
|
||||
parser.add_argument("--language", type=str, default="english", help="Language")
|
||||
parser.add_argument("--model", type=str, default="gpt-oss-120b", help="LLM model to use")
|
||||
parser.add_argument("--persona-name", type=str, help="Name of the new persona to generate")
|
||||
parser.add_argument("--persona-description", type=str, help="Description of the new persona")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
generator = SyntheticDataGenerator(model_name=args.model, language=args.language, concurrency=args.concurrency)
|
||||
asyncio.run(generator.run(num_actions=args.actions, num_status_requests=args.status, num_devices=args.devices, output_file="", persona_name=args.persona_name, persona_description=args.persona_description))
|
||||
823
data/tools.py
Normal file
823
data/tools.py
Normal file
@@ -0,0 +1,823 @@
|
||||
# TOOLS
|
||||
TOOL_TURN_ON = "HassTurnOn"
|
||||
TOOL_TURN_OFF = "HassTurnOff"
|
||||
TOOL_TOGGLE = "HassToggle"
|
||||
TOOL_SET_POSITION = "HassSetPosition"
|
||||
TOOL_LIGHT_SET = "HassLightSet"
|
||||
TOOL_SET_VOLUME = "HassSetVolume"
|
||||
TOOL_MEDIA_UNPAUSE = "HassMediaUnpause"
|
||||
TOOL_MEDIA_PAUSE = "HassMediaPause"
|
||||
TOOL_MEDIA_NEXT = "HassMediaNext"
|
||||
TOOL_MEDIA_PREVIOUS = "HassMediaPrevious"
|
||||
TOOL_VACUUM_START = "HassVacuumStart"
|
||||
TOOL_VACUUM_RETURN_TO_BASE = "HassVacuumReturnToBase"
|
||||
TOOL_LIST_ADD_ITEM = "HassListAddItem"
|
||||
TOOL_START_TIMER = "HassStartTimer"
|
||||
TOOL_CANCEL_TIMER = "HassCancelTimer"
|
||||
TOOL_PAUSE_TIMER = "HassPauseTimer"
|
||||
TOOL_UNPAUSE_TIMER = "HassUnpauseTimer"
|
||||
TOOL_INCREASE_TIMER = "HassIncreaseTimer"
|
||||
TOOL_DECREASE_TIMER = "HassDecreaseTimer"
|
||||
TOOL_TIMER_STATUS = "HassTimerStatus"
|
||||
TOOL_CLIMATE_SET_TEMPERATURE = "HassClimateSetTemperature"
|
||||
TOOL_CLIMATE_GET_TEMPERATURE = "HassClimateGetTemperature"
|
||||
TOOL_SET_HUMIDITY = "HassHumidifierSetpoint"
|
||||
TOOL_SET_HUMIDIFIER_MODE = "HassHumidifierMode"
|
||||
|
||||
# Service name to tool name mapping for backwards compatibility with CSV files
|
||||
SERVICE_TO_TOOL_MAP = {
|
||||
"turn_on": TOOL_TURN_ON,
|
||||
"turn_off": TOOL_TURN_OFF,
|
||||
"toggle": TOOL_TOGGLE,
|
||||
"open_cover": TOOL_TURN_ON,
|
||||
"close_cover": TOOL_TURN_OFF,
|
||||
"stop_cover": TOOL_TOGGLE,
|
||||
"set_cover_position": TOOL_SET_POSITION,
|
||||
"lock": TOOL_TURN_ON,
|
||||
"unlock": TOOL_TURN_OFF,
|
||||
"increase_speed": TOOL_TURN_ON,
|
||||
"decrease_speed": TOOL_TURN_OFF,
|
||||
"media_play_pause": TOOL_TOGGLE,
|
||||
"media_pause": TOOL_MEDIA_PAUSE,
|
||||
"media_play": TOOL_MEDIA_UNPAUSE,
|
||||
"media_next_track": TOOL_MEDIA_NEXT,
|
||||
"media_previous_track": TOOL_MEDIA_PREVIOUS,
|
||||
"start": TOOL_VACUUM_START,
|
||||
"return_to_base": TOOL_VACUUM_RETURN_TO_BASE,
|
||||
"set_temperature": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||
"set_humidity": TOOL_SET_HUMIDITY,
|
||||
"set_hvac_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||
"set_fan_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||
"set_preset_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||
}
|
||||
|
||||
# Home Assistant Intent Tools Definition
|
||||
HASS_TOOLS = [
|
||||
{
|
||||
"name": TOOL_TURN_ON,
|
||||
"description": "Turns on/opens/unlocks a device or entity",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the device or entity"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"},
|
||||
"domain": {"type": "array", "items": {"type": "string"}, "description": "Device domain(s)"},
|
||||
"device_class": {"type": "array", "items": {"type": "string"}, "description": "Device class(es)"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_TURN_OFF,
|
||||
"description": "Turns off/closes/locks a device or entity",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the device or entity"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"},
|
||||
"domain": {"type": "array", "items": {"type": "string"}, "description": "Device domain(s)"},
|
||||
"device_class": {"type": "array", "items": {"type": "string"}, "description": "Device class(es)"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_TOGGLE,
|
||||
"description": "Toggles a device or entity",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the device or entity"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"},
|
||||
"domain": {"type": "array", "items": {"type": "string"}, "description": "Device domain(s)"},
|
||||
"device_class": {"type": "array", "items": {"type": "string"}, "description": "Device class(es)"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_SET_POSITION,
|
||||
"description": "Sets the position of a device or entity (e.g., blinds, covers)",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the device or entity"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"},
|
||||
"position": {"type": "integer", "description": "Position from 0-100", "minimum": 0, "maximum": 100}
|
||||
},
|
||||
"required": ["position"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_LIGHT_SET,
|
||||
"description": "Sets the brightness or color of a light",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the light"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"},
|
||||
"color": {"type": "string", "description": "Color name"},
|
||||
"temperature": {"type": "integer", "description": "Color temperature in Kelvin"},
|
||||
"brightness": {"type": "integer", "description": "Brightness percentage (0-100)", "minimum": 0, "maximum": 100}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||
"description": "Sets the target temperature of a climate device",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the climate device"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"},
|
||||
"temperature": {"type": "number", "description": "Target temperature"}
|
||||
},
|
||||
"required": ["temperature"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_SET_HUMIDITY,
|
||||
"description": "Sets the target humidity level of a humidifier device",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the humidifier"},
|
||||
"humidity": {"type": "integer", "description": "Target humidity percentage (0-100)", "minimum": 0, "maximum": 100}
|
||||
},
|
||||
"required": ["name", "humidity"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_SET_HUMIDIFIER_MODE,
|
||||
"description": "Sets the mode of a humidifier device",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the humidifier"},
|
||||
"mode": {"type": "string", "description": "Humidifier mode"}
|
||||
},
|
||||
"required": ["name", "mode"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_MEDIA_UNPAUSE,
|
||||
"description": "Resumes playback on a media player",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the media player"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_MEDIA_PAUSE,
|
||||
"description": "Pauses playback on a media player",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the media player"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_MEDIA_NEXT,
|
||||
"description": "Skips to the next media item on a media player",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the media player"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_SET_VOLUME,
|
||||
"description": "Sets the volume of a media player",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the media player"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"},
|
||||
"volume_level": {"type": "number", "description": "Volume level (0.0-1.0)", "minimum": 0.0, "maximum": 1.0}
|
||||
},
|
||||
"required": ["volume_level"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_VACUUM_START,
|
||||
"description": "Starts a vacuum",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the vacuum"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_VACUUM_RETURN_TO_BASE,
|
||||
"description": "Returns a vacuum to its base",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the vacuum"},
|
||||
"area": {"type": "string", "description": "Name of the area"},
|
||||
"floor": {"type": "string", "description": "Name of the floor"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_LIST_ADD_ITEM,
|
||||
"description": "Adds an item to a todo list",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item": {"type": "string", "description": "The item to add to the list"},
|
||||
"name": {"type": "string", "description": "Name of the todo list"}
|
||||
},
|
||||
"required": ["item"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_START_TIMER,
|
||||
"description": "Starts a timer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the timer"},
|
||||
"duration": {"type": "string", "description": "Timer duration (HH:MM:SS format)"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_CANCEL_TIMER,
|
||||
"description": "Cancels a timer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the timer"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_PAUSE_TIMER,
|
||||
"description": "Pauses a timer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the timer"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": TOOL_UNPAUSE_TIMER,
|
||||
"description": "Resumes a paused timer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name of the timer"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock",
|
||||
"start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
|
||||
|
||||
SERVICE_TOOLS = [
|
||||
{
|
||||
"name": "<sample>",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to turn on"},
|
||||
"rgb_color": {"type": "string", "description": "The RGB color to set"},
|
||||
"brightness": {"type": "number", "description": "The brightness level"},
|
||||
"temperature": {"type": "number", "description": "The temperature level"},
|
||||
"humidity": {"type": "number", "description": "The humidity level"},
|
||||
"fan_mode": {"type": "string", "description": "The fan mode"},
|
||||
"hvac_mode": {"type": "string", "description": "The HVAC mode"},
|
||||
"preset_mode": {"type": "string", "description": "The preset mode"},
|
||||
"duration": {"type": "string", "description": "The amount of time to apply to the chosen timer"},
|
||||
"item": {"type": "string", "description": "The item to add to the list"}
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "light.turn_on",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to turn on"},
|
||||
"rgb_color": {"type": "string", "description": "The RGB color to set"},
|
||||
"brightness": {"type": "number", "description": "The brightness level"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "light.turn_off",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to turn off"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "light.toggle",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to toggle"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "switch.turn_on",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to turn on"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "switch.turn_off",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to turn off"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "switch.toggle",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to toggle"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fan.turn_on",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to turn on"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fan.turn_off",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to turn off"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fan.toggle",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to toggle"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fan.set_speed",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to set speed"},
|
||||
"fan_mode": {"type": "string", "description": "The fan mode"},
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"fan_mode"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fan.increase_speed",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to increase speed"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fan.decrease_speed",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to decrease speed"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "button.press",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to press"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "cover.open_cover",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to open"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "cover.close_cover",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to close"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "cover.stop_cover",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to stop"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "cover.set_cover_position",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to set position"},
|
||||
"position": {"type": "integer", "description": "Position from 0-100", "minimum": 0, "maximum": 100}
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"position"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "lock.unlock",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to unlock"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "lock.lock",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to lock"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "vacuum.start",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to start"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "vacuum.stop",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to stop"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "vacuum.return_to_base",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target device to return to base"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "media_player.media_play_pause",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target media player to play/pause"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "media_player.media_pause",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target media player to pause"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "media_player.media_play",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target media player to play"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "media_player.media_next_track",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target media player to skip to next track"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "media_player.media_previous_track",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target media player to skip to previous track"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "media_player.volume_set",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target media player to set volume"},
|
||||
"volume_level": {"type": "number", "description": "Volume level (0.0-1.0)", "minimum": 0.0, "maximum": 1.0}
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"volume_level"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "todo.add_item",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target todo list to add item to"},
|
||||
"item": {"type": "string", "description": "The item to add to the list"}
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"item"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "timer.start",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target timer to start"},
|
||||
"duration": {"type": "string", "description": "Timer duration (HH:MM:SS format)"}
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "timer.cancel",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target timer to cancel"},
|
||||
},
|
||||
"required": [
|
||||
"target_device"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "climate.set_temperature",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target climate device to set temperature"},
|
||||
"temperature": {"type": "number", "description": "Target temperature"}
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"temperature"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "climate.set_humidity",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target humidifier device to set humidity"},
|
||||
"humidity": {"type": "integer", "description": "Target humidity percentage (0-100)", "minimum": 0, "maximum": 100}
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"humidity"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "climate.set_hvac_mode",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target climate device to set HVAC mode"},
|
||||
"hvac_mode": {"type": "string", "description": "The HVAC mode"}
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"hvac_mode"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "climate.set_preset_mode",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target climate device to set preset mode"},
|
||||
"preset_mode": {"type": "string", "description": "The preset mode"}
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"preset_mode"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "climate.set_fan_mode",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_device": {"type": "string", "description": "The target climate device to set fan mode"},
|
||||
"fan_mode": {"type": "string", "description": "The fan mode"}
|
||||
},
|
||||
"required": [
|
||||
"target_device",
|
||||
"fan_mode"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
158
data/utils.py
Normal file
158
data/utils.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import random
|
||||
import re
|
||||
import csv
|
||||
import pandas
|
||||
from datetime import datetime, timedelta
|
||||
import webcolors
|
||||
|
||||
class NoResponseAvailableException(Exception):
|
||||
pass
|
||||
|
||||
class NoServicesAvailableException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def closest_color(requested_color):
|
||||
min_colors = {}
|
||||
color_names = webcolors.names("css3")
|
||||
|
||||
for name in color_names:
|
||||
r_c, g_c, b_c = webcolors.name_to_rgb(name)
|
||||
rd = (r_c - requested_color[0]) ** 2
|
||||
gd = (g_c - requested_color[1]) ** 2
|
||||
bd = (b_c - requested_color[2]) ** 2
|
||||
min_colors[(rd + gd + bd)] = name
|
||||
return min_colors[min(min_colors.keys())]
|
||||
|
||||
def generate_random_datetime():
|
||||
start_date = datetime(2022, 1, 1)
|
||||
end_date = datetime(2030, 12, 31)
|
||||
delta = end_date - start_date
|
||||
random_days = random.randint(0, delta.days)
|
||||
random_seconds = random.randint(0, 24 * 60 * 60)
|
||||
random_date_time = start_date + timedelta(days=random_days, seconds=random_seconds)
|
||||
return random_date_time
|
||||
|
||||
var_pattern = re.compile("<(.*?)>")
|
||||
def get_included_vars(response: str):
|
||||
result = []
|
||||
for var in var_pattern.findall(response):
|
||||
if var == "device_name":
|
||||
continue
|
||||
result.append(var)
|
||||
|
||||
return ",".join(sorted(result))
|
||||
|
||||
def generate_random_parameter(param_name, piles_of_data):
|
||||
RANDOM_PARAMETER_GENERATORS = {
|
||||
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
|
||||
"brightness": lambda: random.randint(0, 100),
|
||||
"fan_mode": lambda: random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"]),
|
||||
"temp_f": lambda: random.randint(60, 80),
|
||||
"temp_c": lambda: random.randint(15, 25),
|
||||
"humidity": lambda: random.randint(10, 90),
|
||||
"preset_mode": lambda: random.choice(["home", "eco", "away", "auto"]),
|
||||
"hvac_mode": lambda: random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"]),
|
||||
"media": lambda: random.choice(piles_of_data["pile_of_media_names"]),
|
||||
"volume": lambda: round(random.random(), 2),
|
||||
"duration": lambda: random.choice(list(piles_of_data["pile_of_durations"].keys())),
|
||||
"remaining": lambda: f"{random.randint(0, 3):02}:{random.randint(0, 60)}:{random.randint(0, 60)}",
|
||||
"todo": lambda: random.choice(piles_of_data["pile_of_todo_items"]),
|
||||
}
|
||||
param_generator = RANDOM_PARAMETER_GENERATORS.get(param_name)
|
||||
|
||||
if not param_generator:
|
||||
raise Exception(f"Unknown param to generate random value for {param_name}")
|
||||
|
||||
return param_generator()
|
||||
|
||||
def get_random_response(pile_of_responses, *, service: str, persona: str, question_template: str, short: bool) -> str:
|
||||
|
||||
required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var]))
|
||||
|
||||
possible_results = pile_of_responses.loc[(pile_of_responses['service']==service) &
|
||||
(pile_of_responses['persona']==persona) &
|
||||
(pile_of_responses['short']==(1 if short else 0)) &
|
||||
(pile_of_responses['contains_vars']==",".join(sorted(required_vars)))
|
||||
]
|
||||
|
||||
if len(possible_results) == 0:
|
||||
raise NoResponseAvailableException(f"No responses matched the provided filters: {persona}, {service}, {required_vars}, {short}")
|
||||
|
||||
return possible_results.sample()["response"].values[0]
|
||||
|
||||
class DatasetPiles:
|
||||
def __init__(self, supported_devices, language="english"):
|
||||
self.language = language
|
||||
|
||||
with open(f"piles/{language}/pile_of_and_words.csv", encoding="utf8") as f:
|
||||
self.and_words = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
with open(f"piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
|
||||
|
||||
# media names are not translated
|
||||
with open(f"piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||
self.pile_of_media_names = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
with open(f"piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
self.stacks_of_device_names = { x: [] for x in supported_devices }
|
||||
with open(f"piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_device_names = list(reader)
|
||||
for device_dict in pile_of_device_names:
|
||||
try:
|
||||
device_type = device_dict["device_name"].split(".")[0]
|
||||
self.stacks_of_device_names[device_type].append(device_dict)
|
||||
except KeyError as ex:
|
||||
print(ex)
|
||||
|
||||
with open(f"piles/{language}/pile_of_templated_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_templated_actions = list(reader)
|
||||
processed_pile_of_templated_actions = []
|
||||
for action in pile_of_templated_actions:
|
||||
try:
|
||||
multiplier = int(action["multiplier"])
|
||||
except Exception:
|
||||
raise Exception(f"line has a bad multiplier: {action}")
|
||||
for x in range(multiplier):
|
||||
processed_pile_of_templated_actions.append(action)
|
||||
|
||||
self.pile_of_templated_actions = processed_pile_of_templated_actions
|
||||
|
||||
with open(f"piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_specific_actions = list(reader)
|
||||
|
||||
self.pile_of_responses = pandas.read_csv(f"piles/{language}/pile_of_responses.csv")
|
||||
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response"].apply(get_included_vars)
|
||||
|
||||
with open(f"piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_status_requests = list(reader)
|
||||
|
||||
with open(f"piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
|
||||
|
||||
# service names are not translated
|
||||
with open(f"piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_hallucinated_service_names = list(reader)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
_piles_cache = {}
|
||||
|
||||
def get_dataset_piles(language: str) -> DatasetPiles:
|
||||
if language not in _piles_cache:
|
||||
_piles_cache[language] = DatasetPiles( [
|
||||
"light", "switch", "fan", "garage_door", "blinds",
|
||||
"lock","media_player", "climate", "vacuum", "timer", "todo",
|
||||
], language)
|
||||
return _piles_cache[language]
|
||||
Reference in New Issue
Block a user