Files
home-llm/data/generate_data.py

648 lines
28 KiB
Python

import argparse
import json
import numpy as np
import random
from datasets import load_dataset, concatenate_datasets
from difflib import SequenceMatcher
from typing import Callable
from tqdm import tqdm
import webcolors
from device_types import *
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException
SUPPORTED_DEVICES = {
"light": LightDeviceType(),
"switch": SwitchDeviceType(),
"fan": FanDeviceType(),
"garage_door": GarageDoorDeviceType(),
"blinds": BlindsDeviceType(),
"lock": LockDeviceType(),
"media_player": MediaPlayerDeviceType(),
"climate": ClimateDeviceType(),
"vacuum": VacuumDeviceType(),
"timer": TimerDeviceType(),
"todo": TodoDeviceType(),
}
def format_device_line(*, device_name: str, friendly_name: str, state: str):
return (f"{device_name} '{friendly_name}' = {state}")
# generate a random list of devices for the context
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
num_devices = random.randint(2, max_devices)
piles = get_dataset_piles(language)
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
avoid_climate = False
for avoid_device in avoid_device_names:
avoid_type = avoid_device.split(".")[0]
filtered_possible_devices = []
for possible_device in local_device_names[avoid_type]:
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
if similarity_ratio < 0.4:
filtered_possible_devices.append(possible_device)
local_device_names[avoid_type] = filtered_possible_devices
if avoid_type == "climate":
avoid_climate = True
possible_choices = []
for device_type in local_device_names.keys():
possible_choices.extend(local_device_names[device_type])
device_types = set()
device_list = []
device_lines = []
# TODO: randomly pick attributes for this list
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
while len(device_list) < num_devices:
choice = random.choice(possible_choices)
if choice["device_name"] in device_list:
continue
try:
device_name = choice["device_name"]
device_type = device_name.split(".")[0]
friendly_name = choice["description"]
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
if avoid_climate and device_type == "climate":
continue
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
device_lines.append(format_device_line(
device_name=device_name,
friendly_name=friendly_name,
state=state
))
device_list.append(device_name)
device_types.add(device_type)
except Exception as ex:
print(f"bad device name: {choice}")
print(repr(ex))
return device_lines, list(device_types), list(extra_exposed_attributes)
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 32, use_service_names: bool = False):
question = action["phrase"]
service_name = action["service_name"]
device_type = service_name.split(".")[0]
target_device = f"{device_type}.{action['device_name']}"
friendly_name = target_device.split(".")[1].replace("_", " ").title()
piles = get_dataset_piles(language)
device_list, device_types, extra_exposed_attributes = random_device_list(
max_devices=max_devices, avoid_device_names=[target_device], language=language)
# insert our target device somewhere random in the list
index = random.randint(0, len(device_list))
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
device_list.insert(index, format_device_line(
device_name=target_device,
friendly_name=friendly_name,
state=state
))
# gather a list of all available tools
available_tools = []
for x in set(device_types + [device_type]):
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
# Remove duplicates while preserving order
available_tools = list(dict.fromkeys(available_tools))
# Map service name to tool name
service_action = service_name.split(".")[1]
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
response = get_random_response(
piles.pile_of_responses,
service=service_name,
persona=persona,
question_template="",
short=False
).lower()
response = response.replace("<device_name>", friendly_name)
if use_service_names:
tool_call = {
"tool_name": tool_name,
"service_name": service_name,
"tool_args": {"entity_id": target_device}
}
else:
tool_call = {
"tool_name": tool_name,
"service_name": service_name,
"tool_args": {"name": target_device}
}
if "arguments" in action and action["arguments"]:
try:
import json
args = json.loads(action["arguments"])
tool_call["tool_args"].update(args)
except Exception as e:
print(f"Failed to parse arguments for {action}: {e}")
return {
"states": device_list,
"available_tools": available_tools,
"question": question.lower(),
"answers": [ response ],
"tool_calls": [ tool_call ]
}
def replace_answer(list_of_answer, var, value):
new_list = []
for answer in list_of_answer:
new_list.append(answer.replace(var, value))
return new_list
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 32, use_service_names: bool = False):
template_device_types: list[str] = template["device_type"].split("|")
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
question_template: str = template["phrase"]
piles = get_dataset_piles(language)
# choose a random device for this template
chosen_devices = []
for device_type in template_device_types:
device_dict = random.choice(piles.stacks_of_device_names[device_type])
device_dict["type"] = device_type
chosen_devices.append(device_dict)
device_list, device_types, extra_exposed_attributes = random_device_list(
max_devices=max_devices, avoid_device_names=[d["device_name"] for d in chosen_devices])
# insert our target device somewhere random in the list
for device_dict in chosen_devices:
index = random.randint(0, len(device_list))
if "<brightness>" in question_template and "brightness" not in extra_exposed_attributes:
extra_exposed_attributes.append("brightness")
if "<color>" in question_template and "rgb_color" not in extra_exposed_attributes:
extra_exposed_attributes.append("rgb_color")
if ("<temp_f>" in question_template or "<temp_c>" in question_template) \
and "temperature" not in extra_exposed_attributes:
extra_exposed_attributes.append("temperature")
if "<humidity>" in question_template and "humidity" not in extra_exposed_attributes:
extra_exposed_attributes.append("humidity")
if "<fan_mode>" in question_template and "fan_mode" not in extra_exposed_attributes:
extra_exposed_attributes.append("fan_mode")
if "<duration>" in question_template and "duration" not in extra_exposed_attributes:
extra_exposed_attributes.append("duration")
state = SUPPORTED_DEVICES[device_dict["type"]].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
device_name = device_dict["device_name"]
friendly_name = device_dict["description"]
device_list.insert(index, format_device_line(
device_name=device_name,
friendly_name=friendly_name,
state=state
))
# gather a list of all available tools
available_tools = []
for x in set(device_types + template_device_types):
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
# Remove duplicates while preserving order
available_tools = list(dict.fromkeys(available_tools))
# pick an appropriate response and generate the question
if len(template_device_types) == 1:
answer_template = get_random_response(
piles.pile_of_responses,
service=service_names[0],
persona=persona,
question_template=question_template,
short=False
)
question = question_template.replace("<device_name>", chosen_devices[0]["description"])
answer_list = [ answer_template.replace("<device_name>", chosen_devices[0]["description"]) ]
else:
question = question_template
answers = []
for i in range(len(template_device_types)):
question = question.replace(f"<device_name{(i + 1)}>", chosen_devices[i]["description"])
answer_response = get_random_response(
piles.pile_of_responses,
service=service_names[i],
persona=persona,
question_template=question_template,
short=True
)
answers.append(answer_response.replace(f"<device_name>", chosen_devices[i]["description"]))
answer_list = []
for word in piles.and_words:
answer_list.append(f" {word} ".join(answers))
# generate the list of tool calls
tool_calls = []
for device_dict, service in zip(chosen_devices, service_names):
service_action = service.split(".")[1]
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
tool_call = {
"tool_name": tool_name,
"service_name": service,
"tool_args": {"entity_id" if use_service_names else "name": device_dict["device_name"] if use_service_names else device_dict["description"]}
}
tool_calls.append(tool_call)
if any(["climate" in service for service in service_names ]):
if "<hvac_mode>" in question:
hvac_mode = generate_random_parameter("hvac_mode", piles)
question = question.replace("<hvac_mode>", hvac_mode)
answer_list = replace_answer(answer_list, "<hvac_mode>", hvac_mode)
# Add hvac_mode as temperature parameter for climate tool
for call in tool_calls:
if call["tool_name"] == TOOL_CLIMATE_SET_TEMPERATURE:
call["tool_args"]["hvac_mode"] = hvac_mode
if "<fan_mode>" in question:
fan_mode = generate_random_parameter("fan_mode", piles)
question = question.replace("<fan_mode>", fan_mode)
answer_list = replace_answer(answer_list, "<fan_mode>", fan_mode)
for call in tool_calls:
if call["tool_name"] == TOOL_CLIMATE_SET_TEMPERATURE:
call["tool_args"]["fan_mode"] = fan_mode
if "<temp_f>" in question:
temp_f = generate_random_parameter("temp_f", piles)
question = question.replace("<temp_f>", str(temp_f))
answer_list = replace_answer(answer_list, "<temp_f>", str(temp_f))
for call in tool_calls:
if call["tool_name"] == TOOL_CLIMATE_SET_TEMPERATURE:
call["tool_args"]["temperature"] = temp_f
if "<temp_c>" in question:
temp_c = generate_random_parameter("temp_c", piles)
question = question.replace("<temp_c>", str(temp_c))
answer_list = replace_answer(answer_list, "<temp_c>", str(temp_c))
for call in tool_calls:
if call["tool_name"] == TOOL_CLIMATE_SET_TEMPERATURE:
call["tool_args"]["temperature"] = temp_c
if "<humidity>" in question:
humidity = generate_random_parameter("humidity", piles)
question = question.replace("<humidity>", str(humidity))
answer_list = replace_answer(answer_list, "<humidity>", str(humidity))
for call in tool_calls:
if call["tool_name"] == TOOL_SET_HUMIDITY:
call["tool_args"]["humidity"] = humidity
if any(["light" in service for service in service_names ]):
if "<brightness>" in question:
brightness = generate_random_parameter("brightness", piles)
question = question.replace("<brightness>", str(brightness))
answer_list = replace_answer(answer_list, "<brightness>", str(brightness))
for call in tool_calls:
if call["tool_name"] == TOOL_LIGHT_SET:
call["tool_args"]["brightness"] = brightness
if "<color>" in question:
random_rgb = generate_random_parameter("rgb_color", piles)
random_rgb_name = closest_color(random_rgb)
question = question.replace("<color>", str(random_rgb_name))
answer_list = replace_answer(answer_list, "<color>", str(random_rgb_name))
for call in tool_calls:
if call["tool_name"] == TOOL_LIGHT_SET:
call["tool_args"]["color"] = random_rgb_name
if any(["timer" in service for service in service_names ]):
if "<duration>" in question:
duration = generate_random_parameter("duration", piles)
duration_name = piles.pile_of_durations[duration]
question = question.replace("<duration>", duration_name)
answer_list = replace_answer(answer_list, "<duration>", duration_name)
for call in tool_calls:
if call["tool_name"] == TOOL_START_TIMER:
call["tool_args"]["duration"] = str(duration)
if any(["todo" in service for service in service_names ]):
if "<todo>" in question:
todo = generate_random_parameter("todo", piles)
question = question.replace("<todo>", todo)
answer_list = replace_answer(answer_list, "<todo>", todo)
for call in tool_calls:
if call["tool_name"] == TOOL_LIST_ADD_ITEM:
call["tool_args"]["item"] = todo
return {
"states": device_list,
"available_tools": available_tools,
"question": question.lower(),
"answers": [ sentence.lower() for sentence in answer_list ],
"tool_calls": tool_calls
}
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 32, return_target_device: bool = False, use_service_names: bool = False):
device_type: str = template["device_type"]
state_name: str = template["state"]
question_template: str = template["phrase"]
answer_template: str = template["assistant_response"]
piles = get_dataset_piles(language)
# choose a random device for this template
chosen_device = random.choice(piles.stacks_of_device_names[device_type])
# build a random list of devices
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
# generate the question
question = question_template.replace("<device_name>", chosen_device["description"])
answer = answer_template.replace("<device_name>", chosen_device["description"])
# insert other templated variables
if device_type == "climate":
climate_device_type = SUPPORTED_DEVICES["climate"]
temp_f = climate_device_type.get_random_parameter("temp_f", language)
answer = answer.replace("<temp_f>", str(temp_f))
state_name = state_name.replace("<temp_f>", str(temp_f))
temp_c = climate_device_type.get_random_parameter("temp_c", language)
answer = answer.replace("<temp_c>", str(temp_c))
state_name = state_name.replace("<temp_c>", str(temp_c))
humidity = climate_device_type.get_random_parameter("humidity", language)
answer = answer.replace("<humidity>", str(humidity))
state_name = state_name.replace("<humidity>", str(humidity))
if device_type == "light":
light_device_type = SUPPORTED_DEVICES["light"]
brightness = light_device_type.get_random_parameter("brightness", language)
answer = answer.replace("<brightness>", str(brightness))
state_name = state_name.replace("<brightness>", str(brightness))
random_rgb = light_device_type.get_random_parameter("rgb_color", language)
random_rgb_name = closest_color(random_rgb)
actual_random_rgb = webcolors.name_to_rgb(random_rgb_name)
actual_random_rgb = (actual_random_rgb.red, actual_random_rgb.green, actual_random_rgb.blue)
state_name = state_name.replace("<color>", str(random_rgb_name) + " " + str(actual_random_rgb))
answer = answer.replace("<color>", str(random_rgb_name))
if device_type == "media_player":
media_player_device_type = SUPPORTED_DEVICES["media_player"]
volume = media_player_device_type.get_random_parameter("volume", language)
random_media = media_player_device_type.get_random_parameter("media", language)
answer = answer.replace("<volume>", str(volume) + "%")
state_name = state_name.replace("<volume>", str(volume) + "%")
answer = answer.replace("<media>", random_media)
state_name = state_name.replace("<media>", random_media)
if device_type == "timer":
timer_device_type = SUPPORTED_DEVICES["timer"]
duration = timer_device_type.get_random_parameter("duration", language)
duration_name = piles.pile_of_durations[duration]
remaining = timer_device_type.get_random_parameter("remaining", language)
answer = answer.replace("<duration>", duration_name)
state_name = state_name.replace("<duration>", duration)
answer = answer.replace("<remaining>", remaining)
state_name = state_name.replace("<remaining>", remaining)
# insert our target device somewhere random in the list
index = random.randint(0, len(device_list))
device_list.insert(index, format_device_line(
device_name=chosen_device["device_name"],
friendly_name=chosen_device["description"],
state=state_name
))
# gather a list of all available tools
available_tools = []
for x in set(device_types + [device_type]):
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
# Remove duplicates while preserving order
available_tools = list(dict.fromkeys(available_tools))
result = {
"states": device_list,
"available_tools": available_tools,
"question": question.lower(),
"answers": [ answer.lower() ],
"tool_calls": []
}
if return_target_device:
return result, chosen_device
else:
return result
def format_example_sharegpt(example, persona, language, use_system_role, use_service_names):
piles = get_dataset_piles(language)
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
question = example["question"]
answers = " ".join(example["answers"])
# Build assistant message with content blocks
assistant_content = []
# Add text response
assistant_content.append({
"type": "text",
"text": answers
})
# Add tool use blocks if there are tool calls
if len(example["tool_calls"]) > 0:
for tool_call in example["tool_calls"]:
# Use service_name if in service mode, otherwise use tool_name
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
assistant_content.append({
"type": "tool_use",
"name": call_name,
"parameters": tool_call["tool_args"]
})
if use_system_role:
conversation = [
{
"role": "system",
"content": sys_prompt
},
{
"role": "user",
"content": question
},
{
"role": "assistant",
"content": assistant_content
},
]
else:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
conversation = [
{
"role": "user",
"content": "\n".join([ sys_prompt, user_instruction_words, question ])
},
{
"role": "assistant",
"content": assistant_content
},
]
return {
"conversations": conversation,
"tools": SERVICE_TOOLS if use_service_names else HASS_TOOLS
}
def generate_sft_file(filename: str, seed: int, format_func: Callable, use_system_role: bool, use_service_names: bool, personas: list[str], language: str, *, static_factor: float, template_factor: int, status_request_factor: int):
random.seed(seed)
np.random.seed(seed)
piles = get_dataset_piles(language)
print("Generating...")
def run_factor_times(func, examples, data, persona, factor, language):
if factor >= 1:
for i in range(factor):
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names))
else:
if random.random() < factor:
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, use_service_names))
generated_examples = []
missing_responses = set()
for person in personas:
for action in tqdm(piles.pile_of_specific_actions):
try:
run_factor_times(generate_static_example, generated_examples, action, person, static_factor, language)
except NoResponseAvailableException as ex:
missing_responses.add(str(ex))
for templated_action in tqdm(piles.pile_of_templated_actions):
try:
run_factor_times(generate_templated_example, generated_examples, templated_action, person, template_factor, language)
except NoResponseAvailableException as ex:
missing_responses.add(str(ex))
for status_request in tqdm(piles.pile_of_status_requests):
run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor, language)
print(f"Generated {len(generated_examples)} examples. Saving...")
for missing in sorted(missing_responses):
print(missing)
with open(f"output/{filename}.jsonl", "w") as f:
for item in generated_examples:
json_record = json.dumps(item)
f.write(json_record + '\n')
print("Done!")
def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset_column_names, format_func):
alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1)
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" })
random.seed(seed)
np.random.seed(seed)
alpaca_dataset = alpaca_dataset.map(format_function).remove_columns(dataset_column_names)
combined_dataset_train = concatenate_datasets([home_assistant_dataset["train"], alpaca_dataset["train"]]).shuffle(seed=42)
combined_dataset_test = concatenate_datasets([home_assistant_dataset["test"], alpaca_dataset["test"]]).shuffle(seed=42)
combined_dataset_train.to_json(f"home_assistant_{output_name}_merged_train.jsonl")
combined_dataset_test.to_json(f"home_assistant_{output_name}_merged_test.jsonl")
def merge_languages(filename_prefix: str, languages: list):
all_examples = []
for language in languages:
with open(f"output/{filename_prefix}_{language}.jsonl") as f:
all_examples.extend(f.readlines())
with open(f"output/{filename_prefix}.jsonl", "w") as f:
f.writelines(all_examples)
# TODO: add examples for ambiguous requests. asking a clarifying question
# TODO: support rejection when asking to do a service that isn't exposed
# TODO: make more randomized names for devices (random words or people's names)
# TODO: answer questions about more than one thing in the state list at once
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
# TODO: add time, weather, and calendar/reminders (next 3 events?)
def main(args=None):
parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles")
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--language", nargs="+", default=["english"], help="List of languages to generate: english, german, french, spanish, polish")
parser.add_argument("--no-system-role", action="store_true", help="Set this flag to disable the system role. It will be combined with the user role")
train_size_group = parser.add_mutually_exclusive_group()
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
train_size_group.add_argument('--large', action='store_const', const='large', dest='size')
train_size_group.add_argument('--xl', action='store_const', const='xl', dest='size')
parser.add_argument('--use-service-names', action='store_true',
help='Use service names (e.g., light.turn_on) instead of intent tool names (e.g., HassTurnOn)')
args = parser.parse_args(args=args)
if not args.sample and not args.train and not args.test and not args.merge:
parser.print_usage()
exit(-1)
if args.size and not args.train:
print("Train size was provided but not generating the training set!")
exit(-1)
format_func = format_example_sharegpt
use_system_role = not args.no_system_role
use_service_names = args.use_service_names
for language in args.language:
piles = get_dataset_piles(language)
personas = list(piles.pile_of_system_prompts.keys())
suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample:
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=1, template_factor=1, status_request_factor=1)
if args.train:
if args.size == "small":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=1, template_factor=10, status_request_factor=8)
elif args.size == "medium":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=5, template_factor=15, status_request_factor=12)
elif args.size == "large":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=5, template_factor=20, status_request_factor=15)
elif args.size == "xl":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, use_service_names, personas, language, static_factor=7, template_factor=25, status_request_factor=18)
else:
raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test:
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, use_service_names, personas, language, static_factor=0.25, template_factor=1, status_request_factor=2)
if len(args.language) > 1:
if args.sample:
merge_languages("sample", args.language)
if args.train:
merge_languages("home_assistant_train", args.language)
if args.test:
merge_languages("home_assistant_test", args.language)
if __name__ == "__main__":
main()