mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
Merge pull request #337 from evschlenz/annotate_dataset_generation_scripts
Annotate dataset generation scripts
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from collections import defaultdict
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, Callable, List
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from tools import *
|
||||
from utils import closest_color, generate_random_parameter, get_dataset_piles
|
||||
from utils import PileOfDeviceType, closest_color, generate_random_parameter, get_dataset_piles
|
||||
|
||||
# STATES
|
||||
STATE_ON: Final = "on"
|
||||
@@ -40,7 +41,7 @@ class DeviceType:
|
||||
name: str
|
||||
possible_states: list[tuple[str, float]]
|
||||
|
||||
def get_random_state(self, language: str, extra_exposed_attributes=[]):
|
||||
def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None):
|
||||
states = [ x[0] for x in self.possible_states ]
|
||||
weights = [ x[1] for x in self.possible_states ]
|
||||
return random.choices(states, weights=weights, k=1)[0]
|
||||
@@ -64,7 +65,8 @@ class LightDeviceType(DeviceType):
|
||||
]
|
||||
)
|
||||
|
||||
def get_random_state(self, language: str, extra_exposed_attributes=[]):
|
||||
def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None):
|
||||
extra_exposed_attributes = extra_exposed_attributes or []
|
||||
state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
|
||||
@@ -171,8 +173,9 @@ class ClimateDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("climate", [])
|
||||
|
||||
def get_random_state(self, language: str, extra_exposed_attributes=[]):
|
||||
def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None):
|
||||
"""state;fan_mode;temperature;humidity"""
|
||||
extra_exposed_attributes = extra_exposed_attributes or []
|
||||
state = generate_random_parameter("hvac_mode", get_dataset_piles(language))
|
||||
|
||||
if "fan_mode" in extra_exposed_attributes:
|
||||
@@ -211,7 +214,8 @@ class MediaPlayerDeviceType(DeviceType):
|
||||
(STATE_BUFFERING, 0.01),
|
||||
])
|
||||
|
||||
def get_random_state(self, language: str, extra_exposed_attributes=[]):
|
||||
def get_random_state(self, language: str, extra_exposed_attributes: list[str] | None = None):
|
||||
extra_exposed_attributes = extra_exposed_attributes or []
|
||||
state = super().get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
|
||||
@@ -228,7 +232,7 @@ class MediaPlayerDeviceType(DeviceType):
|
||||
return tools
|
||||
|
||||
|
||||
SUPPORTED_DEVICES = {
|
||||
SUPPORTED_DEVICES: dict[str, DeviceType] = {
|
||||
"light": LightDeviceType(),
|
||||
"switch": SwitchDeviceType(),
|
||||
"fan": FanDeviceType(),
|
||||
@@ -264,14 +268,14 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language
|
||||
if avoid_type == "climate":
|
||||
avoid_climate = True
|
||||
|
||||
possible_choices = []
|
||||
possible_choices: list[PileOfDeviceType] = []
|
||||
for device_type in local_device_names.keys():
|
||||
possible_choices.extend(local_device_names[device_type])
|
||||
|
||||
|
||||
device_types = set()
|
||||
device_types: set[str] = set()
|
||||
device_list = []
|
||||
device_lines = []
|
||||
device_lines: list[str] = []
|
||||
# TODO: randomly pick attributes for this list
|
||||
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
|
||||
|
||||
@@ -301,4 +305,4 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language
|
||||
print(f"bad device name: {choice}")
|
||||
print(repr(ex))
|
||||
|
||||
return device_lines, list(device_types), list(extra_exposed_attributes)
|
||||
return device_lines, list(device_types), list(extra_exposed_attributes)
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import numpy as np
|
||||
import random
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
from tqdm import tqdm
|
||||
import webcolors
|
||||
|
||||
@@ -17,10 +17,19 @@ from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \
|
||||
TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \
|
||||
HASS_TOOLS, SERVICE_TOOLS
|
||||
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
|
||||
from utils import get_random_response, generate_random_parameter, closest_color, \
|
||||
from utils import AssistantTurn, DatasetEntry, Example, PileOfDeviceType, PileOfFailedToolcallType, PileOfRefusalsType, PileOfSpecificActionType, PileOfStatusRequestType, PileOfTemplatedActionType, ToolCall, ToolResult, get_random_response, generate_random_parameter, closest_color, \
|
||||
get_dataset_piles, NoResponseAvailableException
|
||||
|
||||
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
def create_assistant_turn(answer: str, tool_call_sequence: list[ToolCall] | None = None, *, tool_results: list[ToolResult] | None = None, train_on_turn: bool = True) -> AssistantTurn:
|
||||
"""Bundle the assistant utterance with any tool interaction for that turn."""
|
||||
return {
|
||||
"answer": answer,
|
||||
"tool_call_sequence": tool_call_sequence or [],
|
||||
"tool_results": tool_results if tool_results is not None else [],
|
||||
"train_on_turn": train_on_turn,
|
||||
}
|
||||
|
||||
def generate_static_example(action: PileOfSpecificActionType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example:
|
||||
question = action["phrase"]
|
||||
service_name = action["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
@@ -42,7 +51,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types + [device_type]):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
@@ -130,13 +139,13 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
tool_args["item"] = todo
|
||||
|
||||
if use_service_names:
|
||||
tool_call = {
|
||||
tool_call: ToolCall = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": {"entity_id": target_device, **tool_args}
|
||||
}
|
||||
else:
|
||||
tool_call = {
|
||||
tool_call: ToolCall = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": {"name": target_device, **tool_args}
|
||||
@@ -163,33 +172,22 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
"assistant_turns": assistant_turns
|
||||
}
|
||||
|
||||
def replace_answer(list_of_answer, var, value):
|
||||
new_list = []
|
||||
def replace_answer(list_of_answer: list[str], var: str, value: str):
|
||||
new_list: list[str] = []
|
||||
for answer in list_of_answer:
|
||||
new_list.append(answer.replace(var, value))
|
||||
return new_list
|
||||
|
||||
|
||||
def create_assistant_turn(answer: str, tool_call_sequence=None, *, tool_results=None, train_on_turn: bool = True):
|
||||
"""Bundle the assistant utterance with any tool interaction for that turn."""
|
||||
return {
|
||||
"answer": answer,
|
||||
"tool_call_sequence": tool_call_sequence or [],
|
||||
"tool_results": tool_results if tool_results is not None else [],
|
||||
"train_on_turn": train_on_turn,
|
||||
}
|
||||
|
||||
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
def generate_templated_example(template: PileOfTemplatedActionType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example:
|
||||
template_device_types: list[str] = template["device_type"].split("|")
|
||||
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
||||
question_template: str = template["phrase"]
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
# choose a random device for this template
|
||||
chosen_devices = []
|
||||
chosen_devices: list[PileOfDeviceType] = []
|
||||
for device_type in template_device_types:
|
||||
device_dict = random.choice(piles.stacks_of_device_names[device_type])
|
||||
device_dict["type"] = device_type
|
||||
chosen_devices.append(device_dict)
|
||||
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||
@@ -223,7 +221,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types + template_device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
@@ -264,11 +262,11 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
answer_list.append(f" {word} ".join(answers))
|
||||
|
||||
# generate the list of tool calls
|
||||
tool_calls = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
for device_dict, service in zip(chosen_devices, service_names):
|
||||
service_action = service.split(".")[1]
|
||||
tool_name = SERVICE_TO_TOOL_MAP[service_action]
|
||||
tool_call = {
|
||||
tool_call: ToolCall = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service,
|
||||
"tool_args": {"entity_id" if use_service_names else "name": device_dict["device_name"] if use_service_names else device_dict["description"]}
|
||||
@@ -369,7 +367,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
"assistant_turns": assistant_turns
|
||||
}
|
||||
|
||||
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
|
||||
def generate_status_request(template: PileOfStatusRequestType, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False) -> Example | tuple[Example, PileOfDeviceType]:
|
||||
device_type: str = template["device_type"]
|
||||
state_name: str = template["state"]
|
||||
question_template: str = template["phrase"]
|
||||
@@ -447,7 +445,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types + [device_type]):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
@@ -456,7 +454,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
|
||||
assistant_turns = [create_assistant_turn(answer.lower(), [])]
|
||||
|
||||
result = {
|
||||
result: Example = {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
@@ -467,7 +465,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
else:
|
||||
return result
|
||||
|
||||
def generate_tool_failure_example(failure_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
def generate_tool_failure_example(failure_case: PileOfFailedToolcallType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example:
|
||||
piles = get_dataset_piles(language)
|
||||
service_name = failure_case["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
@@ -491,7 +489,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
|
||||
if device_type not in device_types:
|
||||
device_types.append(device_type)
|
||||
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
@@ -506,7 +504,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
|
||||
response_starting = response_starting.replace("<device_name>", friendly_name)
|
||||
response_confirmed = response_confirmed.replace("<device_name>", friendly_name)
|
||||
|
||||
tool_args_extra = {}
|
||||
tool_args_extra: dict[str, Any] = {}
|
||||
if device_type == "climate":
|
||||
if "<temp_f>" in question or "<temp_f>" in response_starting or "<temp_f>" in response_confirmed:
|
||||
temp_f = generate_random_parameter("temp_f", piles)
|
||||
@@ -562,7 +560,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
|
||||
"assistant_turns": [first_turn, second_turn, final_turn]
|
||||
}
|
||||
|
||||
def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
def generate_refusal_example(refusal_case: PileOfRefusalsType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example:
|
||||
service_name = refusal_case["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
target_device = f"{device_type}.{refusal_case['device_name']}"
|
||||
@@ -583,7 +581,7 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma
|
||||
if device_type not in device_types:
|
||||
device_types.append(device_type)
|
||||
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
@@ -600,7 +598,7 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma
|
||||
"assistant_turns": assistant_turns
|
||||
}
|
||||
|
||||
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
|
||||
def format_example_sharegpt(example: Example, persona: str, language: str, use_system_role: bool, append_user_instruction_prompt: bool, use_service_names: bool, tool_response_format: str) -> DatasetEntry:
|
||||
piles = get_dataset_piles(language)
|
||||
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
||||
question = example["question"]
|
||||
@@ -700,7 +698,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
||||
def generate_sft_file(
|
||||
filename: str,
|
||||
seed: int,
|
||||
format_func: Callable,
|
||||
format_func: Callable[[Example, str, str, bool, bool, bool, str], DatasetEntry],
|
||||
use_system_role: bool,
|
||||
append_user_instruction_prompt: bool,
|
||||
use_service_names: bool,
|
||||
@@ -719,15 +717,15 @@ def generate_sft_file(
|
||||
|
||||
print("Generating...")
|
||||
|
||||
def run_factor_times(func, examples, data, persona, factor, language):
|
||||
def run_factor_times(func: Callable[..., Example], examples: list[DatasetEntry], data, persona: str, factor: int | float, language: str):
|
||||
if factor >= 1:
|
||||
for i in range(factor):
|
||||
for i in range(int(factor)):
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||
else:
|
||||
if random.random() < factor:
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||
|
||||
generated_examples = []
|
||||
generated_examples: list[DatasetEntry] = []
|
||||
|
||||
missing_responses = set()
|
||||
|
||||
@@ -774,7 +772,7 @@ def generate_sft_file(
|
||||
|
||||
def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset_column_names, format_func):
|
||||
alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1)
|
||||
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" })
|
||||
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" })
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import babel.dates
|
||||
|
||||
from utils import generate_random_datetime
|
||||
from utils import Example, generate_random_datetime
|
||||
|
||||
CURRENT_DATE_PROMPT = {
|
||||
"english": "The current time and date is",
|
||||
@@ -51,7 +51,7 @@ USER_INSTRUCTION_PROMPT = {
|
||||
}
|
||||
|
||||
|
||||
def generate_system_prompt(example: dict, persona: str, language: str, pile_of_system_prompts: dict) -> str:
|
||||
def generate_system_prompt(example: Example, persona: str, language: str, pile_of_system_prompts: dict[str, str]) -> str:
|
||||
sys_prompt = pile_of_system_prompts[persona]
|
||||
random_datetime = generate_random_datetime()
|
||||
translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])
|
||||
|
||||
112
data/utils.py
112
data/utils.py
@@ -2,6 +2,7 @@ import random
|
||||
import re
|
||||
import os
|
||||
import csv
|
||||
from typing import Any, TypedDict
|
||||
import pandas
|
||||
from datetime import datetime, timedelta
|
||||
import webcolors
|
||||
@@ -13,8 +14,8 @@ class NoServicesAvailableException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def closest_color(requested_color):
|
||||
min_colors = {}
|
||||
def closest_color(requested_color: tuple[int, int, int]):
|
||||
min_colors: dict[int, str] = {}
|
||||
color_names = webcolors.names("css3")
|
||||
|
||||
for name in color_names:
|
||||
@@ -44,7 +45,7 @@ def get_included_vars(response: str):
|
||||
|
||||
return ",".join(sorted(result))
|
||||
|
||||
def generate_random_parameter(param_name, piles_of_data):
|
||||
def generate_random_parameter(param_name: str, piles_of_data: "DatasetPiles"):
|
||||
RANDOM_PARAMETER_GENERATORS = {
|
||||
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
|
||||
"brightness": lambda: random.randint(0, 100),
|
||||
@@ -67,7 +68,7 @@ def generate_random_parameter(param_name, piles_of_data):
|
||||
|
||||
return param_generator()
|
||||
|
||||
def get_random_response(pile_of_responses, *, service: str, persona: str, question_template: str, short: bool) -> tuple[str, str]:
|
||||
def get_random_response(pile_of_responses: pandas.DataFrame, *, service: str, persona: str, question_template: str, short: bool) -> tuple[str, str]:
|
||||
|
||||
required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var]))
|
||||
|
||||
@@ -82,6 +83,58 @@ def get_random_response(pile_of_responses, *, service: str, persona: str, questi
|
||||
|
||||
return possible_results.sample()["response_starting"].values[0], possible_results.sample()["response_confirmed"].values[0]
|
||||
|
||||
|
||||
class PileOfDeviceType(TypedDict):
|
||||
device_name: str
|
||||
description: str
|
||||
type: str
|
||||
|
||||
|
||||
class PileOfSpecificActionType(TypedDict):
|
||||
service_name: str
|
||||
device_name: str
|
||||
phrase: str
|
||||
|
||||
|
||||
class PileOfTemplatedActionType(TypedDict):
|
||||
device_type: str
|
||||
service: str
|
||||
phrase: str
|
||||
multiplier: int
|
||||
|
||||
|
||||
class PileOfStatusRequestType(TypedDict):
|
||||
device_type: str
|
||||
state: str
|
||||
phrase: str
|
||||
assistant_response: str
|
||||
|
||||
|
||||
class PileOfHallucinatedServiceType(TypedDict):
|
||||
real_service: str
|
||||
hallucinated_service: str
|
||||
|
||||
|
||||
class PileOfFailedToolcallType(TypedDict):
|
||||
service_name: str
|
||||
correct_device_name: str
|
||||
correct_friendly_name: str
|
||||
bad_device_name: str
|
||||
phrase: str
|
||||
error_result: str
|
||||
retry_prompt: str
|
||||
|
||||
|
||||
class PileOfRefusalsType(TypedDict):
|
||||
reason_type: str
|
||||
service_name: str
|
||||
device_name: str
|
||||
friendly_name: str
|
||||
desired_state: str
|
||||
phrase: str
|
||||
response: str
|
||||
|
||||
|
||||
class DatasetPiles:
|
||||
def __init__(self, supported_devices, language="english"):
|
||||
self.language = language
|
||||
@@ -93,7 +146,7 @@ class DatasetPiles:
|
||||
|
||||
with open(f"{cwd}/piles/{language}/pile_of_durations.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_durations = { x["duration"]: x["name"] for x in reader }
|
||||
self.pile_of_durations: dict[str, str] = { x["duration"]: x["name"] for x in reader }
|
||||
|
||||
# media names are not translated
|
||||
with open(f"{cwd}/piles/english/pile_of_media_names.txt", encoding="utf8") as f:
|
||||
@@ -102,14 +155,15 @@ class DatasetPiles:
|
||||
with open(f"{cwd}/piles/{language}/pile_of_todo_items.txt", encoding="utf8") as f:
|
||||
self.pile_of_todo_items = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
self.stacks_of_device_names = { x: [] for x in supported_devices }
|
||||
self.stacks_of_device_names: dict[str, list[PileOfDeviceType]] = { x: [] for x in supported_devices }
|
||||
with open(f"{cwd}/piles/{language}/pile_of_device_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
pile_of_device_names = list(reader)
|
||||
for device_dict in pile_of_device_names:
|
||||
try:
|
||||
device_type = device_dict["device_name"].split(".")[0]
|
||||
self.stacks_of_device_names[device_type].append(device_dict)
|
||||
device_dict["type"] = device_type
|
||||
self.stacks_of_device_names[device_type].append(device_dict) # type: ignore
|
||||
except KeyError as ex:
|
||||
print(ex)
|
||||
|
||||
@@ -125,41 +179,41 @@ class DatasetPiles:
|
||||
for x in range(multiplier):
|
||||
processed_pile_of_templated_actions.append(action)
|
||||
|
||||
self.pile_of_templated_actions = processed_pile_of_templated_actions
|
||||
self.pile_of_templated_actions: list[PileOfTemplatedActionType] = processed_pile_of_templated_actions
|
||||
|
||||
with open(f"{cwd}/piles/{language}/pile_of_specific_actions.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_specific_actions = list(reader)
|
||||
self.pile_of_specific_actions: list[PileOfSpecificActionType] = list(reader) # type: ignore
|
||||
|
||||
self.pile_of_responses = pandas.read_csv(f"{cwd}/piles/{language}/pile_of_responses.csv")
|
||||
self.pile_of_responses["contains_vars"] = self.pile_of_responses["response_starting"].apply(get_included_vars)
|
||||
|
||||
with open(f"{cwd}/piles/{language}/pile_of_status_requests.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_status_requests = list(reader)
|
||||
self.pile_of_status_requests: list[PileOfStatusRequestType] = list(reader) # type: ignore
|
||||
|
||||
with open(f"{cwd}/piles/{language}/pile_of_system_prompts.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_system_prompts = { line["persona"]: line["prompt"] for line in reader }
|
||||
self.pile_of_system_prompts: dict[str, str] = { line["persona"]: line["prompt"] for line in reader }
|
||||
|
||||
# service names are not translated
|
||||
with open(f"{cwd}/piles/english/pile_of_hallucinated_service_names.csv", encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_hallucinated_service_names = list(reader)
|
||||
self.pile_of_hallucinated_service_names: list[PileOfHallucinatedServiceType] = list(reader) # type: ignore
|
||||
|
||||
failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv"
|
||||
self.pile_of_failed_tool_calls = []
|
||||
if os.path.exists(failed_tool_calls_path):
|
||||
with open(failed_tool_calls_path, encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_failed_tool_calls = list(reader)
|
||||
self.pile_of_failed_tool_calls: list[PileOfFailedToolcallType] = list(reader) # type: ignore
|
||||
|
||||
refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv"
|
||||
self.pile_of_refusals = []
|
||||
if os.path.exists(refusals_path):
|
||||
with open(refusals_path, encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_refusals = list(reader)
|
||||
self.pile_of_refusals: list[PileOfRefusalsType] = list(reader) # type: ignore
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
@@ -173,3 +227,33 @@ def get_dataset_piles(language: str) -> DatasetPiles:
|
||||
"lock","media_player", "climate", "vacuum", "timer", "todo",
|
||||
], language)
|
||||
return _piles_cache[language]
|
||||
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
tool_name: str
|
||||
service_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class ToolResult(TypedDict):
|
||||
tool_name: str
|
||||
tool_result: str
|
||||
|
||||
class AssistantTurn(TypedDict):
|
||||
answer: str
|
||||
tool_call_sequence: list[ToolCall]
|
||||
tool_results: list[ToolResult]
|
||||
train_on_turn: bool
|
||||
|
||||
|
||||
class Example(TypedDict):
|
||||
states: list[str]
|
||||
available_tools: list[str]
|
||||
question: str
|
||||
assistant_turns: list[AssistantTurn]
|
||||
|
||||
|
||||
class DatasetEntry(TypedDict):
|
||||
messages: list[dict[str, Any]]
|
||||
tools: list[dict[str, Any]]
|
||||
|
||||
Reference in New Issue
Block a user