mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
more synthesizing scenarios + clean up example formatting
This commit is contained in:
@@ -67,6 +67,26 @@ The response pile is a CSV with the following headers: `service,response,languag
|
|||||||
|
|
||||||
Generating the full dataset using the python script will print out a warning for any responses that are missing for a persona.
|
Generating the full dataset using the python script will print out a warning for any responses that are missing for a persona.
|
||||||
|
|
||||||
|
## Synthesizing new pile data
|
||||||
|
You can quickly append fresh examples to the CSV piles without editing them manually by running `synthesize.py`. The script talks to the configured LLM and writes the generated rows directly into the per-language pile files.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Append 25 failed tool-call recoveries and 25 refusals in Spanish
|
||||||
|
python3 synthesize.py --language spanish --model gpt-oss-120b --failed-tool-calls 25 --refusals 25 --concurrency 6
|
||||||
|
|
||||||
|
# Generate new actions plus matching refusal samples in German
|
||||||
|
python3 synthesize.py --language german --actions 100 --refusals 40 --model gpt-oss-120b
|
||||||
|
```
|
||||||
|
|
||||||
|
Useful flags:
|
||||||
|
- `--failed-tool-calls`: number of `pile_of_failed_tool_calls.csv` rows to synthesize.
|
||||||
|
- `--refusals`: number of `pile_of_refusals.csv` rows to synthesize.
|
||||||
|
- `--actions`, `--status`, `--devices`: existing knobs for the other piles.
|
||||||
|
|
||||||
|
The script automatically routes generations to the correct language-specific pile under `data/piles/<language>/`.
|
||||||
|
|
||||||
## Adding new Home Assistant functionality
|
## Adding new Home Assistant functionality
|
||||||
TODO
|
TODO
|
||||||
<!-- In order to add new home assistant device types, you will need to add data to a handful of piles, as well as make small modifications to the `generate_data.py` script.
|
<!-- In order to add new home assistant device types, you will need to add data to a handful of piles, as well as make small modifications to the `generate_data.py` script.
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Final, Callable, List
|
from typing import Final, Callable, List
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
from tools import *
|
from tools import *
|
||||||
from utils import closest_color, generate_random_parameter, get_dataset_piles
|
from utils import closest_color, generate_random_parameter, get_dataset_piles
|
||||||
@@ -31,6 +32,9 @@ STATE_CLEANING: Final = "cleaning"
|
|||||||
STATE_DOCKED: Final = "docked"
|
STATE_DOCKED: Final = "docked"
|
||||||
STATE_RETURNING: Final = "returning"
|
STATE_RETURNING: Final = "returning"
|
||||||
|
|
||||||
|
def format_device_line(*, device_name: str, friendly_name: str, state: str):
|
||||||
|
return (f"{device_name} '{friendly_name}' = {state}")
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeviceType:
|
class DeviceType:
|
||||||
name: str
|
name: str
|
||||||
@@ -222,3 +226,79 @@ class MediaPlayerDeviceType(DeviceType):
|
|||||||
if "volume_level" in extra_exposed_attributes:
|
if "volume_level" in extra_exposed_attributes:
|
||||||
tools.append(TOOL_SET_VOLUME)
|
tools.append(TOOL_SET_VOLUME)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_DEVICES = {
|
||||||
|
"light": LightDeviceType(),
|
||||||
|
"switch": SwitchDeviceType(),
|
||||||
|
"fan": FanDeviceType(),
|
||||||
|
"garage_door": GarageDoorDeviceType(),
|
||||||
|
"blinds": BlindsDeviceType(),
|
||||||
|
"lock": LockDeviceType(),
|
||||||
|
"media_player": MediaPlayerDeviceType(),
|
||||||
|
"climate": ClimateDeviceType(),
|
||||||
|
"vacuum": VacuumDeviceType(),
|
||||||
|
"timer": TimerDeviceType(),
|
||||||
|
"todo": TodoDeviceType(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# generate a random list of devices for the context
|
||||||
|
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
|
||||||
|
num_devices = random.randint(2, max_devices)
|
||||||
|
piles = get_dataset_piles(language)
|
||||||
|
|
||||||
|
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
|
||||||
|
|
||||||
|
avoid_climate = False
|
||||||
|
for avoid_device in avoid_device_names:
|
||||||
|
avoid_type = avoid_device.split(".")[0]
|
||||||
|
|
||||||
|
filtered_possible_devices = []
|
||||||
|
for possible_device in local_device_names[avoid_type]:
|
||||||
|
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
|
||||||
|
|
||||||
|
if similarity_ratio < 0.4:
|
||||||
|
filtered_possible_devices.append(possible_device)
|
||||||
|
local_device_names[avoid_type] = filtered_possible_devices
|
||||||
|
|
||||||
|
if avoid_type == "climate":
|
||||||
|
avoid_climate = True
|
||||||
|
|
||||||
|
possible_choices = []
|
||||||
|
for device_type in local_device_names.keys():
|
||||||
|
possible_choices.extend(local_device_names[device_type])
|
||||||
|
|
||||||
|
|
||||||
|
device_types = set()
|
||||||
|
device_list = []
|
||||||
|
device_lines = []
|
||||||
|
# TODO: randomly pick attributes for this list
|
||||||
|
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
|
||||||
|
|
||||||
|
while len(device_list) < num_devices:
|
||||||
|
choice = random.choice(possible_choices)
|
||||||
|
if choice["device_name"] in device_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
device_name = choice["device_name"]
|
||||||
|
device_type = device_name.split(".")[0]
|
||||||
|
friendly_name = choice["description"]
|
||||||
|
|
||||||
|
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
|
||||||
|
if avoid_climate and device_type == "climate":
|
||||||
|
continue
|
||||||
|
|
||||||
|
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||||
|
device_lines.append(format_device_line(
|
||||||
|
device_name=device_name,
|
||||||
|
friendly_name=friendly_name,
|
||||||
|
state=state
|
||||||
|
))
|
||||||
|
device_list.append(device_name)
|
||||||
|
device_types.add(device_type)
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"bad device name: {choice}")
|
||||||
|
print(repr(ex))
|
||||||
|
|
||||||
|
return device_lines, list(device_types), list(extra_exposed_attributes)
|
||||||
@@ -3,7 +3,6 @@ import json
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
from datasets import load_dataset, concatenate_datasets
|
from datasets import load_dataset, concatenate_datasets
|
||||||
from difflib import SequenceMatcher
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import webcolors
|
import webcolors
|
||||||
@@ -13,87 +12,13 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from device_types import *
|
from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \
|
||||||
|
TOOL_TURN_ON, TOOL_CLIMATE_SET_TEMPERATURE, TOOL_SET_HUMIDITY, \
|
||||||
|
TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \
|
||||||
|
HASS_TOOLS, SERVICE_TOOLS
|
||||||
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
|
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
|
||||||
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException
|
from utils import get_random_response, generate_random_parameter, closest_color, \
|
||||||
|
get_dataset_piles, NoResponseAvailableException
|
||||||
SUPPORTED_DEVICES = {
|
|
||||||
"light": LightDeviceType(),
|
|
||||||
"switch": SwitchDeviceType(),
|
|
||||||
"fan": FanDeviceType(),
|
|
||||||
"garage_door": GarageDoorDeviceType(),
|
|
||||||
"blinds": BlindsDeviceType(),
|
|
||||||
"lock": LockDeviceType(),
|
|
||||||
"media_player": MediaPlayerDeviceType(),
|
|
||||||
"climate": ClimateDeviceType(),
|
|
||||||
"vacuum": VacuumDeviceType(),
|
|
||||||
"timer": TimerDeviceType(),
|
|
||||||
"todo": TodoDeviceType(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def format_device_line(*, device_name: str, friendly_name: str, state: str):
|
|
||||||
return (f"{device_name} '{friendly_name}' = {state}")
|
|
||||||
|
|
||||||
# generate a random list of devices for the context
|
|
||||||
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
|
|
||||||
num_devices = random.randint(2, max_devices)
|
|
||||||
piles = get_dataset_piles(language)
|
|
||||||
|
|
||||||
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
|
|
||||||
|
|
||||||
avoid_climate = False
|
|
||||||
for avoid_device in avoid_device_names:
|
|
||||||
avoid_type = avoid_device.split(".")[0]
|
|
||||||
|
|
||||||
filtered_possible_devices = []
|
|
||||||
for possible_device in local_device_names[avoid_type]:
|
|
||||||
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
|
|
||||||
|
|
||||||
if similarity_ratio < 0.4:
|
|
||||||
filtered_possible_devices.append(possible_device)
|
|
||||||
local_device_names[avoid_type] = filtered_possible_devices
|
|
||||||
|
|
||||||
if avoid_type == "climate":
|
|
||||||
avoid_climate = True
|
|
||||||
|
|
||||||
possible_choices = []
|
|
||||||
for device_type in local_device_names.keys():
|
|
||||||
possible_choices.extend(local_device_names[device_type])
|
|
||||||
|
|
||||||
|
|
||||||
device_types = set()
|
|
||||||
device_list = []
|
|
||||||
device_lines = []
|
|
||||||
# TODO: randomly pick attributes for this list
|
|
||||||
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
|
|
||||||
|
|
||||||
while len(device_list) < num_devices:
|
|
||||||
choice = random.choice(possible_choices)
|
|
||||||
if choice["device_name"] in device_list:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
device_name = choice["device_name"]
|
|
||||||
device_type = device_name.split(".")[0]
|
|
||||||
friendly_name = choice["description"]
|
|
||||||
|
|
||||||
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
|
|
||||||
if avoid_climate and device_type == "climate":
|
|
||||||
continue
|
|
||||||
|
|
||||||
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
|
||||||
device_lines.append(format_device_line(
|
|
||||||
device_name=device_name,
|
|
||||||
friendly_name=friendly_name,
|
|
||||||
state=state
|
|
||||||
))
|
|
||||||
device_list.append(device_name)
|
|
||||||
device_types.add(device_type)
|
|
||||||
except Exception as ex:
|
|
||||||
print(f"bad device name: {choice}")
|
|
||||||
print(repr(ex))
|
|
||||||
|
|
||||||
return device_lines, list(device_types), list(extra_exposed_attributes)
|
|
||||||
|
|
||||||
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||||
question = action["phrase"]
|
question = action["phrase"]
|
||||||
@@ -126,7 +51,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
|||||||
|
|
||||||
# Map service name to tool name
|
# Map service name to tool name
|
||||||
service_action = service_name.split(".")[1]
|
service_action = service_name.split(".")[1]
|
||||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
tool_name = SERVICE_TO_TOOL_MAP[service_action]
|
||||||
|
|
||||||
response_starting, response_confirmed = get_random_response(
|
response_starting, response_confirmed = get_random_response(
|
||||||
piles.pile_of_responses,
|
piles.pile_of_responses,
|
||||||
@@ -225,13 +150,17 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to parse arguments for {action}: {e}")
|
print(f"Failed to parse arguments for {action}: {e}")
|
||||||
|
|
||||||
|
final_answer = " ".join(answer_list)
|
||||||
|
assistant_turns = [
|
||||||
|
create_assistant_turn(response_starting, [tool_call]),
|
||||||
|
create_assistant_turn(final_answer, [])
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_tools": available_tools,
|
"available_tools": available_tools,
|
||||||
"question": question.lower(),
|
"question": question.lower(),
|
||||||
"answers": answer_list,
|
"assistant_turns": assistant_turns
|
||||||
"answer_starting": response_starting,
|
|
||||||
"tool_calls": [ tool_call ]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def replace_answer(list_of_answer, var, value):
|
def replace_answer(list_of_answer, var, value):
|
||||||
@@ -240,6 +169,16 @@ def replace_answer(list_of_answer, var, value):
|
|||||||
new_list.append(answer.replace(var, value))
|
new_list.append(answer.replace(var, value))
|
||||||
return new_list
|
return new_list
|
||||||
|
|
||||||
|
|
||||||
|
def create_assistant_turn(answer: str, tool_call_sequence=None, *, tool_results=None, train_on_turn: bool = True):
|
||||||
|
"""Bundle the assistant utterance with any tool interaction for that turn."""
|
||||||
|
return {
|
||||||
|
"answer": answer,
|
||||||
|
"tool_call_sequence": tool_call_sequence or [],
|
||||||
|
"tool_results": tool_results if tool_results is not None else [],
|
||||||
|
"train_on_turn": train_on_turn,
|
||||||
|
}
|
||||||
|
|
||||||
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||||
template_device_types: list[str] = template["device_type"].split("|")
|
template_device_types: list[str] = template["device_type"].split("|")
|
||||||
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
||||||
@@ -328,7 +267,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
|||||||
tool_calls = []
|
tool_calls = []
|
||||||
for device_dict, service in zip(chosen_devices, service_names):
|
for device_dict, service in zip(chosen_devices, service_names):
|
||||||
service_action = service.split(".")[1]
|
service_action = service.split(".")[1]
|
||||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
tool_name = SERVICE_TO_TOOL_MAP[service_action]
|
||||||
tool_call = {
|
tool_call = {
|
||||||
"tool_name": tool_name,
|
"tool_name": tool_name,
|
||||||
"service_name": service,
|
"service_name": service,
|
||||||
@@ -415,13 +354,19 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
|||||||
if call["tool_name"] == TOOL_LIST_ADD_ITEM:
|
if call["tool_name"] == TOOL_LIST_ADD_ITEM:
|
||||||
call["tool_args"]["item"] = todo
|
call["tool_args"]["item"] = todo
|
||||||
|
|
||||||
|
starting_answer = answer_starting.strip().lower()
|
||||||
|
normalized_answers = [ sentence.lower() for sentence in answer_list ]
|
||||||
|
final_answer = " ".join(normalized_answers)
|
||||||
|
assistant_turns = [
|
||||||
|
create_assistant_turn(starting_answer, tool_calls),
|
||||||
|
create_assistant_turn(final_answer, [])
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_tools": available_tools,
|
"available_tools": available_tools,
|
||||||
"question": question.lower(),
|
"question": question.lower(),
|
||||||
"answer_starting": answer_starting.lower(),
|
"assistant_turns": assistant_turns
|
||||||
"answers": [ sentence.lower() for sentence in answer_list ],
|
|
||||||
"tool_calls": tool_calls
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
|
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
|
||||||
@@ -509,12 +454,13 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
|||||||
# Remove duplicates while preserving order
|
# Remove duplicates while preserving order
|
||||||
available_tools = list(dict.fromkeys(available_tools))
|
available_tools = list(dict.fromkeys(available_tools))
|
||||||
|
|
||||||
|
assistant_turns = [create_assistant_turn(answer.lower(), [])]
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_tools": available_tools,
|
"available_tools": available_tools,
|
||||||
"question": question.lower(),
|
"question": question.lower(),
|
||||||
"answers": [ answer.lower() ],
|
"assistant_turns": assistant_turns
|
||||||
"tool_calls": []
|
|
||||||
}
|
}
|
||||||
if return_target_device:
|
if return_target_device:
|
||||||
return result, chosen_device
|
return result, chosen_device
|
||||||
@@ -578,42 +524,42 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
|
|||||||
retry_prompt = failure_case.get("retry_prompt", f"Trying again with {friendly_name}.").replace("<device_name>", friendly_name)
|
retry_prompt = failure_case.get("retry_prompt", f"Trying again with {friendly_name}.").replace("<device_name>", friendly_name)
|
||||||
error_result = failure_case.get("error_result", "Error").replace("<device_name>", friendly_name)
|
error_result = failure_case.get("error_result", "Error").replace("<device_name>", friendly_name)
|
||||||
|
|
||||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
tool_name = SERVICE_TO_TOOL_MAP[service_action]
|
||||||
first_args = {"entity_id": bad_device} if use_service_names else {"name": bad_device}
|
first_args = {"entity_id": bad_device} if use_service_names else {"name": bad_device}
|
||||||
retry_args = {"entity_id": target_device} if use_service_names else {"name": target_device}
|
retry_args = {"entity_id": target_device} if use_service_names else {"name": target_device}
|
||||||
first_args.update(tool_args_extra)
|
first_args.update(tool_args_extra)
|
||||||
retry_args.update(tool_args_extra)
|
retry_args.update(tool_args_extra)
|
||||||
|
|
||||||
tool_call_sequence = [
|
first_turn = create_assistant_turn(
|
||||||
{
|
response_starting,
|
||||||
"answer_starting": response_starting,
|
[{
|
||||||
"tool_calls": [{
|
"tool_name": tool_name,
|
||||||
"tool_name": tool_name,
|
"service_name": service_name,
|
||||||
"service_name": service_name,
|
"tool_args": first_args
|
||||||
"tool_args": first_args
|
}],
|
||||||
}],
|
tool_results=[{
|
||||||
"tool_results": [{
|
"tool_name": service_name if use_service_names else tool_name,
|
||||||
"tool_name": service_name if use_service_names else tool_name,
|
"tool_result": error_result
|
||||||
"tool_result": error_result
|
}],
|
||||||
}]
|
train_on_turn=False
|
||||||
},
|
)
|
||||||
{
|
|
||||||
"answer_starting": retry_prompt,
|
second_turn = create_assistant_turn(
|
||||||
"tool_calls": [{
|
retry_prompt,
|
||||||
"tool_name": tool_name,
|
[{
|
||||||
"service_name": service_name,
|
"tool_name": tool_name,
|
||||||
"tool_args": retry_args
|
"service_name": service_name,
|
||||||
}]
|
"tool_args": retry_args
|
||||||
}
|
}]
|
||||||
]
|
)
|
||||||
|
|
||||||
|
final_turn = create_assistant_turn(response_confirmed, [])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_tools": available_tools,
|
"available_tools": available_tools,
|
||||||
"question": question,
|
"question": question,
|
||||||
"answers": [response_confirmed],
|
"assistant_turns": [first_turn, second_turn, final_turn]
|
||||||
"tool_call_sequence": tool_call_sequence,
|
|
||||||
"tool_calls": []
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||||
@@ -645,41 +591,20 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma
|
|||||||
response_text = refusal_case["response"].replace("<device_name>", friendly_name).lower()
|
response_text = refusal_case["response"].replace("<device_name>", friendly_name).lower()
|
||||||
question = refusal_case["phrase"].replace("<device_name>", friendly_name).lower()
|
question = refusal_case["phrase"].replace("<device_name>", friendly_name).lower()
|
||||||
|
|
||||||
|
assistant_turns = [create_assistant_turn(response_text, [])]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_tools": available_tools,
|
"available_tools": available_tools,
|
||||||
"question": question,
|
"question": question,
|
||||||
"answers": [response_text],
|
"assistant_turns": assistant_turns
|
||||||
"tool_calls": []
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
|
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
|
||||||
piles = get_dataset_piles(language)
|
piles = get_dataset_piles(language)
|
||||||
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
||||||
question = example["question"]
|
question = example["question"]
|
||||||
answers = " ".join(example["answers"])
|
assistant_turns = example["assistant_turns"]
|
||||||
answer_starting = example.get("answer_starting", "")
|
|
||||||
|
|
||||||
tool_call_sequence = example.get("tool_call_sequence")
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
tool_results = []
|
|
||||||
if not tool_call_sequence:
|
|
||||||
# Add tool use blocks if there are tool calls
|
|
||||||
if len(example["tool_calls"]) > 0:
|
|
||||||
for tool_call in example["tool_calls"]:
|
|
||||||
# Use service_name if in service mode, otherwise use tool_name
|
|
||||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
|
||||||
tool_calls.append({
|
|
||||||
"name": call_name,
|
|
||||||
"arguments": json.dumps(tool_call["tool_args"])
|
|
||||||
})
|
|
||||||
|
|
||||||
tool_results.append({
|
|
||||||
"tool_name": call_name,
|
|
||||||
"tool_call_id": f"call_{len(tool_results) + 1}",
|
|
||||||
"tool_result": "Success"
|
|
||||||
})
|
|
||||||
|
|
||||||
if append_user_instruction_prompt:
|
if append_user_instruction_prompt:
|
||||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||||
@@ -703,95 +628,64 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
|||||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
if tool_call_sequence:
|
call_id_counter = 1
|
||||||
call_id_counter = 1
|
for turn in assistant_turns:
|
||||||
for step in tool_call_sequence:
|
answer_text = turn.get("answer", "")
|
||||||
step_tool_calls = []
|
assistant_block = {
|
||||||
for tool_call in step.get("tool_calls", []):
|
|
||||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
|
||||||
step_tool_calls.append({
|
|
||||||
"name": call_name,
|
|
||||||
"arguments": json.dumps(tool_call["tool_args"])
|
|
||||||
})
|
|
||||||
|
|
||||||
assistant_block = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{ "type": "text", "text": step.get("answer_starting", "") }]
|
|
||||||
}
|
|
||||||
if step_tool_calls:
|
|
||||||
assistant_block["tool_calls"] = [ { "function": tc } for tc in step_tool_calls ]
|
|
||||||
conversation.append(assistant_block)
|
|
||||||
|
|
||||||
if step_tool_calls:
|
|
||||||
provided_results = step.get("tool_results")
|
|
||||||
step_tool_results = []
|
|
||||||
if provided_results:
|
|
||||||
for provided in provided_results:
|
|
||||||
step_tool_results.append({
|
|
||||||
"tool_name": provided.get("tool_name", step_tool_calls[0]["name"]),
|
|
||||||
"tool_call_id": provided.get("tool_call_id", f"call_{call_id_counter}"),
|
|
||||||
"tool_result": provided.get("tool_result", "Success")
|
|
||||||
})
|
|
||||||
call_id_counter += 1
|
|
||||||
else:
|
|
||||||
for call in step_tool_calls:
|
|
||||||
step_tool_results.append({
|
|
||||||
"tool_name": call["name"],
|
|
||||||
"tool_call_id": f"call_{call_id_counter}",
|
|
||||||
"tool_result": "Success"
|
|
||||||
})
|
|
||||||
call_id_counter += 1
|
|
||||||
|
|
||||||
if tool_response_format == "text":
|
|
||||||
conversation.append({
|
|
||||||
"role": "tool",
|
|
||||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
|
|
||||||
})
|
|
||||||
elif tool_response_format == "functiongemma":
|
|
||||||
conversation.append({
|
|
||||||
"role": "tool",
|
|
||||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
|
|
||||||
})
|
|
||||||
|
|
||||||
if step.get("post_tool_response"):
|
|
||||||
conversation.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{ "type": "text", "text": step["post_tool_response"] }]
|
|
||||||
})
|
|
||||||
|
|
||||||
conversation.append({
|
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [{ "type": "text", "text": answers }],
|
"content": [{ "type": "text", "text": answer_text }],
|
||||||
})
|
"train_on_turn": turn.get("train_on_turn", True),
|
||||||
elif len(tool_calls) > 0:
|
|
||||||
assistant_starting_block = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{ "type": "text", "text": answer_starting }],
|
|
||||||
"tool_calls": [ { "function": tc } for tc in tool_calls ]
|
|
||||||
}
|
}
|
||||||
if tool_response_format == "text":
|
|
||||||
tool_response_block = {
|
tool_call_sequence = turn.get("tool_call_sequence", [])
|
||||||
"role": "tool",
|
formatted_calls = []
|
||||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in tool_results]
|
call_names = []
|
||||||
}
|
for tool_call in tool_call_sequence:
|
||||||
elif tool_response_format == "functiongemma":
|
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||||
tool_response_block = {
|
call_names.append(call_name)
|
||||||
"role": "tool",
|
formatted_calls.append({
|
||||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in tool_results]
|
"name": call_name,
|
||||||
}
|
"arguments": json.dumps(tool_call["tool_args"])
|
||||||
assistant_confirmation_block = {
|
})
|
||||||
"role": "assistant",
|
|
||||||
"content": [{ "type": "text", "text": answers }],
|
if formatted_calls:
|
||||||
}
|
assistant_block["tool_calls"] = [{ "function": call } for call in formatted_calls]
|
||||||
conversation.extend([assistant_starting_block, tool_response_block, assistant_confirmation_block])
|
|
||||||
else:
|
conversation.append(assistant_block)
|
||||||
conversation.extend([
|
|
||||||
{
|
if formatted_calls:
|
||||||
"role": "assistant",
|
provided_results = turn.get("tool_results") or []
|
||||||
"content": [{ "type": "text", "text": answer_starting + answers }],
|
step_tool_results = []
|
||||||
}
|
|
||||||
])
|
if provided_results:
|
||||||
|
for idx, provided in enumerate(provided_results):
|
||||||
|
result = dict(provided)
|
||||||
|
if "tool_name" not in result and call_names:
|
||||||
|
result["tool_name"] = call_names[min(idx, len(call_names) - 1)]
|
||||||
|
if "tool_call_id" not in result:
|
||||||
|
result["tool_call_id"] = f"call_{call_id_counter}"
|
||||||
|
call_id_counter += 1
|
||||||
|
step_tool_results.append(result)
|
||||||
|
else:
|
||||||
|
for call_name in call_names:
|
||||||
|
step_tool_results.append({
|
||||||
|
"tool_name": call_name,
|
||||||
|
"tool_call_id": f"call_{call_id_counter}",
|
||||||
|
"tool_result": "Success"
|
||||||
|
})
|
||||||
|
call_id_counter += 1
|
||||||
|
|
||||||
|
if tool_response_format == "text":
|
||||||
|
conversation.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
|
||||||
|
})
|
||||||
|
elif tool_response_format == "functiongemma":
|
||||||
|
conversation.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
|
||||||
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": conversation,
|
"messages": conversation,
|
||||||
@@ -812,8 +706,8 @@ def generate_sft_file(
|
|||||||
static_factor: float,
|
static_factor: float,
|
||||||
template_factor: int,
|
template_factor: int,
|
||||||
status_request_factor: int,
|
status_request_factor: int,
|
||||||
failure_factor: float = 1,
|
failure_factor: int,
|
||||||
refusal_factor: float = 1):
|
refusal_factor: int):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
piles = get_dataset_piles(language)
|
piles = get_dataset_piles(language)
|
||||||
@@ -929,7 +823,7 @@ def main(args=None):
|
|||||||
|
|
||||||
args = parser.parse_args(args=args)
|
args = parser.parse_args(args=args)
|
||||||
|
|
||||||
if not args.sample and not args.train and not args.test and not args.merge:
|
if not args.sample and not args.train and not args.test:
|
||||||
parser.print_usage()
|
parser.print_usage()
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
@@ -950,20 +844,20 @@ def main(args=None):
|
|||||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||||
|
|
||||||
if args.sample:
|
if args.sample:
|
||||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
|
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1, refusal_factor=1, failure_factor=1)
|
||||||
if args.train:
|
if args.train:
|
||||||
if args.size == "small":
|
if args.size == "small":
|
||||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8, refusal_factor=3, failure_factor=1)
|
||||||
elif args.size == "medium":
|
elif args.size == "medium":
|
||||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12, refusal_factor=5, failure_factor=1)
|
||||||
elif args.size == "large":
|
elif args.size == "large":
|
||||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15, refusal_factor=6, failure_factor=1)
|
||||||
elif args.size == "xl":
|
elif args.size == "xl":
|
||||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
|
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18, refusal_factor=8, failure_factor=2)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||||
if args.test:
|
if args.test:
|
||||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
|
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2, refusal_factor=1, failure_factor=1)
|
||||||
if len(args.language) > 1:
|
if len(args.language) > 1:
|
||||||
if args.sample:
|
if args.sample:
|
||||||
merge_languages("sample", args.language)
|
merge_languages("sample", args.language)
|
||||||
|
|||||||
@@ -3,13 +3,19 @@ import asyncio
|
|||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
import string
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from utils import get_dataset_piles
|
from utils import get_dataset_piles
|
||||||
|
from devices import random_device_list
|
||||||
|
|
||||||
LLM_ENDPOINT = "https://ai.cloud.alexoconnell.net/v1/chat/completions"
|
LLM_ENDPOINT = "https://ai.cloud.alexoconnell.net/v1/chat/completions"
|
||||||
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
def get_hass_match_error_message(bad_device_name: str) -> str:
|
||||||
|
return f"<MatchFailedError result=MatchTargetsResult(is_match=False, no_match_reason=<MatchFailedReason.NAME: 1>, states=[], no_match_name=None, areas=[], floors=[]), constraints=MatchTargetsConstraints(name='{bad_device_name}', area_name=None, floor_name=None, domains=None, device_classes=None, features=None, states=None, assistant='conversation', allow_duplicate_names=False, single_target=False), preferences=MatchTargetsPreferences(area_id=None, floor_id=None)>"
|
||||||
|
|
||||||
class SyntheticDataGenerator:
|
class SyntheticDataGenerator:
|
||||||
def __init__(self, model_name: str, language: str, concurrency: int):
|
def __init__(self, model_name: str, language: str, concurrency: int):
|
||||||
@@ -18,6 +24,130 @@ class SyntheticDataGenerator:
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.piles = get_dataset_piles(language)
|
self.piles = get_dataset_piles(language)
|
||||||
self.synthetic_devices = {} # device_type -> list of {device_name, description}
|
self.synthetic_devices = {} # device_type -> list of {device_name, description}
|
||||||
|
self.failed_tool_fieldnames = [
|
||||||
|
"service_name",
|
||||||
|
"correct_device_name",
|
||||||
|
"correct_friendly_name",
|
||||||
|
"bad_device_name",
|
||||||
|
"phrase",
|
||||||
|
"error_result",
|
||||||
|
"retry_prompt"
|
||||||
|
]
|
||||||
|
self.refusal_fieldnames = [
|
||||||
|
"reason_type",
|
||||||
|
"service_name",
|
||||||
|
"device_name",
|
||||||
|
"friendly_name",
|
||||||
|
"desired_state",
|
||||||
|
"phrase",
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _chat_completion(self, session, system_prompt: str, user_prompt: str, *, temperature: float = 0.8, max_tokens: int = 300, structured_response: dict = {}):
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt}
|
||||||
|
],
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens if not structured_response else None, # don't limit if structured (causes failed generations)
|
||||||
|
"response_format": structured_response if structured_response else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(LLM_ENDPOINT, json=payload) as response:
|
||||||
|
if response.status < 400:
|
||||||
|
data = await response.json()
|
||||||
|
return data['choices'][0]['message']['content'].strip()
|
||||||
|
print(f"LLM request failed with status {response.status}")
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Completion request failed: {exc}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_code_fence(content: str) -> str:
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
text = content.strip()
|
||||||
|
if text.startswith("```") and text.endswith("```"):
|
||||||
|
lines = [line for line in text.splitlines() if not line.strip().startswith("```")]
|
||||||
|
return "\n".join(lines).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _parse_json_object(self, content: str):
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
cleaned = self._strip_code_fence(content)
|
||||||
|
start = cleaned.find("{")
|
||||||
|
end = cleaned.rfind("}")
|
||||||
|
if start == -1 or end == -1:
|
||||||
|
return None
|
||||||
|
snippet = cleaned[start:end+1]
|
||||||
|
try:
|
||||||
|
return json.loads(snippet)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_placeholder(text: str, friendly_name: str, placeholder: str) -> str:
|
||||||
|
if placeholder in text:
|
||||||
|
return text
|
||||||
|
lower_text = text.lower()
|
||||||
|
lower_name = friendly_name.lower()
|
||||||
|
idx = lower_text.find(lower_name)
|
||||||
|
if idx == -1:
|
||||||
|
return text
|
||||||
|
return text[:idx] + placeholder + text[idx + len(friendly_name):]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _describe_service(service_name: str, service_data: dict) -> str:
|
||||||
|
domain, action = service_name.split(".", 1)
|
||||||
|
description = f"{action.replace('_', ' ')} the {domain}"
|
||||||
|
if service_data:
|
||||||
|
description += f" with arguments {json.dumps(service_data, ensure_ascii=False)}"
|
||||||
|
return description
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _desired_state_for_service(service_name: str) -> str:
|
||||||
|
action = service_name.split(".", 1)[1]
|
||||||
|
action_mapping = {
|
||||||
|
"turn_on": "on",
|
||||||
|
"turn_off": "off",
|
||||||
|
"open_cover": "open",
|
||||||
|
"close_cover": "closed",
|
||||||
|
"lock": "locked",
|
||||||
|
"unlock": "unlocked",
|
||||||
|
"media_play_pause": "playing",
|
||||||
|
}
|
||||||
|
if action in action_mapping:
|
||||||
|
return action_mapping[action]
|
||||||
|
service_mapping = {
|
||||||
|
"media_player.turn_on": "on",
|
||||||
|
"media_player.turn_off": "off",
|
||||||
|
"vacuum.start": "cleaning",
|
||||||
|
"vacuum.return_to_base": "docked",
|
||||||
|
"garage_door.open_cover": "open",
|
||||||
|
"garage_door.close_cover": "closed",
|
||||||
|
"blinds.open_cover": "open",
|
||||||
|
"blinds.close_cover": "closed",
|
||||||
|
}
|
||||||
|
return service_mapping.get(service_name, "")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_rows_to_csv(path: str, fieldnames, rows):
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
file_exists = os.path.exists(path)
|
||||||
|
if not file_exists:
|
||||||
|
with open(path, "w", newline='', encoding='utf-8') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
with open(path, "a", newline='', encoding='utf-8') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
for row in rows:
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
async def generate_device_names(self, session, device_type, count=10):
|
async def generate_device_names(self, session, device_type, count=10):
|
||||||
"""
|
"""
|
||||||
@@ -144,6 +274,117 @@ class SyntheticDataGenerator:
|
|||||||
print(f"Request failed: {e}")
|
print(f"Request failed: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def generate_failed_tool_call_entry(self, session, context):
|
||||||
|
service_name = context["service_name"]
|
||||||
|
friendly_name = context["friendly_name"]
|
||||||
|
correct_device = context["device_name"]
|
||||||
|
action_description = self._describe_service(service_name, context.get("service_data", {}))
|
||||||
|
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=10, avoid_device_names=[], language=self.language)
|
||||||
|
|
||||||
|
extra_phrase_instructions = ""
|
||||||
|
if service_name == "climate.set_temperature":
|
||||||
|
extra_phrase_instructions = "Always reference the target temperature using the literal placeholder <temp_f>."
|
||||||
|
|
||||||
|
system_prompt = "You create high-quality multilingual training data for a smart home assistant."
|
||||||
|
user_prompt = f"""
|
||||||
|
Generate content for a failed tool call in {self.language}.
|
||||||
|
Service: {service_name}
|
||||||
|
Action description: {action_description}
|
||||||
|
Device name: {correct_device} (friendly name: {friendly_name})
|
||||||
|
The assistant first misidentifies the entity using the "bad_device" name, receives an error, then retries with {correct_device}.
|
||||||
|
Output a single JSON object with the keys "phrase", "bad_device", "error_result", and "retry_prompt".
|
||||||
|
- phrase: A natural voice command template the user would say. {extra_phrase_instructions}
|
||||||
|
- bad_device: The incorrect version of '{correct_device}'. Ensure it is a truncation, mutation, transposition, or inclusion of an extra suffix/prefix. Avoid simple typos in words.
|
||||||
|
- retry_prompt: A short acknowledgement from the assistant that it will try the correct device ({correct_device}).
|
||||||
|
|
||||||
|
Here are potential device names to help generate a bad device name: {', '.join(device_list)}
|
||||||
|
Keep the tone conversational and stay entirely in {self.language}. Do not add explanations or code fences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raw = await self._chat_completion(
|
||||||
|
session,
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=512,
|
||||||
|
structured_response={ "type": "json_object" }
|
||||||
|
)
|
||||||
|
parsed = self._parse_json_object(raw)
|
||||||
|
if not parsed:
|
||||||
|
print(f"Failed to parse JSON for failed tool call generation: {raw}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
phrase = parsed.get("phrase", "").strip()
|
||||||
|
bad_device = parsed.get("bad_device", "").strip()
|
||||||
|
retry_prompt = parsed.get("retry_prompt", "").strip()
|
||||||
|
if not phrase or not bad_device or not retry_prompt:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"service_name": service_name,
|
||||||
|
"correct_device_name": correct_device,
|
||||||
|
"correct_friendly_name": friendly_name,
|
||||||
|
"bad_device_name": bad_device,
|
||||||
|
"phrase": phrase,
|
||||||
|
"error_result": get_hass_match_error_message(bad_device),
|
||||||
|
"retry_prompt": retry_prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
async def generate_refusal_entry(self, session, context, reason_type: str, desired_state: str):
|
||||||
|
service_name = context["service_name"]
|
||||||
|
friendly_name = context["friendly_name"]
|
||||||
|
action_description = self._describe_service(service_name, context.get("service_data", {}))
|
||||||
|
device_suffix = context["device_name"].split(".", 1)[1]
|
||||||
|
|
||||||
|
if reason_type == "not_available":
|
||||||
|
reason_detail = "Explain that the assistant cannot locate that device in the home."
|
||||||
|
else:
|
||||||
|
reason_detail = f"Explain that the device is already {desired_state} and no change is needed."
|
||||||
|
|
||||||
|
system_prompt = "You write concise refusal-style responses for a multilingual smart home assistant."
|
||||||
|
user_prompt = f"""
|
||||||
|
Create a refusal example in {self.language} for the following request.
|
||||||
|
Service: {service_name}
|
||||||
|
Action description: {action_description}
|
||||||
|
Reason: {reason_detail}
|
||||||
|
Output a JSON object with:
|
||||||
|
- "phrase": the user's natural command template. Use the literal placeholder <device_name> anywhere the user mentions the device.
|
||||||
|
- "response": the assistant's refusal message in {self.language} describing the reason.
|
||||||
|
Keep both fields brief, conversational, and free of extra narration or code fences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raw = await self._chat_completion(
|
||||||
|
session,
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=512,
|
||||||
|
structured_response={ "type": "json_object" }
|
||||||
|
)
|
||||||
|
parsed = self._parse_json_object(raw)
|
||||||
|
if not parsed:
|
||||||
|
print("Failed to parse JSON for refusal generation")
|
||||||
|
return None
|
||||||
|
|
||||||
|
phrase = parsed.get("phrase", "").strip()
|
||||||
|
response = parsed.get("response", "").strip()
|
||||||
|
if not phrase or not response:
|
||||||
|
return None
|
||||||
|
|
||||||
|
phrase = self._ensure_placeholder(phrase, friendly_name, "<device_name>")
|
||||||
|
if "<device_name>" not in phrase:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"reason_type": reason_type,
|
||||||
|
"service_name": service_name,
|
||||||
|
"device_name": device_suffix,
|
||||||
|
"friendly_name": friendly_name,
|
||||||
|
"desired_state": desired_state if reason_type == "already_state" else "",
|
||||||
|
"phrase": phrase,
|
||||||
|
"response": response
|
||||||
|
}
|
||||||
|
|
||||||
def sample_context(self, request_type: str):
|
def sample_context(self, request_type: str):
|
||||||
"""
|
"""
|
||||||
Creates a random scenario: device, service, and arguments.
|
Creates a random scenario: device, service, and arguments.
|
||||||
@@ -245,14 +486,12 @@ class SyntheticDataGenerator:
|
|||||||
}
|
}
|
||||||
raise ValueError(f"Unknown request type {request_type}")
|
raise ValueError(f"Unknown request type {request_type}")
|
||||||
|
|
||||||
async def run(self, num_actions: int, num_status_requests: int, num_devices: int, output_file, persona_name=None, persona_description=None):
|
async def run(self, num_actions: int, num_status_requests: int, num_devices: int,
|
||||||
|
persona_name=None, persona_description=None, num_failed_tool_calls: int = 0,
|
||||||
|
num_refusals: int = 0):
|
||||||
print(f"Starting generation...")
|
print(f"Starting generation...")
|
||||||
print(f"Language: {self.language}")
|
print(f"Language: {self.language}")
|
||||||
|
|
||||||
# Ensure output directory exists
|
|
||||||
if output_file:
|
|
||||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
||||||
if num_devices > 0:
|
if num_devices > 0:
|
||||||
@@ -268,7 +507,7 @@ class SyntheticDataGenerator:
|
|||||||
all_new_devices = [item for sublist in generated_lists if sublist for item in sublist]
|
all_new_devices = [item for sublist in generated_lists if sublist for item in sublist]
|
||||||
|
|
||||||
if all_new_devices:
|
if all_new_devices:
|
||||||
csv_path = f"data/piles/{self.language}/pile_of_device_names.csv"
|
csv_path = f"{cwd}/piles/{self.language}/pile_of_device_names.csv"
|
||||||
try:
|
try:
|
||||||
with open(csv_path, "a", newline='', encoding='utf-8') as f:
|
with open(csv_path, "a", newline='', encoding='utf-8') as f:
|
||||||
writer = csv.DictWriter(f, fieldnames=["device_name", "description"])
|
writer = csv.DictWriter(f, fieldnames=["device_name", "description"])
|
||||||
@@ -280,7 +519,6 @@ class SyntheticDataGenerator:
|
|||||||
|
|
||||||
if num_actions > 0 or num_status_requests > 0:
|
if num_actions > 0 or num_status_requests > 0:
|
||||||
print(f"Generating {num_actions} actions and {num_status_requests} status requests...")
|
print(f"Generating {num_actions} actions and {num_status_requests} status requests...")
|
||||||
print(f"Output file: {output_file}")
|
|
||||||
tasks = {}
|
tasks = {}
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@@ -312,7 +550,7 @@ class SyntheticDataGenerator:
|
|||||||
|
|
||||||
if entry["type"] == "action":
|
if entry["type"] == "action":
|
||||||
# Write to pile_of_specific_actions.csv
|
# Write to pile_of_specific_actions.csv
|
||||||
csv_path = f"data/piles/{self.language}/pile_of_specific_actions.csv"
|
csv_path = f"{cwd}/piles/{self.language}/pile_of_specific_actions.csv"
|
||||||
|
|
||||||
# Prepare row
|
# Prepare row
|
||||||
# device_name in CSV is the suffix (e.g. 'kitchen' from 'light.kitchen')
|
# device_name in CSV is the suffix (e.g. 'kitchen' from 'light.kitchen')
|
||||||
@@ -383,7 +621,7 @@ class SyntheticDataGenerator:
|
|||||||
# For now, just skip if we can't templatize.
|
# For now, just skip if we can't templatize.
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
csv_path = f"data/piles/{self.language}/pile_of_status_requests.csv"
|
csv_path = f"{cwd}/piles/{self.language}/pile_of_status_requests.csv"
|
||||||
# Columns: device_type,state,phrase,assistant_response
|
# Columns: device_type,state,phrase,assistant_response
|
||||||
# We don't have assistant_response.
|
# We don't have assistant_response.
|
||||||
# We can generate a generic one?
|
# We can generate a generic one?
|
||||||
@@ -398,6 +636,85 @@ class SyntheticDataGenerator:
|
|||||||
|
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
|
if num_failed_tool_calls > 0:
|
||||||
|
print(f"Generating {num_failed_tool_calls} failed tool call scenarios...")
|
||||||
|
failed_entries = []
|
||||||
|
tasks = {}
|
||||||
|
pbar_failed = tqdm(total=num_failed_tool_calls, desc="Failed tool calls")
|
||||||
|
|
||||||
|
while len(failed_entries) < num_failed_tool_calls:
|
||||||
|
while len(tasks) < self.concurrency and (len(failed_entries) + len(tasks)) < num_failed_tool_calls:
|
||||||
|
context = self.sample_context("action")
|
||||||
|
if not context:
|
||||||
|
continue
|
||||||
|
task = asyncio.create_task(self.generate_failed_tool_call_entry(session, context))
|
||||||
|
tasks[task] = None
|
||||||
|
|
||||||
|
if not tasks:
|
||||||
|
break
|
||||||
|
|
||||||
|
done, _ = await asyncio.wait(list(tasks.keys()), return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
for task in done:
|
||||||
|
tasks.pop(task, None)
|
||||||
|
try:
|
||||||
|
entry = await task
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Failed tool call task error: {exc}")
|
||||||
|
entry = None
|
||||||
|
if entry:
|
||||||
|
failed_entries.append(entry)
|
||||||
|
pbar_failed.update(1)
|
||||||
|
|
||||||
|
pbar_failed.close()
|
||||||
|
failed_path = f"{cwd}/piles/{self.language}/pile_of_failed_tool_calls.csv"
|
||||||
|
self._append_rows_to_csv(failed_path, self.failed_tool_fieldnames, failed_entries)
|
||||||
|
|
||||||
|
if num_refusals > 0:
|
||||||
|
print(f"Generating {num_refusals} refusal scenarios...")
|
||||||
|
refusal_entries = []
|
||||||
|
tasks = {}
|
||||||
|
pbar_refusals = tqdm(total=num_refusals, desc="Refusals")
|
||||||
|
|
||||||
|
while len(refusal_entries) < num_refusals:
|
||||||
|
while len(tasks) < self.concurrency and (len(refusal_entries) + len(tasks)) < num_refusals:
|
||||||
|
context = self.sample_context("action")
|
||||||
|
if not context:
|
||||||
|
continue
|
||||||
|
reason_type = random.choice(["not_available", "already_state"])
|
||||||
|
desired_state = ""
|
||||||
|
if reason_type == "already_state":
|
||||||
|
desired_state = self._desired_state_for_service(context["service_name"])
|
||||||
|
if not desired_state:
|
||||||
|
reason_type = "not_available"
|
||||||
|
task = asyncio.create_task(self.generate_refusal_entry(
|
||||||
|
session,
|
||||||
|
context,
|
||||||
|
reason_type,
|
||||||
|
desired_state
|
||||||
|
))
|
||||||
|
tasks[task] = None
|
||||||
|
|
||||||
|
if not tasks:
|
||||||
|
break
|
||||||
|
|
||||||
|
done, _ = await asyncio.wait(list(tasks.keys()), return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
for task in done:
|
||||||
|
tasks.pop(task, None)
|
||||||
|
try:
|
||||||
|
entry = await task
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Refusal generation error: {exc}")
|
||||||
|
entry = None
|
||||||
|
if entry:
|
||||||
|
refusal_entries.append(entry)
|
||||||
|
pbar_refusals.update(1)
|
||||||
|
|
||||||
|
pbar_refusals.close()
|
||||||
|
refusal_path = f"{cwd}/piles/{self.language}/pile_of_refusals.csv"
|
||||||
|
self._append_rows_to_csv(refusal_path, self.refusal_fieldnames, refusal_entries)
|
||||||
|
|
||||||
if persona_name and persona_description:
|
if persona_name and persona_description:
|
||||||
await self.generate_persona(session, persona_name, persona_description)
|
await self.generate_persona(session, persona_name, persona_description)
|
||||||
|
|
||||||
@@ -442,7 +759,7 @@ class SyntheticDataGenerator:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 2. Get list of services to generate responses for
|
# 2. Get list of services to generate responses for
|
||||||
responses_csv_path = f"data/piles/{self.language}/pile_of_responses.csv"
|
responses_csv_path = f"{cwd}/piles/{self.language}/pile_of_responses.csv"
|
||||||
services = set()
|
services = set()
|
||||||
try:
|
try:
|
||||||
with open(responses_csv_path, "r", encoding='utf-8') as f:
|
with open(responses_csv_path, "r", encoding='utf-8') as f:
|
||||||
@@ -527,7 +844,7 @@ class SyntheticDataGenerator:
|
|||||||
|
|
||||||
# 4. Write to files
|
# 4. Write to files
|
||||||
# Append system prompt
|
# Append system prompt
|
||||||
sys_prompts_path = f"data/piles/{self.language}/pile_of_system_prompts.csv"
|
sys_prompts_path = f"{cwd}/piles/{self.language}/pile_of_system_prompts.csv"
|
||||||
try:
|
try:
|
||||||
with open(sys_prompts_path, "a", newline='', encoding='utf-8') as f:
|
with open(sys_prompts_path, "a", newline='', encoding='utf-8') as f:
|
||||||
writer = csv.writer(f)
|
writer = csv.writer(f)
|
||||||
@@ -555,13 +872,25 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--actions", type=int, default=0, help="Number of actions to generate")
|
parser.add_argument("--actions", type=int, default=0, help="Number of actions to generate")
|
||||||
parser.add_argument("--status", type=int, default=0, help="Number of status requests to generate")
|
parser.add_argument("--status", type=int, default=0, help="Number of status requests to generate")
|
||||||
parser.add_argument("--devices", type=int, default=0, help="Number of new devices to generate")
|
parser.add_argument("--devices", type=int, default=0, help="Number of new devices to generate")
|
||||||
parser.add_argument("--concurrency", type=int, default=8, help="Number of concurrent requests")
|
parser.add_argument("--failed-tool-calls", type=int, default=0, help="Number of failed tool call scenarios to generate")
|
||||||
parser.add_argument("--language", type=str, default="english", help="Language")
|
parser.add_argument("--refusals", type=int, default=0, help="Number of refusal scenarios to generate")
|
||||||
|
parser.add_argument("--concurrency", type=int, default=4, help="Number of concurrent requests")
|
||||||
|
parser.add_argument("--languages", type=str, nargs='+', default=["english"], help="Languages to generate data for")
|
||||||
parser.add_argument("--model", type=str, default="gpt-oss-120b", help="LLM model to use")
|
parser.add_argument("--model", type=str, default="gpt-oss-120b", help="LLM model to use")
|
||||||
parser.add_argument("--persona-name", type=str, help="Name of the new persona to generate")
|
parser.add_argument("--persona-name", type=str, help="Name of the new persona to generate")
|
||||||
parser.add_argument("--persona-description", type=str, help="Description of the new persona")
|
parser.add_argument("--persona-description", type=str, help="Description of the new persona")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
generator = SyntheticDataGenerator(model_name=args.model, language=args.language, concurrency=args.concurrency)
|
for language in args.languages:
|
||||||
asyncio.run(generator.run(num_actions=args.actions, num_status_requests=args.status, num_devices=args.devices, output_file="", persona_name=args.persona_name, persona_description=args.persona_description))
|
print(f"=== Generating data for language: {language} ===")
|
||||||
|
generator = SyntheticDataGenerator(model_name=args.model, language=language, concurrency=args.concurrency)
|
||||||
|
asyncio.run(generator.run(
|
||||||
|
num_actions=args.actions,
|
||||||
|
num_status_requests=args.status,
|
||||||
|
num_devices=args.devices,
|
||||||
|
persona_name=args.persona_name,
|
||||||
|
persona_description=args.persona_description,
|
||||||
|
num_failed_tool_calls=args.failed_tool_calls,
|
||||||
|
num_refusals=args.refusals
|
||||||
|
))
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ SERVICE_TO_TOOL_MAP = {
|
|||||||
"unlock": TOOL_TURN_OFF,
|
"unlock": TOOL_TURN_OFF,
|
||||||
"increase_speed": TOOL_TURN_ON,
|
"increase_speed": TOOL_TURN_ON,
|
||||||
"decrease_speed": TOOL_TURN_OFF,
|
"decrease_speed": TOOL_TURN_OFF,
|
||||||
|
"media_stop": TOOL_TURN_OFF,
|
||||||
"media_play_pause": TOOL_TOGGLE,
|
"media_play_pause": TOOL_TOGGLE,
|
||||||
"media_pause": TOOL_MEDIA_PAUSE,
|
"media_pause": TOOL_MEDIA_PAUSE,
|
||||||
"media_play": TOOL_MEDIA_UNPAUSE,
|
"media_play": TOOL_MEDIA_UNPAUSE,
|
||||||
@@ -49,6 +50,13 @@ SERVICE_TO_TOOL_MAP = {
|
|||||||
"set_hvac_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
"set_hvac_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||||
"set_fan_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
"set_fan_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||||
"set_preset_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
"set_preset_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||||
|
"cancel": TOOL_CANCEL_TIMER,
|
||||||
|
"volume_down": TOOL_SET_VOLUME,
|
||||||
|
"volume_up": TOOL_SET_VOLUME,
|
||||||
|
"volume_mute": TOOL_SET_VOLUME,
|
||||||
|
"stop": TOOL_TURN_OFF,
|
||||||
|
"pause": TOOL_TURN_OFF,
|
||||||
|
"add_item": TOOL_LIST_ADD_ITEM
|
||||||
}
|
}
|
||||||
|
|
||||||
# Home Assistant Intent Tools Definition
|
# Home Assistant Intent Tools Definition
|
||||||
|
|||||||
Reference in New Issue
Block a user