mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
correctly type example generation functions
This commit is contained in:
@@ -3,7 +3,7 @@ import json
|
||||
import numpy as np
|
||||
import random
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, TypedDict
|
||||
from tqdm import tqdm
|
||||
import webcolors
|
||||
|
||||
@@ -17,10 +17,44 @@ from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \
|
||||
TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \
|
||||
HASS_TOOLS, SERVICE_TOOLS
|
||||
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
|
||||
from utils import get_random_response, generate_random_parameter, closest_color, \
|
||||
from utils import PileOfDeviceType, PileOfFailedToolcallType, PileOfRefusalsType, PileOfSpecificActionType, PileOfStatusRequestType, PileOfTemplatedActionType, PileOfType, get_random_response, generate_random_parameter, closest_color, \
|
||||
get_dataset_piles, NoResponseAvailableException
|
||||
|
||||
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
tool_name: str
|
||||
service_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class ToolResult(TypedDict):
|
||||
tool_name: str
|
||||
tool_result: str
|
||||
|
||||
class AssistantTurn(TypedDict):
|
||||
answer: str
|
||||
tool_call_sequence: list[ToolCall]
|
||||
tool_results: list[ToolResult]
|
||||
train_on_turn: bool
|
||||
|
||||
|
||||
class Example(TypedDict):
|
||||
states: list[str]
|
||||
available_tools: list[str]
|
||||
question: str
|
||||
assistant_turns: list[AssistantTurn]
|
||||
|
||||
|
||||
def create_assistant_turn(answer: str, tool_call_sequence: list[ToolCall] | None = None, *, tool_results: list[ToolResult] | None = None, train_on_turn: bool = True) -> AssistantTurn:
|
||||
"""Bundle the assistant utterance with any tool interaction for that turn."""
|
||||
return {
|
||||
"answer": answer,
|
||||
"tool_call_sequence": tool_call_sequence or [],
|
||||
"tool_results": tool_results if tool_results is not None else [],
|
||||
"train_on_turn": train_on_turn,
|
||||
}
|
||||
|
||||
def generate_static_example(action: PileOfSpecificActionType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example:
|
||||
question = action["phrase"]
|
||||
service_name = action["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
@@ -42,7 +76,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types + [device_type]):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
@@ -130,13 +164,13 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
tool_args["item"] = todo
|
||||
|
||||
if use_service_names:
|
||||
tool_call = {
|
||||
tool_call: ToolCall = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": {"entity_id": target_device, **tool_args}
|
||||
}
|
||||
else:
|
||||
tool_call = {
|
||||
tool_call: ToolCall = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": {"name": target_device, **tool_args}
|
||||
@@ -169,24 +203,14 @@ def replace_answer(list_of_answer, var, value):
|
||||
new_list.append(answer.replace(var, value))
|
||||
return new_list
|
||||
|
||||
|
||||
def create_assistant_turn(answer: str, tool_call_sequence=None, *, tool_results=None, train_on_turn: bool = True):
|
||||
"""Bundle the assistant utterance with any tool interaction for that turn."""
|
||||
return {
|
||||
"answer": answer,
|
||||
"tool_call_sequence": tool_call_sequence or [],
|
||||
"tool_results": tool_results if tool_results is not None else [],
|
||||
"train_on_turn": train_on_turn,
|
||||
}
|
||||
|
||||
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
def generate_templated_example(template: PileOfTemplatedActionType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example:
|
||||
template_device_types: list[str] = template["device_type"].split("|")
|
||||
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
||||
question_template: str = template["phrase"]
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
# choose a random device for this template
|
||||
chosen_devices = []
|
||||
chosen_devices: list[PileOfDeviceType] = []
|
||||
for device_type in template_device_types:
|
||||
device_dict = random.choice(piles.stacks_of_device_names[device_type])
|
||||
chosen_devices.append(device_dict)
|
||||
@@ -222,7 +246,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types + template_device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
@@ -263,11 +287,11 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
answer_list.append(f" {word} ".join(answers))
|
||||
|
||||
# generate the list of tool calls
|
||||
tool_calls = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
for device_dict, service in zip(chosen_devices, service_names):
|
||||
service_action = service.split(".")[1]
|
||||
tool_name = SERVICE_TO_TOOL_MAP[service_action]
|
||||
tool_call = {
|
||||
tool_call: ToolCall = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service,
|
||||
"tool_args": {"entity_id" if use_service_names else "name": device_dict["device_name"] if use_service_names else device_dict["description"]}
|
||||
@@ -368,7 +392,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
"assistant_turns": assistant_turns
|
||||
}
|
||||
|
||||
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
|
||||
def generate_status_request(template: PileOfStatusRequestType, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False) -> Example | tuple[Example, PileOfDeviceType]:
|
||||
device_type: str = template["device_type"]
|
||||
state_name: str = template["state"]
|
||||
question_template: str = template["phrase"]
|
||||
@@ -446,7 +470,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
))
|
||||
|
||||
# gather a list of all available tools
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types + [device_type]):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
|
||||
@@ -455,7 +479,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
|
||||
assistant_turns = [create_assistant_turn(answer.lower(), [])]
|
||||
|
||||
result = {
|
||||
result: Example = {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
@@ -466,7 +490,7 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
else:
|
||||
return result
|
||||
|
||||
def generate_tool_failure_example(failure_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
def generate_tool_failure_example(failure_case: PileOfFailedToolcallType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example:
|
||||
piles = get_dataset_piles(language)
|
||||
service_name = failure_case["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
@@ -490,7 +514,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
|
||||
if device_type not in device_types:
|
||||
device_types.append(device_type)
|
||||
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
@@ -561,7 +585,7 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
|
||||
"assistant_turns": [first_turn, second_turn, final_turn]
|
||||
}
|
||||
|
||||
def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
def generate_refusal_example(refusal_case: PileOfRefusalsType, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False) -> Example:
|
||||
service_name = refusal_case["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
target_device = f"{device_type}.{refusal_case['device_name']}"
|
||||
@@ -582,7 +606,7 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma
|
||||
if device_type not in device_types:
|
||||
device_types.append(device_type)
|
||||
|
||||
available_tools = []
|
||||
available_tools: list[str] = []
|
||||
for x in set(device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
@@ -718,15 +742,15 @@ def generate_sft_file(
|
||||
|
||||
print("Generating...")
|
||||
|
||||
def run_factor_times(func, examples, data, persona, factor, language):
|
||||
def run_factor_times(func: Callable[..., Example], examples: list[Example], data, persona: str, factor: int | float, language: str):
|
||||
if factor >= 1:
|
||||
for i in range(factor):
|
||||
for i in range(int(factor)):
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||
else:
|
||||
if random.random() < factor:
|
||||
examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format))
|
||||
|
||||
generated_examples = []
|
||||
generated_examples: list[Example] = []
|
||||
|
||||
missing_responses = set()
|
||||
|
||||
@@ -773,7 +797,7 @@ def generate_sft_file(
|
||||
|
||||
def merge_with_dataset(dataset_name, seed, output_name, format_function, dataset_column_names, format_func):
|
||||
alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1)
|
||||
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" })
|
||||
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.jsonl", "test": "home_assistant_test.jsonl" })
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user