From eef7c1b9328943fab0297a25acea0557c8bbe4e1 Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Thu, 11 Jan 2024 19:05:28 -0500 Subject: [PATCH] add RGB + brightness to lighting requests and clean up dataset generation script --- data/README.md | 6 +- data/generate_home_assistant_data.py | 109 +++++++++++++++++++---- data/merge_with_alpaca.py | 43 --------- data/piles/pile_of_templated_actions.csv | 25 +++++- requirements.txt | 3 +- 5 files changed, 119 insertions(+), 67 deletions(-) delete mode 100644 data/merge_with_alpaca.py diff --git a/data/README.md b/data/README.md index fb1ea01..3a9d68d 100644 --- a/data/README.md +++ b/data/README.md @@ -1,9 +1,11 @@ # Dataset +The dataset is generated from the different CSV "piles". The "piles" contain different chunks of requests that are assembled into a final context that is presented to the LLM. For example, `piles/pile_of_device_names.csv` contains only names of various devices to be used as part of context as well as inserted into `piles/pile_of_templated_actions.csv` and `piles/pile_of_status_requests.csv`. The logic for assembling the final dataset from the piles is contained in [generate_home_assistant_data.py](./generate_home_assistant_data.py). + ## Generating the custom dataset -`python3 generate_home_assistant_data.py` +`python3 generate_home_assistant_data.py --train --test` ## Merging with Alpaca for training -`python3 merge_with_alpaca.py` \ No newline at end of file +`python3 generate_home_assistant_data.py --merge-alpaca` \ No newline at end of file diff --git a/data/generate_home_assistant_data.py b/data/generate_home_assistant_data.py index 7272c55..8dce376 100644 --- a/data/generate_home_assistant_data.py +++ b/data/generate_home_assistant_data.py @@ -1,8 +1,9 @@ +import argparse import json import csv -import enum import random from dataclasses import dataclass +from datasets import load_dataset, concatenate_datasets from difflib import SequenceMatcher from typing import Final, Any from tqdm import tqdm @@ -30,6 +31,16 @@ STATE_UNAVAILABLE: Final = "unavailable" STATE_OK: Final = "ok" STATE_PROBLEM: Final = "problem" +def closest_color(requested_color): + min_colors = {} + for key, name in webcolors.CSS3_HEX_TO_NAMES.items(): + r_c, g_c, b_c = webcolors.hex_to_rgb(key) + rd = (r_c - requested_color[0]) ** 2 + gd = (g_c - requested_color[1]) ** 2 + bd = (b_c - requested_color[2]) ** 2 + min_colors[(rd + gd + bd)] = name + return min_colors[min(min_colors.keys())] + @dataclass class DeviceType: name: str @@ -55,22 +66,17 @@ class LightDeviceType(DeviceType): ], ) - def closest_color(requested_color): - min_colors = {} - for key, name in webcolors.CSS3_HEX_TO_NAMES.items(): - r_c, g_c, b_c = webcolors.hex_to_rgb(key) - rd = (r_c - requested_color[0]) ** 2 - gd = (g_c - requested_color[1]) ** 2 - bd = (b_c - requested_color[2]) ** 2 - min_colors[(rd + gd + bd)] = name - return min_colors[min(min_colors.keys())] - def get_random_state(self, force_rgb=False): + + def get_random_state(self, force_rgb=False, force_brightness=False): state = super().get_random_state() if random.random() < 0.05 or force_rgb: random_rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) - state = state + ";" + self.closest_color(random_rgb) + ";" + random_rgb + state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb) + + if random.random() < 0.3 or force_brightness: + state = state + ";" + str(random.randint(0, 100)) + "%" return state @@ -395,6 +401,23 @@ def generate_templated_example(template: dict, max_devices: int = 32): answer = answer.replace("", str(humidity)) service_calls = [ { **call, "humidity": humidity} for call in service_calls ] + if any(["light" in service for service in service_names ]): + if "" in question: + brightness = random.randint(0, 100) + question = question.replace("", str(brightness)) + answer = answer.replace("", str(brightness)) + service_calls = [ { **call, "brightness_pct": brightness} for call in service_calls ] + + if "" in question: + random_rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + random_rgb_name = closest_color(random_rgb) + actual_random_rgb = webcolors.name_to_rgb(random_rgb_name) + actual_random_rgb = (actual_random_rgb.red, actual_random_rgb.green, actual_random_rgb.blue) + question = question.replace("", str(random_rgb_name)) + answer = answer.replace("", str(random_rgb_name)) + service_calls = [ { **call, "rgb_color": str(actual_random_rgb) } for call in service_calls ] + + return { "states": device_list, "available_services": list(available_services), @@ -496,16 +519,68 @@ def generate_example_file(filename: str, seed: int, *, static_factor: int, templ print("Done!") +def format_alpaca(example): + question = example["instruction"] + if example["input"]: + question = question = "\n" + example["input"] + + answer = example["output"] + + device_list, device_types = random_device_list(max_devices=32, avoid_device_names=[]) + + available_services = [] + for x in device_types: + available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ]) + + text = format_example(example={ + "states": device_list, + "available_services": list(available_services), + "question": question, + "answers": [ answer ], + "service_calls": [] + }) + + result = { + "text": text + } + + return result + +def merge_with_dataset(dataset_name, seed, outupt_name, format_function): + alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1) + home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.json", "test": "home_assistant_test.json" }) + + random.seed(seed) + + alpaca_dataset = alpaca_dataset.map(format_function).remove_columns(["input", "output", "instruction"]) + + combined_dataset_train = concatenate_datasets([home_assistant_dataset["train"], alpaca_dataset["train"]]).shuffle(seed=42) + combined_dataset_test = concatenate_datasets([home_assistant_dataset["test"], alpaca_dataset["test"]]).shuffle(seed=42) + + combined_dataset_train.to_json(f"home_assistant_{outupt_name}_merged_train.json") + combined_dataset_test.to_json(f"home_assistant_{outupt_name}_merged_test.json") + + # TODO: add examples for ambiguous requests. asking a clarifying question # TODO: make more randomized names for devices (random words or people's names) # TODO: answer questions about more than one thing in the state list at once # TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen" -# TODO: setup argparse to configure which partition to generate -# TODO: merge this with the alpaca merge script + support other datasets to merge with def main(): - generate_example_file("sample", 42, static_factor=1, template_factor=1, status_request_factor=1) - # generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=20, status_request_factor=15) - # generate_example_file("home_assistant_test", 12345, static_factor=0.25, template_factor=3, status_request_factor=2) + parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles") + parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.") + parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..") + parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.") + parser.add_argument("--merge-alpaca", action="store_true", help="Set this flag to merge the generated datasets with the alpaca-cleaned dataset.") + args = parser.parse_args() + + if args.sample: + generate_example_file("sample", 42, static_factor=1, template_factor=1, status_request_factor=1) + if args.train: + generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=20, status_request_factor=15) + if args.test: + generate_example_file("home_assistant_test", 12345, static_factor=0.25, template_factor=3, status_request_factor=2) + if args.merge_alpaca: + merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca) if __name__ == "__main__": main() \ No newline at end of file diff --git a/data/merge_with_alpaca.py b/data/merge_with_alpaca.py deleted file mode 100644 index da843bb..0000000 --- a/data/merge_with_alpaca.py +++ /dev/null @@ -1,43 +0,0 @@ -import random -from datasets import load_dataset, concatenate_datasets -from generate_home_assistant_data import format_example, random_device_list, SUPPORTED_DEVICES - -alpaca_dataset = load_dataset("yahma/alpaca-cleaned")["train"].train_test_split(test_size=0.1) -home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.json", "test": "home_assistant_test.json" }) - -random.seed(42) - -def format_alpaca(example): - question = example["instruction"] - if example["input"]: - question = question = "\n" + example["input"] - - answer = example["output"] - - device_list, device_types = random_device_list(max_devices=32, avoid_device_names=[]) - - available_services = [] - for x in device_types: - available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ]) - - text = format_example(example={ - "states": device_list, - "available_services": list(available_services), - "question": question, - "answers": [ answer ], - "service_calls": [] - }) - - result = { - "text": text - } - - return result - -alpaca_dataset = alpaca_dataset.map(format_alpaca).remove_columns(["input", "output", "instruction"]) - -combined_dataset_train = concatenate_datasets([home_assistant_dataset["train"], alpaca_dataset["train"]]).shuffle(seed=42) -combined_dataset_test = concatenate_datasets([home_assistant_dataset["test"], alpaca_dataset["test"]]).shuffle(seed=42) - -combined_dataset_train.to_json("home_assistant_alpaca_merged_train.json") -combined_dataset_test.to_json("home_assistant_alpaca_merged_test.json") \ No newline at end of file diff --git a/data/piles/pile_of_templated_actions.csv b/data/piles/pile_of_templated_actions.csv index 6448678..bf04f0c 100644 --- a/data/piles/pile_of_templated_actions.csv +++ b/data/piles/pile_of_templated_actions.csv @@ -183,14 +183,31 @@ climate,set_temperature,"I want the room cooler, set it to degrees.","S climate,set_temperature,"Make it warmer, set temperature at degrees.","Making it warmer, setting temperature to degrees." climate,set_temperature,"Can you lower the temperature to ?","Lowering the temperature to Celsius." climate,set_temperature,"Raise the temperature to degrees, please.","Raising the temperature to degrees Fahrenheit." -climate,set_humidity,"Increase the humidity to %.","Increasing humidity to %." +climate,set_humidity,"Increase the humidity to .","Increasing humidity to ." climate,set_humidity,"Set the humidity level to percent.","Setting humidity to percent." -climate,set_humidity,"Can you adjust the humidity to %?","Adjusting humidity to %." +climate,set_humidity,"Can you adjust the humidity to percent?","Adjusting humidity to %." climate,set_fan_mode,"Set the fan to high speed.","Setting the fan to high speed." climate,set_fan_mode,"Please put the fan on low.","Putting the fan on low." climate,set_fan_mode,"Change the fan setting to medium.","Changing the fan to medium setting." climate,set_hvac_mode,"Switch the system to cooling mode.","Switching to cooling mode." climate,set_hvac_mode,"Can we set the HVAC to heat?","Setting the HVAC to heat." climate,set_hvac_mode,"Change the HVAC to automatic.","Changing HVAC to automatic mode." - -# TODO: add requests to set the color of lights \ No newline at end of file +light,turn_on,"Set the brightness of to %.","Setting the brightness of to %." +light,turn_on,"Can you make ?","Turning ." +light,turn_on,"Change the color of to .","Changing the color of to ." +light,turn_on,"Dim to percent brightness.","Dimming to % brightness." +light,turn_on,"Please set to a color.","Setting to a color." +light,turn_on,"Brighten to .","Brightening to %." +light,turn_on,"Turn .","Turning ." +light,turn_on,"I want at a setting.","Setting to a ." +light,turn_on,"Set to a color.","Setting to a color." +light,turn_on,"Adjust brightness to .","Adjusting brightness to %." +light,turn_on,"Increase 's brightness to .","Increasing 's brightness to %." +light,turn_on,"Lower the brightness of to .","Lowering the brightness of to %." +light,turn_on,"Can you set 's brightness level to percent?","Setting 's brightness level to %." +light,turn_on,"I'd like at percent brightness, please.","Setting to % brightness." +light,turn_on,"Change to a hue.","Changing to a hue." +light,turn_on,"Set to be .","Setting to be ." +light,turn_on,"I want to be , please.","Setting to be ." +light,turn_on,"Can you make shine in ?","Making shine in ." +light,turn_on,"Turn to a shade.","Turning to a shade." \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0158e8a..8ccd2ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ transformers tensorboard peft bitsandbytes -datasets \ No newline at end of file +datasets +webcolors \ No newline at end of file