mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
start re-working training to use axlotl instead of the custom script
This commit is contained in:
@@ -95,7 +95,7 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language
|
||||
|
||||
return device_lines, list(device_types), list(extra_exposed_attributes)
|
||||
|
||||
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 32, use_service_names: bool = False):
|
||||
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
question = action["phrase"]
|
||||
service_name = action["service_name"]
|
||||
device_type = service_name.split(".")[0]
|
||||
@@ -136,19 +136,84 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
short=False
|
||||
).lower()
|
||||
|
||||
response = response.replace("<device_name>", friendly_name)
|
||||
answer_list = [response]
|
||||
tool_args = {}
|
||||
|
||||
question = question.replace("<device_name>", target_device)
|
||||
answer_list = replace_answer(answer_list, "<device_name>", target_device)
|
||||
|
||||
if "climate" in service_action:
|
||||
if "<hvac_mode>" in question:
|
||||
hvac_mode = generate_random_parameter("hvac_mode", piles)
|
||||
question = question.replace("<hvac_mode>", hvac_mode)
|
||||
answer_list = replace_answer(answer_list, "<hvac_mode>", hvac_mode)
|
||||
# Add hvac_mode as temperature parameter for climate tool
|
||||
tool_args["hvac_mode"] = hvac_mode
|
||||
|
||||
if "<fan_mode>" in question:
|
||||
fan_mode = generate_random_parameter("fan_mode", piles)
|
||||
question = question.replace("<fan_mode>", fan_mode)
|
||||
answer_list = replace_answer(answer_list, "<fan_mode>", fan_mode)
|
||||
tool_args["fan_mode"] = fan_mode
|
||||
|
||||
if "<temp_f>" in question:
|
||||
temp_f = generate_random_parameter("temp_f", piles)
|
||||
question = question.replace("<temp_f>", str(temp_f))
|
||||
answer_list = replace_answer(answer_list, "<temp_f>", str(temp_f))
|
||||
tool_args["temperature"] = temp_f
|
||||
|
||||
if "<temp_c>" in question:
|
||||
temp_c = generate_random_parameter("temp_c", piles)
|
||||
question = question.replace("<temp_c>", str(temp_c))
|
||||
answer_list = replace_answer(answer_list, "<temp_c>", str(temp_c))
|
||||
tool_args["temperature"] = temp_c
|
||||
|
||||
if "<humidity>" in question:
|
||||
humidity = generate_random_parameter("humidity", piles)
|
||||
question = question.replace("<humidity>", str(humidity))
|
||||
answer_list = replace_answer(answer_list, "<humidity>", str(humidity))
|
||||
tool_args["humidity"] = humidity
|
||||
|
||||
if "light" in service_action:
|
||||
if "<brightness>" in question:
|
||||
brightness = generate_random_parameter("brightness", piles)
|
||||
question = question.replace("<brightness>", str(brightness))
|
||||
answer_list = replace_answer(answer_list, "<brightness>", str(brightness))
|
||||
tool_args["brightness"] = brightness
|
||||
|
||||
if "<color>" in question:
|
||||
random_rgb = generate_random_parameter("rgb_color", piles)
|
||||
random_rgb_name = closest_color(random_rgb)
|
||||
question = question.replace("<color>", str(random_rgb_name))
|
||||
answer_list = replace_answer(answer_list, "<color>", str(random_rgb_name))
|
||||
tool_args["color"] = random_rgb_name
|
||||
|
||||
if "timer" in service_action:
|
||||
if "<duration>" in question:
|
||||
duration = generate_random_parameter("duration", piles)
|
||||
duration_name = piles.pile_of_durations[duration]
|
||||
question = question.replace("<duration>", duration_name)
|
||||
answer_list = replace_answer(answer_list, "<duration>", duration_name)
|
||||
tool_args["duration"] = str(duration)
|
||||
|
||||
if "todo" in service_action:
|
||||
if "<todo>" in question:
|
||||
todo = generate_random_parameter("todo", piles)
|
||||
question = question.replace("<todo>", todo)
|
||||
answer_list = replace_answer(answer_list, "<todo>", todo)
|
||||
tool_args["item"] = todo
|
||||
|
||||
if use_service_names:
|
||||
tool_call = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": {"entity_id": target_device}
|
||||
"tool_args": {"entity_id": target_device, **tool_args}
|
||||
}
|
||||
else:
|
||||
tool_call = {
|
||||
"tool_name": tool_name,
|
||||
"service_name": service_name,
|
||||
"tool_args": {"name": target_device}
|
||||
"tool_args": {"name": target_device, **tool_args}
|
||||
}
|
||||
|
||||
if "arguments" in action and action["arguments"]:
|
||||
@@ -163,7 +228,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
|
||||
"states": device_list,
|
||||
"available_tools": available_tools,
|
||||
"question": question.lower(),
|
||||
"answers": [ response ],
|
||||
"answers": answer_list,
|
||||
"tool_calls": [ tool_call ]
|
||||
}
|
||||
|
||||
@@ -173,7 +238,7 @@ def replace_answer(list_of_answer, var, value):
|
||||
new_list.append(answer.replace(var, value))
|
||||
return new_list
|
||||
|
||||
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 32, use_service_names: bool = False):
|
||||
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
|
||||
template_device_types: list[str] = template["device_type"].split("|")
|
||||
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
|
||||
question_template: str = template["phrase"]
|
||||
@@ -353,7 +418,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
|
||||
"tool_calls": tool_calls
|
||||
}
|
||||
|
||||
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 32, return_target_device: bool = False, use_service_names: bool = False):
|
||||
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
|
||||
device_type: str = template["device_type"]
|
||||
state_name: str = template["state"]
|
||||
question_template: str = template["phrase"]
|
||||
@@ -456,56 +521,72 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
|
||||
question = example["question"]
|
||||
answers = " ".join(example["answers"])
|
||||
|
||||
# Build assistant message with content blocks
|
||||
assistant_content = []
|
||||
|
||||
# Add text response
|
||||
assistant_content.append({
|
||||
"type": "text",
|
||||
"text": answers
|
||||
})
|
||||
tool_calls = []
|
||||
tool_results = []
|
||||
|
||||
# Add tool use blocks if there are tool calls
|
||||
if len(example["tool_calls"]) > 0:
|
||||
for tool_call in example["tool_calls"]:
|
||||
# Use service_name if in service mode, otherwise use tool_name
|
||||
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
|
||||
assistant_content.append({
|
||||
"type": "tool_use",
|
||||
tool_calls.append({
|
||||
"name": call_name,
|
||||
"parameters": tool_call["tool_args"]
|
||||
"arguments": json.dumps(tool_call["tool_args"])
|
||||
})
|
||||
|
||||
tool_results.append({
|
||||
"tool_name": call_name,
|
||||
"tool_call_id": f"call_{len(tool_results) + 1}",
|
||||
"tool_result": "Success"
|
||||
})
|
||||
|
||||
if use_system_role:
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": sys_prompt
|
||||
"content": [{"type": "text", "text": sys_prompt}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": assistant_content
|
||||
},
|
||||
"content": [{ "type": "text", "text": question }]
|
||||
}
|
||||
]
|
||||
else:
|
||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "\n".join([ sys_prompt, user_instruction_words, question ])
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, user_instruction_words, question ]) }]
|
||||
}
|
||||
]
|
||||
|
||||
if len(tool_calls) > 0:
|
||||
conversation.extend([
|
||||
{
|
||||
"role": "assistant",
|
||||
# FIXME: use the "confirmation" response here instead of a canned text
|
||||
"content": [{ "type": "text", "text": "I will perform the requested user action." }],
|
||||
"tool_calls": tool_calls
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in tool_results]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": assistant_content
|
||||
"content": [{ "type": "text", "text": answers }],
|
||||
},
|
||||
]
|
||||
])
|
||||
else:
|
||||
conversation.extend([
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{ "type": "text", "text": answers }],
|
||||
}
|
||||
])
|
||||
|
||||
return {
|
||||
"conversations": conversation,
|
||||
"messages": conversation,
|
||||
"tools": SERVICE_TOOLS if use_service_names else HASS_TOOLS
|
||||
}
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ def generate_random_parameter(param_name, piles_of_data):
|
||||
|
||||
return param_generator()
|
||||
|
||||
# FIXME: return 2 responses, 1 to confirm the action and one to confirm completion of the action
|
||||
def get_random_response(pile_of_responses, *, service: str, persona: str, question_template: str, short: bool) -> str:
|
||||
|
||||
required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var]))
|
||||
|
||||
Reference in New Issue
Block a user