more synthesizing scenarios + clean up example formatting

This commit is contained in:
Alex O'Connell
2025-12-21 13:31:43 -05:00
parent ecf9586b5a
commit 4407aefdf5
5 changed files with 584 additions and 253 deletions

View File

@@ -67,6 +67,26 @@ The response pile is a CSV with the following headers: `service,response,languag
Generating the full dataset using the python script will print out a warning for any responses that are missing for a persona. Generating the full dataset using the python script will print out a warning for any responses that are missing for a persona.
## Synthesizing new pile data
You can quickly append fresh examples to the CSV piles without editing them manually by running `synthesize.py`. The script talks to the configured LLM and writes the generated rows directly into the per-language pile files.
Examples:
```bash
# Append 25 failed tool-call recoveries and 25 refusals in Spanish
python3 synthesize.py --language spanish --model gpt-oss-120b --failed-tool-calls 25 --refusals 25 --concurrency 6
# Generate new actions plus matching refusal samples in German
python3 synthesize.py --language german --actions 100 --refusals 40 --model gpt-oss-120b
```
Useful flags:
- `--failed-tool-calls`: number of `pile_of_failed_tool_calls.csv` rows to synthesize.
- `--refusals`: number of `pile_of_refusals.csv` rows to synthesize.
- `--actions`, `--status`, `--devices`: existing knobs for the other piles.
The script automatically routes generations to the correct language-specific pile under `data/piles/<language>/`.
## Adding new Home Assistant functionality ## Adding new Home Assistant functionality
TODO TODO
<!-- In order to add new home assistant device types, you will need to add data to a handful of piles, as well as make small modifications to the `generate_data.py` script. <!-- In order to add new home assistant device types, you will need to add data to a handful of piles, as well as make small modifications to the `generate_data.py` script.

View File

@@ -1,6 +1,7 @@
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Final, Callable, List from typing import Final, Callable, List
from difflib import SequenceMatcher
from tools import * from tools import *
from utils import closest_color, generate_random_parameter, get_dataset_piles from utils import closest_color, generate_random_parameter, get_dataset_piles
@@ -31,6 +32,9 @@ STATE_CLEANING: Final = "cleaning"
STATE_DOCKED: Final = "docked" STATE_DOCKED: Final = "docked"
STATE_RETURNING: Final = "returning" STATE_RETURNING: Final = "returning"
def format_device_line(*, device_name: str, friendly_name: str, state: str):
return (f"{device_name} '{friendly_name}' = {state}")
@dataclass @dataclass
class DeviceType: class DeviceType:
name: str name: str
@@ -222,3 +226,79 @@ class MediaPlayerDeviceType(DeviceType):
if "volume_level" in extra_exposed_attributes: if "volume_level" in extra_exposed_attributes:
tools.append(TOOL_SET_VOLUME) tools.append(TOOL_SET_VOLUME)
return tools return tools
SUPPORTED_DEVICES = {
"light": LightDeviceType(),
"switch": SwitchDeviceType(),
"fan": FanDeviceType(),
"garage_door": GarageDoorDeviceType(),
"blinds": BlindsDeviceType(),
"lock": LockDeviceType(),
"media_player": MediaPlayerDeviceType(),
"climate": ClimateDeviceType(),
"vacuum": VacuumDeviceType(),
"timer": TimerDeviceType(),
"todo": TodoDeviceType(),
}
# generate a random list of devices for the context
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
num_devices = random.randint(2, max_devices)
piles = get_dataset_piles(language)
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
avoid_climate = False
for avoid_device in avoid_device_names:
avoid_type = avoid_device.split(".")[0]
filtered_possible_devices = []
for possible_device in local_device_names[avoid_type]:
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
if similarity_ratio < 0.4:
filtered_possible_devices.append(possible_device)
local_device_names[avoid_type] = filtered_possible_devices
if avoid_type == "climate":
avoid_climate = True
possible_choices = []
for device_type in local_device_names.keys():
possible_choices.extend(local_device_names[device_type])
device_types = set()
device_list = []
device_lines = []
# TODO: randomly pick attributes for this list
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
while len(device_list) < num_devices:
choice = random.choice(possible_choices)
if choice["device_name"] in device_list:
continue
try:
device_name = choice["device_name"]
device_type = device_name.split(".")[0]
friendly_name = choice["description"]
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
if avoid_climate and device_type == "climate":
continue
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
device_lines.append(format_device_line(
device_name=device_name,
friendly_name=friendly_name,
state=state
))
device_list.append(device_name)
device_types.add(device_type)
except Exception as ex:
print(f"bad device name: {choice}")
print(repr(ex))
return device_lines, list(device_types), list(extra_exposed_attributes)

View File

@@ -3,7 +3,6 @@ import json
import numpy as np import numpy as np
import random import random
from datasets import load_dataset, concatenate_datasets from datasets import load_dataset, concatenate_datasets
from difflib import SequenceMatcher
from typing import Callable from typing import Callable
from tqdm import tqdm from tqdm import tqdm
import webcolors import webcolors
@@ -13,87 +12,13 @@ import os
import sys import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from device_types import * from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \
TOOL_TURN_ON, TOOL_CLIMATE_SET_TEMPERATURE, TOOL_SET_HUMIDITY, \
TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \
HASS_TOOLS, SERVICE_TOOLS
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException from utils import get_random_response, generate_random_parameter, closest_color, \
get_dataset_piles, NoResponseAvailableException
SUPPORTED_DEVICES = {
"light": LightDeviceType(),
"switch": SwitchDeviceType(),
"fan": FanDeviceType(),
"garage_door": GarageDoorDeviceType(),
"blinds": BlindsDeviceType(),
"lock": LockDeviceType(),
"media_player": MediaPlayerDeviceType(),
"climate": ClimateDeviceType(),
"vacuum": VacuumDeviceType(),
"timer": TimerDeviceType(),
"todo": TodoDeviceType(),
}
def format_device_line(*, device_name: str, friendly_name: str, state: str):
return (f"{device_name} '{friendly_name}' = {state}")
# generate a random list of devices for the context
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
num_devices = random.randint(2, max_devices)
piles = get_dataset_piles(language)
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
avoid_climate = False
for avoid_device in avoid_device_names:
avoid_type = avoid_device.split(".")[0]
filtered_possible_devices = []
for possible_device in local_device_names[avoid_type]:
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
if similarity_ratio < 0.4:
filtered_possible_devices.append(possible_device)
local_device_names[avoid_type] = filtered_possible_devices
if avoid_type == "climate":
avoid_climate = True
possible_choices = []
for device_type in local_device_names.keys():
possible_choices.extend(local_device_names[device_type])
device_types = set()
device_list = []
device_lines = []
# TODO: randomly pick attributes for this list
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
while len(device_list) < num_devices:
choice = random.choice(possible_choices)
if choice["device_name"] in device_list:
continue
try:
device_name = choice["device_name"]
device_type = device_name.split(".")[0]
friendly_name = choice["description"]
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
if avoid_climate and device_type == "climate":
continue
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
device_lines.append(format_device_line(
device_name=device_name,
friendly_name=friendly_name,
state=state
))
device_list.append(device_name)
device_types.add(device_type)
except Exception as ex:
print(f"bad device name: {choice}")
print(repr(ex))
return device_lines, list(device_types), list(extra_exposed_attributes)
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
question = action["phrase"] question = action["phrase"]
@@ -126,7 +51,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
# Map service name to tool name # Map service name to tool name
service_action = service_name.split(".")[1] service_action = service_name.split(".")[1]
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON) tool_name = SERVICE_TO_TOOL_MAP[service_action]
response_starting, response_confirmed = get_random_response( response_starting, response_confirmed = get_random_response(
piles.pile_of_responses, piles.pile_of_responses,
@@ -225,13 +150,17 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
except Exception as e: except Exception as e:
print(f"Failed to parse arguments for {action}: {e}") print(f"Failed to parse arguments for {action}: {e}")
final_answer = " ".join(answer_list)
assistant_turns = [
create_assistant_turn(response_starting, [tool_call]),
create_assistant_turn(final_answer, [])
]
return { return {
"states": device_list, "states": device_list,
"available_tools": available_tools, "available_tools": available_tools,
"question": question.lower(), "question": question.lower(),
"answers": answer_list, "assistant_turns": assistant_turns
"answer_starting": response_starting,
"tool_calls": [ tool_call ]
} }
def replace_answer(list_of_answer, var, value): def replace_answer(list_of_answer, var, value):
@@ -240,6 +169,16 @@ def replace_answer(list_of_answer, var, value):
new_list.append(answer.replace(var, value)) new_list.append(answer.replace(var, value))
return new_list return new_list
def create_assistant_turn(answer: str, tool_call_sequence=None, *, tool_results=None, train_on_turn: bool = True):
"""Bundle the assistant utterance with any tool interaction for that turn."""
return {
"answer": answer,
"tool_call_sequence": tool_call_sequence or [],
"tool_results": tool_results if tool_results is not None else [],
"train_on_turn": train_on_turn,
}
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
template_device_types: list[str] = template["device_type"].split("|") template_device_types: list[str] = template["device_type"].split("|")
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ] service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
@@ -328,7 +267,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
tool_calls = [] tool_calls = []
for device_dict, service in zip(chosen_devices, service_names): for device_dict, service in zip(chosen_devices, service_names):
service_action = service.split(".")[1] service_action = service.split(".")[1]
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON) tool_name = SERVICE_TO_TOOL_MAP[service_action]
tool_call = { tool_call = {
"tool_name": tool_name, "tool_name": tool_name,
"service_name": service, "service_name": service,
@@ -415,13 +354,19 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
if call["tool_name"] == TOOL_LIST_ADD_ITEM: if call["tool_name"] == TOOL_LIST_ADD_ITEM:
call["tool_args"]["item"] = todo call["tool_args"]["item"] = todo
starting_answer = answer_starting.strip().lower()
normalized_answers = [ sentence.lower() for sentence in answer_list ]
final_answer = " ".join(normalized_answers)
assistant_turns = [
create_assistant_turn(starting_answer, tool_calls),
create_assistant_turn(final_answer, [])
]
return { return {
"states": device_list, "states": device_list,
"available_tools": available_tools, "available_tools": available_tools,
"question": question.lower(), "question": question.lower(),
"answer_starting": answer_starting.lower(), "assistant_turns": assistant_turns
"answers": [ sentence.lower() for sentence in answer_list ],
"tool_calls": tool_calls
} }
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False): def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
@@ -509,12 +454,13 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
# Remove duplicates while preserving order # Remove duplicates while preserving order
available_tools = list(dict.fromkeys(available_tools)) available_tools = list(dict.fromkeys(available_tools))
assistant_turns = [create_assistant_turn(answer.lower(), [])]
result = { result = {
"states": device_list, "states": device_list,
"available_tools": available_tools, "available_tools": available_tools,
"question": question.lower(), "question": question.lower(),
"answers": [ answer.lower() ], "assistant_turns": assistant_turns
"tool_calls": []
} }
if return_target_device: if return_target_device:
return result, chosen_device return result, chosen_device
@@ -578,42 +524,42 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
retry_prompt = failure_case.get("retry_prompt", f"Trying again with {friendly_name}.").replace("<device_name>", friendly_name) retry_prompt = failure_case.get("retry_prompt", f"Trying again with {friendly_name}.").replace("<device_name>", friendly_name)
error_result = failure_case.get("error_result", "Error").replace("<device_name>", friendly_name) error_result = failure_case.get("error_result", "Error").replace("<device_name>", friendly_name)
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON) tool_name = SERVICE_TO_TOOL_MAP[service_action]
first_args = {"entity_id": bad_device} if use_service_names else {"name": bad_device} first_args = {"entity_id": bad_device} if use_service_names else {"name": bad_device}
retry_args = {"entity_id": target_device} if use_service_names else {"name": target_device} retry_args = {"entity_id": target_device} if use_service_names else {"name": target_device}
first_args.update(tool_args_extra) first_args.update(tool_args_extra)
retry_args.update(tool_args_extra) retry_args.update(tool_args_extra)
tool_call_sequence = [ first_turn = create_assistant_turn(
{ response_starting,
"answer_starting": response_starting, [{
"tool_calls": [{ "tool_name": tool_name,
"tool_name": tool_name, "service_name": service_name,
"service_name": service_name, "tool_args": first_args
"tool_args": first_args }],
}], tool_results=[{
"tool_results": [{ "tool_name": service_name if use_service_names else tool_name,
"tool_name": service_name if use_service_names else tool_name, "tool_result": error_result
"tool_result": error_result }],
}] train_on_turn=False
}, )
{
"answer_starting": retry_prompt, second_turn = create_assistant_turn(
"tool_calls": [{ retry_prompt,
"tool_name": tool_name, [{
"service_name": service_name, "tool_name": tool_name,
"tool_args": retry_args "service_name": service_name,
}] "tool_args": retry_args
} }]
] )
final_turn = create_assistant_turn(response_confirmed, [])
return { return {
"states": device_list, "states": device_list,
"available_tools": available_tools, "available_tools": available_tools,
"question": question, "question": question,
"answers": [response_confirmed], "assistant_turns": [first_turn, second_turn, final_turn]
"tool_call_sequence": tool_call_sequence,
"tool_calls": []
} }
def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
@@ -645,41 +591,20 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma
response_text = refusal_case["response"].replace("<device_name>", friendly_name).lower() response_text = refusal_case["response"].replace("<device_name>", friendly_name).lower()
question = refusal_case["phrase"].replace("<device_name>", friendly_name).lower() question = refusal_case["phrase"].replace("<device_name>", friendly_name).lower()
assistant_turns = [create_assistant_turn(response_text, [])]
return { return {
"states": device_list, "states": device_list,
"available_tools": available_tools, "available_tools": available_tools,
"question": question, "question": question,
"answers": [response_text], "assistant_turns": assistant_turns
"tool_calls": []
} }
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format): def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
piles = get_dataset_piles(language) piles = get_dataset_piles(language)
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts) sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
question = example["question"] question = example["question"]
answers = " ".join(example["answers"]) assistant_turns = example["assistant_turns"]
answer_starting = example.get("answer_starting", "")
tool_call_sequence = example.get("tool_call_sequence")
tool_calls = []
tool_results = []
if not tool_call_sequence:
# Add tool use blocks if there are tool calls
if len(example["tool_calls"]) > 0:
for tool_call in example["tool_calls"]:
# Use service_name if in service mode, otherwise use tool_name
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
tool_calls.append({
"name": call_name,
"arguments": json.dumps(tool_call["tool_args"])
})
tool_results.append({
"tool_name": call_name,
"tool_call_id": f"call_{len(tool_results) + 1}",
"tool_result": "Success"
})
if append_user_instruction_prompt: if append_user_instruction_prompt:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":" user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
@@ -703,95 +628,64 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }] "content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
} }
] ]
if tool_call_sequence: call_id_counter = 1
call_id_counter = 1 for turn in assistant_turns:
for step in tool_call_sequence: answer_text = turn.get("answer", "")
step_tool_calls = [] assistant_block = {
for tool_call in step.get("tool_calls", []):
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
step_tool_calls.append({
"name": call_name,
"arguments": json.dumps(tool_call["tool_args"])
})
assistant_block = {
"role": "assistant",
"content": [{ "type": "text", "text": step.get("answer_starting", "") }]
}
if step_tool_calls:
assistant_block["tool_calls"] = [ { "function": tc } for tc in step_tool_calls ]
conversation.append(assistant_block)
if step_tool_calls:
provided_results = step.get("tool_results")
step_tool_results = []
if provided_results:
for provided in provided_results:
step_tool_results.append({
"tool_name": provided.get("tool_name", step_tool_calls[0]["name"]),
"tool_call_id": provided.get("tool_call_id", f"call_{call_id_counter}"),
"tool_result": provided.get("tool_result", "Success")
})
call_id_counter += 1
else:
for call in step_tool_calls:
step_tool_results.append({
"tool_name": call["name"],
"tool_call_id": f"call_{call_id_counter}",
"tool_result": "Success"
})
call_id_counter += 1
if tool_response_format == "text":
conversation.append({
"role": "tool",
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
})
elif tool_response_format == "functiongemma":
conversation.append({
"role": "tool",
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
})
if step.get("post_tool_response"):
conversation.append({
"role": "assistant",
"content": [{ "type": "text", "text": step["post_tool_response"] }]
})
conversation.append({
"role": "assistant", "role": "assistant",
"content": [{ "type": "text", "text": answers }], "content": [{ "type": "text", "text": answer_text }],
}) "train_on_turn": turn.get("train_on_turn", True),
elif len(tool_calls) > 0:
assistant_starting_block = {
"role": "assistant",
"content": [{ "type": "text", "text": answer_starting }],
"tool_calls": [ { "function": tc } for tc in tool_calls ]
} }
if tool_response_format == "text":
tool_response_block = { tool_call_sequence = turn.get("tool_call_sequence", [])
"role": "tool", formatted_calls = []
"content": [{ "type": "text", "text": json.dumps(result) } for result in tool_results] call_names = []
} for tool_call in tool_call_sequence:
elif tool_response_format == "functiongemma": call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
tool_response_block = { call_names.append(call_name)
"role": "tool", formatted_calls.append({
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in tool_results] "name": call_name,
} "arguments": json.dumps(tool_call["tool_args"])
assistant_confirmation_block = { })
"role": "assistant",
"content": [{ "type": "text", "text": answers }], if formatted_calls:
} assistant_block["tool_calls"] = [{ "function": call } for call in formatted_calls]
conversation.extend([assistant_starting_block, tool_response_block, assistant_confirmation_block])
else: conversation.append(assistant_block)
conversation.extend([
{ if formatted_calls:
"role": "assistant", provided_results = turn.get("tool_results") or []
"content": [{ "type": "text", "text": answer_starting + answers }], step_tool_results = []
}
]) if provided_results:
for idx, provided in enumerate(provided_results):
result = dict(provided)
if "tool_name" not in result and call_names:
result["tool_name"] = call_names[min(idx, len(call_names) - 1)]
if "tool_call_id" not in result:
result["tool_call_id"] = f"call_{call_id_counter}"
call_id_counter += 1
step_tool_results.append(result)
else:
for call_name in call_names:
step_tool_results.append({
"tool_name": call_name,
"tool_call_id": f"call_{call_id_counter}",
"tool_result": "Success"
})
call_id_counter += 1
if tool_response_format == "text":
conversation.append({
"role": "tool",
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
})
elif tool_response_format == "functiongemma":
conversation.append({
"role": "tool",
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
})
return { return {
"messages": conversation, "messages": conversation,
@@ -812,8 +706,8 @@ def generate_sft_file(
static_factor: float, static_factor: float,
template_factor: int, template_factor: int,
status_request_factor: int, status_request_factor: int,
failure_factor: float = 1, failure_factor: int,
refusal_factor: float = 1): refusal_factor: int):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
piles = get_dataset_piles(language) piles = get_dataset_piles(language)
@@ -929,7 +823,7 @@ def main(args=None):
args = parser.parse_args(args=args) args = parser.parse_args(args=args)
if not args.sample and not args.train and not args.test and not args.merge: if not args.sample and not args.train and not args.test:
parser.print_usage() parser.print_usage()
exit(-1) exit(-1)
@@ -950,20 +844,20 @@ def main(args=None):
suffix = f"_{language}" if len(args.language) > 1 else "" suffix = f"_{language}" if len(args.language) > 1 else ""
if args.sample: if args.sample:
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1) generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1, refusal_factor=1, failure_factor=1)
if args.train: if args.train:
if args.size == "small": if args.size == "small":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8) generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8, refusal_factor=3, failure_factor=1)
elif args.size == "medium": elif args.size == "medium":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12) generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12, refusal_factor=5, failure_factor=1)
elif args.size == "large": elif args.size == "large":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15) generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15, refusal_factor=6, failure_factor=1)
elif args.size == "xl": elif args.size == "xl":
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18) generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18, refusal_factor=8, failure_factor=2)
else: else:
raise Exception(f"Unrecognized dataset size: {args.size}") raise Exception(f"Unrecognized dataset size: {args.size}")
if args.test: if args.test:
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2) generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2, refusal_factor=1, failure_factor=1)
if len(args.language) > 1: if len(args.language) > 1:
if args.sample: if args.sample:
merge_languages("sample", args.language) merge_languages("sample", args.language)

View File

@@ -3,13 +3,19 @@ import asyncio
import csv import csv
import json import json
import random import random
import string
import aiohttp import aiohttp
from tqdm import tqdm from tqdm import tqdm
import os import os
from utils import get_dataset_piles from utils import get_dataset_piles
from devices import random_device_list
LLM_ENDPOINT = "https://ai.cloud.alexoconnell.net/v1/chat/completions" LLM_ENDPOINT = "https://ai.cloud.alexoconnell.net/v1/chat/completions"
cwd = os.path.dirname(os.path.abspath(__file__))
def get_hass_match_error_message(bad_device_name: str) -> str:
return f"<MatchFailedError result=MatchTargetsResult(is_match=False, no_match_reason=<MatchFailedReason.NAME: 1>, states=[], no_match_name=None, areas=[], floors=[]), constraints=MatchTargetsConstraints(name='{bad_device_name}', area_name=None, floor_name=None, domains=None, device_classes=None, features=None, states=None, assistant='conversation', allow_duplicate_names=False, single_target=False), preferences=MatchTargetsPreferences(area_id=None, floor_id=None)>"
class SyntheticDataGenerator: class SyntheticDataGenerator:
def __init__(self, model_name: str, language: str, concurrency: int): def __init__(self, model_name: str, language: str, concurrency: int):
@@ -18,6 +24,130 @@ class SyntheticDataGenerator:
self.model_name = model_name self.model_name = model_name
self.piles = get_dataset_piles(language) self.piles = get_dataset_piles(language)
self.synthetic_devices = {} # device_type -> list of {device_name, description} self.synthetic_devices = {} # device_type -> list of {device_name, description}
self.failed_tool_fieldnames = [
"service_name",
"correct_device_name",
"correct_friendly_name",
"bad_device_name",
"phrase",
"error_result",
"retry_prompt"
]
self.refusal_fieldnames = [
"reason_type",
"service_name",
"device_name",
"friendly_name",
"desired_state",
"phrase",
"response"
]
async def _chat_completion(self, session, system_prompt: str, user_prompt: str, *, temperature: float = 0.8, max_tokens: int = 300, structured_response: dict = {}):
payload = {
"model": self.model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": temperature,
"max_tokens": max_tokens if not structured_response else None, # don't limit if structured (causes failed generations)
"response_format": structured_response if structured_response else None,
}
try:
async with session.post(LLM_ENDPOINT, json=payload) as response:
if response.status < 400:
data = await response.json()
return data['choices'][0]['message']['content'].strip()
print(f"LLM request failed with status {response.status}")
except Exception as exc:
print(f"Completion request failed: {exc}")
return None
@staticmethod
def _strip_code_fence(content: str) -> str:
if not content:
return ""
text = content.strip()
if text.startswith("```") and text.endswith("```"):
lines = [line for line in text.splitlines() if not line.strip().startswith("```")]
return "\n".join(lines).strip()
return text
def _parse_json_object(self, content: str):
if not content:
return None
cleaned = self._strip_code_fence(content)
start = cleaned.find("{")
end = cleaned.rfind("}")
if start == -1 or end == -1:
return None
snippet = cleaned[start:end+1]
try:
return json.loads(snippet)
except json.JSONDecodeError:
return None
@staticmethod
def _ensure_placeholder(text: str, friendly_name: str, placeholder: str) -> str:
if placeholder in text:
return text
lower_text = text.lower()
lower_name = friendly_name.lower()
idx = lower_text.find(lower_name)
if idx == -1:
return text
return text[:idx] + placeholder + text[idx + len(friendly_name):]
@staticmethod
def _describe_service(service_name: str, service_data: dict) -> str:
domain, action = service_name.split(".", 1)
description = f"{action.replace('_', ' ')} the {domain}"
if service_data:
description += f" with arguments {json.dumps(service_data, ensure_ascii=False)}"
return description
@staticmethod
def _desired_state_for_service(service_name: str) -> str:
action = service_name.split(".", 1)[1]
action_mapping = {
"turn_on": "on",
"turn_off": "off",
"open_cover": "open",
"close_cover": "closed",
"lock": "locked",
"unlock": "unlocked",
"media_play_pause": "playing",
}
if action in action_mapping:
return action_mapping[action]
service_mapping = {
"media_player.turn_on": "on",
"media_player.turn_off": "off",
"vacuum.start": "cleaning",
"vacuum.return_to_base": "docked",
"garage_door.open_cover": "open",
"garage_door.close_cover": "closed",
"blinds.open_cover": "open",
"blinds.close_cover": "closed",
}
return service_mapping.get(service_name, "")
@staticmethod
def _append_rows_to_csv(path: str, fieldnames, rows):
if not rows:
return
os.makedirs(os.path.dirname(path), exist_ok=True)
file_exists = os.path.exists(path)
if not file_exists:
with open(path, "w", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
with open(path, "a", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
for row in rows:
writer.writerow(row)
async def generate_device_names(self, session, device_type, count=10): async def generate_device_names(self, session, device_type, count=10):
""" """
@@ -144,6 +274,117 @@ class SyntheticDataGenerator:
print(f"Request failed: {e}") print(f"Request failed: {e}")
return None return None
async def generate_failed_tool_call_entry(self, session, context):
service_name = context["service_name"]
friendly_name = context["friendly_name"]
correct_device = context["device_name"]
action_description = self._describe_service(service_name, context.get("service_data", {}))
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=10, avoid_device_names=[], language=self.language)
extra_phrase_instructions = ""
if service_name == "climate.set_temperature":
extra_phrase_instructions = "Always reference the target temperature using the literal placeholder <temp_f>."
system_prompt = "You create high-quality multilingual training data for a smart home assistant."
user_prompt = f"""
Generate content for a failed tool call in {self.language}.
Service: {service_name}
Action description: {action_description}
Device name: {correct_device} (friendly name: {friendly_name})
The assistant first misidentifies the entity using the "bad_device" name, receives an error, then retries with {correct_device}.
Output a single JSON object with the keys "phrase", "bad_device", "error_result", and "retry_prompt".
- phrase: A natural voice command template the user would say. {extra_phrase_instructions}
- bad_device: The incorrect version of '{correct_device}'. Ensure it is a truncation, mutation, transposition, or inclusion of an extra suffix/prefix. Avoid simple typos in words.
- retry_prompt: A short acknowledgement from the assistant that it will try the correct device ({correct_device}).
Here are potential device names to help generate a bad device name: {', '.join(device_list)}
Keep the tone conversational and stay entirely in {self.language}. Do not add explanations or code fences.
"""
raw = await self._chat_completion(
session,
system_prompt,
user_prompt,
temperature=0.7,
max_tokens=512,
structured_response={ "type": "json_object" }
)
parsed = self._parse_json_object(raw)
if not parsed:
print(f"Failed to parse JSON for failed tool call generation: {raw}")
return None
phrase = parsed.get("phrase", "").strip()
bad_device = parsed.get("bad_device", "").strip()
retry_prompt = parsed.get("retry_prompt", "").strip()
if not phrase or not bad_device or not retry_prompt:
return None
return {
"service_name": service_name,
"correct_device_name": correct_device,
"correct_friendly_name": friendly_name,
"bad_device_name": bad_device,
"phrase": phrase,
"error_result": get_hass_match_error_message(bad_device),
"retry_prompt": retry_prompt
}
async def generate_refusal_entry(self, session, context, reason_type: str, desired_state: str):
service_name = context["service_name"]
friendly_name = context["friendly_name"]
action_description = self._describe_service(service_name, context.get("service_data", {}))
device_suffix = context["device_name"].split(".", 1)[1]
if reason_type == "not_available":
reason_detail = "Explain that the assistant cannot locate that device in the home."
else:
reason_detail = f"Explain that the device is already {desired_state} and no change is needed."
system_prompt = "You write concise refusal-style responses for a multilingual smart home assistant."
user_prompt = f"""
Create a refusal example in {self.language} for the following request.
Service: {service_name}
Action description: {action_description}
Reason: {reason_detail}
Output a JSON object with:
- "phrase": the user's natural command template. Use the literal placeholder <device_name> anywhere the user mentions the device.
- "response": the assistant's refusal message in {self.language} describing the reason.
Keep both fields brief, conversational, and free of extra narration or code fences.
"""
raw = await self._chat_completion(
session,
system_prompt,
user_prompt,
temperature=0.7,
max_tokens=512,
structured_response={ "type": "json_object" }
)
parsed = self._parse_json_object(raw)
if not parsed:
print("Failed to parse JSON for refusal generation")
return None
phrase = parsed.get("phrase", "").strip()
response = parsed.get("response", "").strip()
if not phrase or not response:
return None
phrase = self._ensure_placeholder(phrase, friendly_name, "<device_name>")
if "<device_name>" not in phrase:
return None
return {
"reason_type": reason_type,
"service_name": service_name,
"device_name": device_suffix,
"friendly_name": friendly_name,
"desired_state": desired_state if reason_type == "already_state" else "",
"phrase": phrase,
"response": response
}
def sample_context(self, request_type: str): def sample_context(self, request_type: str):
""" """
Creates a random scenario: device, service, and arguments. Creates a random scenario: device, service, and arguments.
@@ -245,14 +486,12 @@ class SyntheticDataGenerator:
} }
raise ValueError(f"Unknown request type {request_type}") raise ValueError(f"Unknown request type {request_type}")
async def run(self, num_actions: int, num_status_requests: int, num_devices: int, output_file, persona_name=None, persona_description=None): async def run(self, num_actions: int, num_status_requests: int, num_devices: int,
persona_name=None, persona_description=None, num_failed_tool_calls: int = 0,
num_refusals: int = 0):
print(f"Starting generation...") print(f"Starting generation...")
print(f"Language: {self.language}") print(f"Language: {self.language}")
# Ensure output directory exists
if output_file:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
if num_devices > 0: if num_devices > 0:
@@ -268,7 +507,7 @@ class SyntheticDataGenerator:
all_new_devices = [item for sublist in generated_lists if sublist for item in sublist] all_new_devices = [item for sublist in generated_lists if sublist for item in sublist]
if all_new_devices: if all_new_devices:
csv_path = f"data/piles/{self.language}/pile_of_device_names.csv" csv_path = f"{cwd}/piles/{self.language}/pile_of_device_names.csv"
try: try:
with open(csv_path, "a", newline='', encoding='utf-8') as f: with open(csv_path, "a", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["device_name", "description"]) writer = csv.DictWriter(f, fieldnames=["device_name", "description"])
@@ -280,7 +519,6 @@ class SyntheticDataGenerator:
if num_actions > 0 or num_status_requests > 0: if num_actions > 0 or num_status_requests > 0:
print(f"Generating {num_actions} actions and {num_status_requests} status requests...") print(f"Generating {num_actions} actions and {num_status_requests} status requests...")
print(f"Output file: {output_file}")
tasks = {} tasks = {}
results = [] results = []
@@ -312,7 +550,7 @@ class SyntheticDataGenerator:
if entry["type"] == "action": if entry["type"] == "action":
# Write to pile_of_specific_actions.csv # Write to pile_of_specific_actions.csv
csv_path = f"data/piles/{self.language}/pile_of_specific_actions.csv" csv_path = f"{cwd}/piles/{self.language}/pile_of_specific_actions.csv"
# Prepare row # Prepare row
# device_name in CSV is the suffix (e.g. 'kitchen' from 'light.kitchen') # device_name in CSV is the suffix (e.g. 'kitchen' from 'light.kitchen')
@@ -383,7 +621,7 @@ class SyntheticDataGenerator:
# For now, just skip if we can't templatize. # For now, just skip if we can't templatize.
pass pass
else: else:
csv_path = f"data/piles/{self.language}/pile_of_status_requests.csv" csv_path = f"{cwd}/piles/{self.language}/pile_of_status_requests.csv"
# Columns: device_type,state,phrase,assistant_response # Columns: device_type,state,phrase,assistant_response
# We don't have assistant_response. # We don't have assistant_response.
# We can generate a generic one? # We can generate a generic one?
@@ -398,6 +636,85 @@ class SyntheticDataGenerator:
pbar.close() pbar.close()
if num_failed_tool_calls > 0:
print(f"Generating {num_failed_tool_calls} failed tool call scenarios...")
failed_entries = []
tasks = {}
pbar_failed = tqdm(total=num_failed_tool_calls, desc="Failed tool calls")
while len(failed_entries) < num_failed_tool_calls:
while len(tasks) < self.concurrency and (len(failed_entries) + len(tasks)) < num_failed_tool_calls:
context = self.sample_context("action")
if not context:
continue
task = asyncio.create_task(self.generate_failed_tool_call_entry(session, context))
tasks[task] = None
if not tasks:
break
done, _ = await asyncio.wait(list(tasks.keys()), return_when=asyncio.FIRST_COMPLETED)
for task in done:
tasks.pop(task, None)
try:
entry = await task
except Exception as exc:
print(f"Failed tool call task error: {exc}")
entry = None
if entry:
failed_entries.append(entry)
pbar_failed.update(1)
pbar_failed.close()
failed_path = f"{cwd}/piles/{self.language}/pile_of_failed_tool_calls.csv"
self._append_rows_to_csv(failed_path, self.failed_tool_fieldnames, failed_entries)
if num_refusals > 0:
print(f"Generating {num_refusals} refusal scenarios...")
refusal_entries = []
tasks = {}
pbar_refusals = tqdm(total=num_refusals, desc="Refusals")
while len(refusal_entries) < num_refusals:
while len(tasks) < self.concurrency and (len(refusal_entries) + len(tasks)) < num_refusals:
context = self.sample_context("action")
if not context:
continue
reason_type = random.choice(["not_available", "already_state"])
desired_state = ""
if reason_type == "already_state":
desired_state = self._desired_state_for_service(context["service_name"])
if not desired_state:
reason_type = "not_available"
task = asyncio.create_task(self.generate_refusal_entry(
session,
context,
reason_type,
desired_state
))
tasks[task] = None
if not tasks:
break
done, _ = await asyncio.wait(list(tasks.keys()), return_when=asyncio.FIRST_COMPLETED)
for task in done:
tasks.pop(task, None)
try:
entry = await task
except Exception as exc:
print(f"Refusal generation error: {exc}")
entry = None
if entry:
refusal_entries.append(entry)
pbar_refusals.update(1)
pbar_refusals.close()
refusal_path = f"{cwd}/piles/{self.language}/pile_of_refusals.csv"
self._append_rows_to_csv(refusal_path, self.refusal_fieldnames, refusal_entries)
if persona_name and persona_description: if persona_name and persona_description:
await self.generate_persona(session, persona_name, persona_description) await self.generate_persona(session, persona_name, persona_description)
@@ -442,7 +759,7 @@ class SyntheticDataGenerator:
return return
# 2. Get list of services to generate responses for # 2. Get list of services to generate responses for
responses_csv_path = f"data/piles/{self.language}/pile_of_responses.csv" responses_csv_path = f"{cwd}/piles/{self.language}/pile_of_responses.csv"
services = set() services = set()
try: try:
with open(responses_csv_path, "r", encoding='utf-8') as f: with open(responses_csv_path, "r", encoding='utf-8') as f:
@@ -527,7 +844,7 @@ class SyntheticDataGenerator:
# 4. Write to files # 4. Write to files
# Append system prompt # Append system prompt
sys_prompts_path = f"data/piles/{self.language}/pile_of_system_prompts.csv" sys_prompts_path = f"{cwd}/piles/{self.language}/pile_of_system_prompts.csv"
try: try:
with open(sys_prompts_path, "a", newline='', encoding='utf-8') as f: with open(sys_prompts_path, "a", newline='', encoding='utf-8') as f:
writer = csv.writer(f) writer = csv.writer(f)
@@ -555,13 +872,25 @@ if __name__ == "__main__":
parser.add_argument("--actions", type=int, default=0, help="Number of actions to generate") parser.add_argument("--actions", type=int, default=0, help="Number of actions to generate")
parser.add_argument("--status", type=int, default=0, help="Number of status requests to generate") parser.add_argument("--status", type=int, default=0, help="Number of status requests to generate")
parser.add_argument("--devices", type=int, default=0, help="Number of new devices to generate") parser.add_argument("--devices", type=int, default=0, help="Number of new devices to generate")
parser.add_argument("--concurrency", type=int, default=8, help="Number of concurrent requests") parser.add_argument("--failed-tool-calls", type=int, default=0, help="Number of failed tool call scenarios to generate")
parser.add_argument("--language", type=str, default="english", help="Language") parser.add_argument("--refusals", type=int, default=0, help="Number of refusal scenarios to generate")
parser.add_argument("--concurrency", type=int, default=4, help="Number of concurrent requests")
parser.add_argument("--languages", type=str, nargs='+', default=["english"], help="Languages to generate data for")
parser.add_argument("--model", type=str, default="gpt-oss-120b", help="LLM model to use") parser.add_argument("--model", type=str, default="gpt-oss-120b", help="LLM model to use")
parser.add_argument("--persona-name", type=str, help="Name of the new persona to generate") parser.add_argument("--persona-name", type=str, help="Name of the new persona to generate")
parser.add_argument("--persona-description", type=str, help="Description of the new persona") parser.add_argument("--persona-description", type=str, help="Description of the new persona")
args = parser.parse_args() args = parser.parse_args()
generator = SyntheticDataGenerator(model_name=args.model, language=args.language, concurrency=args.concurrency) for language in args.languages:
asyncio.run(generator.run(num_actions=args.actions, num_status_requests=args.status, num_devices=args.devices, output_file="", persona_name=args.persona_name, persona_description=args.persona_description)) print(f"=== Generating data for language: {language} ===")
generator = SyntheticDataGenerator(model_name=args.model, language=language, concurrency=args.concurrency)
asyncio.run(generator.run(
num_actions=args.actions,
num_status_requests=args.status,
num_devices=args.devices,
persona_name=args.persona_name,
persona_description=args.persona_description,
num_failed_tool_calls=args.failed_tool_calls,
num_refusals=args.refusals
))

