start re-working training to use axlotl instead of the custom script

This commit is contained in:
Alex O'Connell
2025-11-30 22:29:08 -05:00
parent 04a5909214
commit 55f254149a
14 changed files with 280 additions and 1309 deletions

View File

@@ -95,7 +95,7 @@ def random_device_list(max_devices: int, avoid_device_names: list[str], language
return device_lines, list(device_types), list(extra_exposed_attributes)
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 32, use_service_names: bool = False):
def generate_static_example(action: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
question = action["phrase"]
service_name = action["service_name"]
device_type = service_name.split(".")[0]
@@ -136,19 +136,84 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
short=False
).lower()
response = response.replace("<device_name>", friendly_name)
answer_list = [response]
tool_args = {}
question = question.replace("<device_name>", target_device)
answer_list = replace_answer(answer_list, "<device_name>", target_device)
if "climate" in service_action:
if "<hvac_mode>" in question:
hvac_mode = generate_random_parameter("hvac_mode", piles)
question = question.replace("<hvac_mode>", hvac_mode)
answer_list = replace_answer(answer_list, "<hvac_mode>", hvac_mode)
# Add hvac_mode as temperature parameter for climate tool
tool_args["hvac_mode"] = hvac_mode
if "<fan_mode>" in question:
fan_mode = generate_random_parameter("fan_mode", piles)
question = question.replace("<fan_mode>", fan_mode)
answer_list = replace_answer(answer_list, "<fan_mode>", fan_mode)
tool_args["fan_mode"] = fan_mode
if "<temp_f>" in question:
temp_f = generate_random_parameter("temp_f", piles)
question = question.replace("<temp_f>", str(temp_f))
answer_list = replace_answer(answer_list, "<temp_f>", str(temp_f))
tool_args["temperature"] = temp_f
if "<temp_c>" in question:
temp_c = generate_random_parameter("temp_c", piles)
question = question.replace("<temp_c>", str(temp_c))
answer_list = replace_answer(answer_list, "<temp_c>", str(temp_c))
tool_args["temperature"] = temp_c
if "<humidity>" in question:
humidity = generate_random_parameter("humidity", piles)
question = question.replace("<humidity>", str(humidity))
answer_list = replace_answer(answer_list, "<humidity>", str(humidity))
tool_args["humidity"] = humidity
if "light" in service_action:
if "<brightness>" in question:
brightness = generate_random_parameter("brightness", piles)
question = question.replace("<brightness>", str(brightness))
answer_list = replace_answer(answer_list, "<brightness>", str(brightness))
tool_args["brightness"] = brightness
if "<color>" in question:
random_rgb = generate_random_parameter("rgb_color", piles)
random_rgb_name = closest_color(random_rgb)
question = question.replace("<color>", str(random_rgb_name))
answer_list = replace_answer(answer_list, "<color>", str(random_rgb_name))
tool_args["color"] = random_rgb_name
if "timer" in service_action:
if "<duration>" in question:
duration = generate_random_parameter("duration", piles)
duration_name = piles.pile_of_durations[duration]
question = question.replace("<duration>", duration_name)
answer_list = replace_answer(answer_list, "<duration>", duration_name)
tool_args["duration"] = str(duration)
if "todo" in service_action:
if "<todo>" in question:
todo = generate_random_parameter("todo", piles)
question = question.replace("<todo>", todo)
answer_list = replace_answer(answer_list, "<todo>", todo)
tool_args["item"] = todo
if use_service_names:
tool_call = {
"tool_name": tool_name,
"service_name": service_name,
"tool_args": {"entity_id": target_device}
"tool_args": {"entity_id": target_device, **tool_args}
}
else:
tool_call = {
"tool_name": tool_name,
"service_name": service_name,
"tool_args": {"name": target_device}
"tool_args": {"name": target_device, **tool_args}
}
if "arguments" in action and action["arguments"]:
@@ -163,7 +228,7 @@ def generate_static_example(action: dict, persona: str, language: str, max_devic
"states": device_list,
"available_tools": available_tools,
"question": question.lower(),
"answers": [ response ],
"answers": answer_list,
"tool_calls": [ tool_call ]
}
@@ -173,7 +238,7 @@ def replace_answer(list_of_answer, var, value):
new_list.append(answer.replace(var, value))
return new_list
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 32, use_service_names: bool = False):
def generate_templated_example(template: dict, persona: str, language: str, max_devices: int = 128, use_service_names: bool = False):
template_device_types: list[str] = template["device_type"].split("|")
service_names: list[str] = [ f"{x}.{y}" for x, y in zip(template_device_types, template["service"].split("|")) ]
question_template: str = template["phrase"]
@@ -353,7 +418,7 @@ def generate_templated_example(template: dict, persona: str, language: str, max_
"tool_calls": tool_calls
}
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 32, return_target_device: bool = False, use_service_names: bool = False):
def generate_status_request(template: dict, persona: str, language: str, max_devices: int = 128, return_target_device: bool = False, use_service_names: bool = False):
device_type: str = template["device_type"]
state_name: str = template["state"]
question_template: str = template["phrase"]
@@ -456,56 +521,72 @@ def format_example_sharegpt(example, persona, language, use_system_role, use_ser
question = example["question"]
answers = " ".join(example["answers"])
# Build assistant message with content blocks
assistant_content = []
# Add text response
assistant_content.append({
"type": "text",
"text": answers
})
tool_calls = []
tool_results = []
# Add tool use blocks if there are tool calls
if len(example["tool_calls"]) > 0:
for tool_call in example["tool_calls"]:
# Use service_name if in service mode, otherwise use tool_name
call_name = tool_call.get("service_name", tool_call["tool_name"]) if use_service_names else tool_call["tool_name"]
assistant_content.append({
"type": "tool_use",
tool_calls.append({
"name": call_name,
"parameters": tool_call["tool_args"]
"arguments": json.dumps(tool_call["tool_args"])
})
tool_results.append({
"tool_name": call_name,
"tool_call_id": f"call_{len(tool_results) + 1}",
"tool_result": "Success"
})
if use_system_role:
conversation = [
{
"role": "system",
"content": sys_prompt
"content": [{"type": "text", "text": sys_prompt}]
},
{
"role": "user",
"content": question
},
{
"role": "assistant",
"content": assistant_content
},
"content": [{ "type": "text", "text": question }]
}
]
else:
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
conversation = [
{
"role": "user",
"content": "\n".join([ sys_prompt, user_instruction_words, question ])
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, user_instruction_words, question ]) }]
}
]
if len(tool_calls) > 0:
conversation.extend([
{
"role": "assistant",
# FIXME: use the "confirmation" response here instead of a canned text
"content": [{ "type": "text", "text": "I will perform the requested user action." }],
"tool_calls": tool_calls
},
{
"role": "tool",
"content": [{ "type": "text", "text": json.dumps(result) } for result in tool_results]
},
{
"role": "assistant",
"content": assistant_content
"content": [{ "type": "text", "text": answers }],
},
]
])
else:
conversation.extend([
{
"role": "assistant",
"content": [{ "type": "text", "text": answers }],
}
])
return {
"conversations": conversation,
"messages": conversation,
"tools": SERVICE_TOOLS if use_service_names else HASS_TOOLS
}

View File

@@ -66,6 +66,7 @@ def generate_random_parameter(param_name, piles_of_data):
return param_generator()
# FIXME: return 2 responses, 1 to confirm the action and one to confirm completion of the action
def get_random_response(pile_of_responses, *, service: str, persona: str, question_template: str, short: bool) -> str:
required_vars = list(set([var for var in var_pattern.findall(question_template) if "device_name" not in var]))