more synthesizing scenarios + clean up example formatting

This commit is contained in:
Alex O'Connell
2025-12-21 13:31:43 -05:00
parent ecf9586b5a
commit 4407aefdf5
5 changed files with 584 additions and 253 deletions

View File

@@ -67,6 +67,26 @@ The response pile is a CSV with the following headers: `service,response,languag
Generating the full dataset using the python script will print out a warning for any responses that are missing for a persona.
## Synthesizing new pile data
You can quickly append fresh examples to the CSV piles without editing them manually by running `synthesize.py`. The script talks to the configured LLM and writes the generated rows directly into the per-language pile files.
Examples:
```bash
# Append 25 failed tool-call recoveries and 25 refusals in Spanish
python3 synthesize.py --language spanish --model gpt-oss-120b --failed-tool-calls 25 --refusals 25 --concurrency 6
# Generate new actions plus matching refusal samples in German
python3 synthesize.py --language german --actions 100 --refusals 40 --model gpt-oss-120b
```
Useful flags:
- `--failed-tool-calls`: number of `pile_of_failed_tool_calls.csv` rows to synthesize.
- `--refusals`: number of `pile_of_refusals.csv` rows to synthesize.
- `--actions`, `--status`, `--devices`: existing knobs for the other piles.
The script automatically routes generations to the correct language-specific pile under `data/piles/<language>/`.
## Adding new Home Assistant functionality
TODO
<!-- In order to add new home assistant device types, you will need to add data to a handful of piles, as well as make small modifications to the `generate_data.py` script.

View File

@@ -1,6 +1,7 @@
import random
from dataclasses import dataclass
from typing import Final, Callable, List
from difflib import SequenceMatcher
from tools import *
from utils import closest_color, generate_random_parameter, get_dataset_piles
@@ -31,6 +32,9 @@ STATE_CLEANING: Final = "cleaning"
STATE_DOCKED: Final = "docked"
STATE_RETURNING: Final = "returning"
def format_device_line(*, device_name: str, friendly_name: str, state: str):
return (f"{device_name} '{friendly_name}' = {state}")
@dataclass
class DeviceType:
name: str
@@ -222,3 +226,79 @@ class MediaPlayerDeviceType(DeviceType):
if "volume_level" in extra_exposed_attributes:
tools.append(TOOL_SET_VOLUME)
return tools
SUPPORTED_DEVICES = {
"light": LightDeviceType(),
"switch": SwitchDeviceType(),
"fan": FanDeviceType(),
"garage_door": GarageDoorDeviceType(),
"blinds": BlindsDeviceType(),
"lock": LockDeviceType(),
"media_player": MediaPlayerDeviceType(),
"climate": ClimateDeviceType(),
"vacuum": VacuumDeviceType(),
"timer": TimerDeviceType(),
"todo": TodoDeviceType(),
}
# generate a random list of devices for the context
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
num_devices = random.randint(2, max_devices)
piles = get_dataset_piles(language)
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
avoid_climate = False
for avoid_device in avoid_device_names:
avoid_type = avoid_device.split(".")[0]
filtered_possible_devices = []
for possible_device in local_device_names[avoid_type]:
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
if similarity_ratio < 0.4:
filtered_possible_devices.append(possible_device)
local_device_names[avoid_type] = filtered_possible_devices
if avoid_type == "climate":
avoid_climate = True
possible_choices = []
for device_type in local_device_names.keys():
possible_choices.extend(local_device_names[device_type])
device_types = set()
device_list = []
device_lines = []
# TODO: randomly pick attributes for this list
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
while len(device_list) < num_devices:
choice = random.choice(possible_choices)
if choice["device_name"] in device_list:
continue
try:
device_name = choice["device_name"]
device_type = device_name.split(".")[0]
friendly_name = choice["description"]
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
if avoid_climate and device_type == "climate":
continue
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
device_lines.append(format_device_line(
device_name=device_name,
friendly_name=friendly_name,
state=state
))
device_list.append(device_name)
device_types.add(device_type)
except Exception as ex:
print(f"bad device name: {choice}")
print(repr(ex))
return device_lines, list(device_types), list(extra_exposed_attributes)

View File

@@ -3,7 +3,6 @@ import json
import numpy as np
import random
from datasets import load_dataset, concatenate_datasets
from difflib import SequenceMatcher
from typing import Callable
from tqdm import tqdm
import webcolors
@@ -13,87 +12,13 @@ import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from device_types import *
from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \
TOOL_TURN_ON, TOOL_CLIMATE_SET_TEMPERATURE, TOOL_SET_HUMIDITY, \
TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \
HASS_TOOLS, SERVICE_TOOLS
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException
SUPPORTED_DEVICES = {
"light": LightDeviceType(),
"switch": SwitchDeviceType(),
"fan": FanDeviceType(),
"garage_door": GarageDoorDeviceType(),
"blinds": BlindsDeviceType(),
"lock": LockDeviceType(),
"media_player": MediaPlayerDeviceType(),
"climate": ClimateDeviceType(),
"vacuum": VacuumDeviceType(),
"timer": TimerDeviceType(),
"todo": TodoDeviceType(),
}
def format_device_line(*, device_name: str, friendly_name: str, state: str):
return (f"{device_name} '{friendly_name}' = {state}")
# generate a random list of devices for the context
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
num_devices = random.randint(2, max_devices)
piles = get_dataset_piles(language)
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
avoid_climate = False
for avoid_device in avoid_device_names:
avoid_type = avoid_device.split(".")[0]
filtered_possible_devices = []
for possible_device in local_device_names[avoid_type]:
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
if similarity_ratio < 0.4:
filtered_possible_devices.append(possible_device)
local_device_names[avoid_type] = filtered_possible_devices
if avoid_type == "climate":
avoid_climate = True
possible_choices = []
for device_type in local_device_names.keys():
possible_choices.extend(local_device_names[device_type])
device_types = set()
device_list = []
device_lines = []
# TODO: randomly pick attributes for this list
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
while len(device_list) < num_devices:
choice = random.choice(possible_choices)
if choice["device_name"] in device_list:
continue
try:
device_name = choice["device_name"]
device_type = device_name.split(".")[0]
friendly_name = choice["description"]
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
if avoid_climate and device_type == "climate":
continue
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
device_lines.append(format_device_line(
device_name=device_name,
friendly_name=friendly_name,
state=state
))
device_list.append(device_name)
device_types.add(device_type)
except Exception as ex:
print(f"bad device name: {choice}")
print(repr(ex))
return device_lines, list(device_types), list(extra_exposed_attributes)
from utils import get_random_response, generate_random_parameter, closest_color, \
get_dataset_piles, NoResponseAvailableException
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
question = action["phrase"]
@@ -126,7 +51,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
# Map service name to tool name
service_action = service_name.split(".")[1]
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
tool_name = SERVICE_TO_TOOL_MAP[service_action]
response_starting, response_confirmed = get_random_response(
piles.pile_of_responses,
@@ -225,13 +150,17 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
except Exception as e:
print(f"Failed to parse arguments for {action}: {e}")
final_answer = " ".join(answer_list)
assistant_turns = [
create_assistant_turn(response_starting, [tool_call]),
create_assistant_turn(final_answer, [])
]
return {
"states": device_list,
"available_tools": available_tools,
"question": question.lower(),
"answers": answer_list,
"answer_starting": response_starting,
"tool_calls": [ tool_call ]
"assistant_turns": assistant_turns
}
def replace_answer(list_of_answer, var, value):
@@ -240,6 +169,16 @@ def replace_answer(list_of_answer, var, value):
new_list.append(answer.replace(var, value))
return new_list
def create_assistant_turn(answer: str, tool_call_sequence=None, *, tool_results=None, train_on_turn: bool = True):
"""Bundle the assistant utterance with any tool interaction for that turn."""
return {
"answer": answer,
"tool_call_sequence": tool_call_sequence or [],
"tool_results": tool_results if tool_results is not None else [],
"train_on_turn": train_on_turn,
}
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
template_device_types: list[str] = template["device_type"].split("|")
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
@@ -328,7 +267,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
tool_calls = []
for device_dict, service in zip(chosen_devices, service_names):
service_action = service.split(".")[1]
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
tool_name = SERVICE_TO_TOOL_MAP[service_action]
tool_call = {
"tool_name": tool_name,
"service_name": service,
@@ -415,13 +354,19 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
if call["tool_name"] == TOOL_LIST_ADD_ITEM:
call["tool_args"]["item"] = todo
starting_answer = answer_starting.strip().lower()
normalized_answers = [ sentence.lower() for sentence in answer_list ]
final_answer = " ".join(normalized_answers)
assistant_turns = [
create_assistant_turn(starting_answer, tool_calls),
create_assistant_turn(final_answer, [])
]
return {
"states": device_list,
"available_tools": available_tools,
"question": question.lower(),
"answer_starting": answer_starting.lower(),
"answers": [ sentence.lower() for sentence in answer_list ],
"tool_calls": tool_calls
"assistant_turns": assistant_turns
}
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
@@ -509,12 +454,13 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
# Remove duplicates while preserving order
available_tools = list(dict.fromkeys(available_tools))
assistant_turns = [create_assistant_turn(answer.lower(), [])]
result = {
"states": device_list,
"available_tools": available_tools,
"question": question.lower(),
"answers": [ answer.lower() ],
"tool_calls": []
"assistant_turns": assistant_turns
}
if return_target_device:
return result, chosen_device
@@ -578,42 +524,42 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
retry_prompt = failure_case.get("retry_prompt", f"Trying again with {friendly_name}.").replace("<device_name>", friendly_name)
error_result = failure_case.get("error_result", "Error").replace("<device_name>", friendly_name)
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
tool_name = SERVICE_TO_TOOL_MAP[service_action]
first_args = {"entity_id": bad_device} if use_service_names else {"name": bad_device}
retry_args = {"entity_id": target_device} if use_service_names else {"name": target_device}
first_args.update(tool_args_extra)
retry_args.update(tool_args_extra)
tool_call_sequence = [
{
"answer_starting": response_starting,
"tool_calls": [{
"tool_name": tool_name,
"service_name": service_name,
"tool_args": first_args
}],
"tool_results": [{
"tool_name": service_name if use_service_names else tool_name,
"tool_result": error_result
}]
},
{
"answer_starting": retry_prompt,
"tool_calls": [{
"tool_name": tool_name,
"service_name": service_name,
"tool_args": retry_args
}]
}
]
first_turn = create_assistant_turn(
response_starting,
[{
"tool_name": tool_name,
"service_name": service_name,
"tool_args": first_args
}],
tool_results=[{
"tool_name": service_name if use_service_names else tool_name,
"tool_result": error_result
}],
train_on_turn=False
)
second_turn = create_assistant_turn(
retry_prompt,
[{
"tool_name": tool_name,
"service_name": service_name,
"tool_args": retry_args
}]
)
final_turn = create_assistant_turn(response_confirmed, [])
return {
"states": device_list,
"available_tools": available_tools,
"question": question,
"answers": [response_confirmed],
"tool_call_sequence": tool_call_sequence,
"tool_calls": []
"assistant_turns": [first_turn, second_turn, final_turn]
}
def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
@@ -645,41 +591,20 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma
response_text = refusal_case["response"].replace("<device_name>", friendly_name).lower()
question = refusal_case["phrase"].replace("<device_name>", friendly_name).lower()
assistant_turns = [create_assistant_turn(response_text, [])]
return {
"states": device_list,
"available_tools": available_tools,
"question": question,
"answers": [response_text],
"tool_calls": []
"assistant_turns": assistant_turns
}
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
piles = get_dataset_piles(language)
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
question = example["question"]
answers = " ".join(example["answers"])
answer_starting = example.get("answer_starting", "")
tool_call_sequence = example.get("tool_call_sequence")
tool_calls = []
tool_results = []
if not tool_call_sequence:
# Add tool use blocks if there are tool calls
if len(example["tool_calls"]) > 0:
for tool_call in example["tool_calls"]:
# Use service_name if in service mode, otherwise use tool_name
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
tool_calls.append({
"name": call_name,
"arguments": json.dumps(tool_call["tool_args"])
})
tool_results.append({
"tool_name": call_name,
"tool_call_id": f"call_{len(tool_results) + 1}",
"tool_result": "Success"
})
assistant_turns = example["assistant_turns"]
if append_user_instruction_prompt:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
@@ -703,95 +628,64 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
}
]
if tool_call_sequence:
call_id_counter = 1
for step in tool_call_sequence:
step_tool_calls = []
for tool_call in step.get("tool_calls", []):
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
step_tool_calls.append({
"name": call_name,
"arguments": json.dumps(tool_call["tool_args"])
})
assistant_block = {
"role": "assistant",
"content": [{ "type": "text", "text": step.get("answer_starting", "") }]
}
if step_tool_calls:
assistant_block["tool_calls"] = [ { "function": tc } for tc in step_tool_calls ]
conversation.append(assistant_block)
if step_tool_calls:
provided_results = step.get("tool_results")
step_tool_results = []
if provided_results:
for provided in provided_results:
step_tool_results.append({
"tool_name": provided.get("tool_name", step_tool_calls[0]["name"]),
"tool_call_id": provided.get("tool_call_id", f"call_{call_id_counter}"),
"tool_result": provided.get("tool_result", "Success")
})
call_id_counter += 1
else:
for call in step_tool_calls:
step_tool_results.append({
"tool_name": call["name"],
"tool_call_id": f"call_{call_id_counter}",
"tool_result": "Success"
})
call_id_counter += 1
if tool_response_format == "text":
conversation.append({
"role": "tool",
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
})
elif tool_response_format == "functiongemma":
conversation.append({
"role": "tool",
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
})
if step.get("post_tool_response"):
conversation.append({
"role": "assistant",
"content": [{ "type": "text", "text": step["post_tool_response"] }]
})
conversation.append({
call_id_counter = 1
for turn in assistant_turns:
answer_text = turn.get("answer", "")
assistant_block = {
"role": "assistant",
"content": [{ "type": "text", "text": answers }],
})
elif len(tool_calls) > 0:
assistant_starting_block = {
"role": "assistant",
"content": [{ "type": "text", "text": answer_starting }],
"tool_calls": [ { "function": tc } for tc in tool_calls ]
"content": [{ "type": "text", "text": answer_text }],
"train_on_turn": turn.get("train_on_turn", True),
}
if tool_response_format == "text":
tool_response_block = {
"role": "tool",
"content": [{ "type": "text", "text": json.dumps(result) } for result in tool_results]
}
elif tool_response_format == "functiongemma":
tool_response_block = {
"role": "tool",
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in tool_results]
}
assistant_confirmation_block = {
"role": "assistant",
"content": [{ "type": "text", "text": answers }],
}
conversation.extend([assistant_starting_block, tool_response_block, assistant_confirmation_block])
else:
conversation.extend([
{
"role": "assistant",
"content": [{ "type": "text", "text": answer_starting + answers }],
}
])
tool_call_sequence = turn.get("tool_call_sequence", [])
formatted_calls = []
call_names = []
for tool_call in tool_call_sequence:
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
call_names.append(call_name)
formatted_calls.append({
"name": call_name,
"arguments": json.dumps(tool_call["tool_args"])
})
if formatted_calls:
assistant_block["tool_calls"] = [{ "function": call } for call in formatted_calls]
conversation.append(assistant_block)
if formatted_calls:
provided_results = turn.get("tool_results") or []
step_tool_results = []
if provided_results:
for idx, provided in enumerate(provided_results):
result = dict(provided)
if "tool_name" not in result and call_names:
result["tool_name"] = call_names[min(idx, len(call_names) - 1)]
if "tool_call_id" not in result:
result["tool_call_id"] = f"call_{call_id_counter}"
call_id_counter += 1
step_tool_results.append(result)
else:
for call_name in call_names:
step_tool_results.append({
"tool_name": call_name,
"tool_call_id": f"call_{call_id_counter}",
"tool_result": "Success"
})
call_id_counter += 1
if tool_response_format == "text":
conversation.append({
"role": "tool",
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
})
elif tool_response_format == "functiongemma":
conversation.append({
"role": "tool",
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
})
return {
"messages": conversation,
@@ -812,8 +706,8 @@ def generate_sft_file(
static_factor: float,
template_factor: int,
status_request_factor: int,
failure_factor: float = 1,
refusal_factor: float = 1):
failure_factor: int,
refusal_factor: int):
random.seed(seed)
np.random.seed(seed)
piles = get_dataset_piles(language)
@@ -929,7 +823,7 @@ def main(args=None):
args = parser.parse_args(args=args)
if not args.sample and not args.train and not args.test and not args.merge:
if not args.sample and not args.train and not args.test:
parser.print_usage()
exit(-1)
@@ -950,20 +844,20 @@ def main(args=None):
suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample:
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1, refusal_factor=1, failure_factor=1)
if args.train:
if args.size == "small":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8, refusal_factor=3, failure_factor=1)
elif args.size == "medium":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12, refusal_factor=5, failure_factor=1)
elif args.size == "large":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15, refusal_factor=6, failure_factor=1)
elif args.size == "xl":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18, refusal_factor=8, failure_factor=2)
else:
raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test:
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2, refusal_factor=1, failure_factor=1)
if len(args.language) > 1:
if args.sample:
merge_languages("sample", args.language)

View File

@@ -3,13 +3,19 @@ import asyncio
import csv
import json
import random
import string
import aiohttp
from tqdm import tqdm
import os
from utils import get_dataset_piles
from devices import random_device_list
LLM_ENDPOINT = "https://ai.cloud.alexoconnell.net/v1/chat/completions"
cwd = os.path.dirname(os.path.abspath(__file__))
def get_hass_match_error_message(bad_device_name: str) -> str:
return f"<MatchFailedError result=MatchTargetsResult(is_match=False, no_match_reason=<MatchFailedReason.NAME: 1>, states=[], no_match_name=None, areas=[], floors=[]), constraints=MatchTargetsConstraints(name='{bad_device_name}', area_name=None, floor_name=None, domains=None, device_classes=None, features=None, states=None, assistant='conversation', allow_duplicate_names=False, single_target=False), preferences=MatchTargetsPreferences(area_id=None, floor_id=None)>"
class SyntheticDataGenerator:
def __init__(self, model_name: str, language: str, concurrency: int):
@@ -18,6 +24,130 @@ class SyntheticDataGenerator:
self.model_name = model_name
self.piles = get_dataset_piles(language)
self.synthetic_devices = {} # device_type -> list of {device_name, description}
self.failed_tool_fieldnames = [
"service_name",
"correct_device_name",
"correct_friendly_name",
"bad_device_name",
"phrase",
"error_result",
"retry_prompt"
]
self.refusal_fieldnames = [
"reason_type",
"service_name",
"device_name",
"friendly_name",
"desired_state",
"phrase",
"response"
]
async def _chat_completion(self, session, system_prompt: str, user_prompt: str, *, temperature: float = 0.8, max_tokens: int = 300, structured_response: dict = {}):
payload = {
"model": self.model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": temperature,
"max_tokens": max_tokens if not structured_response else None, # don't limit if structured (causes failed generations)
"response_format": structured_response if structured_response else None,
}
try:
async with session.post(LLM_ENDPOINT, json=payload) as response:
if response.status < 400:
data = await response.json()
return data['choices'][0]['message']['content'].strip()
print(f"LLM request failed with status {response.status}")
except Exception as exc:
print(f"Completion request failed: {exc}")
return None
@staticmethod
def _strip_code_fence(content: str) -> str:
if not content:
return ""
text = content.strip()
if text.startswith("```") and text.endswith("```"):
lines = [line for line in text.splitlines() if not line.strip().startswith("```")]
return "\n".join(lines).strip()
return text
def _parse_json_object(self, content: str):
if not content:
return None
cleaned = self._strip_code_fence(content)
start = cleaned.find("{")
end = cleaned.rfind("}")
if start == -1 or end == -1:
return None
snippet = cleaned[start:end+1]
try:
return json.loads(snippet)
except json.JSONDecodeError:
return None
@staticmethod
def _ensure_placeholder(text: str, friendly_name: str, placeholder: str) -> str:
if placeholder in text:
return text
lower_text = text.lower()
lower_name = friendly_name.lower()
idx = lower_text.find(lower_name)
if idx == -1:
return text
return text[:idx] + placeholder + text[idx + len(friendly_name):]
@staticmethod
def _describe_service(service_name: str, service_data: dict) -> str:
domain, action = service_name.split(".", 1)
description = f"{action.replace('_', ' ')} the {domain}"
if service_data:
description += f" with arguments {json.dumps(service_data, ensure_ascii=False)}"
return description
@staticmethod
def _desired_state_for_service(service_name: str) -> str:
action = service_name.split(".", 1)[1]
action_mapping = {
"turn_on": "on",
"turn_off": "off",
"open_cover": "open",
"close_cover": "closed",
"lock": "locked",
"unlock": "unlocked",
"media_play_pause": "playing",
}
if action in action_mapping:
return action_mapping[action]
service_mapping = {
"media_player.turn_on": "on",
"media_player.turn_off": "off",
"vacuum.start": "cleaning",
"vacuum.return_to_base": "docked",
"garage_door.open_cover": "open",
"garage_door.close_cover": "closed",
"blinds.open_cover": "open",
"blinds.close_cover": "closed",
}
return service_mapping.get(service_name, "")
@staticmethod
def _append_rows_to_csv(path: str, fieldnames, rows):
if not rows:
return
os.makedirs(os.path.dirname(path), exist_ok=True)
file_exists = os.path.exists(path)
if not file_exists:
with open(path, "w", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
with open(path, "a", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
for row in rows:
writer.writerow(row)
async def generate_device_names(self, session, device_type, count=10):
"""
@@ -144,6 +274,117 @@ class SyntheticDataGenerator:
print(f"Request failed: {e}")
return None
async def generate_failed_tool_call_entry(self, session, context):
service_name = context["service_name"]
friendly_name = context["friendly_name"]
correct_device = context["device_name"]
action_description = self._describe_service(service_name, context.get("service_data", {}))
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=10, avoid_device_names=[], language=self.language)
extra_phrase_instructions = ""
if service_name == "climate.set_temperature":
extra_phrase_instructions = "Always reference the target temperature using the literal placeholder <temp_f>."
system_prompt = "You create high-quality multilingual training data for a smart home assistant."
user_prompt = f"""
Generate content for a failed tool call in {self.language}.
Service: {service_name}
Action description: {action_description}
Device name: {correct_device} (friendly name: {friendly_name})
The assistant first misidentifies the entity using the "bad_device" name, receives an error, then retries with {correct_device}.
Output a single JSON object with the keys "phrase", "bad_device", "error_result", and "retry_prompt".
- phrase: A natural voice command template the user would say. {extra_phrase_instructions}
- bad_device: The incorrect version of '{correct_device}'. Ensure it is a truncation, mutation, transposition, or inclusion of an extra suffix/prefix. Avoid simple typos in words.
- retry_prompt: A short acknowledgement from the assistant that it will try the correct device ({correct_device}).
Here are potential device names to help generate a bad device name: {', '.join(device_list)}
Keep the tone conversational and stay entirely in {self.language}. Do not add explanations or code fences.
"""
raw = await self._chat_completion(
session,
system_prompt,
user_prompt,
temperature=0.7,
max_tokens=512,
structured_response={ "type": "json_object" }
)
parsed = self._parse_json_object(raw)
if not parsed:
print(f"Failed to parse JSON for failed tool call generation: {raw}")
return None
phrase = parsed.get("phrase", "").strip()
bad_device = parsed.get("bad_device", "").strip()
retry_prompt = parsed.get("retry_prompt", "").strip()
if not phrase or not bad_device or not retry_prompt:
return None
return {
"service_name": service_name,
"correct_device_name": correct_device,
"correct_friendly_name": friendly_name,
"bad_device_name": bad_device,
"phrase": phrase,
"error_result": get_hass_match_error_message(bad_device),
"retry_prompt": retry_prompt
}
async def generate_refusal_entry(self, session, context, reason_type: str, desired_state: str):
service_name = context["service_name"]
friendly_name = context["friendly_name"]
action_description = self._describe_service(service_name, context.get("service_data", {}))
device_suffix = context["device_name"].split(".", 1)[1]
if reason_type == "not_available":
reason_detail = "Explain that the assistant cannot locate that device in the home."
else:
reason_detail = f"Explain that the device is already {desired_state} and no change is needed."
system_prompt = "You write concise refusal-style responses for a multilingual smart home assistant."
user_prompt = f"""
Create a refusal example in {self.language} for the following request.
Service: {service_name}
Action description: {action_description}
Reason: {reason_detail}
Output a JSON object with:
- "phrase": the user's natural command template. Use the literal placeholder <device_name> anywhere the user mentions the device.
- "response": the assistant's refusal message in {self.language} describing the reason.
Keep both fields brief, conversational, and free of extra narration or code fences.
"""
raw = await self._chat_completion(
session,
system_prompt,
user_prompt,
temperature=0.7,
max_tokens=512,
structured_response={ "type": "json_object" }
)
parsed = self._parse_json_object(raw)
if not parsed:
print("Failed to parse JSON for refusal generation")
return None
phrase = parsed.get("phrase", "").strip()
response = parsed.get("response", "").strip()
if not phrase or not response:
return None
phrase = self._ensure_placeholder(phrase, friendly_name, "<device_name>")
if "<device_name>" not in phrase:
return None
return {
"reason_type": reason_type,
"service_name": service_name,
"device_name": device_suffix,
"friendly_name": friendly_name,
"desired_state": desired_state if reason_type == "already_state" else "",
"phrase": phrase,
"response": response
}
def sample_context(self, request_type: str):
"""
Creates a random scenario: device, service, and arguments.
@@ -245,14 +486,12 @@ class SyntheticDataGenerator:
}
raise ValueError(f"Unknown request type {request_type}")
async def run(self, num_actions: int, num_status_requests: int, num_devices: int, output_file, persona_name=None, persona_description=None):
async def run(self, num_actions: int, num_status_requests: int, num_devices: int,
persona_name=None, persona_description=None, num_failed_tool_calls: int = 0,
num_refusals: int = 0):
print(f"Starting generation...")
print(f"Language: {self.language}")
# Ensure output directory exists
if output_file:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
async with aiohttp.ClientSession() as session:
if num_devices > 0:
@@ -268,7 +507,7 @@ class SyntheticDataGenerator:
all_new_devices = [item for sublist in generated_lists if sublist for item in sublist]
if all_new_devices:
csv_path = f"data/piles/{self.language}/pile_of_device_names.csv"
csv_path = f"{cwd}/piles/{self.language}/pile_of_device_names.csv"
try:
with open(csv_path, "a", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["device_name", "description"])
@@ -280,7 +519,6 @@ class SyntheticDataGenerator:
if num_actions > 0 or num_status_requests > 0:
print(f"Generating {num_actions} actions and {num_status_requests} status requests...")
print(f"Output file: {output_file}")
tasks = {}
results = []
@@ -312,7 +550,7 @@ class SyntheticDataGenerator:
if entry["type"] == "action":
# Write to pile_of_specific_actions.csv
csv_path = f"data/piles/{self.language}/pile_of_specific_actions.csv"
csv_path = f"{cwd}/piles/{self.language}/pile_of_specific_actions.csv"
# Prepare row
# device_name in CSV is the suffix (e.g. 'kitchen' from 'light.kitchen')
@@ -383,7 +621,7 @@ class SyntheticDataGenerator:
# For now, just skip if we can't templatize.
pass
else:
csv_path = f"data/piles/{self.language}/pile_of_status_requests.csv"
csv_path = f"{cwd}/piles/{self.language}/pile_of_status_requests.csv"
# Columns: device_type,state,phrase,assistant_response
# We don't have assistant_response.
# We can generate a generic one?
@@ -398,6 +636,85 @@ class SyntheticDataGenerator:
pbar.close()
if num_failed_tool_calls > 0:
print(f"Generating {num_failed_tool_calls} failed tool call scenarios...")
failed_entries = []
tasks = {}
pbar_failed = tqdm(total=num_failed_tool_calls, desc="Failed tool calls")
while len(failed_entries) < num_failed_tool_calls:
while len(tasks) < self.concurrency and (len(failed_entries) + len(tasks)) < num_failed_tool_calls:
context = self.sample_context("action")
if not context:
continue
task = asyncio.create_task(self.generate_failed_tool_call_entry(session, context))
tasks[task] = None
if not tasks:
break
done, _ = await asyncio.wait(list(tasks.keys()), return_when=asyncio.FIRST_COMPLETED)
for task in done:
tasks.pop(task, None)
try:
entry = await task
except Exception as exc:
print(f"Failed tool call task error: {exc}")
entry = None
if entry:
failed_entries.append(entry)
pbar_failed.update(1)
pbar_failed.close()
failed_path = f"{cwd}/piles/{self.language}/pile_of_failed_tool_calls.csv"
self._append_rows_to_csv(failed_path, self.failed_tool_fieldnames, failed_entries)
if num_refusals > 0:
print(f"Generating {num_refusals} refusal scenarios...")
refusal_entries = []
tasks = {}
pbar_refusals = tqdm(total=num_refusals, desc="Refusals")
while len(refusal_entries) < num_refusals:
while len(tasks) < self.concurrency and (len(refusal_entries) + len(tasks)) < num_refusals:
context = self.sample_context("action")
if not context:
continue
reason_type = random.choice(["not_available", "already_state"])
desired_state = ""
if reason_type == "already_state":
desired_state = self._desired_state_for_service(context["service_name"])
if not desired_state:
reason_type = "not_available"
task = asyncio.create_task(self.generate_refusal_entry(
session,
context,
reason_type,
desired_state
))
tasks[task] = None
if not tasks:
break
done, _ = await asyncio.wait(list(tasks.keys()), return_when=asyncio.FIRST_COMPLETED)
for task in done:
tasks.pop(task, None)
try:
entry = await task
except Exception as exc:
print(f"Refusal generation error: {exc}")
entry = None
if entry:
refusal_entries.append(entry)
pbar_refusals.update(1)
pbar_refusals.close()
refusal_path = f"{cwd}/piles/{self.language}/pile_of_refusals.csv"
self._append_rows_to_csv(refusal_path, self.refusal_fieldnames, refusal_entries)
if persona_name and persona_description:
await self.generate_persona(session, persona_name, persona_description)
@@ -442,7 +759,7 @@ class SyntheticDataGenerator:
return
# 2. Get list of services to generate responses for
responses_csv_path = f"data/piles/{self.language}/pile_of_responses.csv"
responses_csv_path = f"{cwd}/piles/{self.language}/pile_of_responses.csv"
services = set()
try:
with open(responses_csv_path, "r", encoding='utf-8') as f:
@@ -527,7 +844,7 @@ class SyntheticDataGenerator:
# 4. Write to files
# Append system prompt
sys_prompts_path = f"data/piles/{self.language}/pile_of_system_prompts.csv"
sys_prompts_path = f"{cwd}/piles/{self.language}/pile_of_system_prompts.csv"
try:
with open(sys_prompts_path, "a", newline='', encoding='utf-8') as f:
writer = csv.writer(f)
@@ -555,13 +872,25 @@ if __name__ == "__main__":
parser.add_argument("--actions", type=int, default=0, help="Number of actions to generate")
parser.add_argument("--status", type=int, default=0, help="Number of status requests to generate")
parser.add_argument("--devices", type=int, default=0, help="Number of new devices to generate")
parser.add_argument("--concurrency", type=int, default=8, help="Number of concurrent requests")
parser.add_argument("--language", type=str, default="english", help="Language")
parser.add_argument("--failed-tool-calls", type=int, default=0, help="Number of failed tool call scenarios to generate")
parser.add_argument("--refusals", type=int, default=0, help="Number of refusal scenarios to generate")
parser.add_argument("--concurrency", type=int, default=4, help="Number of concurrent requests")
parser.add_argument("--languages", type=str, nargs='+', default=["english"], help="Languages to generate data for")
parser.add_argument("--model", type=str, default="gpt-oss-120b", help="LLM model to use")
parser.add_argument("--persona-name", type=str, help="Name of the new persona to generate")
parser.add_argument("--persona-description", type=str, help="Description of the new persona")
args = parser.parse_args()
generator = SyntheticDataGenerator(model_name=args.model, language=args.language, concurrency=args.concurrency)
asyncio.run(generator.run(num_actions=args.actions, num_status_requests=args.status, num_devices=args.devices, output_file="", persona_name=args.persona_name, persona_description=args.persona_description))
for language in args.languages:
print(f"=== Generating data for language: {language} ===")
generator = SyntheticDataGenerator(model_name=args.model, language=language, concurrency=args.concurrency)
asyncio.run(generator.run(
num_actions=args.actions,
num_status_requests=args.status,
num_devices=args.devices,
persona_name=args.persona_name,
persona_description=args.persona_description,
num_failed_tool_calls=args.failed_tool_calls,
num_refusals=args.refusals
))

View File

@@ -37,6 +37,7 @@ SERVICE_TO_TOOL_MAP = {
"unlock": TOOL_TURN_OFF,
"increase_speed": TOOL_TURN_ON,
"decrease_speed": TOOL_TURN_OFF,
"media_stop": TOOL_TURN_OFF,
"media_play_pause": TOOL_TOGGLE,
"media_pause": TOOL_MEDIA_PAUSE,
"media_play": TOOL_MEDIA_UNPAUSE,
@@ -49,6 +50,13 @@ SERVICE_TO_TOOL_MAP = {
"set_hvac_mode": TOOL_CLIMATE_SET_TEMPERATURE,
"set_fan_mode": TOOL_CLIMATE_SET_TEMPERATURE,
"set_preset_mode": TOOL_CLIMATE_SET_TEMPERATURE,
"cancel": TOOL_CANCEL_TIMER,
"volume_down": TOOL_SET_VOLUME,
"volume_up": TOOL_SET_VOLUME,
"volume_mute": TOOL_SET_VOLUME,
"stop": TOOL_TURN_OFF,
"pause": TOOL_TURN_OFF,
"add_item": TOOL_LIST_ADD_ITEM
}
# Home Assistant Intent Tools Definition