View File

@@ -37,6 +37,7 @@ SERVICE_TO_TOOL_MAP = {
"unlock": TOOL_TURN_OFF, "unlock": TOOL_TURN_OFF,
"increase_speed": TOOL_TURN_ON, "increase_speed": TOOL_TURN_ON,
"decrease_speed": TOOL_TURN_OFF, "decrease_speed": TOOL_TURN_OFF,
"media_stop": TOOL_TURN_OFF,
"media_play_pause": TOOL_TOGGLE, "media_play_pause": TOOL_TOGGLE,
"media_pause": TOOL_MEDIA_PAUSE, "media_pause": TOOL_MEDIA_PAUSE,
"media_play": TOOL_MEDIA_UNPAUSE, "media_play": TOOL_MEDIA_UNPAUSE,
@@ -49,6 +50,13 @@ SERVICE_TO_TOOL_MAP = {
"set_hvac_mode": TOOL_CLIMATE_SET_TEMPERATURE, "set_hvac_mode": TOOL_CLIMATE_SET_TEMPERATURE,
"set_fan_mode": TOOL_CLIMATE_SET_TEMPERATURE, "set_fan_mode": TOOL_CLIMATE_SET_TEMPERATURE,
"set_preset_mode": TOOL_CLIMATE_SET_TEMPERATURE, "set_preset_mode": TOOL_CLIMATE_SET_TEMPERATURE,
"cancel": TOOL_CANCEL_TIMER,
"volume_down": TOOL_SET_VOLUME,
"volume_up": TOOL_SET_VOLUME,
"volume_mute": TOOL_SET_VOLUME,
"stop": TOOL_TURN_OFF,
"pause": TOOL_TURN_OFF,
"add_item": TOOL_LIST_ADD_ITEM
} }
# Home Assistant Intent Tools Definition # Home Assistant Intent Tools Definition