mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
more synthesizing scenarios + clean up example formatting
This commit is contained in:
@@ -67,6 +67,26 @@ The response pile is a CSV with the following headers: `service,response,languag
|
||||
|
||||
Generating the full dataset using the python script will print out a warning for any responses that are missing for a persona.
|
||||
|
||||
## Synthesizing new pile data
|
||||
You can quickly append fresh examples to the CSV piles without editing them manually by running `synthesize.py`. The script talks to the configured LLM and writes the generated rows directly into the per-language pile files.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Append 25 failed tool-call recoveries and 25 refusals in Spanish
|
||||
python3 synthesize.py --language spanish --model gpt-oss-120b --failed-tool-calls 25 --refusals 25 --concurrency 6
|
||||
|
||||
# Generate new actions plus matching refusal samples in German
|
||||
python3 synthesize.py --language german --actions 100 --refusals 40 --model gpt-oss-120b
|
||||
```
|
||||
|
||||
Useful flags:
|
||||
- `--failed-tool-calls`: number of `pile_of_failed_tool_calls.csv` rows to synthesize.
|
||||
- `--refusals`: number of `pile_of_refusals.csv` rows to synthesize.
|
||||
- `--actions`, `--status`, `--devices`: existing knobs for the other piles.
|
||||
|
||||
The script automatically routes generations to the correct language-specific pile under `data/piles/<language>/`.
|
||||
|
||||
## Adding new Home Assistant functionality
|
||||
TODO
|
||||
<!-- In order to add new home assistant device types, you will need to add data to a handful of piles, as well as make small modifications to the `generate_data.py` script.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, Callable, List
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from tools import *
|
||||
from utils import closest_color, generate_random_parameter, get_dataset_piles
|
||||
@@ -31,6 +32,9 @@ STATE_CLEANING: Final = "cleaning"
|
||||
STATE_DOCKED: Final = "docked"
|
||||
STATE_RETURNING: Final = "returning"
|
||||
|
||||
def format_device_line(*, device_name: str, friendly_name: str, state: str):
|
||||
return (f"{device_name} '{friendly_name}' = {state}")
|
||||
|
||||
@dataclass
|
||||
class DeviceType:
|
||||
name: str
|
||||
@@ -222,3 +226,79 @@ class MediaPlayerDeviceType(DeviceType):
|
||||
if "volume_level" in extra_exposed_attributes:
|
||||
tools.append(TOOL_SET_VOLUME)
|
||||
return tools
|
||||
|
||||
|
||||
SUPPORTED_DEVICES = {
|
||||
"light": LightDeviceType(),
|
||||
"switch": SwitchDeviceType(),
|
||||
"fan": FanDeviceType(),
|
||||
"garage_door": GarageDoorDeviceType(),
|
||||
"blinds": BlindsDeviceType(),
|
||||
"lock": LockDeviceType(),
|
||||
"media_player": MediaPlayerDeviceType(),
|
||||
"climate": ClimateDeviceType(),
|
||||
"vacuum": VacuumDeviceType(),
|
||||
"timer": TimerDeviceType(),
|
||||
"todo": TodoDeviceType(),
|
||||
}
|
||||
|
||||
# generate a random list of devices for the context
|
||||
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
|
||||
num_devices = random.randint(2, max_devices)
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
|
||||
|
||||
avoid_climate = False
|
||||
for avoid_device in avoid_device_names:
|
||||
avoid_type = avoid_device.split(".")[0]
|
||||
|
||||
filtered_possible_devices = []
|
||||
for possible_device in local_device_names[avoid_type]:
|
||||
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
|
||||
|
||||
if similarity_ratio < 0.4:
|
||||
filtered_possible_devices.append(possible_device)
|
||||
local_device_names[avoid_type] = filtered_possible_devices
|
||||
|
||||
if avoid_type == "climate":
|
||||
avoid_climate = True
|
||||
|
||||
possible_choices = []
|
||||
for device_type in local_device_names.keys():
|
||||
possible_choices.extend(local_device_names[device_type])
|
||||
|
||||
|
||||
device_types = set()
|
||||
device_list = []
|
||||
device_lines = []
|
||||
# TODO: randomly pick attributes for this list
|
||||
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
|
||||
|
||||
while len(device_list) < num_devices:
|
||||
choice = random.choice(possible_choices)
|
||||
if choice["device_name"] in device_list:
|
||||
continue
|
||||
|
||||
try:
|
||||
device_name = choice["device_name"]
|
||||
device_type = device_name.split(".")[0]
|
||||
friendly_name = choice["description"]
|
||||
|
||||
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
|
||||
if avoid_climate and device_type == "climate":
|
||||
continue
|
||||
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
device_lines.append(format_device_line(
|
||||
device_name=device_name,
|
||||
friendly_name=friendly_name,
|
||||
state=state
|
||||
))
|
||||
device_list.append(device_name)
|
||||
device_types.add(device_type)
|
||||
except Exception as ex:
|
||||
print(f"bad device name: {choice}")
|
||||
print(repr(ex))
|
||||
|
||||
return device_lines, list(device_types), list(extra_exposed_attributes)
|
||||
@@ -3,7 +3,6 @@ import json
|
||||
import numpy as np
|
||||
import random
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Callable
|
||||
from tqdm import tqdm
|
||||
import webcolors
|
||||
@@ -13,87 +12,13 @@ import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from device_types import *
|
||||
from devices import SUPPORTED_DEVICES, format_device_line, random_device_list, \
|
||||
TOOL_TURN_ON, TOOL_CLIMATE_SET_TEMPERATURE, TOOL_SET_HUMIDITY, \
|
||||
TOOL_LIGHT_SET, TOOL_START_TIMER, TOOL_LIST_ADD_ITEM, SERVICE_TO_TOOL_MAP, \
|
||||
HASS_TOOLS, SERVICE_TOOLS
|
||||
from prompting import generate_system_prompt, USER_INSTRUCTION_PROMPT
|
||||
from utils import get_random_response, generate_random_parameter, closest_color, get_dataset_piles, NoResponseAvailableException
|
||||
|
||||
SUPPORTED_DEVICES = {
|
||||
"light": LightDeviceType(),
|
||||
"switch": SwitchDeviceType(),
|
||||
"fan": FanDeviceType(),
|
||||
"garage_door": GarageDoorDeviceType(),
|
||||
"blinds": BlindsDeviceType(),
|
||||
"lock": LockDeviceType(),
|
||||
"media_player": MediaPlayerDeviceType(),
|
||||
"climate": ClimateDeviceType(),
|
||||
"vacuum": VacuumDeviceType(),
|
||||
"timer": TimerDeviceType(),
|
||||
"todo": TodoDeviceType(),
|
||||
}
|
||||
|
||||
def format_device_line(*, device_name: str, friendly_name: str, state: str):
|
||||
return (f"{device_name} '{friendly_name}' = {state}")
|
||||
|
||||
# generate a random list of devices for the context
|
||||
def random_device_list(max_devices: int, avoid_device_names: list[str], language: str = "english"):
|
||||
num_devices = random.randint(2, max_devices)
|
||||
piles = get_dataset_piles(language)
|
||||
|
||||
local_device_names = { k: v[:] for k,v in piles.stacks_of_device_names.items() }
|
||||
|
||||
avoid_climate = False
|
||||
for avoid_device in avoid_device_names:
|
||||
avoid_type = avoid_device.split(".")[0]
|
||||
|
||||
filtered_possible_devices = []
|
||||
for possible_device in local_device_names[avoid_type]:
|
||||
similarity_ratio = SequenceMatcher(None, avoid_device, possible_device["device_name"].split(".")[1]).ratio()
|
||||
|
||||
if similarity_ratio < 0.4:
|
||||
filtered_possible_devices.append(possible_device)
|
||||
local_device_names[avoid_type] = filtered_possible_devices
|
||||
|
||||
if avoid_type == "climate":
|
||||
avoid_climate = True
|
||||
|
||||
possible_choices = []
|
||||
for device_type in local_device_names.keys():
|
||||
possible_choices.extend(local_device_names[device_type])
|
||||
|
||||
|
||||
device_types = set()
|
||||
device_list = []
|
||||
device_lines = []
|
||||
# TODO: randomly pick attributes for this list
|
||||
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "duration", "remaining", "item"]
|
||||
|
||||
while len(device_list) < num_devices:
|
||||
choice = random.choice(possible_choices)
|
||||
if choice["device_name"] in device_list:
|
||||
continue
|
||||
|
||||
try:
|
||||
device_name = choice["device_name"]
|
||||
device_type = device_name.split(".")[0]
|
||||
friendly_name = choice["description"]
|
||||
|
||||
# don't add random thermostats. we need to be careful about how we handle multiple thermostats
|
||||
if avoid_climate and device_type == "climate":
|
||||
continue
|
||||
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
device_lines.append(format_device_line(
|
||||
device_name=device_name,
|
||||
friendly_name=friendly_name,
|
||||
state=state
|
||||
))
|
||||
device_list.append(device_name)
|
||||
device_types.add(device_type)
|
||||
except Exception as ex:
|
||||
print(f"bad device name: {choice}")
|
||||
print(repr(ex))
|
||||
|
||||
return device_lines, list(device_types), list(extra_exposed_attributes)
|
||||
from utils import get_random_response, generate_random_parameter, closest_color, \
|
||||
get_dataset_piles, NoResponseAvailableException
|
||||
|
||||
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
question = action["phrase"]
|
||||
@@ -126,7 +51,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
|
||||
# Map service name to tool name
|
||||
service_action = service_name.split(".")[1]
|
||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
||||
tool_name = SERVICE_TO_TOOL_MAP[service_action]
|
||||
|
||||
response_starting, response_confirmed = get_random_response(
|
||||
piles.pile_of_responses,
|
||||
@@ -225,13 +150,17 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
except Exception as e:
|
||||
print(f"Failed to parse arguments for {action}: {e}")
|
||||
|
||||
final_answer = " ".join(answer_list)
|
||||
assistant_turns = [
|
||||
create_assistant_turn(response_starting, [tool_call]),
|
||||
create_assistant_turn(final_answer, [])
|
||||
]
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
"answers": answer_list,
|
||||
"answer_starting": response_starting,
|
||||
"tool_calls": [ tool_call ]
|
||||
"assistant_turns": assistant_turns
|
||||
}
|
||||
|
||||
def replace_answer(list_of_answer, var, value):
|
||||
@@ -240,6 +169,16 @@ def replace_answer(list_of_answer, var, value):
|
||||
new_list.append(answer.replace(var, value))
|
||||
return new_list
|
||||
|
||||
|
||||
def create_assistant_turn(answer: str, tool_call_sequence=None, *, tool_results=None, train_on_turn: bool = True):
|
||||
"""Bundle the assistant utterance with any tool interaction for that turn."""
|
||||
return {
|
||||
"answer": answer,
|
||||
"tool_call_sequence": tool_call_sequence or [],
|
||||
"tool_results": tool_results if tool_results is not None else [],
|
||||
"train_on_turn": train_on_turn,
|
||||
}
|
||||
|
||||
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
template_device_types: list[str] = template["device_type"].split("|")
|
||||
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
||||
@@ -328,7 +267,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
tool_calls = []
|
||||
for device_dict, service in zip(chosen_devices, service_names):
|
||||
service_action = service.split(".")[1]
|
||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
||||
tool_name = SERVICE_TO_TOOL_MAP[service_action]
|
||||
tool_call = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service,
|
||||
@@ -415,13 +354,19 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
if call["tool_name"] == TOOL_LIST_ADD_ITEM:
|
||||
call["tool_args"]["item"] = todo
|
||||
|
||||
starting_answer = answer_starting.strip().lower()
|
||||
normalized_answers = [ sentence.lower() for sentence in answer_list ]
|
||||
final_answer = " ".join(normalized_answers)
|
||||
assistant_turns = [
|
||||
create_assistant_turn(starting_answer, tool_calls),
|
||||
create_assistant_turn(final_answer, [])
|
||||
]
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
"answer_starting": answer_starting.lower(),
|
||||
"answers": [ sentence.lower() for sentence in answer_list ],
|
||||
"tool_calls": tool_calls
|
||||
"assistant_turns": assistant_turns
|
||||
}
|
||||
|
||||
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
|
||||
@@ -509,12 +454,13 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
# Remove duplicates while preserving order
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
|
||||
assistant_turns = [create_assistant_turn(answer.lower(), [])]
|
||||
|
||||
result = {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
"answers": [ answer.lower() ],
|
||||
"tool_calls": []
|
||||
"assistant_turns": assistant_turns
|
||||
}
|
||||
if return_target_device:
|
||||
return result, chosen_device
|
||||
@@ -578,42 +524,42 @@ def generate_tool_failure_example(failure_case: dict, persona: str, language: st
|
||||
retry_prompt = failure_case.get("retry_prompt", f"Trying again with {friendly_name}.").replace("<device_name>", friendly_name)
|
||||
error_result = failure_case.get("error_result", "Error").replace("<device_name>", friendly_name)
|
||||
|
||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
||||
tool_name = SERVICE_TO_TOOL_MAP[service_action]
|
||||
first_args = {"entity_id": bad_device} if use_service_names else {"name": bad_device}
|
||||
retry_args = {"entity_id": target_device} if use_service_names else {"name": target_device}
|
||||
first_args.update(tool_args_extra)
|
||||
retry_args.update(tool_args_extra)
|
||||
|
||||
tool_call_sequence = [
|
||||
{
|
||||
"answer_starting": response_starting,
|
||||
"tool_calls": [{
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": first_args
|
||||
}],
|
||||
"tool_results": [{
|
||||
"tool_name": service_name if use_service_names else tool_name,
|
||||
"tool_result": error_result
|
||||
}]
|
||||
},
|
||||
{
|
||||
"answer_starting": retry_prompt,
|
||||
"tool_calls": [{
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": retry_args
|
||||
}]
|
||||
}
|
||||
]
|
||||
first_turn = create_assistant_turn(
|
||||
response_starting,
|
||||
[{
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": first_args
|
||||
}],
|
||||
tool_results=[{
|
||||
"tool_name": service_name if use_service_names else tool_name,
|
||||
"tool_result": error_result
|
||||
}],
|
||||
train_on_turn=False
|
||||
)
|
||||
|
||||
second_turn = create_assistant_turn(
|
||||
retry_prompt,
|
||||
[{
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": retry_args
|
||||
}]
|
||||
)
|
||||
|
||||
final_turn = create_assistant_turn(response_confirmed, [])
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question,
|
||||
"answers": [response_confirmed],
|
||||
"tool_call_sequence": tool_call_sequence,
|
||||
"tool_calls": []
|
||||
"assistant_turns": [first_turn, second_turn, final_turn]
|
||||
}
|
||||
|
||||
def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
@@ -645,41 +591,20 @@ def generate_refusal_example(refusal_case: dict, persona: str, language: str, ma
|
||||
response_text = refusal_case["response"].replace("<device_name>", friendly_name).lower()
|
||||
question = refusal_case["phrase"].replace("<device_name>", friendly_name).lower()
|
||||
|
||||
assistant_turns = [create_assistant_turn(response_text, [])]
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question,
|
||||
"answers": [response_text],
|
||||
"tool_calls": []
|
||||
"assistant_turns": assistant_turns
|
||||
}
|
||||
|
||||
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
|
||||
piles = get_dataset_piles(language)
|
||||
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
||||
question = example["question"]
|
||||
answers = " ".join(example["answers"])
|
||||
answer_starting = example.get("answer_starting", "")
|
||||
|
||||
tool_call_sequence = example.get("tool_call_sequence")
|
||||
|
||||
tool_calls = []
|
||||
tool_results = []
|
||||
if not tool_call_sequence:
|
||||
# Add tool use blocks if there are tool calls
|
||||
if len(example["tool_calls"]) > 0:
|
||||
for tool_call in example["tool_calls"]:
|
||||
# Use service_name if in service mode, otherwise use tool_name
|
||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||
tool_calls.append({
|
||||
"name": call_name,
|
||||
"arguments": json.dumps(tool_call["tool_args"])
|
||||
})
|
||||
|
||||
tool_results.append({
|
||||
"tool_name": call_name,
|
||||
"tool_call_id": f"call_{len(tool_results) + 1}",
|
||||
"tool_result": "Success"
|
||||
})
|
||||
assistant_turns = example["assistant_turns"]
|
||||
|
||||
if append_user_instruction_prompt:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
@@ -703,95 +628,64 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
||||
}
|
||||
]
|
||||
|
||||
if tool_call_sequence:
|
||||
call_id_counter = 1
|
||||
for step in tool_call_sequence:
|
||||
step_tool_calls = []
|
||||
for tool_call in step.get("tool_calls", []):
|
||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||
step_tool_calls.append({
|
||||
"name": call_name,
|
||||
"arguments": json.dumps(tool_call["tool_args"])
|
||||
})
|
||||
|
||||
assistant_block = {
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": step.get("answer_starting", "") }]
|
||||
}
|
||||
if step_tool_calls:
|
||||
assistant_block["tool_calls"] = [ { "function": tc } for tc in step_tool_calls ]
|
||||
conversation.append(assistant_block)
|
||||
|
||||
if step_tool_calls:
|
||||
provided_results = step.get("tool_results")
|
||||
step_tool_results = []
|
||||
if provided_results:
|
||||
for provided in provided_results:
|
||||
step_tool_results.append({
|
||||
"tool_name": provided.get("tool_name", step_tool_calls[0]["name"]),
|
||||
"tool_call_id": provided.get("tool_call_id", f"call_{call_id_counter}"),
|
||||
"tool_result": provided.get("tool_result", "Success")
|
||||
})
|
||||
call_id_counter += 1
|
||||
else:
|
||||
for call in step_tool_calls:
|
||||
step_tool_results.append({
|
||||
"tool_name": call["name"],
|
||||
"tool_call_id": f"call_{call_id_counter}",
|
||||
"tool_result": "Success"
|
||||
})
|
||||
call_id_counter += 1
|
||||
|
||||
if tool_response_format == "text":
|
||||
conversation.append({
|
||||
"role": "tool",
|
||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
|
||||
})
|
||||
elif tool_response_format == "functiongemma":
|
||||
conversation.append({
|
||||
"role": "tool",
|
||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
|
||||
})
|
||||
|
||||
if step.get("post_tool_response"):
|
||||
conversation.append({
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": step["post_tool_response"] }]
|
||||
})
|
||||
|
||||
conversation.append({
|
||||
|
||||
call_id_counter = 1
|
||||
for turn in assistant_turns:
|
||||
answer_text = turn.get("answer", "")
|
||||
assistant_block = {
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": answers }],
|
||||
})
|
||||
elif len(tool_calls) > 0:
|
||||
assistant_starting_block = {
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": answer_starting }],
|
||||
"tool_calls": [ { "function": tc } for tc in tool_calls ]
|
||||
"content": [{ "type": "text", "text": answer_text }],
|
||||
"train_on_turn": turn.get("train_on_turn", True),
|
||||
}
|
||||
if tool_response_format == "text":
|
||||
tool_response_block = {
|
||||
"role": "tool",
|
||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in tool_results]
|
||||
}
|
||||
elif tool_response_format == "functiongemma":
|
||||
tool_response_block = {
|
||||
"role": "tool",
|
||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in tool_results]
|
||||
}
|
||||
assistant_confirmation_block = {
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": answers }],
|
||||
}
|
||||
conversation.extend([assistant_starting_block, tool_response_block, assistant_confirmation_block])
|
||||
else:
|
||||
conversation.extend([
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": answer_starting + answers }],
|
||||
}
|
||||
])
|
||||
|
||||
tool_call_sequence = turn.get("tool_call_sequence", [])
|
||||
formatted_calls = []
|
||||
call_names = []
|
||||
for tool_call in tool_call_sequence:
|
||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||
call_names.append(call_name)
|
||||
formatted_calls.append({
|
||||
"name": call_name,
|
||||
"arguments": json.dumps(tool_call["tool_args"])
|
||||
})
|
||||
|
||||
if formatted_calls:
|
||||
assistant_block["tool_calls"] = [{ "function": call } for call in formatted_calls]
|
||||
|
||||
conversation.append(assistant_block)
|
||||
|
||||
if formatted_calls:
|
||||
provided_results = turn.get("tool_results") or []
|
||||
step_tool_results = []
|
||||
|
||||
if provided_results:
|
||||
for idx, provided in enumerate(provided_results):
|
||||
result = dict(provided)
|
||||
if "tool_name" not in result and call_names:
|
||||
result["tool_name"] = call_names[min(idx, len(call_names) - 1)]
|
||||
if "tool_call_id" not in result:
|
||||
result["tool_call_id"] = f"call_{call_id_counter}"
|
||||
call_id_counter += 1
|
||||
step_tool_results.append(result)
|
||||
else:
|
||||
for call_name in call_names:
|
||||
step_tool_results.append({
|
||||
"tool_name": call_name,
|
||||
"tool_call_id": f"call_{call_id_counter}",
|
||||
"tool_result": "Success"
|
||||
})
|
||||
call_id_counter += 1
|
||||
|
||||
if tool_response_format == "text":
|
||||
conversation.append({
|
||||
"role": "tool",
|
||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
|
||||
})
|
||||
elif tool_response_format == "functiongemma":
|
||||
conversation.append({
|
||||
"role": "tool",
|
||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
|
||||
})
|
||||
|
||||
return {
|
||||
"messages": conversation,
|
||||
@@ -812,8 +706,8 @@ def generate_sft_file(
|
||||
static_factor: float,
|
||||
template_factor: int,
|
||||
status_request_factor: int,
|
||||
failure_factor: float = 1,
|
||||
refusal_factor: float = 1):
|
||||
failure_factor: int,
|
||||
refusal_factor: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
piles = get_dataset_piles(language)
|
||||
@@ -929,7 +823,7 @@ def main(args=None):
|
||||
|
||||
args = parser.parse_args(args=args)
|
||||
|
||||
if not args.sample and not args.train and not args.test and not args.merge:
|
||||
if not args.sample and not args.train and not args.test:
|
||||
parser.print_usage()
|
||||
exit(-1)
|
||||
|
||||
@@ -950,20 +844,20 @@ def main(args=None):
|
||||
suffix = f"_{language}" if len(args.language) > 1 else ""
|
||||
|
||||
if args.sample:
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
generate_sft_file(f"sample{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=1, status_request_factor=1, refusal_factor=1, failure_factor=1)
|
||||
if args.train:
|
||||
if args.size == "small":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=1, template_factor=10, status_request_factor=8, refusal_factor=3, failure_factor=1)
|
||||
elif args.size == "medium":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=15, status_request_factor=12, refusal_factor=5, failure_factor=1)
|
||||
elif args.size == "large":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=5, template_factor=20, status_request_factor=15, refusal_factor=6, failure_factor=1)
|
||||
elif args.size == "xl":
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
generate_sft_file(f"home_assistant_train{suffix}", 42, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=7, template_factor=25, status_request_factor=18, refusal_factor=8, failure_factor=2)
|
||||
else:
|
||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||
if args.test:
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2)
|
||||
generate_sft_file(f"home_assistant_test{suffix}", 12345, format_func, use_system_role, append_user_instruction_prompt, use_service_names, personas, language, tool_response_format, static_factor=0.25, template_factor=1, status_request_factor=2, refusal_factor=1, failure_factor=1)
|
||||
if len(args.language) > 1:
|
||||
if args.sample:
|
||||
merge_languages("sample", args.language)
|
||||
|
||||
@@ -3,13 +3,19 @@ import asyncio
|
||||
import csv
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
import aiohttp
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
from utils import get_dataset_piles
|
||||
from devices import random_device_list
|
||||
|
||||
LLM_ENDPOINT = "https://ai.cloud.alexoconnell.net/v1/chat/completions"
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def get_hass_match_error_message(bad_device_name: str) -> str:
|
||||
return f"<MatchFailedError result=MatchTargetsResult(is_match=False, no_match_reason=<MatchFailedReason.NAME: 1>, states=[], no_match_name=None, areas=[], floors=[]), constraints=MatchTargetsConstraints(name='{bad_device_name}', area_name=None, floor_name=None, domains=None, device_classes=None, features=None, states=None, assistant='conversation', allow_duplicate_names=False, single_target=False), preferences=MatchTargetsPreferences(area_id=None, floor_id=None)>"
|
||||
|
||||
class SyntheticDataGenerator:
|
||||
def __init__(self, model_name: str, language: str, concurrency: int):
|
||||
@@ -18,6 +24,130 @@ class SyntheticDataGenerator:
|
||||
self.model_name = model_name
|
||||
self.piles = get_dataset_piles(language)
|
||||
self.synthetic_devices = {} # device_type -> list of {device_name, description}
|
||||
self.failed_tool_fieldnames = [
|
||||
"service_name",
|
||||
"correct_device_name",
|
||||
"correct_friendly_name",
|
||||
"bad_device_name",
|
||||
"phrase",
|
||||
"error_result",
|
||||
"retry_prompt"
|
||||
]
|
||||
self.refusal_fieldnames = [
|
||||
"reason_type",
|
||||
"service_name",
|
||||
"device_name",
|
||||
"friendly_name",
|
||||
"desired_state",
|
||||
"phrase",
|
||||
"response"
|
||||
]
|
||||
|
||||
async def _chat_completion(self, session, system_prompt: str, user_prompt: str, *, temperature: float = 0.8, max_tokens: int = 300, structured_response: dict = {}):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens if not structured_response else None, # don't limit if structured (causes failed generations)
|
||||
"response_format": structured_response if structured_response else None,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(LLM_ENDPOINT, json=payload) as response:
|
||||
if response.status < 400:
|
||||
data = await response.json()
|
||||
return data['choices'][0]['message']['content'].strip()
|
||||
print(f"LLM request failed with status {response.status}")
|
||||
except Exception as exc:
|
||||
print(f"Completion request failed: {exc}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _strip_code_fence(content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
text = content.strip()
|
||||
if text.startswith("```") and text.endswith("```"):
|
||||
lines = [line for line in text.splitlines() if not line.strip().startswith("```")]
|
||||
return "\n".join(lines).strip()
|
||||
return text
|
||||
|
||||
def _parse_json_object(self, content: str):
|
||||
if not content:
|
||||
return None
|
||||
cleaned = self._strip_code_fence(content)
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start == -1 or end == -1:
|
||||
return None
|
||||
snippet = cleaned[start:end+1]
|
||||
try:
|
||||
return json.loads(snippet)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _ensure_placeholder(text: str, friendly_name: str, placeholder: str) -> str:
|
||||
if placeholder in text:
|
||||
return text
|
||||
lower_text = text.lower()
|
||||
lower_name = friendly_name.lower()
|
||||
idx = lower_text.find(lower_name)
|
||||
if idx == -1:
|
||||
return text
|
||||
return text[:idx] + placeholder + text[idx + len(friendly_name):]
|
||||
|
||||
@staticmethod
|
||||
def _describe_service(service_name: str, service_data: dict) -> str:
|
||||
domain, action = service_name.split(".", 1)
|
||||
description = f"{action.replace('_', ' ')} the {domain}"
|
||||
if service_data:
|
||||
description += f" with arguments {json.dumps(service_data, ensure_ascii=False)}"
|
||||
return description
|
||||
|
||||
@staticmethod
|
||||
def _desired_state_for_service(service_name: str) -> str:
|
||||
action = service_name.split(".", 1)[1]
|
||||
action_mapping = {
|
||||
"turn_on": "on",
|
||||
"turn_off": "off",
|
||||
"open_cover": "open",
|
||||
"close_cover": "closed",
|
||||
"lock": "locked",
|
||||
"unlock": "unlocked",
|
||||
"media_play_pause": "playing",
|
||||
}
|
||||
if action in action_mapping:
|
||||
return action_mapping[action]
|
||||
service_mapping = {
|
||||
"media_player.turn_on": "on",
|
||||
"media_player.turn_off": "off",
|
||||
"vacuum.start": "cleaning",
|
||||
"vacuum.return_to_base": "docked",
|
||||
"garage_door.open_cover": "open",
|
||||
"garage_door.close_cover": "closed",
|
||||
"blinds.open_cover": "open",
|
||||
"blinds.close_cover": "closed",
|
||||
}
|
||||
return service_mapping.get(service_name, "")
|
||||
|
||||
@staticmethod
|
||||
def _append_rows_to_csv(path: str, fieldnames, rows):
|
||||
if not rows:
|
||||
return
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
file_exists = os.path.exists(path)
|
||||
if not file_exists:
|
||||
with open(path, "w", newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
with open(path, "a", newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
async def generate_device_names(self, session, device_type, count=10):
|
||||
"""
|
||||
@@ -144,6 +274,117 @@ class SyntheticDataGenerator:
|
||||
print(f"Request failed: {e}")
|
||||
return None
|
||||
|
||||
async def generate_failed_tool_call_entry(self, session, context):
|
||||
service_name = context["service_name"]
|
||||
friendly_name = context["friendly_name"]
|
||||
correct_device = context["device_name"]
|
||||
action_description = self._describe_service(service_name, context.get("service_data", {}))
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=10, avoid_device_names=[], language=self.language)
|
||||
|
||||
extra_phrase_instructions = ""
|
||||
if service_name == "climate.set_temperature":
|
||||
extra_phrase_instructions = "Always reference the target temperature using the literal placeholder <temp_f>."
|
||||
|
||||
system_prompt = "You create high-quality multilingual training data for a smart home assistant."
|
||||
user_prompt = f"""
|
||||
Generate content for a failed tool call in {self.language}.
|
||||
Service: {service_name}
|
||||
Action description: {action_description}
|
||||
Device name: {correct_device} (friendly name: {friendly_name})
|
||||
The assistant first misidentifies the entity using the "bad_device" name, receives an error, then retries with {correct_device}.
|
||||
Output a single JSON object with the keys "phrase", "bad_device", "error_result", and "retry_prompt".
|
||||
- phrase: A natural voice command template the user would say. {extra_phrase_instructions}
|
||||
- bad_device: The incorrect version of '{correct_device}'. Ensure it is a truncation, mutation, transposition, or inclusion of an extra suffix/prefix. Avoid simple typos in words.
|
||||
- retry_prompt: A short acknowledgement from the assistant that it will try the correct device ({correct_device}).
|
||||
|
||||
Here are potential device names to help generate a bad device name: {', '.join(device_list)}
|
||||
Keep the tone conversational and stay entirely in {self.language}. Do not add explanations or code fences.
|
||||
"""
|
||||
|
||||
raw = await self._chat_completion(
|
||||
session,
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=0.7,
|
||||
max_tokens=512,
|
||||
structured_response={ "type": "json_object" }
|
||||
)
|
||||
parsed = self._parse_json_object(raw)
|
||||
if not parsed:
|
||||
print(f"Failed to parse JSON for failed tool call generation: {raw}")
|
||||
return None
|
||||
|
||||
phrase = parsed.get("phrase", "").strip()
|
||||
bad_device = parsed.get("bad_device", "").strip()
|
||||
retry_prompt = parsed.get("retry_prompt", "").strip()
|
||||
if not phrase or not bad_device or not retry_prompt:
|
||||
return None
|
||||
|
||||
return {
|
||||
"service_name": service_name,
|
||||
"correct_device_name": correct_device,
|
||||
"correct_friendly_name": friendly_name,
|
||||
"bad_device_name": bad_device,
|
||||
"phrase": phrase,
|
||||
"error_result": get_hass_match_error_message(bad_device),
|
||||
"retry_prompt": retry_prompt
|
||||
}
|
||||
|
||||
async def generate_refusal_entry(self, session, context, reason_type: str, desired_state: str):
|
||||
service_name = context["service_name"]
|
||||
friendly_name = context["friendly_name"]
|
||||
action_description = self._describe_service(service_name, context.get("service_data", {}))
|
||||
device_suffix = context["device_name"].split(".", 1)[1]
|
||||
|
||||
if reason_type == "not_available":
|
||||
reason_detail = "Explain that the assistant cannot locate that device in the home."
|
||||
else:
|
||||
reason_detail = f"Explain that the device is already {desired_state} and no change is needed."
|
||||
|
||||
system_prompt = "You write concise refusal-style responses for a multilingual smart home assistant."
|
||||
user_prompt = f"""
|
||||
Create a refusal example in {self.language} for the following request.
|
||||
Service: {service_name}
|
||||
Action description: {action_description}
|
||||
Reason: {reason_detail}
|
||||
Output a JSON object with:
|
||||
- "phrase": the user's natural command template. Use the literal placeholder <device_name> anywhere the user mentions the device.
|
||||
- "response": the assistant's refusal message in {self.language} describing the reason.
|
||||
Keep both fields brief, conversational, and free of extra narration or code fences.
|
||||
"""
|
||||
|
||||
raw = await self._chat_completion(
|
||||
session,
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=0.7,
|
||||
max_tokens=512,
|
||||
structured_response={ "type": "json_object" }
|
||||
)
|
||||
parsed = self._parse_json_object(raw)
|
||||
if not parsed:
|
||||
print("Failed to parse JSON for refusal generation")
|
||||
return None
|
||||
|
||||
phrase = parsed.get("phrase", "").strip()
|
||||
response = parsed.get("response", "").strip()
|
||||
if not phrase or not response:
|
||||
return None
|
||||
|
||||
phrase = self._ensure_placeholder(phrase, friendly_name, "<device_name>")
|
||||
if "<device_name>" not in phrase:
|
||||
return None
|
||||
|
||||
return {
|
||||
"reason_type": reason_type,
|
||||
"service_name": service_name,
|
||||
"device_name": device_suffix,
|
||||
"friendly_name": friendly_name,
|
||||
"desired_state": desired_state if reason_type == "already_state" else "",
|
||||
"phrase": phrase,
|
||||
"response": response
|
||||
}
|
||||
|
||||
def sample_context(self, request_type: str):
|
||||
"""
|
||||
Creates a random scenario: device, service, and arguments.
|
||||
@@ -245,14 +486,12 @@ class SyntheticDataGenerator:
|
||||
}
|
||||
raise ValueError(f"Unknown request type {request_type}")
|
||||
|
||||
async def run(self, num_actions: int, num_status_requests: int, num_devices: int, output_file, persona_name=None, persona_description=None):
|
||||
async def run(self, num_actions: int, num_status_requests: int, num_devices: int,
|
||||
persona_name=None, persona_description=None, num_failed_tool_calls: int = 0,
|
||||
num_refusals: int = 0):
|
||||
print(f"Starting generation...")
|
||||
print(f"Language: {self.language}")
|
||||
|
||||
# Ensure output directory exists
|
||||
if output_file:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
if num_devices > 0:
|
||||
@@ -268,7 +507,7 @@ class SyntheticDataGenerator:
|
||||
all_new_devices = [item for sublist in generated_lists if sublist for item in sublist]
|
||||
|
||||
if all_new_devices:
|
||||
csv_path = f"data/piles/{self.language}/pile_of_device_names.csv"
|
||||
csv_path = f"{cwd}/piles/{self.language}/pile_of_device_names.csv"
|
||||
try:
|
||||
with open(csv_path, "a", newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["device_name", "description"])
|
||||
@@ -280,7 +519,6 @@ class SyntheticDataGenerator:
|
||||
|
||||
if num_actions > 0 or num_status_requests > 0:
|
||||
print(f"Generating {num_actions} actions and {num_status_requests} status requests...")
|
||||
print(f"Output file: {output_file}")
|
||||
tasks = {}
|
||||
results = []
|
||||
|
||||
@@ -312,7 +550,7 @@ class SyntheticDataGenerator:
|
||||
|
||||
if entry["type"] == "action":
|
||||
# Write to pile_of_specific_actions.csv
|
||||
csv_path = f"data/piles/{self.language}/pile_of_specific_actions.csv"
|
||||
csv_path = f"{cwd}/piles/{self.language}/pile_of_specific_actions.csv"
|
||||
|
||||
# Prepare row
|
||||
# device_name in CSV is the suffix (e.g. 'kitchen' from 'light.kitchen')
|
||||
@@ -383,7 +621,7 @@ class SyntheticDataGenerator:
|
||||
# For now, just skip if we can't templatize.
|
||||
pass
|
||||
else:
|
||||
csv_path = f"data/piles/{self.language}/pile_of_status_requests.csv"
|
||||
csv_path = f"{cwd}/piles/{self.language}/pile_of_status_requests.csv"
|
||||
# Columns: device_type,state,phrase,assistant_response
|
||||
# We don't have assistant_response.
|
||||
# We can generate a generic one?
|
||||
@@ -398,6 +636,85 @@ class SyntheticDataGenerator:
|
||||
|
||||
pbar.close()
|
||||
|
||||
if num_failed_tool_calls > 0:
|
||||
print(f"Generating {num_failed_tool_calls} failed tool call scenarios...")
|
||||
failed_entries = []
|
||||
tasks = {}
|
||||
pbar_failed = tqdm(total=num_failed_tool_calls, desc="Failed tool calls")
|
||||
|
||||
while len(failed_entries) < num_failed_tool_calls:
|
||||
while len(tasks) < self.concurrency and (len(failed_entries) + len(tasks)) < num_failed_tool_calls:
|
||||
context = self.sample_context("action")
|
||||
if not context:
|
||||
continue
|
||||
task = asyncio.create_task(self.generate_failed_tool_call_entry(session, context))
|
||||
tasks[task] = None
|
||||
|
||||
if not tasks:
|
||||
break
|
||||
|
||||
done, _ = await asyncio.wait(list(tasks.keys()), return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in done:
|
||||
tasks.pop(task, None)
|
||||
try:
|
||||
entry = await task
|
||||
except Exception as exc:
|
||||
print(f"Failed tool call task error: {exc}")
|
||||
entry = None
|
||||
if entry:
|
||||
failed_entries.append(entry)
|
||||
pbar_failed.update(1)
|
||||
|
||||
pbar_failed.close()
|
||||
failed_path = f"{cwd}/piles/{self.language}/pile_of_failed_tool_calls.csv"
|
||||
self._append_rows_to_csv(failed_path, self.failed_tool_fieldnames, failed_entries)
|
||||
|
||||
if num_refusals > 0:
|
||||
print(f"Generating {num_refusals} refusal scenarios...")
|
||||
refusal_entries = []
|
||||
tasks = {}
|
||||
pbar_refusals = tqdm(total=num_refusals, desc="Refusals")
|
||||
|
||||
while len(refusal_entries) < num_refusals:
|
||||
while len(tasks) < self.concurrency and (len(refusal_entries) + len(tasks)) < num_refusals:
|
||||
context = self.sample_context("action")
|
||||
if not context:
|
||||
continue
|
||||
reason_type = random.choice(["not_available", "already_state"])
|
||||
desired_state = ""
|
||||
if reason_type == "already_state":
|
||||
desired_state = self._desired_state_for_service(context["service_name"])
|
||||
if not desired_state:
|
||||
reason_type = "not_available"
|
||||
task = asyncio.create_task(self.generate_refusal_entry(
|
||||
session,
|
||||
context,
|
||||
reason_type,
|
||||
desired_state
|
||||
))
|
||||
tasks[task] = None
|
||||
|
||||
if not tasks:
|
||||
break
|
||||
|
||||
done, _ = await asyncio.wait(list(tasks.keys()), return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in done:
|
||||
tasks.pop(task, None)
|
||||
try:
|
||||
entry = await task
|
||||
except Exception as exc:
|
||||
print(f"Refusal generation error: {exc}")
|
||||
entry = None
|
||||
if entry:
|
||||
refusal_entries.append(entry)
|
||||
pbar_refusals.update(1)
|
||||
|
||||
pbar_refusals.close()
|
||||
refusal_path = f"{cwd}/piles/{self.language}/pile_of_refusals.csv"
|
||||
self._append_rows_to_csv(refusal_path, self.refusal_fieldnames, refusal_entries)
|
||||
|
||||
if persona_name and persona_description:
|
||||
await self.generate_persona(session, persona_name, persona_description)
|
||||
|
||||
@@ -442,7 +759,7 @@ class SyntheticDataGenerator:
|
||||
return
|
||||
|
||||
# 2. Get list of services to generate responses for
|
||||
responses_csv_path = f"data/piles/{self.language}/pile_of_responses.csv"
|
||||
responses_csv_path = f"{cwd}/piles/{self.language}/pile_of_responses.csv"
|
||||
services = set()
|
||||
try:
|
||||
with open(responses_csv_path, "r", encoding='utf-8') as f:
|
||||
@@ -527,7 +844,7 @@ class SyntheticDataGenerator:
|
||||
|
||||
# 4. Write to files
|
||||
# Append system prompt
|
||||
sys_prompts_path = f"data/piles/{self.language}/pile_of_system_prompts.csv"
|
||||
sys_prompts_path = f"{cwd}/piles/{self.language}/pile_of_system_prompts.csv"
|
||||
try:
|
||||
with open(sys_prompts_path, "a", newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
@@ -555,13 +872,25 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--actions", type=int, default=0, help="Number of actions to generate")
|
||||
parser.add_argument("--status", type=int, default=0, help="Number of status requests to generate")
|
||||
parser.add_argument("--devices", type=int, default=0, help="Number of new devices to generate")
|
||||
parser.add_argument("--concurrency", type=int, default=8, help="Number of concurrent requests")
|
||||
parser.add_argument("--language", type=str, default="english", help="Language")
|
||||
parser.add_argument("--failed-tool-calls", type=int, default=0, help="Number of failed tool call scenarios to generate")
|
||||
parser.add_argument("--refusals", type=int, default=0, help="Number of refusal scenarios to generate")
|
||||
parser.add_argument("--concurrency", type=int, default=4, help="Number of concurrent requests")
|
||||
parser.add_argument("--languages", type=str, nargs='+', default=["english"], help="Languages to generate data for")
|
||||
parser.add_argument("--model", type=str, default="gpt-oss-120b", help="LLM model to use")
|
||||
parser.add_argument("--persona-name", type=str, help="Name of the new persona to generate")
|
||||
parser.add_argument("--persona-description", type=str, help="Description of the new persona")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
generator = SyntheticDataGenerator(model_name=args.model, language=args.language, concurrency=args.concurrency)
|
||||
asyncio.run(generator.run(num_actions=args.actions, num_status_requests=args.status, num_devices=args.devices, output_file="", persona_name=args.persona_name, persona_description=args.persona_description))
|
||||
for language in args.languages:
|
||||
print(f"=== Generating data for language: {language} ===")
|
||||
generator = SyntheticDataGenerator(model_name=args.model, language=language, concurrency=args.concurrency)
|
||||
asyncio.run(generator.run(
|
||||
num_actions=args.actions,
|
||||
num_status_requests=args.status,
|
||||
num_devices=args.devices,
|
||||
persona_name=args.persona_name,
|
||||
persona_description=args.persona_description,
|
||||
num_failed_tool_calls=args.failed_tool_calls,
|
||||
num_refusals=args.refusals
|
||||
))
|
||||
|
||||
@@ -37,6 +37,7 @@ SERVICE_TO_TOOL_MAP = {
|
||||
"unlock": TOOL_TURN_OFF,
|
||||
"increase_speed": TOOL_TURN_ON,
|
||||
"decrease_speed": TOOL_TURN_OFF,
|
||||
"media_stop": TOOL_TURN_OFF,
|
||||
"media_play_pause": TOOL_TOGGLE,
|
||||
"media_pause": TOOL_MEDIA_PAUSE,
|
||||
"media_play": TOOL_MEDIA_UNPAUSE,
|
||||
@@ -49,6 +50,13 @@ SERVICE_TO_TOOL_MAP = {
|
||||
"set_hvac_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||
"set_fan_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||
"set_preset_mode": TOOL_CLIMATE_SET_TEMPERATURE,
|
||||
"cancel": TOOL_CANCEL_TIMER,
|
||||
"volume_down": TOOL_SET_VOLUME,
|
||||
"volume_up": TOOL_SET_VOLUME,
|
||||
"volume_mute": TOOL_SET_VOLUME,
|
||||
"stop": TOOL_TURN_OFF,
|
||||
"pause": TOOL_TURN_OFF,
|
||||
"add_item": TOOL_LIST_ADD_ITEM
|
||||
}
|
||||
|
||||
# Home Assistant Intent Tools Definition
|
||||
|
||||
Reference in New Issue
Block a user