make TypedDict types for piles

This commit is contained in:
Evian Schlenz
2025-12-30 18:31:47 +01:00
parent 831ef9bfca
commit b955897c18
2 changed files with 64 additions and 11 deletions

View File

@@ -189,7 +189,6 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
chosen_devices = [] chosen_devices = []
for device_type in template_device_types: for device_type in template_device_types:
device_dict = random.choice(piles.stacks_of_device_names[device_type]) device_dict = random.choice(piles.stacks_of_device_names[device_type])
device_dict["type"] = device_type
chosen_devices.append(device_dict) chosen_devices.append(device_dict)
device_list, device_types, extra_exposed_attributes = random_device_list( device_list, device_types, extra_exposed_attributes = random_device_list(

View File

@@ -2,6 +2,7 @@ import random
import re import re
import os import os
import csv import csv
from typing import TypedDict
import pandas import pandas
from datetime import datetime, timedelta from datetime import datetime, timedelta
import webcolors import webcolors
@@ -82,6 +83,58 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0] return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0]
class DeviceType(TypedDict):
device_name: str
description: str
type: str
class SpecificActionType(TypedDict):
service_name: str
device_name: str
phrase: str
class TemplatedActionType(TypedDict):
device_type: str
service: str
phrase: str
multiplier: int
class StatusRequestType(TypedDict):
device_type: str
state: str
phrase: str
assistant_response: str
class HallucinatedServiceType(TypedDict):
real_service: str
hallucinated_service: str
class FailedToolcallType(TypedDict):
service_name: str
correct_device_name: str
correct_friendly_name: str
bad_device_name: str
phrase: str
error_result: str
retry_prompt: str
class RefusalsType(TypedDict):
reason_type: str
service_name: str
device_name: str
friendly_name: str
desired_state: str
phrase: str
response: str
class DatasetPiles: class DatasetPiles:
def __init__(self, supported_devices, language="english"): def __init__(self, supported_devices, language="english"):
self.language = language self.language = language
@@ -93,7 +146,7 @@ class DatasetPiles:
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f: with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
self.pile_of_durations = { x["duration"]: x["name"] for x in reader } self.pile_of_durations: dict[str, str] = { x["duration"]: x["name"] for x in reader }
# media names are not translated # media names are not translated
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f: with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
@@ -102,14 +155,15 @@ class DatasetPiles:
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f: with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ] self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
self.stacks_of_device_names = { x: [] for x in supported_devices } self.stacks_of_device_names: dict[str, list[DeviceType]] = { x: [] for x in supported_devices }
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f: with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
pile_of_device_names = list(reader) pile_of_device_names = list(reader)
for device_dict in pile_of_device_names: for device_dict in pile_of_device_names:
try: try:
device_type = device_dict["device_name"].split(".")[0] device_type = device_dict["device_name"].split(".")[0]
self.stacks_of_device_names[device_type].append(device_dict) device_dict["type"] = device_type
self.stacks_of_device_names[device_type].append(device_dict) # type: ignore
except KeyError as ex: except KeyError as ex:
print(ex) print(ex)
@@ -125,41 +179,41 @@ class DatasetPiles:
for x in range(multiplier): for x in range(multiplier):
processed_pile_of_templated_actions.append(action) processed_pile_of_templated_actions.append(action)
self.pile_of_templated_actions = processed_pile_of_templated_actions self.pile_of_templated_actions: list[TemplatedActionType] = processed_pile_of_templated_actions
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f: with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
self.pile_of_specific_actions = list(reader) self.pile_of_specific_actions: list[SpecificActionType] = list(reader) # type: ignore
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv") self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars) self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f: with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
self.pile_of_status_requests = list(reader) self.pile_of_status_requests: list[StatusRequestType] = list(reader) # type: ignore
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f: with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader } self.pile_of_system_prompts: dict[str, str] = { line["persona"]: line["prompt"] for line in reader }
# service names are not translated # service names are not translated
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f: with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
self.pile_of_hallucinated_service_names = list(reader) self.pile_of_hallucinated_service_names: list[HallucinatedServiceType] = list(reader) # type: ignore
failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv" failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv"
self.pile_of_failed_tool_calls = [] self.pile_of_failed_tool_calls = []
if os.path.exists(failed_tool_calls_path): if os.path.exists(failed_tool_calls_path):
with open(failed_tool_calls_path, encoding="utf8") as f: with open(failed_tool_calls_path, encoding="utf8") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
self.pile_of_failed_tool_calls = list(reader) self.pile_of_failed_tool_calls: list[FailedToolcallType] = list(reader) # type: ignore
refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv" refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv"
self.pile_of_refusals = [] self.pile_of_refusals = []
if os.path.exists(refusals_path): if os.path.exists(refusals_path):
with open(refusals_path, encoding="utf8") as f: with open(refusals_path, encoding="utf8") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
self.pile_of_refusals = list(reader) self.pile_of_refusals: list[RefusalsType] = list(reader) # type: ignore
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)