mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
make TypedDict types for piles
This commit is contained in:
@@ -189,7 +189,6 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
|||||||
chosen_devices = []
|
chosen_devices = []
|
||||||
for device_type in template_device_types:
|
for device_type in template_device_types:
|
||||||
device_dict = random.choice(piles.stacks_of_device_names[device_type])
|
device_dict = random.choice(piles.stacks_of_device_names[device_type])
|
||||||
device_dict["type"] = device_type
|
|
||||||
chosen_devices.append(device_dict)
|
chosen_devices.append(device_dict)
|
||||||
|
|
||||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import csv
|
import csv
|
||||||
|
from typing import TypedDict
|
||||||
import pandas
|
import pandas
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import webcolors
|
import webcolors
|
||||||
@@ -82,6 +83,58 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
|
|||||||
|
|
||||||
return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0]
|
return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0]
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceType(TypedDict):
|
||||||
|
device_name: str
|
||||||
|
description: str
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
class SpecificActionType(TypedDict):
|
||||||
|
service_name: str
|
||||||
|
device_name: str
|
||||||
|
phrase: str
|
||||||
|
|
||||||
|
|
||||||
|
class TemplatedActionType(TypedDict):
|
||||||
|
device_type: str
|
||||||
|
service: str
|
||||||
|
phrase: str
|
||||||
|
multiplier: int
|
||||||
|
|
||||||
|
|
||||||
|
class StatusRequestType(TypedDict):
|
||||||
|
device_type: str
|
||||||
|
state: str
|
||||||
|
phrase: str
|
||||||
|
assistant_response: str
|
||||||
|
|
||||||
|
|
||||||
|
class HallucinatedServiceType(TypedDict):
|
||||||
|
real_service: str
|
||||||
|
hallucinated_service: str
|
||||||
|
|
||||||
|
|
||||||
|
class FailedToolcallType(TypedDict):
|
||||||
|
service_name: str
|
||||||
|
correct_device_name: str
|
||||||
|
correct_friendly_name: str
|
||||||
|
bad_device_name: str
|
||||||
|
phrase: str
|
||||||
|
error_result: str
|
||||||
|
retry_prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class RefusalsType(TypedDict):
|
||||||
|
reason_type: str
|
||||||
|
service_name: str
|
||||||
|
device_name: str
|
||||||
|
friendly_name: str
|
||||||
|
desired_state: str
|
||||||
|
phrase: str
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
class DatasetPiles:
|
class DatasetPiles:
|
||||||
def __init__(self, supported_devices, language="english"):
|
def __init__(self, supported_devices, language="english"):
|
||||||
self.language = language
|
self.language = language
|
||||||
@@ -93,7 +146,7 @@ class DatasetPiles:
|
|||||||
|
|
||||||
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
|
self.pile_of_durations: dict[str, str] = { x["duration"]: x["name"] for x in reader }
|
||||||
|
|
||||||
# media names are not translated
|
# media names are not translated
|
||||||
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||||
@@ -102,14 +155,15 @@ class DatasetPiles:
|
|||||||
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||||
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
||||||
|
|
||||||
self.stacks_of_device_names = { x: [] for x in supported_devices }
|
self.stacks_of_device_names: dict[str, list[DeviceType]] = { x: [] for x in supported_devices }
|
||||||
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
pile_of_device_names = list(reader)
|
pile_of_device_names = list(reader)
|
||||||
for device_dict in pile_of_device_names:
|
for device_dict in pile_of_device_names:
|
||||||
try:
|
try:
|
||||||
device_type = device_dict["device_name"].split(".")[0]
|
device_type = device_dict["device_name"].split(".")[0]
|
||||||
self.stacks_of_device_names[device_type].append(device_dict)
|
device_dict["type"] = device_type
|
||||||
|
self.stacks_of_device_names[device_type].append(device_dict) # type: ignore
|
||||||
except KeyError as ex:
|
except KeyError as ex:
|
||||||
print(ex)
|
print(ex)
|
||||||
|
|
||||||
@@ -125,41 +179,41 @@ class DatasetPiles:
|
|||||||
for x in range(multiplier):
|
for x in range(multiplier):
|
||||||
processed_pile_of_templated_actions.append(action)
|
processed_pile_of_templated_actions.append(action)
|
||||||
|
|
||||||
self.pile_of_templated_actions = processed_pile_of_templated_actions
|
self.pile_of_templated_actions: list[TemplatedActionType] = processed_pile_of_templated_actions
|
||||||
|
|
||||||
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_specific_actions = list(reader)
|
self.pile_of_specific_actions: list[SpecificActionType] = list(reader) # type: ignore
|
||||||
|
|
||||||
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
|
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
|
||||||
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
|
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
|
||||||
|
|
||||||
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_status_requests = list(reader)
|
self.pile_of_status_requests: list[StatusRequestType] = list(reader) # type: ignore
|
||||||
|
|
||||||
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
|
self.pile_of_system_prompts: dict[str, str] = { line["persona"]: line["prompt"] for line in reader }
|
||||||
|
|
||||||
# service names are not translated
|
# service names are not translated
|
||||||
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_hallucinated_service_names = list(reader)
|
self.pile_of_hallucinated_service_names: list[HallucinatedServiceType] = list(reader) # type: ignore
|
||||||
|
|
||||||
failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv"
|
failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv"
|
||||||
self.pile_of_failed_tool_calls = []
|
self.pile_of_failed_tool_calls = []
|
||||||
if os.path.exists(failed_tool_calls_path):
|
if os.path.exists(failed_tool_calls_path):
|
||||||
with open(failed_tool_calls_path, encoding="utf8") as f:
|
with open(failed_tool_calls_path, encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_failed_tool_calls = list(reader)
|
self.pile_of_failed_tool_calls: list[FailedToolcallType] = list(reader) # type: ignore
|
||||||
|
|
||||||
refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv"
|
refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv"
|
||||||
self.pile_of_refusals = []
|
self.pile_of_refusals = []
|
||||||
if os.path.exists(refusals_path):
|
if os.path.exists(refusals_path):
|
||||||
with open(refusals_path, encoding="utf8") as f:
|
with open(refusals_path, encoding="utf8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
self.pile_of_refusals = list(reader)
|
self.pile_of_refusals: list[RefusalsType] = list(reader) # type: ignore
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
|
|||||||
Reference in New Issue
Block a user