diff --git a/data/devices.py b/data/devices.py index 1d2dcdd..5f04410 100644 --- a/data/devices.py +++ b/data/devices.py @@ -1,10 +1,11 @@ +from collections import defaultdict import random from dataclasses import dataclass from typing import Final, Callable, List from difflib import SequenceMatcher from tools import * -from utils import closest_color, generate_random_parameter, get_dataset_piles +from utils import PileOfDeviceType, closest_color, generate_random_parameter, get_dataset_piles # STATES STATE_ON: Final = "on" @@ -40,7 +41,7 @@ class DeviceType: name: str possible_states: list[tuple[str, float]] - def get_random_state(self, language: str, extra_exposed_attributes=[]): + def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None): states = [ x[0] for x in self.possible_states ] weights = [ x[1] for x in self.possible_states ] return random.choices(states, weights=weights, k=1)[0] @@ -64,7 +65,8 @@ class LightDeviceType(DeviceType): ] ) - def get_random_state(self, language: str, extra_exposed_attributes=[]): + def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None): + extra_exposed_attributes = extra_exposed_attributes or [] state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes) if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes: @@ -171,8 +173,9 @@ class ClimateDeviceType(DeviceType): def __init__(self): super().__init__("climate", []) - def get_random_state(self, language: str, extra_exposed_attributes=[]): + def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None): """state;fan_mode;temperature;humidity""" + extra_exposed_attributes = extra_exposed_attributes or [] state = generate_random_parameter("hvac_mode", get_dataset_piles(language)) if "fan_mode" in extra_exposed_attributes: @@ -211,7 +214,8 @@ class MediaPlayerDeviceType(DeviceType): (STATE_BUFFERING, 0.01), ]) - def get_random_state(self, language: str, extra_exposed_attributes=[]): + def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None): + extra_exposed_attributes = extra_exposed_attributes or [] state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes) if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]: @@ -228,7 +232,7 @@ class MediaPlayerDeviceType(DeviceType): return tools -SUPPORTED_DEVICES = { +SUPPORTED_DEVICES: dict[str, DeviceType] = { "light": LightDeviceType(), "switch": SwitchDeviceType(), "fan": FanDeviceType(), @@ -264,14 +268,14 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language if avoid_type == "climate": avoid_climate = True - possible_choices = [] + possible_choices: list[PileOfDeviceType] = [] for device_type in local_device_names.keys(): possible_choices.extend(local_device_names[device_type]) - device_types = set() + device_types: set[str] = set() device_list = [] - device_lines = [] + device_lines: list[str] = [] # TODO: randomly pick attributes for this list extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"] @@ -301,4 +305,4 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language print(f"bad device name: {choice}") print(repr(ex)) - return device_lines, list(device_types), list(extra_exposed_attributes) \ No newline at end of file + return device_lines, list(device_types), list(extra_exposed_attributes) diff --git a/data/generate_data.py b/data/generate_data.py index eba4f0b..edd3d35 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -3,7 +3,7 @@ import json import numpy as np import random from datasets import load_dataset, concatenate_datasets -from typing import Callable +from typing import Any, Callable from tqdm import tqdm import webcolors @@ -17,10 +17,19 @@ from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \ TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \ HASS_TOOLS, SERVICE_TOOLS from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT -from utils import get_random_response, generate_random_parameter, closest_color, \ +from utils import AssistantTurn, DatasetEntry, Example, PileOfDeviceType, PileOfFailedToolcallType, PileOfRefusalsType, PileOfSpecificActionType, PileOfStatusRequestType, PileOfTemplatedActionType, ToolCall, ToolResult, get_random_response, generate_random_parameter, closest_color, \ get_dataset_piles, NoResponseAvailableException -def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): +def create_assistant_turn(answer: str, tool_call_sequence: list[ToolCall] | None = None, *, tool_results: list[ToolResult] | None = None, train_on_turn: bool = True) -> AssistantTurn: + """Bundle the assistant utterance with any tool interaction for that turn.""" + return { + "answer": answer, + "tool_call_sequence": tool_call_sequence or [], + "tool_results": tool_results if tool_results is not None else [], + "train_on_turn": train_on_turn, + } + +def generate_static_example(action: PileOfSpecificActionType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example: question = action["phrase"] service_name = action["service_name"] device_type = service_name.split(".")[0] @@ -42,7 +51,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic )) # gather a list of all available tools - available_tools = [] + available_tools: list[str] = [] for x in set(device_types + [device_type]): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) @@ -130,13 +139,13 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic tool_args["item"] = todo if use_service_names: - tool_call = { + tool_call: ToolCall = { "tool_name": tool_name, "service_name": service_name, "tool_args": {"entity_id": target_device, **tool_args} } else: - tool_call = { + tool_call: ToolCall = { "tool_name": tool_name, "service_name": service_name, "tool_args": {"name": target_device, **tool_args} @@ -163,33 +172,22 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic "assistant_turns": assistant_turns } -def replace_answer(list_of_answer, var, value): - new_list = [] +def replace_answer(list_of_answer: list[str], var: str, value: str): + new_list: list[str] = [] for answer in list_of_answer: new_list.append(answer.replace(var, value)) return new_list - -def create_assistant_turn(answer: str, tool_call_sequence=None, *, tool_results=None, train_on_turn: bool = True): - """Bundle the assistant utterance with any tool interaction for that turn.""" - return { - "answer": answer, - "tool_call_sequence": tool_call_sequence or [], - "tool_results": tool_results if tool_results is not None else [], - "train_on_turn": train_on_turn, - } - -def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): +def generate_templated_example(template: PileOfTemplatedActionType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example: template_device_types: list[str] = template["device_type"].split("|") service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ] question_template: str = template["phrase"] piles = get_dataset_piles(language) # choose a random device for this template - chosen_devices = [] + chosen_devices: list[PileOfDeviceType] = [] for device_type in template_device_types: device_dict = random.choice(piles.stacks_of_device_names[device_type]) - device_dict["type"] = device_type chosen_devices.append(device_dict) device_list, device_types, extra_exposed_attributes = random_device_list( @@ -223,7 +221,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_ )) # gather a list of all available tools - available_tools = [] + available_tools: list[str] = [] for x in set(device_types + template_device_types): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) @@ -264,11 +262,11 @@ def generate_templated_example(template: dict, persona: str, language: str, max_ answer_list.append(f" {word} ".join(answers)) # generate the list of tool calls - tool_calls = [] + tool_calls: list[ToolCall] = [] for device_dict, service in zip(chosen_devices, service_names): service_action = service.split(".")[1] tool_name = SERVICE_TO_TOOL_MAP[service_action] - tool_call = { + tool_call: ToolCall = { "tool_name": tool_name, "service_name": service, "tool_args": {"entity_id" if use_service_names else "name": device_dict["device_name"] if use_service_names else device_dict["description"]} @@ -369,7 +367,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_ "assistant_turns": assistant_turns } -def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False): +def generate_status_request(template: PileOfStatusRequestType, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False) -> Example | tuple[Example, PileOfDeviceType]: device_type: str = template["device_type"] state_name: str = template["state"] question_template: str = template["phrase"] @@ -447,7 +445,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev )) # gather a list of all available tools - available_tools = [] + available_tools: list[str] = [] for x in set(device_types + [device_type]): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) @@ -456,7 +454,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev assistant_turns = [create_assistant_turn(answer.lower(), [])] - result = { + result: Example = { "states": device_list, "available_tools": available_tools, "question": question.lower(), @@ -467,7 +465,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev else: return result -def generate_tool_failure_example(failure_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): +def generate_tool_failure_example(failure_case: PileOfFailedToolcallType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example: piles = get_dataset_piles(language) service_name = failure_case["service_name"] device_type = service_name.split(".")[0] @@ -491,7 +489,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st if device_type not in device_types: device_types.append(device_type) - available_tools = [] + available_tools: list[str] = [] for x in set(device_types): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) available_tools = list(dict.fromkeys(available_tools)) @@ -506,7 +504,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st response_starting = response_starting.replace("", friendly_name) response_confirmed = response_confirmed.replace("", friendly_name) - tool_args_extra = {} + tool_args_extra: dict[str, Any] = {} if device_type == "climate": if "" in question or "" in response_starting or "" in response_confirmed: temp_f = generate_random_parameter("temp_f", piles) @@ -562,7 +560,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st "assistant_turns": [first_turn, second_turn, final_turn] } -def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): +def generate_refusal_example(refusal_case: PileOfRefusalsType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example: service_name = refusal_case["service_name"] device_type = service_name.split(".")[0] target_device = f"{device_type}.{refusal_case['device_name']}" @@ -583,7 +581,7 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma if device_type not in device_types: device_types.append(device_type) - available_tools = [] + available_tools: list[str] = [] for x in set(device_types): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) available_tools = list(dict.fromkeys(available_tools)) @@ -600,7 +598,7 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma "assistant_turns": assistant_turns } -def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format): +def format_example_sharegpt(example: Example, persona: str, language: str, use_system_role: bool, append_user_instruction_prompt: bool, use_service_names: bool, tool_response_format: str) -> DatasetEntry: piles = get_dataset_piles(language) sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts) question = example["question"] @@ -700,7 +698,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_ def generate_sft_file( filename: str, seed: int, - format_func: Callable, + format_func: Callable[[Example, str, str, bool, bool, bool, str], DatasetEntry], use_system_role: bool, append_user_instruction_prompt: bool, use_service_names: bool, @@ -719,15 +717,15 @@ def generate_sft_file( print("Generating...") - def run_factor_times(func, examples, data, persona, factor, language): + def run_factor_times(func: Callable[..., Example], examples: list[DatasetEntry], data, persona: str, factor: int | float, language: str): if factor >= 1: - for i in range(factor): + for i in range(int(factor)): examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format)) else: if random.random() < factor: examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format)) - generated_examples = [] + generated_examples: list[DatasetEntry] = [] missing_responses = set() @@ -774,7 +772,7 @@ def generate_sft_file( def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset_column_names, format_func): alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1) - home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" }) + home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" }) random.seed(seed) np.random.seed(seed) diff --git a/data/prompting.py b/data/prompting.py index f2041e9..ab724b5 100644 --- a/data/prompting.py +++ b/data/prompting.py @@ -1,6 +1,6 @@ import babel.dates -from utils import generate_random_datetime +from utils import Example, generate_random_datetime CURRENT_DATE_PROMPT = { "english": "The current time and date is", @@ -51,7 +51,7 @@ USER_INSTRUCTION_PROMPT = { } -def generate_system_prompt(example: dict, persona: str, language: str, pile_of_system_prompts: dict) -> str: +def generate_system_prompt(example: Example, persona: str, language: str, pile_of_system_prompts: dict[str, str]) -> str: sys_prompt = pile_of_system_prompts[persona] random_datetime = generate_random_datetime() translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language]) diff --git a/data/utils.py b/data/utils.py index 63f88a8..a3666cc 100644 --- a/data/utils.py +++ b/data/utils.py @@ -2,6 +2,7 @@ import random import re import os import csv +from typing import Any, TypedDict import pandas from datetime import datetime, timedelta import webcolors @@ -13,8 +14,8 @@ class NoServicesAvailableException(Exception): pass -def closest_color(requested_color): - min_colors = {} +def closest_color(requested_color: tuple[int, int, int]): + min_colors: dict[int, str] = {} color_names = webcolors.names("css3") for name in color_names: @@ -44,7 +45,7 @@ def get_included_vars(response: str): return ",".join(sorted(result)) -def generate_random_parameter(param_name, piles_of_data): +def generate_random_parameter(param_name: str, piles_of_data: "DatasetPiles"): RANDOM_PARAMETER_GENERATORS = { "rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), "brightness": lambda: random.randint(0, 100), @@ -67,7 +68,7 @@ def generate_random_parameter(param_name, piles_of_data): return param_generator() -def get_random_response(pile_of_responses, *, service: str, persona: str, question_template: str, short: bool) -> tuple[str, str]: +def get_random_response(pile_of_responses: pandas.DataFrame, *, service: str, persona: str, question_template: str, short: bool) -> tuple[str, str]: required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var])) @@ -82,6 +83,58 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0] + +class PileOfDeviceType(TypedDict): + device_name: str + description: str + type: str + + +class PileOfSpecificActionType(TypedDict): + service_name: str + device_name: str + phrase: str + + +class PileOfTemplatedActionType(TypedDict): + device_type: str + service: str + phrase: str + multiplier: int + + +class PileOfStatusRequestType(TypedDict): + device_type: str + state: str + phrase: str + assistant_response: str + + +class PileOfHallucinatedServiceType(TypedDict): + real_service: str + hallucinated_service: str + + +class PileOfFailedToolcallType(TypedDict): + service_name: str + correct_device_name: str + correct_friendly_name: str + bad_device_name: str + phrase: str + error_result: str + retry_prompt: str + + +class PileOfRefusalsType(TypedDict): + reason_type: str + service_name: str + device_name: str + friendly_name: str + desired_state: str + phrase: str + response: str + + class DatasetPiles: def __init__(self, supported_devices, language="english"): self.language = language @@ -93,7 +146,7 @@ class DatasetPiles: with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_durations = { x["duration"]: x["name"] for x in reader } + self.pile_of_durations: dict[str, str] = { x["duration"]: x["name"] for x in reader } # media names are not translated with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f: @@ -102,14 +155,15 @@ class DatasetPiles: with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f: self.pile_of_todo_items = [ x.strip() for x in f.readlines() ] - self.stacks_of_device_names = { x: [] for x in supported_devices } + self.stacks_of_device_names: dict[str, list[PileOfDeviceType]] = { x: [] for x in supported_devices } with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f: reader = csv.DictReader(f) pile_of_device_names = list(reader) for device_dict in pile_of_device_names: try: device_type = device_dict["device_name"].split(".")[0] - self.stacks_of_device_names[device_type].append(device_dict) + device_dict["type"] = device_type + self.stacks_of_device_names[device_type].append(device_dict) # type: ignore except KeyError as ex: print(ex) @@ -125,41 +179,41 @@ class DatasetPiles: for x in range(multiplier): processed_pile_of_templated_actions.append(action) - self.pile_of_templated_actions = processed_pile_of_templated_actions + self.pile_of_templated_actions: list[PileOfTemplatedActionType] = processed_pile_of_templated_actions with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_specific_actions = list(reader) + self.pile_of_specific_actions: list[PileOfSpecificActionType] = list(reader) # type: ignore self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv") self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars) with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_status_requests = list(reader) + self.pile_of_status_requests: list[PileOfStatusRequestType] = list(reader) # type: ignore with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader } + self.pile_of_system_prompts: dict[str, str] = { line["persona"]: line["prompt"] for line in reader } # service names are not translated with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_hallucinated_service_names = list(reader) + self.pile_of_hallucinated_service_names: list[PileOfHallucinatedServiceType] = list(reader) # type: ignore failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv" self.pile_of_failed_tool_calls = [] if os.path.exists(failed_tool_calls_path): with open(failed_tool_calls_path, encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_failed_tool_calls = list(reader) + self.pile_of_failed_tool_calls: list[PileOfFailedToolcallType] = list(reader) # type: ignore refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv" self.pile_of_refusals = [] if os.path.exists(refusals_path): with open(refusals_path, encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_refusals = list(reader) + self.pile_of_refusals: list[PileOfRefusalsType] = list(reader) # type: ignore def __getitem__(self, key): return getattr(self, key) @@ -173,3 +227,33 @@ def get_dataset_piles(language: str) -> DatasetPiles: "lock","media_player", "climate", "vacuum", "timer", "todo", ], language) return _piles_cache[language] + + + +class ToolCall(TypedDict): + tool_name: str + service_name: str + tool_args: dict[str, Any] + + +class ToolResult(TypedDict): + tool_name: str + tool_result: str + +class AssistantTurn(TypedDict): + answer: str + tool_call_sequence: list[ToolCall] + tool_results: list[ToolResult] + train_on_turn: bool + + +class Example(TypedDict): + states: list[str] + available_tools: list[str] + question: str + assistant_turns: list[AssistantTurn] + + +class DatasetEntry(TypedDict): + messages: list[dict[str, Any]] + tools: list[dict[str, Any]]