mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
feat: add failure and refusal dataset examples
Co-authored-by: acon96 <35843486+acon96@users.noreply.github.com>
This commit is contained in:
@@ -521,6 +521,139 @@ def generate_status_request(template: dict, persona: str, language: str, max_dev
|
||||
else:
|
||||
return result
|
||||
|
||||
def generate_tool_failure_example(failure_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
piles = get_dataset_piles(language)
|
||||
service_name = failure_case["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
service_action = service_name.split(".")[1]
|
||||
target_device = failure_case["correct_device_name"]
|
||||
friendly_name = failure_case.get("correct_friendly_name", target_device.split(".")[1].replace("_", " ").title())
|
||||
bad_device = failure_case["bad_device_name"]
|
||||
|
||||
question_template = failure_case["phrase"]
|
||||
question = question_template.replace("<device_name>", friendly_name).lower()
|
||||
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||
max_devices=max_devices, avoid_device_names=[target_device], language=language)
|
||||
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state(language, extra_exposed_attributes=extra_exposed_attributes)
|
||||
device_list.insert(random.randint(0, len(device_list)), format_device_line(
|
||||
device_name=target_device,
|
||||
friendly_name=friendly_name,
|
||||
state=state
|
||||
))
|
||||
if device_type not in device_types:
|
||||
device_types.append(device_type)
|
||||
|
||||
available_tools = []
|
||||
for x in set(device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
|
||||
response_starting, response_confirmed = get_random_response(
|
||||
piles.pile_of_responses,
|
||||
service=service_name,
|
||||
persona=persona,
|
||||
question_template=question_template,
|
||||
short=False
|
||||
)
|
||||
response_starting = response_starting.replace("<device_name>", friendly_name)
|
||||
response_confirmed = response_confirmed.replace("<device_name>", friendly_name)
|
||||
|
||||
tool_args_extra = {}
|
||||
if "climate" in service_action:
|
||||
if "<temp_f>" in question or "<temp_f>" in response_starting or "<temp_f>" in response_confirmed:
|
||||
temp_f = generate_random_parameter("temp_f", piles)
|
||||
question = question.replace("<temp_f>", str(temp_f))
|
||||
response_starting = response_starting.replace("<temp_f>", str(temp_f))
|
||||
response_confirmed = response_confirmed.replace("<temp_f>", str(temp_f))
|
||||
tool_args_extra["temperature"] = temp_f
|
||||
if "<temp_c>" in question or "<temp_c>" in response_starting or "<temp_c>" in response_confirmed:
|
||||
temp_c = generate_random_parameter("temp_c", piles)
|
||||
question = question.replace("<temp_c>", str(temp_c))
|
||||
response_starting = response_starting.replace("<temp_c>", str(temp_c))
|
||||
response_confirmed = response_confirmed.replace("<temp_c>", str(temp_c))
|
||||
tool_args_extra["temperature"] = temp_c
|
||||
|
||||
retry_prompt = failure_case.get("retry_prompt", f"Trying again with {friendly_name}.").replace("<device_name>", friendly_name)
|
||||
error_result = failure_case.get("error_result", "Error").replace("<device_name>", friendly_name)
|
||||
|
||||
tool_name = SERVICE_TO_TOOL_MAP.get(service_action, TOOL_TURN_ON)
|
||||
first_args = {"entity_id": bad_device} if use_service_names else {"name": bad_device}
|
||||
retry_args = {"entity_id": target_device} if use_service_names else {"name": target_device}
|
||||
first_args.update(tool_args_extra)
|
||||
retry_args.update(tool_args_extra)
|
||||
|
||||
tool_call_sequence = [
|
||||
{
|
||||
"answer_starting": response_starting,
|
||||
"tool_calls": [{
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": first_args
|
||||
}],
|
||||
"tool_results": [{
|
||||
"tool_name": service_name if use_service_names else tool_name,
|
||||
"tool_result": error_result
|
||||
}]
|
||||
},
|
||||
{
|
||||
"answer_starting": retry_prompt,
|
||||
"tool_calls": [{
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": retry_args
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question,
|
||||
"answers": [response_confirmed],
|
||||
"tool_call_sequence": tool_call_sequence,
|
||||
"tool_calls": []
|
||||
}
|
||||
|
||||
def generate_refusal_example(refusal_case: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
service_name = refusal_case["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
target_device = f"{device_type}.{refusal_case['device_name']}"
|
||||
friendly_name = refusal_case.get("friendly_name", refusal_case["device_name"].replace("_", " ").title())
|
||||
desired_state = refusal_case.get("desired_state", "")
|
||||
reason_type = refusal_case.get("reason_type", "not_available")
|
||||
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||
max_devices=max_devices, avoid_device_names=[target_device], language=language)
|
||||
|
||||
if reason_type == "already_state":
|
||||
state = desired_state if desired_state else SUPPORTED_DEVICES[device_type].possible_states[0][0]
|
||||
device_list.insert(random.randint(0, len(device_list)), format_device_line(
|
||||
device_name=target_device,
|
||||
friendly_name=friendly_name,
|
||||
state=state
|
||||
))
|
||||
if device_type not in device_types:
|
||||
device_types.append(device_type)
|
||||
|
||||
available_tools = []
|
||||
for x in set(device_types):
|
||||
available_tools.extend(SUPPORTED_DEVICES[x].get_all_tools(extra_exposed_attributes))
|
||||
available_tools = list(dict.fromkeys(available_tools))
|
||||
|
||||
response_text = refusal_case["response"].replace("<device_name>", friendly_name).lower()
|
||||
question = refusal_case["phrase"].replace("<device_name>", friendly_name).lower()
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question,
|
||||
"answers": [response_text],
|
||||
"answer_starting": "",
|
||||
"tool_calls": []
|
||||
}
|
||||
|
||||
def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format):
|
||||
piles = get_dataset_piles(language)
|
||||
sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts)
|
||||
@@ -528,24 +661,26 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
||||
answers = " ".join(example["answers"])
|
||||
answer_starting = example.get("answer_starting", "")
|
||||
|
||||
tool_call_sequence = example.get("tool_call_sequence")
|
||||
|
||||
tool_calls = []
|
||||
tool_results = []
|
||||
|
||||
# Add tool use blocks if there are tool calls
|
||||
if len(example["tool_calls"]) > 0:
|
||||
for tool_call in example["tool_calls"]:
|
||||
# Use service_name if in service mode, otherwise use tool_name
|
||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||
tool_calls.append({
|
||||
"name": call_name,
|
||||
"arguments": json.dumps(tool_call["tool_args"])
|
||||
})
|
||||
if not tool_call_sequence:
|
||||
# Add tool use blocks if there are tool calls
|
||||
if len(example["tool_calls"]) > 0:
|
||||
for tool_call in example["tool_calls"]:
|
||||
# Use service_name if in service mode, otherwise use tool_name
|
||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||
tool_calls.append({
|
||||
"name": call_name,
|
||||
"arguments": json.dumps(tool_call["tool_args"])
|
||||
})
|
||||
|
||||
tool_results.append({
|
||||
"tool_name": call_name,
|
||||
"tool_call_id": f"call_{len(tool_results) + 1}",
|
||||
"tool_result": "Success"
|
||||
})
|
||||
tool_results.append({
|
||||
"tool_name": call_name,
|
||||
"tool_call_id": f"call_{len(tool_results) + 1}",
|
||||
"tool_result": "Success"
|
||||
})
|
||||
|
||||
if append_user_instruction_prompt:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
@@ -569,8 +704,68 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
||||
}
|
||||
]
|
||||
|
||||
if len(tool_calls) > 0:
|
||||
|
||||
if tool_call_sequence:
|
||||
call_id_counter = 1
|
||||
for step in tool_call_sequence:
|
||||
step_tool_calls = []
|
||||
for tool_call in step.get("tool_calls", []):
|
||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||
step_tool_calls.append({
|
||||
"name": call_name,
|
||||
"arguments": json.dumps(tool_call["tool_args"])
|
||||
})
|
||||
|
||||
assistant_block = {
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": step.get("answer_starting", "") }]
|
||||
}
|
||||
if step_tool_calls:
|
||||
assistant_block["tool_calls"] = [ { "function": tc } for tc in step_tool_calls ]
|
||||
conversation.append(assistant_block)
|
||||
|
||||
if step_tool_calls:
|
||||
provided_results = step.get("tool_results")
|
||||
step_tool_results = []
|
||||
if provided_results:
|
||||
for provided in provided_results:
|
||||
step_tool_results.append({
|
||||
"tool_name": provided.get("tool_name", step_tool_calls[0]["name"]),
|
||||
"tool_call_id": provided.get("tool_call_id", f"call_{call_id_counter}"),
|
||||
"tool_result": provided.get("tool_result", "Success")
|
||||
})
|
||||
call_id_counter += 1
|
||||
else:
|
||||
for call in step_tool_calls:
|
||||
step_tool_results.append({
|
||||
"tool_name": call["name"],
|
||||
"tool_call_id": f"call_{call_id_counter}",
|
||||
"tool_result": "Success"
|
||||
})
|
||||
call_id_counter += 1
|
||||
|
||||
if tool_response_format == "text":
|
||||
conversation.append({
|
||||
"role": "tool",
|
||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
|
||||
})
|
||||
elif tool_response_format == "functiongemma":
|
||||
conversation.append({
|
||||
"role": "tool",
|
||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
|
||||
})
|
||||
|
||||
if step.get("post_tool_response"):
|
||||
conversation.append({
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": step["post_tool_response"] }]
|
||||
})
|
||||
|
||||
conversation.append({
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": answers }],
|
||||
})
|
||||
elif len(tool_calls) > 0:
|
||||
assistant_starting_block = {
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": answer_starting }],
|
||||
@@ -617,7 +812,9 @@ def generate_sft_file(
|
||||
*,
|
||||
static_factor: float,
|
||||
template_factor: int,
|
||||
status_request_factor: int):
|
||||
status_request_factor: int,
|
||||
failure_factor: float = 1,
|
||||
refusal_factor: float = 1):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
piles = get_dataset_piles(language)
|
||||
@@ -649,6 +846,15 @@ def generate_sft_file(
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
for failure_case in tqdm(piles.pile_of_failed_tool_calls):
|
||||
try:
|
||||
run_factor_times(generate_tool_failure_example, generated_examples, failure_case, person, failure_factor, language)
|
||||
except NoResponseAvailableException as ex:
|
||||
missing_responses.add(str(ex))
|
||||
|
||||
for refusal_case in tqdm(piles.pile_of_refusals):
|
||||
run_factor_times(generate_refusal_example, generated_examples, refusal_case, person, refusal_factor, language)
|
||||
|
||||
for status_request in tqdm(piles.pile_of_status_requests):
|
||||
run_factor_times(generate_status_request, generated_examples, status_request, "assistant", status_request_factor, language)
|
||||
|
||||
|
||||
4
data/piles/english/pile_of_failed_tool_calls.csv
Normal file
4
data/piles/english/pile_of_failed_tool_calls.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt
|
||||
light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the <device_name>.","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead."
|
||||
climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the <device_name> to <temp_f> degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat."
|
||||
fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the <device_name>.","Error: Entity fan.offce not found.","I'll try the office fan again."
|
||||
|
4
data/piles/english/pile_of_refusals.csv
Normal file
4
data/piles/english/pile_of_refusals.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
reason_type,service_name,device_name,friendly_name,desired_state,phrase,response
|
||||
not_available,lock.lock,back_door,Back Door Lock,,"Lock the <device_name>.","I can't find a back door lock to control."
|
||||
already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the <device_name>.","The hallway lights are already on."
|
||||
already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the <device_name>.","The garage outlet is already off."
|
||||
|
4
data/piles/french/pile_of_failed_tool_calls.csv
Normal file
4
data/piles/french/pile_of_failed_tool_calls.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt
|
||||
light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the <device_name>.","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead."
|
||||
climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the <device_name> to <temp_f> degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat."
|
||||
fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the <device_name>.","Error: Entity fan.offce not found.","I'll try the office fan again."
|
||||
|
4
data/piles/french/pile_of_refusals.csv
Normal file
4
data/piles/french/pile_of_refusals.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
reason_type,service_name,device_name,friendly_name,desired_state,phrase,response
|
||||
not_available,lock.lock,back_door,Back Door Lock,,"Lock the <device_name>.","I can't find a back door lock to control."
|
||||
already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the <device_name>.","The hallway lights are already on."
|
||||
already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the <device_name>.","The garage outlet is already off."
|
||||
|
4
data/piles/german/pile_of_failed_tool_calls.csv
Normal file
4
data/piles/german/pile_of_failed_tool_calls.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt
|
||||
light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the <device_name>.","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead."
|
||||
climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the <device_name> to <temp_f> degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat."
|
||||
fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the <device_name>.","Error: Entity fan.offce not found.","I'll try the office fan again."
|
||||
|
4
data/piles/german/pile_of_refusals.csv
Normal file
4
data/piles/german/pile_of_refusals.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
reason_type,service_name,device_name,friendly_name,desired_state,phrase,response
|
||||
not_available,lock.lock,back_door,Back Door Lock,,"Lock the <device_name>.","I can't find a back door lock to control."
|
||||
already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the <device_name>.","The hallway lights are already on."
|
||||
already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the <device_name>.","The garage outlet is already off."
|
||||
|
4
data/piles/polish/pile_of_failed_tool_calls.csv
Normal file
4
data/piles/polish/pile_of_failed_tool_calls.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt
|
||||
light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the <device_name>.","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead."
|
||||
climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the <device_name> to <temp_f> degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat."
|
||||
fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the <device_name>.","Error: Entity fan.offce not found.","I'll try the office fan again."
|
||||
|
4
data/piles/polish/pile_of_refusals.csv
Normal file
4
data/piles/polish/pile_of_refusals.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
reason_type,service_name,device_name,friendly_name,desired_state,phrase,response
|
||||
not_available,lock.lock,back_door,Back Door Lock,,"Lock the <device_name>.","I can't find a back door lock to control."
|
||||
already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the <device_name>.","The hallway lights are already on."
|
||||
already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the <device_name>.","The garage outlet is already off."
|
||||
|
4
data/piles/spanish/pile_of_failed_tool_calls.csv
Normal file
4
data/piles/spanish/pile_of_failed_tool_calls.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
service_name,correct_device_name,correct_friendly_name,bad_device_name,phrase,error_result,retry_prompt
|
||||
light.turn_on,light.living_room_lamp,Living Room Lamp,light.livng_room_lamp,"Please turn on the <device_name>.","Error: Entity light.livng_room_lamp not found.","Trying the living room lamp instead."
|
||||
climate.set_temperature,climate.hallway,Hallway Thermostat,climate.halway,"Set the <device_name> to <temp_f> degrees.","Error: Entity climate.halway not found.","Retrying with the hallway thermostat."
|
||||
fan.turn_off,fan.office,Office Fan,fan.offce,"Turn off the <device_name>.","Error: Entity fan.offce not found.","I'll try the office fan again."
|
||||
|
4
data/piles/spanish/pile_of_refusals.csv
Normal file
4
data/piles/spanish/pile_of_refusals.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
reason_type,service_name,device_name,friendly_name,desired_state,phrase,response
|
||||
not_available,lock.lock,back_door,Back Door Lock,,"Lock the <device_name>.","I can't find a back door lock to control."
|
||||
already_state,light.turn_on,hallway,Hallway Lights,on,"Turn on the <device_name>.","The hallway lights are already on."
|
||||
already_state,switch.turn_off,garage,Garage Outlet,off,"Turn off the <device_name>.","The garage outlet is already off."
|
||||
|
@@ -147,6 +147,20 @@ class DatasetPiles:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_hallucinated_service_names = list(reader)
|
||||
|
||||
failed_tool_calls_path = f"{cwd}/piles/{language}/pile_of_failed_tool_calls.csv"
|
||||
self.pile_of_failed_tool_calls = []
|
||||
if os.path.exists(failed_tool_calls_path):
|
||||
with open(failed_tool_calls_path, encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_failed_tool_calls = list(reader)
|
||||
|
||||
refusals_path = f"{cwd}/piles/{language}/pile_of_refusals.csv"
|
||||
self.pile_of_refusals = []
|
||||
if os.path.exists(refusals_path):
|
||||
with open(refusals_path, encoding="utf8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
self.pile_of_refusals = list(reader)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
@@ -158,4 +172,4 @@ def get_dataset_piles(language: str) -> DatasetPiles:
|
||||
"light", "switch", "fan", "garage_door", "blinds",
|
||||
"lock","media_player", "climate", "vacuum", "timer", "todo",
|
||||
], language)
|
||||
return _piles_cache[language]
|
||||
return _piles_cache[language]
|
||||
|
||||
Reference in New Issue
Block a user