mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
start re-writing dataset generation to use the new HA Assist API
This commit is contained in:
@@ -15,7 +15,7 @@ from typing import Final, Any, Callable, Optional
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import webcolors
|
import webcolors
|
||||||
|
|
||||||
# #### STATES ####
|
# STATES
|
||||||
STATE_ON: Final = "on"
|
STATE_ON: Final = "on"
|
||||||
STATE_OFF: Final = "off"
|
STATE_OFF: Final = "off"
|
||||||
STATE_ACTIVE: Final = "active"
|
STATE_ACTIVE: Final = "active"
|
||||||
@@ -41,6 +41,121 @@ STATE_CLEANING: Final = "cleaning"
|
|||||||
STATE_DOCKED: Final = "docked"
|
STATE_DOCKED: Final = "docked"
|
||||||
STATE_RETURNING: Final = "returning"
|
STATE_RETURNING: Final = "returning"
|
||||||
|
|
||||||
|
# TOOLS
|
||||||
|
TOOL_TURN_ON = "HassTurnOn"
|
||||||
|
TOOL_TURN_OFF = "HassTurnOff"
|
||||||
|
TOOL_SET_POSITION = "HassSetPosition"
|
||||||
|
TOOL_LIGHT_SET = "HassLightSet"
|
||||||
|
TOOL_SET_VOLUME = "HassSetVolume" # not implemented yet
|
||||||
|
TOOL_MEDIA_UNPAUSE = "HassMediaUnpause"
|
||||||
|
TOOL_MEDIA_PAUSE = "HassMediaPause"
|
||||||
|
TOOL_MEDIA_NEXT = "HassMediaNext" # not implemented yet
|
||||||
|
TOOL_VACUUM_START = "HassVacuumStart"
|
||||||
|
TOOL_VACUUM_RETURN_TO_BASE = "HassVacuumReturnToBase"
|
||||||
|
TOOL_LIST_ADD_ITEM = "HassListAddItem"
|
||||||
|
TOOL_START_TIMER = "HassStartTimer" # ignored if timers unsupported
|
||||||
|
TOOL_CANCEL_TIMER = "HassCancelTimer" # ignored if timers unsupported
|
||||||
|
TOOL_PAUSE_TIMER = "HassPauseTimer" # ignored if timers unsupported
|
||||||
|
TOOL_UNPAUSE_TIMER = "HassUnpauseTimer" # ignored if timers unsupported
|
||||||
|
TOOL_SET_HUMIDITY = "HassHumidifierSetpoint"
|
||||||
|
TOOL_SET_HUMIDIFIER_MODE = "HassHumidifierMode"
|
||||||
|
|
||||||
|
# TOOLS
|
||||||
|
HASS_TOOLS = [
|
||||||
|
{
|
||||||
|
"name": TOOL_TURN_ON,
|
||||||
|
"description": "Turns on/opens a device or entity",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_TURN_OFF,
|
||||||
|
"description": "Turns off/closes a device or entity",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_SET_POSITION,
|
||||||
|
"description": "Sets the position of a device or entity",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "area": "string", "floor": "string", "position": "integer" },
|
||||||
|
"required": [ "position"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_LIST_ADD_ITEM,
|
||||||
|
"description": "Add item to a todo list",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "item": "string", "name": "string" },
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_SET_HUMIDITY,
|
||||||
|
"description": "Set desired humidity level",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "humidity": "integer" },
|
||||||
|
"required": [ "name", "humidity"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_SET_HUMIDIFIER_MODE,
|
||||||
|
"description": "Set humidifier mode",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "mode": "string" },
|
||||||
|
"required": [ "name", "mode"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_LIGHT_SET,
|
||||||
|
"description": "Sets the brightness or color of a light",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "area": "string", "floor": "string", "color": "string", "temperature": "integer", "brightness": "integer" },
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_MEDIA_UNPAUSE,
|
||||||
|
"description": "Resumes a media player",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_MEDIA_PAUSE,
|
||||||
|
"description": "Pauses a media player",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_VACUUM_START,
|
||||||
|
"description": "Starts a vacuum",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": TOOL_VACUUM_RETURN_TO_BASE,
|
||||||
|
"description": "Returns a vacuum to base",
|
||||||
|
"parameters": {
|
||||||
|
"properties": { "name": "string", "area": "string", "floor": "string" },
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# define piles for global access
|
# define piles for global access
|
||||||
pile_of_durations = None
|
pile_of_durations = None
|
||||||
pile_of_media_names = None
|
pile_of_media_names = None
|
||||||
@@ -83,96 +198,83 @@ def get_included_vars(response: str):
|
|||||||
|
|
||||||
return ",".join(sorted(result))
|
return ",".join(sorted(result))
|
||||||
|
|
||||||
|
RANDOM_PARAMETER_GENERATORS = {
|
||||||
|
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
|
||||||
|
"brightness": lambda: random.randint(0, 100),
|
||||||
|
"fan_mode": lambda: random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"]),
|
||||||
|
"temp_f": lambda: random.randint(60, 80),
|
||||||
|
"temp_c": lambda: random.randint(15, 25),
|
||||||
|
"humidity": lambda: random.randint(10, 90),
|
||||||
|
"preset_mode": lambda: random.choice(["home", "eco", "away", "auto"]),
|
||||||
|
"hvac_mode": lambda: random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"]),
|
||||||
|
"media": lambda: random.choice(pile_of_media_names),
|
||||||
|
"volume": lambda: round(random.random(), 2),
|
||||||
|
"duration": lambda: random.choice(list(pile_of_durations.keys())),
|
||||||
|
"remaining": lambda: f"{random.randint(0, 3):02}:{random.randint(0, 60)}:{random.randint(0, 60)}",
|
||||||
|
"todo": lambda: random.choice(pile_of_todo_items),
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_random_parameter(param_name):
|
||||||
|
param_generator = RANDOM_PARAMETER_GENERATORS.get(param_name)
|
||||||
|
|
||||||
|
if not param_generator:
|
||||||
|
raise Exception(f"Unknown param to generate random value for {param_name}")
|
||||||
|
|
||||||
|
return param_generator()
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeviceType:
|
class DeviceType:
|
||||||
name: str
|
name: str
|
||||||
possible_states: list[(str, float)]
|
possible_states: list[(str, float)]
|
||||||
services: dict[str, list]
|
|
||||||
random_parameter_generator: Optional[dict[str, Callable]] = None
|
|
||||||
|
|
||||||
def get_all_services(self, extra_exposed_attributes):
|
|
||||||
result = []
|
|
||||||
for service in self.services.keys():
|
|
||||||
args = set(extra_exposed_attributes).intersection(self.services[service])
|
|
||||||
result.append(f"{self.name}.{service}({','.join(args)})")
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_random_parameter(self, param_name):
|
|
||||||
return self.random_parameter_generator[param_name]()
|
|
||||||
|
|
||||||
def get_random_state(self, extra_exposed_attributes=[]):
|
def get_random_state(self, extra_exposed_attributes=[]):
|
||||||
states = [ x[0] for x in self.possible_states ]
|
states = [ x[0] for x in self.possible_states ]
|
||||||
weights = [ x[1] for x in self.possible_states ]
|
weights = [ x[1] for x in self.possible_states ]
|
||||||
return random.choices(states, weights=weights, k=1)[0]
|
return random.choices(states, weights=weights, k=1)[0]
|
||||||
|
|
||||||
|
# TODO: make services into a global "tools" concept since tools are not specific to device types
|
||||||
class LightDeviceType(DeviceType):
|
class LightDeviceType(DeviceType):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("light",
|
super().__init__("light",
|
||||||
possible_states=[
|
possible_states=[
|
||||||
(STATE_ON, 0.5),
|
(STATE_ON, 0.5),
|
||||||
(STATE_OFF, 0.5)
|
(STATE_OFF, 0.5)
|
||||||
],
|
]
|
||||||
services={
|
|
||||||
"turn_on": [ "rgb_color", "brightness" ],
|
|
||||||
"turn_off": [],
|
|
||||||
"toggle": []
|
|
||||||
},
|
|
||||||
random_parameter_generator={
|
|
||||||
"rgb_color": lambda: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
|
|
||||||
"brightness": lambda: random.randint(0, 100),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_random_state(self, extra_exposed_attributes=[]):
|
def get_random_state(self, extra_exposed_attributes=[]):
|
||||||
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||||
|
|
||||||
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
|
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
|
||||||
random_rgb = self.get_random_parameter("rgb_color")
|
random_rgb = generate_random_parameter("rgb_color")
|
||||||
state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb)
|
state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb)
|
||||||
|
|
||||||
if random.random() < 0.7 and "brightness" in extra_exposed_attributes:
|
if random.random() < 0.7 and "brightness" in extra_exposed_attributes:
|
||||||
state = state + ";" + str(self.get_random_parameter("brightness")) + "%"
|
state = state + ";" + str(generate_random_parameter("brightness")) + "%"
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
class ClimateDeviceType(DeviceType):
|
class ClimateDeviceType(DeviceType):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("climate", [], {
|
super().__init__("climate", [])
|
||||||
"turn_on": [],
|
|
||||||
"turn_off": [],
|
|
||||||
"toggle": [],
|
|
||||||
"set_temperature": ["temperature"],
|
|
||||||
"set_humidity": ["humidity"],
|
|
||||||
"set_fan_mode": ["fan_mode"],
|
|
||||||
"set_hvac_mode": ["hvac_mode"],
|
|
||||||
"set_preset_mode": ["preset_mode"]
|
|
||||||
},
|
|
||||||
random_parameter_generator={
|
|
||||||
"fan_mode": lambda: random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"]),
|
|
||||||
"temp_f": lambda: random.randint(60, 80),
|
|
||||||
"temp_c": lambda: random.randint(15, 25),
|
|
||||||
"humidity": lambda: random.randint(10, 90),
|
|
||||||
"preset_mode": lambda: random.choice(["home", "eco", "away", "auto"]),
|
|
||||||
"hvac_mode": lambda: random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"]),
|
|
||||||
})
|
|
||||||
|
|
||||||
def get_random_state(self, extra_exposed_attributes=[]):
|
def get_random_state(self, extra_exposed_attributes=[]):
|
||||||
"""state;fan_mode;temperature;humidity"""
|
"""state;fan_mode;temperature;humidity"""
|
||||||
state = self.get_random_parameter("hvac_mode")
|
state = generate_random_parameter("hvac_mode")
|
||||||
|
|
||||||
if "fan_mode" in extra_exposed_attributes:
|
if "fan_mode" in extra_exposed_attributes:
|
||||||
state = state + ";" + self.get_random_parameter("fan_mode")
|
state = state + ";" + generate_random_parameter("fan_mode")
|
||||||
if "temperature" in extra_exposed_attributes:
|
if "temperature" in extra_exposed_attributes:
|
||||||
if random.random() > 0.5:
|
if random.random() > 0.5:
|
||||||
state = state + ";" + str(self.get_random_parameter("temp_f")) + "F"
|
state = state + ";" + str(generate_random_parameter("temp_f")) + "F"
|
||||||
else:
|
else:
|
||||||
state = state + ";" + str(self.get_random_parameter("temp_c")) + "C"
|
state = state + ";" + str(generate_random_parameter("temp_c")) + "C"
|
||||||
if "humidity" in extra_exposed_attributes:
|
if "humidity" in extra_exposed_attributes:
|
||||||
state = state + ";" + str(self.get_random_parameter("humidity")) + "%"
|
state = state + ";" + str(generate_random_parameter("humidity")) + "%"
|
||||||
|
|
||||||
if random.random() < 0.8 and "preset_mode" in extra_exposed_attributes:
|
if random.random() < 0.8 and "preset_mode" in extra_exposed_attributes:
|
||||||
# if it is not "on a preset" then don't add the mode
|
# if it is not "on a preset" then don't add the mode
|
||||||
state = state + ";" + self.get_random_parameter("preset_mode")
|
state = state + ";" + generate_random_parameter("preset_mode")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@@ -186,33 +288,16 @@ class MediaPlayerDeviceType(DeviceType):
|
|||||||
(STATE_PAUSED, 0.05),
|
(STATE_PAUSED, 0.05),
|
||||||
(STATE_STANDBY, 0.05),
|
(STATE_STANDBY, 0.05),
|
||||||
(STATE_BUFFERING, 0.01),
|
(STATE_BUFFERING, 0.01),
|
||||||
], {
|
])
|
||||||
"turn_on": [],
|
|
||||||
"turn_off": [],
|
|
||||||
"toggle": [],
|
|
||||||
"volume_up": [],
|
|
||||||
"volume_down": [],
|
|
||||||
"volume_mute": [],
|
|
||||||
"media_play_pause": [],
|
|
||||||
"media_play": [],
|
|
||||||
"media_pause": [],
|
|
||||||
"media_stop": [],
|
|
||||||
"media_next_track": [],
|
|
||||||
"media_previous_track": []
|
|
||||||
},
|
|
||||||
random_parameter_generator={
|
|
||||||
"media": lambda: random.choice(pile_of_media_names),
|
|
||||||
"volume": lambda: round(random.random(), 2),
|
|
||||||
})
|
|
||||||
|
|
||||||
def get_random_state(self, extra_exposed_attributes=[]):
|
def get_random_state(self, extra_exposed_attributes=[]):
|
||||||
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||||
|
|
||||||
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
|
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
|
||||||
state = state + ";" + self.get_random_parameter("media")
|
state = state + ";" + generate_random_parameter("media")
|
||||||
|
|
||||||
if "volume_level" in extra_exposed_attributes and state != STATE_OFF:
|
if "volume_level" in extra_exposed_attributes and state != STATE_OFF:
|
||||||
state = state + ";vol=" + str(self.get_random_parameter("volume"))
|
state = state + ";vol=" + str(generate_random_parameter("volume"))
|
||||||
return state
|
return state
|
||||||
|
|
||||||
SUPPORTED_DEVICES = {
|
SUPPORTED_DEVICES = {
|
||||||
@@ -308,26 +393,11 @@ SUPPORTED_DEVICES = {
|
|||||||
(STATE_IDLE, 0.2),
|
(STATE_IDLE, 0.2),
|
||||||
(STATE_ACTIVE, 0.6),
|
(STATE_ACTIVE, 0.6),
|
||||||
(STATE_PAUSED, 0.1),
|
(STATE_PAUSED, 0.1),
|
||||||
],
|
]
|
||||||
services={
|
|
||||||
"start": ["duration"],
|
|
||||||
"pause": [],
|
|
||||||
"cancel": [],
|
|
||||||
},
|
|
||||||
random_parameter_generator={
|
|
||||||
"duration": lambda: random.choice(list(pile_of_durations.keys())),
|
|
||||||
"remaining": lambda: f"{random.randint(0, 3):02}:{random.randint(0, 60)}:{random.randint(0, 60)}"
|
|
||||||
}
|
|
||||||
),
|
),
|
||||||
"todo": DeviceType(
|
"todo": DeviceType(
|
||||||
name="todo",
|
name="todo",
|
||||||
possible_states=[ (f"{i}", (1/32)) for i in range(32) ],
|
possible_states=[ (f"{i}", (1/32)) for i in range(32) ],
|
||||||
services={
|
|
||||||
"add_item": ["item"],
|
|
||||||
},
|
|
||||||
random_parameter_generator={
|
|
||||||
"todo": lambda: random.choice(pile_of_todo_items),
|
|
||||||
}
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -500,7 +570,7 @@ def generate_static_example(action: dict, persona: str, max_devices: int = 32):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_services": list(available_services),
|
"available_tools": list(available_services),
|
||||||
"question": question.lower(),
|
"question": question.lower(),
|
||||||
"answers": [ response ],
|
"answers": [ response ],
|
||||||
"service_calls": [ { "service": service_name, "target_device": target_device } ]
|
"service_calls": [ { "service": service_name, "target_device": target_device } ]
|
||||||
@@ -660,7 +730,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_services": list(available_services),
|
"available_tools": list(available_services),
|
||||||
"question": question.lower(),
|
"question": question.lower(),
|
||||||
"answers": [ sentence.lower() for sentence in answer ],
|
"answers": [ sentence.lower() for sentence in answer ],
|
||||||
"service_calls": service_calls
|
"service_calls": service_calls
|
||||||
@@ -749,7 +819,7 @@ def generate_status_request(template: dict, persona: str, max_devices: int = 32,
|
|||||||
|
|
||||||
result = {
|
result = {
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_services": list(available_services),
|
"available_tools": list(available_services),
|
||||||
"question": question.lower(),
|
"question": question.lower(),
|
||||||
"answers": [ answer.lower() ],
|
"answers": [ answer.lower() ],
|
||||||
"service_calls": []
|
"service_calls": []
|
||||||
@@ -786,7 +856,7 @@ def generate_dpo_wrong_argument(template: dict, persona: str, max_devices: int =
|
|||||||
random_device = None
|
random_device = None
|
||||||
|
|
||||||
# random service should probably be "related"
|
# random service should probably be "related"
|
||||||
available_services = [ x[:-2] for x in example["available_services"] if call["service"] not in x ]
|
available_services = [ x[:-2] for x in example["available_tools"] if call["service"] not in x ]
|
||||||
hallucinated_services = [ x["hallucinated_service"] for x in pile_of_hallucinated_service_names if x["real_service"] == call["service"].split(".")[1]]
|
hallucinated_services = [ x["hallucinated_service"] for x in pile_of_hallucinated_service_names if x["real_service"] == call["service"].split(".")[1]]
|
||||||
random_service = random.choice(available_services + hallucinated_services)
|
random_service = random.choice(available_services + hallucinated_services)
|
||||||
random_argument = None # based on the service, add arguments that might be there like rgb, temperature, etc
|
random_argument = None # based on the service, add arguments that might be there like rgb, temperature, etc
|
||||||
@@ -823,7 +893,7 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i
|
|||||||
|
|
||||||
device_name = target_device["device_name"]
|
device_name = target_device["device_name"]
|
||||||
device_type = device_name.split(".")[0]
|
device_type = device_name.split(".")[0]
|
||||||
random_device_services = [ x for x in example["available_services"] if x.split(".")[0] == device_type ]
|
random_device_services = [ x for x in example["available_tools"] if x.split(".")[0] == device_type ]
|
||||||
|
|
||||||
if len(random_device_services) == 0:
|
if len(random_device_services) == 0:
|
||||||
raise NoServicesAvailableException()
|
raise NoServicesAvailableException()
|
||||||
@@ -835,45 +905,12 @@ def generate_dpo_extra_service_call(template: dict, persona: str, max_devices: i
|
|||||||
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
|
def generate_dpo_incorrect_persona(template: dict, persona: str, max_devices: int = 32):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def format_example_raw_chatml(example, persona, language, use_system_role):
|
|
||||||
"""Don't use this one anymore"""
|
|
||||||
sys_prompt = pile_of_system_prompts[persona]
|
|
||||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
|
||||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
|
||||||
question = example["question"]
|
|
||||||
answers = " ".join(example["answers"])
|
|
||||||
|
|
||||||
if use_system_role:
|
|
||||||
system_block = "\n".join([ "<|im_start|>system", sys_prompt, services_block, states_block ]) + "<|im_end|>"
|
|
||||||
user_block = "\n".join([ "<|im_start|>user", question]) + "<|im_end|>"
|
|
||||||
else:
|
|
||||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
|
||||||
system_block = ""
|
|
||||||
user_block = "\n".join([ "<|im_start|>user", sys_prompt, services_block, states_block, user_instruction_words, question]) + "<|im_end|>"
|
|
||||||
|
|
||||||
assistant_block = "<|im_start|>assistant\n" + answers
|
|
||||||
if len(example["service_calls"]) > 0:
|
|
||||||
json_calls = [ json.dumps(x) for x in example["service_calls"] ]
|
|
||||||
code_block = "\n```homeassistant\n" + "\n".join(json_calls) + "\n```"
|
|
||||||
assistant_block = assistant_block + code_block
|
|
||||||
assistant_block = assistant_block + "<|im_end|>"
|
|
||||||
|
|
||||||
example_lines = [system_block, user_block, assistant_block]
|
|
||||||
result = "\n".join(example_lines)
|
|
||||||
if "<device_name" in result:
|
|
||||||
print("bad templating")
|
|
||||||
|
|
||||||
# replace aliases with their actual values
|
|
||||||
result = result.replace("blinds.", "cover.")
|
|
||||||
result = result.replace("garage_door.", "cover.")
|
|
||||||
return { "text": result }
|
|
||||||
|
|
||||||
def format_example_sharegpt(example, persona, language, use_system_role):
|
def format_example_sharegpt(example, persona, language, use_system_role):
|
||||||
sys_prompt = pile_of_system_prompts[persona]
|
sys_prompt = pile_of_system_prompts[persona]
|
||||||
random_datetime = generate_random_datetime()
|
random_datetime = generate_random_datetime()
|
||||||
translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])
|
translate_datetime = babel.dates.format_datetime(random_datetime, BABEL_FORMAT[language], locale=BABEL_LOCALE[language])
|
||||||
time_block = f"{CURRENT_DATE_PROMPT[language]} {translate_datetime}"
|
time_block = f"{CURRENT_DATE_PROMPT[language]} {translate_datetime}"
|
||||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_tools"]))
|
||||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||||
question = example["question"]
|
question = example["question"]
|
||||||
answers = " ".join(example["answers"])
|
answers = " ".join(example["answers"])
|
||||||
@@ -887,18 +924,18 @@ def format_example_sharegpt(example, persona, language, use_system_role):
|
|||||||
# replace aliases with their actual values
|
# replace aliases with their actual values
|
||||||
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
assistant_block = assistant_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||||
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
states_block = states_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||||
services_block = services_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
tools_block = tools_block.replace("blinds.", "cover.").replace("garage_door.", "cover.")
|
||||||
|
|
||||||
if use_system_role:
|
if use_system_role:
|
||||||
conversation = [
|
conversation = [
|
||||||
{ "from": "system", "value": "\n".join([ sys_prompt, time_block, services_block, states_block ])},
|
{ "from": "system", "value": "\n".join([ sys_prompt, time_block, tools_block, states_block ])},
|
||||||
{ "from": "user", "value": question },
|
{ "from": "user", "value": question },
|
||||||
{ "from": "assistant", "value": assistant_block },
|
{ "from": "assistant", "value": assistant_block },
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
user_instruction_words = USER_INSTRUCTION_PROMPT[language] + ":"
|
||||||
conversation = [
|
conversation = [
|
||||||
{ "from": "user", "value": "\n".join([ sys_prompt, time_block, services_block, states_block, user_instruction_words, question ]) },
|
{ "from": "user", "value": "\n".join([ sys_prompt, time_block, tools_block, states_block, user_instruction_words, question ]) },
|
||||||
{ "from": "assistant", "value": assistant_block },
|
{ "from": "assistant", "value": assistant_block },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -909,7 +946,7 @@ def format_example_dpo(example, persona, language):
|
|||||||
example = example["accepted"]
|
example = example["accepted"]
|
||||||
|
|
||||||
sys_prompt = pile_of_system_prompts[persona]
|
sys_prompt = pile_of_system_prompts[persona]
|
||||||
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_services"]))
|
services_block = f"{SERVICES_PROMPT[language]}: " + ", ".join(sorted(example["available_tools"]))
|
||||||
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
states_block = f"{DEVICES_PROMPT[language]}:\n" + "\n".join(example["states"])
|
||||||
question = example["question"]
|
question = example["question"]
|
||||||
|
|
||||||
@@ -1045,7 +1082,7 @@ def format_alpaca(example, format_func: Callable):
|
|||||||
|
|
||||||
text = format_func(example={
|
text = format_func(example={
|
||||||
"states": device_list,
|
"states": device_list,
|
||||||
"available_services": list(available_services),
|
"available_tools": list(available_services),
|
||||||
"question": question,
|
"question": question,
|
||||||
"answers": [ answer ],
|
"answers": [ answer ],
|
||||||
"service_calls": []
|
"service_calls": []
|
||||||
@@ -1181,10 +1218,7 @@ def main(args=None):
|
|||||||
print("Train size was provided but not generating the training set!")
|
print("Train size was provided but not generating the training set!")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
if not args.format or args.format == "raw":
|
format_func = format_example_sharegpt
|
||||||
format_func = format_example_raw_chatml
|
|
||||||
elif args.format == "sharegpt":
|
|
||||||
format_func = format_example_sharegpt
|
|
||||||
|
|
||||||
use_system_role = not args.no_system_role
|
use_system_role = not args.no_system_role
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user