From b955897c18f07e75bcce3c5b6a8d60182d00d969 Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Tue, 30 Dec 2025 18:31:47 +0100 Subject: [PATCH 01/10] make TypedDict types for piles --- data/generate_data.py | 1 - data/utils.py | 74 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 64 insertions(+), 11 deletions(-) diff --git a/data/generate_data.py b/data/generate_data.py index eba4f0b..c95a471 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -189,7 +189,6 @@ def generate_templated_example(template: dict, persona: str, language: str, max_ chosen_devices = [] for device_type in template_device_types: device_dict = random.choice(piles.stacks_of_device_names[device_type]) - device_dict["type"] = device_type chosen_devices.append(device_dict) device_list, device_types, extra_exposed_attributes = random_device_list( diff --git a/data/utils.py b/data/utils.py index 63f88a8..8dd6e21 100644 --- a/data/utils.py +++ b/data/utils.py @@ -2,6 +2,7 @@ import random import re import os import csv +from typing import TypedDict import pandas from datetime import datetime, timedelta import webcolors @@ -82,6 +83,58 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0] + +class DeviceType(TypedDict): + device_name: str + description: str + type: str + + +class SpecificActionType(TypedDict): + service_name: str + device_name: str + phrase: str + + +class TemplatedActionType(TypedDict): + device_type: str + service: str + phrase: str + multiplier: int + + +class StatusRequestType(TypedDict): + device_type: str + state: str + phrase: str + assistant_response: str + + +class HallucinatedServiceType(TypedDict): + real_service: str + hallucinated_service: str + + +class FailedToolcallType(TypedDict): + service_name: str + correct_device_name: str + correct_friendly_name: str + bad_device_name: str + phrase: str + error_result: str + retry_prompt: str + + +class RefusalsType(TypedDict): + reason_type: str + service_name: str + device_name: str + friendly_name: str + desired_state: str + phrase: str + response: str + + class DatasetPiles: def __init__(self, supported_devices, language="english"): self.language = language @@ -93,7 +146,7 @@ class DatasetPiles: with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_durations = { x["duration"]: x["name"] for x in reader } + self.pile_of_durations: dict[str, str] = { x["duration"]: x["name"] for x in reader } # media names are not translated with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f: @@ -102,14 +155,15 @@ class DatasetPiles: with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f: self.pile_of_todo_items = [ x.strip() for x in f.readlines() ] - self.stacks_of_device_names = { x: [] for x in supported_devices } + self.stacks_of_device_names: dict[str, list[DeviceType]] = { x: [] for x in supported_devices } with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f: reader = csv.DictReader(f) pile_of_device_names = list(reader) for device_dict in pile_of_device_names: try: device_type = device_dict["device_name"].split(".")[0] - self.stacks_of_device_names[device_type].append(device_dict) + device_dict["type"] = device_type + self.stacks_of_device_names[device_type].append(device_dict) # type: ignore except KeyError as ex: print(ex) @@ -125,41 +179,41 @@ class DatasetPiles: for x in range(multiplier): processed_pile_of_templated_actions.append(action) - self.pile_of_templated_actions = processed_pile_of_templated_actions + self.pile_of_templated_actions: list[TemplatedActionType] = processed_pile_of_templated_actions with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_specific_actions = list(reader) + self.pile_of_specific_actions: list[SpecificActionType] = list(reader) # type: ignore self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv") self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars) with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_status_requests = list(reader) + self.pile_of_status_requests: list[StatusRequestType] = list(reader) # type: ignore with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader } + self.pile_of_system_prompts: dict[str, str] = { line["persona"]: line["prompt"] for line in reader } # service names are not translated with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_hallucinated_service_names = list(reader) + self.pile_of_hallucinated_service_names: list[HallucinatedServiceType] = list(reader) # type: ignore failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv" self.pile_of_failed_tool_calls = [] if os.path.exists(failed_tool_calls_path): with open(failed_tool_calls_path, encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_failed_tool_calls = list(reader) + self.pile_of_failed_tool_calls: list[FailedToolcallType] = list(reader) # type: ignore refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv" self.pile_of_refusals = [] if os.path.exists(refusals_path): with open(refusals_path, encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_refusals = list(reader) + self.pile_of_refusals: list[RefusalsType] = list(reader) # type: ignore def __getitem__(self, key): return getattr(self, key) From c26ef953c6e65a02154210111eb6d77ad4e10c9e Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Tue, 30 Dec 2025 23:47:54 +0100 Subject: [PATCH 02/10] Do not use mutable objects as default --- data/devices.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/data/devices.py b/data/devices.py index 1d2dcdd..2b5ff87 100644 --- a/data/devices.py +++ b/data/devices.py @@ -40,7 +40,7 @@ class DeviceType: name: str possible_states: list[tuple[str, float]] - def get_random_state(self, language: str, extra_exposed_attributes=[]): + def get_random_state(self, language: str, extra_exposed_attributes: list | None = None): states = [ x[0] for x in self.possible_states ] weights = [ x[1] for x in self.possible_states ] return random.choices(states, weights=weights, k=1)[0] @@ -64,7 +64,8 @@ class LightDeviceType(DeviceType): ] ) - def get_random_state(self, language: str, extra_exposed_attributes=[]): + def get_random_state(self, language: str, extra_exposed_attributes: list | None = None): + extra_exposed_attributes = extra_exposed_attributes or [] state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes) if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes: @@ -171,8 +172,9 @@ class ClimateDeviceType(DeviceType): def __init__(self): super().__init__("climate", []) - def get_random_state(self, language: str, extra_exposed_attributes=[]): + def get_random_state(self, language: str, extra_exposed_attributes: list | None = None): """state;fan_mode;temperature;humidity""" + extra_exposed_attributes = extra_exposed_attributes or [] state = generate_random_parameter("hvac_mode", get_dataset_piles(language)) if "fan_mode" in extra_exposed_attributes: @@ -211,7 +213,8 @@ class MediaPlayerDeviceType(DeviceType): (STATE_BUFFERING, 0.01), ]) - def get_random_state(self, language: str, extra_exposed_attributes=[]): + def get_random_state(self, language: str, extra_exposed_attributes: list | None = None): + extra_exposed_attributes = extra_exposed_attributes or [] state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes) if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]: @@ -228,7 +231,7 @@ class MediaPlayerDeviceType(DeviceType): return tools -SUPPORTED_DEVICES = { +SUPPORTED_DEVICES: dict[str, DeviceType] = { "light": LightDeviceType(), "switch": SwitchDeviceType(), "fan": FanDeviceType(), From 80669b65220363f592ad79fffdb075d5b28c87f7 Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Tue, 30 Dec 2025 23:51:27 +0100 Subject: [PATCH 03/10] Rename pile types to not confuse with other type classes --- data/utils.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/data/utils.py b/data/utils.py index 8dd6e21..75e6747 100644 --- a/data/utils.py +++ b/data/utils.py @@ -84,38 +84,38 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0] -class DeviceType(TypedDict): +class PileOfDeviceType(TypedDict): device_name: str description: str type: str -class SpecificActionType(TypedDict): +class PileOfSpecificActionType(TypedDict): service_name: str device_name: str phrase: str -class TemplatedActionType(TypedDict): +class PileOfTemplatedActionType(TypedDict): device_type: str service: str phrase: str multiplier: int -class StatusRequestType(TypedDict): +class PileOfStatusRequestType(TypedDict): device_type: str state: str phrase: str assistant_response: str -class HallucinatedServiceType(TypedDict): +class PileOfHallucinatedServiceType(TypedDict): real_service: str hallucinated_service: str -class FailedToolcallType(TypedDict): +class PileOfFailedToolcallType(TypedDict): service_name: str correct_device_name: str correct_friendly_name: str @@ -125,7 +125,7 @@ class FailedToolcallType(TypedDict): retry_prompt: str -class RefusalsType(TypedDict): +class PileOfRefusalsType(TypedDict): reason_type: str service_name: str device_name: str @@ -155,7 +155,7 @@ class DatasetPiles: with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f: self.pile_of_todo_items = [ x.strip() for x in f.readlines() ] - self.stacks_of_device_names: dict[str, list[DeviceType]] = { x: [] for x in supported_devices } + self.stacks_of_device_names: dict[str, list[PileOfDeviceType]] = { x: [] for x in supported_devices } with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f: reader = csv.DictReader(f) pile_of_device_names = list(reader) @@ -179,18 +179,18 @@ class DatasetPiles: for x in range(multiplier): processed_pile_of_templated_actions.append(action) - self.pile_of_templated_actions: list[TemplatedActionType] = processed_pile_of_templated_actions + self.pile_of_templated_actions: list[PileOfTemplatedActionType] = processed_pile_of_templated_actions with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_specific_actions: list[SpecificActionType] = list(reader) # type: ignore + self.pile_of_specific_actions: list[PileOfSpecificActionType] = list(reader) # type: ignore self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv") self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars) with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_status_requests: list[StatusRequestType] = list(reader) # type: ignore + self.pile_of_status_requests: list[PileOfStatusRequestType] = list(reader) # type: ignore with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f: reader = csv.DictReader(f) @@ -199,21 +199,21 @@ class DatasetPiles: # service names are not translated with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_hallucinated_service_names: list[HallucinatedServiceType] = list(reader) # type: ignore + self.pile_of_hallucinated_service_names: list[PileOfHallucinatedServiceType] = list(reader) # type: ignore failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv" self.pile_of_failed_tool_calls = [] if os.path.exists(failed_tool_calls_path): with open(failed_tool_calls_path, encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_failed_tool_calls: list[FailedToolcallType] = list(reader) # type: ignore + self.pile_of_failed_tool_calls: list[PileOfFailedToolcallType] = list(reader) # type: ignore refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv" self.pile_of_refusals = [] if os.path.exists(refusals_path): with open(refusals_path, encoding="utf8") as f: reader = csv.DictReader(f) - self.pile_of_refusals: list[RefusalsType] = list(reader) # type: ignore + self.pile_of_refusals: list[PileOfRefusalsType] = list(reader) # type: ignore def __getitem__(self, key): return getattr(self, key) From 44b296e6df0db1a2cd2b8c6b866b80c79e87f3a0 Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Wed, 31 Dec 2025 00:57:12 +0100 Subject: [PATCH 04/10] correctly type get_random_state --- data/devices.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/data/devices.py b/data/devices.py index 2b5ff87..e8f47eb 100644 --- a/data/devices.py +++ b/data/devices.py @@ -1,10 +1,11 @@ +from collections import defaultdict import random from dataclasses import dataclass from typing import Final, Callable, List from difflib import SequenceMatcher from tools import * -from utils import closest_color, generate_random_parameter, get_dataset_piles +from utils import PileOfDeviceType, closest_color, generate_random_parameter, get_dataset_piles # STATES STATE_ON: Final = "on" @@ -40,7 +41,7 @@ class DeviceType: name: str possible_states: list[tuple[str, float]] - def get_random_state(self, language: str, extra_exposed_attributes: list | None = None): + def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None): states = [ x[0] for x in self.possible_states ] weights = [ x[1] for x in self.possible_states ] return random.choices(states, weights=weights, k=1)[0] @@ -64,7 +65,7 @@ class LightDeviceType(DeviceType): ] ) - def get_random_state(self, language: str, extra_exposed_attributes: list | None = None): + def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None): extra_exposed_attributes = extra_exposed_attributes or [] state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes) @@ -172,7 +173,7 @@ class ClimateDeviceType(DeviceType): def __init__(self): super().__init__("climate", []) - def get_random_state(self, language: str, extra_exposed_attributes: list | None = None): + def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None): """state;fan_mode;temperature;humidity""" extra_exposed_attributes = extra_exposed_attributes or [] state = generate_random_parameter("hvac_mode", get_dataset_piles(language)) @@ -213,7 +214,7 @@ class MediaPlayerDeviceType(DeviceType): (STATE_BUFFERING, 0.01), ]) - def get_random_state(self, language: str, extra_exposed_attributes: list | None = None): + def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None): extra_exposed_attributes = extra_exposed_attributes or [] state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes) @@ -267,14 +268,14 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language if avoid_type == "climate": avoid_climate = True - possible_choices = [] + possible_choices: list[PileOfDeviceType] = [] for device_type in local_device_names.keys(): possible_choices.extend(local_device_names[device_type]) device_types = set() device_list = [] - device_lines = [] + device_lines: list[str] = [] # TODO: randomly pick attributes for this list extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"] @@ -304,4 +305,4 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language print(f"bad device name: {choice}") print(repr(ex)) - return device_lines, list(device_types), list(extra_exposed_attributes) \ No newline at end of file + return device_lines, list(device_types), list(extra_exposed_attributes) From b3df3d5346a8b8ce951e30d806fb732b4316e762 Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Wed, 31 Dec 2025 00:58:22 +0100 Subject: [PATCH 05/10] correctly type example generation functions --- data/generate_data.py | 88 +++++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/data/generate_data.py b/data/generate_data.py index c95a471..54a3aa0 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -3,7 +3,7 @@ import json import numpy as np import random from datasets import load_dataset, concatenate_datasets -from typing import Callable +from typing import Any, Callable, TypedDict from tqdm import tqdm import webcolors @@ -17,10 +17,44 @@ from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \ TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \ HASS_TOOLS, SERVICE_TOOLS from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT -from utils import get_random_response, generate_random_parameter, closest_color, \ +from utils import PileOfDeviceType, PileOfFailedToolcallType, PileOfRefusalsType, PileOfSpecificActionType, PileOfStatusRequestType, PileOfTemplatedActionType, PileOfType, get_random_response, generate_random_parameter, closest_color, \ get_dataset_piles, NoResponseAvailableException -def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): + +class ToolCall(TypedDict): + tool_name: str + service_name: str + tool_args: dict[str, Any] + + +class ToolResult(TypedDict): + tool_name: str + tool_result: str + +class AssistantTurn(TypedDict): + answer: str + tool_call_sequence: list[ToolCall] + tool_results: list[ToolResult] + train_on_turn: bool + + +class Example(TypedDict): + states: list[str] + available_tools: list[str] + question: str + assistant_turns: list[AssistantTurn] + + +def create_assistant_turn(answer: str, tool_call_sequence: list[ToolCall] | None = None, *, tool_results: list[ToolResult] | None = None, train_on_turn: bool = True) -> AssistantTurn: + """Bundle the assistant utterance with any tool interaction for that turn.""" + return { + "answer": answer, + "tool_call_sequence": tool_call_sequence or [], + "tool_results": tool_results if tool_results is not None else [], + "train_on_turn": train_on_turn, + } + +def generate_static_example(action: PileOfSpecificActionType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example: question = action["phrase"] service_name = action["service_name"] device_type = service_name.split(".")[0] @@ -42,7 +76,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic )) # gather a list of all available tools - available_tools = [] + available_tools: list[str] = [] for x in set(device_types + [device_type]): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) @@ -130,13 +164,13 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic tool_args["item"] = todo if use_service_names: - tool_call = { + tool_call: ToolCall = { "tool_name": tool_name, "service_name": service_name, "tool_args": {"entity_id": target_device, **tool_args} } else: - tool_call = { + tool_call: ToolCall = { "tool_name": tool_name, "service_name": service_name, "tool_args": {"name": target_device, **tool_args} @@ -169,24 +203,14 @@ def replace_answer(list_of_answer, var, value): new_list.append(answer.replace(var, value)) return new_list - -def create_assistant_turn(answer: str, tool_call_sequence=None, *, tool_results=None, train_on_turn: bool = True): - """Bundle the assistant utterance with any tool interaction for that turn.""" - return { - "answer": answer, - "tool_call_sequence": tool_call_sequence or [], - "tool_results": tool_results if tool_results is not None else [], - "train_on_turn": train_on_turn, - } - -def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): +def generate_templated_example(template: PileOfTemplatedActionType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example: template_device_types: list[str] = template["device_type"].split("|") service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ] question_template: str = template["phrase"] piles = get_dataset_piles(language) # choose a random device for this template - chosen_devices = [] + chosen_devices: list[PileOfDeviceType] = [] for device_type in template_device_types: device_dict = random.choice(piles.stacks_of_device_names[device_type]) chosen_devices.append(device_dict) @@ -222,7 +246,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_ )) # gather a list of all available tools - available_tools = [] + available_tools: list[str] = [] for x in set(device_types + template_device_types): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) @@ -263,11 +287,11 @@ def generate_templated_example(template: dict, persona: str, language: str, max_ answer_list.append(f" {word} ".join(answers)) # generate the list of tool calls - tool_calls = [] + tool_calls: list[ToolCall] = [] for device_dict, service in zip(chosen_devices, service_names): service_action = service.split(".")[1] tool_name = SERVICE_TO_TOOL_MAP[service_action] - tool_call = { + tool_call: ToolCall = { "tool_name": tool_name, "service_name": service, "tool_args": {"entity_id" if use_service_names else "name": device_dict["device_name"] if use_service_names else device_dict["description"]} @@ -368,7 +392,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_ "assistant_turns": assistant_turns } -def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False): +def generate_status_request(template: PileOfStatusRequestType, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False) -> Example | tuple[Example, PileOfDeviceType]: device_type: str = template["device_type"] state_name: str = template["state"] question_template: str = template["phrase"] @@ -446,7 +470,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev )) # gather a list of all available tools - available_tools = [] + available_tools: list[str] = [] for x in set(device_types + [device_type]): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) @@ -455,7 +479,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev assistant_turns = [create_assistant_turn(answer.lower(), [])] - result = { + result: Example = { "states": device_list, "available_tools": available_tools, "question": question.lower(), @@ -466,7 +490,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev else: return result -def generate_tool_failure_example(failure_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): +def generate_tool_failure_example(failure_case: PileOfFailedToolcallType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example: piles = get_dataset_piles(language) service_name = failure_case["service_name"] device_type = service_name.split(".")[0] @@ -490,7 +514,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st if device_type not in device_types: device_types.append(device_type) - available_tools = [] + available_tools: list[str] = [] for x in set(device_types): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) available_tools = list(dict.fromkeys(available_tools)) @@ -561,7 +585,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st "assistant_turns": [first_turn, second_turn, final_turn] } -def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): +def generate_refusal_example(refusal_case: PileOfRefusalsType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example: service_name = refusal_case["service_name"] device_type = service_name.split(".")[0] target_device = f"{device_type}.{refusal_case['device_name']}" @@ -582,7 +606,7 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma if device_type not in device_types: device_types.append(device_type) - available_tools = [] + available_tools: list[str] = [] for x in set(device_types): available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) available_tools = list(dict.fromkeys(available_tools)) @@ -718,15 +742,15 @@ def generate_sft_file( print("Generating...") - def run_factor_times(func, examples, data, persona, factor, language): + def run_factor_times(func: Callable[..., Example], examples: list[Example], data, persona: str, factor: int | float, language: str): if factor >= 1: - for i in range(factor): + for i in range(int(factor)): examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format)) else: if random.random() < factor: examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format)) - generated_examples = [] + generated_examples: list[Example] = [] missing_responses = set() @@ -773,7 +797,7 @@ def generate_sft_file( def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset_column_names, format_func): alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1) - home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" }) + home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" }) random.seed(seed) np.random.seed(seed) From c1b5d912d2e2483a805a190a11e4159dd8b11fba Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Wed, 31 Dec 2025 01:12:22 +0100 Subject: [PATCH 06/10] Annotate generate_random_parameter and get_random_response --- data/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/utils.py b/data/utils.py index 75e6747..3baee11 100644 --- a/data/utils.py +++ b/data/utils.py @@ -45,7 +45,7 @@ def get_included_vars(response: str): return ",".join(sorted(result)) -def generate_random_parameter(param_name, piles_of_data): +def generate_random_parameter(param_name: str, piles_of_data: "DatasetPiles"): RANDOM_PARAMETER_GENERATORS = { "rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), "brightness": lambda: random.randint(0, 100), @@ -68,7 +68,7 @@ def generate_random_parameter(param_name, piles_of_data): return param_generator() -def get_random_response(pile_of_responses, *, service: str, persona: str, question_template: str, short: bool) -> tuple[str, str]: +def get_random_response(pile_of_responses: pandas.DataFrame, *, service: str, persona: str, question_template: str, short: bool) -> tuple[str, str]: required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var])) From 242655af848d49e230faa550c1deaa2f5f2441da Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Wed, 31 Dec 2025 01:21:27 +0100 Subject: [PATCH 07/10] annotate random vars and small funcs --- data/devices.py | 2 +- data/generate_data.py | 6 +++--- data/utils.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/data/devices.py b/data/devices.py index e8f47eb..5f04410 100644 --- a/data/devices.py +++ b/data/devices.py @@ -273,7 +273,7 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language possible_choices.extend(local_device_names[device_type]) - device_types = set() + device_types: set[str] = set() device_list = [] device_lines: list[str] = [] # TODO: randomly pick attributes for this list diff --git a/data/generate_data.py b/data/generate_data.py index 54a3aa0..f98f268 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -197,8 +197,8 @@ def generate_static_example(action: PileOfSpecificActionType, persona: str, lang "assistant_turns": assistant_turns } -def replace_answer(list_of_answer, var, value): - new_list = [] +def replace_answer(list_of_answer: list[str], var: str, value: str): + new_list: list[str] = [] for answer in list_of_answer: new_list.append(answer.replace(var, value)) return new_list @@ -529,7 +529,7 @@ def generate_tool_failure_example(failure_case: PileOfFailedToolcallType, person response_starting = response_starting.replace("", friendly_name) response_confirmed = response_confirmed.replace("", friendly_name) - tool_args_extra = {} + tool_args_extra: dict[str, Any] = {} if device_type == "climate": if "" in question or "" in response_starting or "" in response_confirmed: temp_f = generate_random_parameter("temp_f", piles) diff --git a/data/utils.py b/data/utils.py index 3baee11..34c2377 100644 --- a/data/utils.py +++ b/data/utils.py @@ -14,8 +14,8 @@ class NoServicesAvailableException(Exception): pass -def closest_color(requested_color): - min_colors = {} +def closest_color(requested_color: tuple[int, int, int]): + min_colors: dict[int, str] = {} color_names = webcolors.names("css3") for name in color_names: From 5ceff59c6584ac16b859a2df18f06e7503f40aee Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Wed, 31 Dec 2025 01:35:36 +0100 Subject: [PATCH 08/10] move PileOfTypes to utils --- data/generate_data.py | 29 ++--------------------------- data/utils.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/data/generate_data.py b/data/generate_data.py index f98f268..ebfc11a 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -3,7 +3,7 @@ import json import numpy as np import random from datasets import load_dataset, concatenate_datasets -from typing import Any, Callable, TypedDict +from typing import Any, Callable from tqdm import tqdm import webcolors @@ -17,34 +17,9 @@ from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \ TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \ HASS_TOOLS, SERVICE_TOOLS from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT -from utils import PileOfDeviceType, PileOfFailedToolcallType, PileOfRefusalsType, PileOfSpecificActionType, PileOfStatusRequestType, PileOfTemplatedActionType, PileOfType, get_random_response, generate_random_parameter, closest_color, \ +from utils import AssistantTurn, DatasetEntry, Example, PileOfDeviceType, PileOfFailedToolcallType, PileOfRefusalsType, PileOfSpecificActionType, PileOfStatusRequestType, PileOfTemplatedActionType, ToolCall, ToolResult, get_random_response, generate_random_parameter, closest_color, \ get_dataset_piles, NoResponseAvailableException - -class ToolCall(TypedDict): - tool_name: str - service_name: str - tool_args: dict[str, Any] - - -class ToolResult(TypedDict): - tool_name: str - tool_result: str - -class AssistantTurn(TypedDict): - answer: str - tool_call_sequence: list[ToolCall] - tool_results: list[ToolResult] - train_on_turn: bool - - -class Example(TypedDict): - states: list[str] - available_tools: list[str] - question: str - assistant_turns: list[AssistantTurn] - - def create_assistant_turn(answer: str, tool_call_sequence: list[ToolCall] | None = None, *, tool_results: list[ToolResult] | None = None, train_on_turn: bool = True) -> AssistantTurn: """Bundle the assistant utterance with any tool interaction for that turn.""" return { diff --git a/data/utils.py b/data/utils.py index 34c2377..a3666cc 100644 --- a/data/utils.py +++ b/data/utils.py @@ -2,7 +2,7 @@ import random import re import os import csv -from typing import TypedDict +from typing import Any, TypedDict import pandas from datetime import datetime, timedelta import webcolors @@ -227,3 +227,33 @@ def get_dataset_piles(language: str) -> DatasetPiles: "lock","media_player", "climate", "vacuum", "timer", "todo", ], language) return _piles_cache[language] + + + +class ToolCall(TypedDict): + tool_name: str + service_name: str + tool_args: dict[str, Any] + + +class ToolResult(TypedDict): + tool_name: str + tool_result: str + +class AssistantTurn(TypedDict): + answer: str + tool_call_sequence: list[ToolCall] + tool_results: list[ToolResult] + train_on_turn: bool + + +class Example(TypedDict): + states: list[str] + available_tools: list[str] + question: str + assistant_turns: list[AssistantTurn] + + +class DatasetEntry(TypedDict): + messages: list[dict[str, Any]] + tools: list[dict[str, Any]] From 5a3d37c56ada4bd884b14561ece19b289c79620f Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Wed, 31 Dec 2025 01:48:44 +0100 Subject: [PATCH 09/10] Annotate format_example_sharegpt --- data/generate_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/data/generate_data.py b/data/generate_data.py index ebfc11a..edd3d35 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -598,7 +598,7 @@ def generate_refusal_example(refusal_case: PileOfRefusalsType, persona: str, lan "assistant_turns": assistant_turns } -def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format): +def format_example_sharegpt(example: Example, persona: str, language: str, use_system_role: bool, append_user_instruction_prompt: bool, use_service_names: bool, tool_response_format: str) -> DatasetEntry: piles = get_dataset_piles(language) sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts) question = example["question"] @@ -698,7 +698,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_ def generate_sft_file( filename: str, seed: int, - format_func: Callable, + format_func: Callable[[Example, str, str, bool, bool, bool, str], DatasetEntry], use_system_role: bool, append_user_instruction_prompt: bool, use_service_names: bool, @@ -717,7 +717,7 @@ def generate_sft_file( print("Generating...") - def run_factor_times(func: Callable[..., Example], examples: list[Example], data, persona: str, factor: int | float, language: str): + def run_factor_times(func: Callable[..., Example], examples: list[DatasetEntry], data, persona: str, factor: int | float, language: str): if factor >= 1: for i in range(int(factor)): examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format)) @@ -725,7 +725,7 @@ def generate_sft_file( if random.random() < factor: examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format)) - generated_examples: list[Example] = [] + generated_examples: list[DatasetEntry] = [] missing_responses = set() From f96ded3abb4de29bda69af38304bae8bb8c4e598 Mon Sep 17 00:00:00 2001 From: Evian Schlenz Date: Wed, 31 Dec 2025 01:48:47 +0100 Subject: [PATCH 10/10] Annotate generate_system_prompt --- data/prompting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/prompting.py b/data/prompting.py index f2041e9..ab724b5 100644 --- a/data/prompting.py +++ b/data/prompting.py @@ -1,6 +1,6 @@ import babel.dates -from utils import generate_random_datetime +from utils import Example, generate_random_datetime CURRENT_DATE_PROMPT = { "english": "The current time and date is", @@ -51,7 +51,7 @@ USER_INSTRUCTION_PROMPT = { } -def generate_system_prompt(example: dict, persona: str, language: str, pile_of_system_prompts: dict) -> str: +def generate_system_prompt(example: Example, persona: str, language: str, pile_of_system_prompts: dict[str, str]) -> str: sys_prompt = pile_of_system_prompts[persona] random_datetime = generate_random_datetime() translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])