diff --git a/data/generate_data.py b/data/generate_data.py index ab4f740..919ec5f 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -521,6 +521,139 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev else: return result +def generate_tool_failure_example(failure_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): + piles = get_dataset_piles(language) + service_name = failure_case["service_name"] + device_type = service_name.split(".")[0] + service_action = service_name.split(".")[1] + target_device = failure_case["correct_device_name"] + friendly_name = failure_case.get("correct_friendly_name", target_device.split(".")[1].replace("_", " ").title()) + bad_device = failure_case["bad_device_name"] + + question_template = failure_case["phrase"] + question = question_template.replace("", friendly_name).lower() + + device_list, device_types, extra_exposed_attributes = random_device_list( + max_devices=max_devices, avoid_device_names=[target_device], language=language) + + state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes) + device_list.insert(random.randint(0, len(device_list)), format_device_line( + device_name=target_device, + friendly_name=friendly_name, + state=state + )) + if device_type not in device_types: + device_types.append(device_type) + + available_tools = [] + for x in set(device_types): + available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) + available_tools = list(dict.fromkeys(available_tools)) + + response_starting, response_confirmed = get_random_response( + piles.pile_of_responses, + service=service_name, + persona=persona, + question_template=question_template, + short=False + ) + response_starting = response_starting.replace("", friendly_name) + response_confirmed = response_confirmed.replace("", friendly_name) + + tool_args_extra = {} + if "climate" in service_action: + if "" in question or "" in response_starting or "" in response_confirmed: + temp_f = generate_random_parameter("temp_f", piles) + question = question.replace("", str(temp_f)) + response_starting = response_starting.replace("", str(temp_f)) + response_confirmed = response_confirmed.replace("", str(temp_f)) + tool_args_extra["temperature"] = temp_f + if "" in question or "" in response_starting or "" in response_confirmed: + temp_c = generate_random_parameter("temp_c", piles) + question = question.replace("", str(temp_c)) + response_starting = response_starting.replace("", str(temp_c)) + response_confirmed = response_confirmed.replace("", str(temp_c)) + tool_args_extra["temperature"] = temp_c + + retry_prompt = failure_case.get("retry_prompt", f"Trying again with {friendly_name}.").replace("", friendly_name) + error_result = failure_case.get("error_result", "Error").replace("", friendly_name) + + tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON) + first_args = {"entity_id": bad_device} if use_service_names else {"name": bad_device} + retry_args = {"entity_id": target_device} if use_service_names else {"name": target_device} + first_args.update(tool_args_extra) + retry_args.update(tool_args_extra) + + tool_call_sequence = [ + { + "answer_starting": response_starting, + "tool_calls": [{ + "tool_name": tool_name, + "service_name": service_name, + "tool_args": first_args + }], + "tool_results": [{ + "tool_name": service_name if use_service_names else tool_name, + "tool_result": error_result + }] + }, + { + "answer_starting": retry_prompt, + "tool_calls": [{ + "tool_name": tool_name, + "service_name": service_name, + "tool_args": retry_args + }] + } + ] + + return { + "states": device_list, + "available_tools": available_tools, + "question": question, + "answers": [response_confirmed], + "tool_call_sequence": tool_call_sequence, + "tool_calls": [] + } + +def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False): + service_name = refusal_case["service_name"] + device_type = service_name.split(".")[0] + target_device = f"{device_type}.{refusal_case['device_name']}" + friendly_name = refusal_case.get("friendly_name", refusal_case["device_name"].replace("_", " ").title()) + desired_state = refusal_case.get("desired_state", "") + reason_type = refusal_case.get("reason_type", "not_available") + + device_list, device_types, extra_exposed_attributes = random_device_list( + max_devices=max_devices, avoid_device_names=[target_device], language=language) + + if reason_type == "already_state": + state = desired_state if desired_state else SUPPORTED_DEVICES[device_type].possible_states[0][0] + device_list.insert(random.randint(0, len(device_list)), format_device_line( + device_name=target_device, + friendly_name=friendly_name, + state=state + )) + if device_type not in device_types: + device_types.append(device_type) + + available_tools = [] + for x in set(device_types): + available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes)) + available_tools = list(dict.fromkeys(available_tools)) + + response_text = refusal_case["response"].replace("", friendly_name).lower() + question = refusal_case["phrase"].replace("", friendly_name).lower() + + return { + "states": device_list, + "available_tools": available_tools, + "question": question, + "answers": [response_text], + "answer_starting": "", + "tool_calls": [] + } + def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format): piles = get_dataset_piles(language) sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts) @@ -528,24 +661,26 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_ answers = " ".join(example["answers"]) answer_starting = example.get("answer_starting", "") + tool_call_sequence = example.get("tool_call_sequence") + tool_calls = [] tool_results = [] - - # Add tool use blocks if there are tool calls - if len(example["tool_calls"]) > 0: - for tool_call in example["tool_calls"]: - # Use service_name if in service mode, otherwise use tool_name - call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"] - tool_calls.append({ - "name": call_name, - "arguments": json.dumps(tool_call["tool_args"]) - }) + if not tool_call_sequence: + # Add tool use blocks if there are tool calls + if len(example["tool_calls"]) > 0: + for tool_call in example["tool_calls"]: + # Use service_name if in service mode, otherwise use tool_name + call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"] + tool_calls.append({ + "name": call_name, + "arguments": json.dumps(tool_call["tool_args"]) + }) - tool_results.append({ - "tool_name": call_name, - "tool_call_id": f"call_{len(tool_results) + 1}", - "tool_result": "Success" - }) + tool_results.append({ + "tool_name": call_name, + "tool_call_id": f"call_{len(tool_results) + 1}", + "tool_result": "Success" + }) if append_user_instruction_prompt: user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":" @@ -569,8 +704,68 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_ "content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }] } ] - - if len(tool_calls) > 0: + + if tool_call_sequence: + call_id_counter = 1 + for step in tool_call_sequence: + step_tool_calls = [] + for tool_call in step.get("tool_calls", []): + call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"] + step_tool_calls.append({ + "name": call_name, + "arguments": json.dumps(tool_call["tool_args"]) + }) + + assistant_block = { + "role": "assistant", + "content": [{ "type": "text", "text": step.get("answer_starting", "") }] + } + if step_tool_calls: + assistant_block["tool_calls"] = [ { "function": tc } for tc in step_tool_calls ] + conversation.append(assistant_block) + + if step_tool_calls: + provided_results = step.get("tool_results") + step_tool_results = [] + if provided_results: + for provided in provided_results: + step_tool_results.append({ + "tool_name": provided.get("tool_name", step_tool_calls[0]["name"]), + "tool_call_id": provided.get("tool_call_id", f"call_{call_id_counter}"), + "tool_result": provided.get("tool_result", "Success") + }) + call_id_counter += 1 + else: + for call in step_tool_calls: + step_tool_results.append({ + "tool_name": call["name"], + "tool_call_id": f"call_{call_id_counter}", + "tool_result": "Success" + }) + call_id_counter += 1 + + if tool_response_format == "text": + conversation.append({ + "role": "tool", + "content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results] + }) + elif tool_response_format == "functiongemma": + conversation.append({ + "role": "tool", + "content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results] + }) + + if step.get("post_tool_response"): + conversation.append({ + "role": "assistant", + "content": [{ "type": "text", "text": step["post_tool_response"] }] + }) + + conversation.append({ + "role": "assistant", + "content": [{ "type": "text", "text": answers }], + }) + elif len(tool_calls) > 0: assistant_starting_block = { "role": "assistant", "content": [{ "type": "text", "text": answer_starting }], @@ -617,7 +812,9 @@ def generate_sft_file( *, static_factor: float, template_factor: int, - status_request_factor: int): + status_request_factor: int, + failure_factor: float = 1, + refusal_factor: float = 1): random.seed(seed) np.random.seed(seed) piles = get_dataset_piles(language) @@ -649,6 +846,15 @@ def generate_sft_file( except NoResponseAvailableException as ex: missing_responses.add(str(ex)) + for failure_case in tqdm(piles.pile_of_failed_tool_calls): + try: + run_factor_times(generate_tool_failure_example, generated_examples, failure_case, person, failure_factor, language) + except NoResponseAvailableException as ex: + missing_responses.add(str(ex)) + + for refusal_case in tqdm(piles.pile_of_refusals): + run_factor_times(generate_refusal_example, generated_examples, refusal_case, person, refusal_factor, language) + for status_request in tqdm(piles.pile_of_status_requests): run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor, language) diff --git a/data/piles/english/pile_of_failed_tool_calls.csv b/data/piles/english/pile_of_failed_tool_calls.csv new file mode 100644 index 0000000..86e08d9 --- /dev/null +++ b/data/piles/english/pile_of_failed_tool_calls.csv @@ -0,0 +1,4 @@ +service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt +light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the .","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead." +climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the to degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat." +fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the .","Error: Entity fan.offce not found.","I'll try the office fan again." diff --git a/data/piles/english/pile_of_refusals.csv b/data/piles/english/pile_of_refusals.csv new file mode 100644 index 0000000..4b57822 --- /dev/null +++ b/data/piles/english/pile_of_refusals.csv @@ -0,0 +1,4 @@ +reason_type,service_name,device_name,friendly_name,desired_state,phrase,response +not_available,lock.lock,back_door,Back Door Lock,,"Lock the .","I can't find a back door lock to control." +already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the .","The hallway lights are already on." +already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the .","The garage outlet is already off." diff --git a/data/piles/french/pile_of_failed_tool_calls.csv b/data/piles/french/pile_of_failed_tool_calls.csv new file mode 100644 index 0000000..86e08d9 --- /dev/null +++ b/data/piles/french/pile_of_failed_tool_calls.csv @@ -0,0 +1,4 @@ +service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt +light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the .","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead." +climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the to degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat." +fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the .","Error: Entity fan.offce not found.","I'll try the office fan again." diff --git a/data/piles/french/pile_of_refusals.csv b/data/piles/french/pile_of_refusals.csv new file mode 100644 index 0000000..4b57822 --- /dev/null +++ b/data/piles/french/pile_of_refusals.csv @@ -0,0 +1,4 @@ +reason_type,service_name,device_name,friendly_name,desired_state,phrase,response +not_available,lock.lock,back_door,Back Door Lock,,"Lock the .","I can't find a back door lock to control." +already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the .","The hallway lights are already on." +already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the .","The garage outlet is already off." diff --git a/data/piles/german/pile_of_failed_tool_calls.csv b/data/piles/german/pile_of_failed_tool_calls.csv new file mode 100644 index 0000000..86e08d9 --- /dev/null +++ b/data/piles/german/pile_of_failed_tool_calls.csv @@ -0,0 +1,4 @@ +service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt +light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the .","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead." +climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the to degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat." +fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the .","Error: Entity fan.offce not found.","I'll try the office fan again." diff --git a/data/piles/german/pile_of_refusals.csv b/data/piles/german/pile_of_refusals.csv new file mode 100644 index 0000000..4b57822 --- /dev/null +++ b/data/piles/german/pile_of_refusals.csv @@ -0,0 +1,4 @@ +reason_type,service_name,device_name,friendly_name,desired_state,phrase,response +not_available,lock.lock,back_door,Back Door Lock,,"Lock the .","I can't find a back door lock to control." +already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the .","The hallway lights are already on." +already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the .","The garage outlet is already off." diff --git a/data/piles/polish/pile_of_failed_tool_calls.csv b/data/piles/polish/pile_of_failed_tool_calls.csv new file mode 100644 index 0000000..86e08d9 --- /dev/null +++ b/data/piles/polish/pile_of_failed_tool_calls.csv @@ -0,0 +1,4 @@ +service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt +light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the .","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead." +climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the to degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat." +fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the .","Error: Entity fan.offce not found.","I'll try the office fan again." diff --git a/data/piles/polish/pile_of_refusals.csv b/data/piles/polish/pile_of_refusals.csv new file mode 100644 index 0000000..4b57822 --- /dev/null +++ b/data/piles/polish/pile_of_refusals.csv @@ -0,0 +1,4 @@ +reason_type,service_name,device_name,friendly_name,desired_state,phrase,response +not_available,lock.lock,back_door,Back Door Lock,,"Lock the .","I can't find a back door lock to control." +already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the .","The hallway lights are already on." +already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the .","The garage outlet is already off." diff --git a/data/piles/spanish/pile_of_failed_tool_calls.csv b/data/piles/spanish/pile_of_failed_tool_calls.csv new file mode 100644 index 0000000..86e08d9 --- /dev/null +++ b/data/piles/spanish/pile_of_failed_tool_calls.csv @@ -0,0 +1,4 @@ +service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt +light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the .","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead." +climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the to degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat." +fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the .","Error: Entity fan.offce not found.","I'll try the office fan again." diff --git a/data/piles/spanish/pile_of_refusals.csv b/data/piles/spanish/pile_of_refusals.csv new file mode 100644 index 0000000..4b57822 --- /dev/null +++ b/data/piles/spanish/pile_of_refusals.csv @@ -0,0 +1,4 @@ +reason_type,service_name,device_name,friendly_name,desired_state,phrase,response +not_available,lock.lock,back_door,Back Door Lock,,"Lock the .","I can't find a back door lock to control." +already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the .","The hallway lights are already on." +already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the .","The garage outlet is already off." diff --git a/data/utils.py b/data/utils.py index 94f688d..63f88a8 100644 --- a/data/utils.py +++ b/data/utils.py @@ -147,6 +147,20 @@ class DatasetPiles: reader = csv.DictReader(f) self.pile_of_hallucinated_service_names = list(reader) + failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv" + self.pile_of_failed_tool_calls = [] + if os.path.exists(failed_tool_calls_path): + with open(failed_tool_calls_path, encoding="utf8") as f: + reader = csv.DictReader(f) + self.pile_of_failed_tool_calls = list(reader) + + refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv" + self.pile_of_refusals = [] + if os.path.exists(refusals_path): + with open(refusals_path, encoding="utf8") as f: + reader = csv.DictReader(f) + self.pile_of_refusals = list(reader) + def __getitem__(self, key): return getattr(self, key) @@ -158,4 +172,4 @@ def get_dataset_piles(language: str) -> DatasetPiles: "light", "switch", "fan", "garage_door", "blinds", "lock","media_player", "climate", "vacuum", "timer", "todo", ], language) - return _piles_cache[language] \ No newline at end of file + return _piles_cache[language]