mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
dataset tweak to remove duplicates in states block + add actual test split
This commit is contained in:
@@ -3,6 +3,7 @@ import csv
|
||||
import enum
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Final, Any
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -182,32 +183,54 @@ with open("pile_of_responses.csv") as f:
|
||||
pile_of_responses[raw["device_type"]][raw["service"]] = [ raw["response_1"], raw["response_2"], raw["response_3"] ]
|
||||
|
||||
# generate a random list of devices for the context
|
||||
def random_device_list(max_devices: int):
|
||||
num_devices = random.randint(int(max_devices / 2), max_devices)
|
||||
def random_device_list(max_devices: int, avoid_device_names: list[str]):
|
||||
num_devices = random.randint(2, max_devices)
|
||||
|
||||
choices = random.choices(pile_of_device_names, k=num_devices)
|
||||
local_device_names = { k: v[:] for k,v in stacks_of_device_names.items() }
|
||||
|
||||
for avoid_device in avoid_device_names:
|
||||
avoid_type = avoid_device.split(".")[0]
|
||||
|
||||
filtered_possible_devices = []
|
||||
for possible_device in local_device_names[avoid_type]:
|
||||
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
|
||||
|
||||
if similarity_ratio < 0.4:
|
||||
filtered_possible_devices.append(possible_device)
|
||||
local_device_names[avoid_type] = filtered_possible_devices
|
||||
|
||||
possible_choices = []
|
||||
for device_type in local_device_names.keys():
|
||||
possible_choices.extend(local_device_names[device_type])
|
||||
|
||||
|
||||
device_types = set()
|
||||
device_list = []
|
||||
for choice in choices:
|
||||
device_lines = []
|
||||
while len(device_list) < num_devices:
|
||||
choice = random.choice(possible_choices)
|
||||
if choice["device_name"] in device_list:
|
||||
continue
|
||||
|
||||
try:
|
||||
device_name = choice["device_name"]
|
||||
device_type = device_name.split(".")[0]
|
||||
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state()
|
||||
device_list.append(f"{device_name} - {state}")
|
||||
device_lines.append(f"{device_name} - {state}")
|
||||
device_list.append(device_name)
|
||||
device_types.add(device_type)
|
||||
except:
|
||||
print(f"bad device name: {choice}")
|
||||
|
||||
return device_list, list(device_types)
|
||||
return device_lines, list(device_types)
|
||||
|
||||
def generate_static_example(action: dict):
|
||||
def generate_static_example(action: dict, max_devices: int = 32):
|
||||
question = action["english_phrase"]
|
||||
target_devices = action["device_name"].split("|")
|
||||
service_names = action["service_name"].split("|")
|
||||
|
||||
device_list, device_types = random_device_list(max_devices=32)
|
||||
device_list, device_types = random_device_list(max_devices=max_devices, avoid_device_names=target_devices)
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
for device in target_devices:
|
||||
@@ -239,27 +262,27 @@ def generate_static_example(action: dict):
|
||||
"service_calls": service_calls
|
||||
}
|
||||
|
||||
def generate_templated_example(template: dict):
|
||||
def generate_templated_example(template: dict, max_devices: int = 32):
|
||||
template_device_types: list[str] = template["device_type"].split("|")
|
||||
service_names: list[str] = template["service"].split("|")
|
||||
question_template: str = template["english_phrase"]
|
||||
answer_template: str = template["assistant_response"]
|
||||
|
||||
device_list, device_types = random_device_list(max_devices=32)
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
# choose a random device for this template
|
||||
chosen_devices = []
|
||||
for device_type in template_device_types:
|
||||
device_dict = random.choice(stacks_of_device_names[device_type])
|
||||
device_dict["type"] = device_type
|
||||
chosen_devices.append(device_dict)
|
||||
|
||||
device = device_dict["device_name"]
|
||||
device_list, device_types = random_device_list(max_devices=max_devices, avoid_device_names=[d["device_name"] for d in chosen_devices])
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
for device_dict in chosen_devices:
|
||||
index = random.randint(0, len(device_list))
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state()
|
||||
state = SUPPORTED_DEVICES[device_dict["type"]].get_random_state()
|
||||
|
||||
device_list.insert(index, f"{device} - {state}")
|
||||
device_list.insert(index, f"{device_dict['device_name']} - {state}")
|
||||
|
||||
# gather a list of all available services
|
||||
available_services = set()
|
||||
@@ -300,30 +323,31 @@ def format_example(example):
|
||||
return "\n".join([sys_prompt, services_block, states_block, example["question"], answers, code_block])
|
||||
|
||||
|
||||
# TODO: add examples for ambiguous requests. asking a clarifying question
|
||||
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
|
||||
def main():
|
||||
random.seed(42)
|
||||
def generate_example_file(filename: str, seed: int, *, static_factor: int, template_factor: int):
|
||||
random.seed(seed)
|
||||
|
||||
print("Generating...")
|
||||
|
||||
STATIC_FACTOR = 5
|
||||
TEMPLATE_FACTOR = 40
|
||||
|
||||
examples = []
|
||||
for action in tqdm(pile_of_device_actions):
|
||||
for i in range(STATIC_FACTOR):
|
||||
for i in range(static_factor):
|
||||
examples.append({ "text": format_example(generate_static_example(action)) })
|
||||
|
||||
for templated_action in tqdm(pile_of_templated_actions):
|
||||
for i in range(TEMPLATE_FACTOR):
|
||||
for i in range(template_factor):
|
||||
examples.append({ "text": format_example(generate_templated_example(templated_action)) })
|
||||
|
||||
print(f"Generated {len(examples)} examples. Saving...")
|
||||
with open("home_assistant_examples.json", "w") as f:
|
||||
with open(f"{filename}.json", "w") as f:
|
||||
json.dump(examples, f, indent=4)
|
||||
|
||||
print("Done!")
|
||||
|
||||
# TODO: add examples for ambiguous requests. asking a clarifying question
|
||||
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
|
||||
def main():
|
||||
generate_example_file("home_assistant_train", 42, static_factor=3, template_factor=40)
|
||||
generate_example_file("home_assistant_test", 42, static_factor=1, template_factor=3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
1901
data/home_assistant_test.json
Normal file
1901
data/home_assistant_test.json
Normal file
File diff suppressed because it is too large
Load Diff
16022
data/home_assistant_train.json
Normal file
16022
data/home_assistant_train.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,8 +0,0 @@
|
||||
You are a helpful AI Assistant that controls the devices in a house. Complete the following task ask instructed with the information provided only.
|
||||
Services: turn_off, media_previous_track, volume_up, decrease_speed, toggle, media_pause, media_play_pause, stop_cover, media_next_track, lock, volume_down, increase_speed, close_cover, turn_on, unlock, media_play, volume_mute, media_stop, open_cover
|
||||
States:
|
||||
light.master_bedroom_lamp - off
|
||||
fan.kitchen - off
|
||||
light.kitchen_1 - off
|
||||
light.bedroom_guest - off
|
||||
turn off the light in the kitchen
|
||||
3
train.py
3
train.py
@@ -29,8 +29,7 @@ def tokenize_function(example):
|
||||
return result
|
||||
|
||||
|
||||
datasets = load_dataset("json", data_files="data/home_assistant_examples.json")
|
||||
datasets = datasets["train"].train_test_split(test_size=0.1)
|
||||
datasets = load_dataset("json", data_files={ "train": "data/home_assistant_train.json", "test": "data/home_assistant_test.json" })
|
||||
tokenized_train_dataset = datasets["train"].map(tokenize_function, remove_columns=datasets["train"].column_names)
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize_function, remove_columns=datasets["test"].column_names)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user