mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 14:18:00 -05:00
Merge branch 'feature/proper-functioncalling-args' into develop
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,3 +7,5 @@ data/*.json
|
||||
*.pyc
|
||||
main.log
|
||||
.venv
|
||||
*.xlsx
|
||||
notes.txt
|
||||
|
||||
15
README.md
15
README.md
@@ -6,12 +6,13 @@ The "Home" models are a fine tuning of the Phi model series from Microsoft. The
|
||||
|
||||
The latest models can be found on HuggingFace:
|
||||
3B v2 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v2-GGUF
|
||||
1B v1 (Based on Phi-1.5): https://huggingface.co/acon96/Home-1B-v1-GGUF
|
||||
1B v2 (Based on Phi-1.5): https://huggingface.co/acon96/Home-1B-v2-GGUF
|
||||
|
||||
Make sure you have `llama-cpp-python>=0.2.29` in order to run these models.
|
||||
|
||||
Old Models:
|
||||
3B v1 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v1-GGUF
|
||||
3B v1 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v1-GGUF
|
||||
1B v1 (Based on Phi-1.5): https://huggingface.co/acon96/Home-1B-v1-GGUF
|
||||
|
||||
The main difference between the 2 models (besides parameter count) is the training data. The 1B model is ONLY trained on the synthetic dataset provided in this project, while the 3B model is trained on a mixture of this synthetic dataset, and the cleaned Stanford Alpaca dataset.
|
||||
|
||||
@@ -23,11 +24,12 @@ Example "system" prompt:
|
||||
```
|
||||
<|im_start|>system
|
||||
You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.
|
||||
Services: light.turn_off, light.turn_on, fan.turn_on, fan.turn_off
|
||||
Services: light.turn_off(), light.turn_on(brightness,rgb_color), fan.turn_on(), fan.turn_off()
|
||||
Devices:
|
||||
light.office 'Office Light' = on
|
||||
light.office 'Office Light' = on;80%
|
||||
fan.office 'Office fan' = off
|
||||
light.kitchen 'Kitchen Light' = on<|im_end|>
|
||||
light.kitchen 'Kitchen Light' = on;80%;red
|
||||
light.bedroom 'Bedroom Light' = off<|im_end|>
|
||||
```
|
||||
|
||||
For more about how the model is prompted see [./docs/Model Prompting.md]
|
||||
@@ -69,11 +71,11 @@ python3 train.py \
|
||||
--add_chatml_tokens \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_alpaca_merged_train.json \
|
||||
--test_dataset data/home_assistant_alpaca_merged_test.json \
|
||||
--learning_rate 1e-5 \
|
||||
--save_steps 1000 \
|
||||
--micro_batch_size 2 --gradient_checkpointing \
|
||||
--ctx_size 2048 \
|
||||
--group_by_length \
|
||||
--use_lora --lora_rank 32 --lora_alpha 64 --lora_modules fc1,fc2,q_proj,v_proj,dense --lora_modules_to_save embed_tokens,lm_head --lora_merge
|
||||
```
|
||||
|
||||
@@ -87,7 +89,6 @@ python3 train.py \
|
||||
--add_chatml_tokens \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.json \
|
||||
--test_dataset data/home_assistant_test.json \
|
||||
--learning_rate 1e-5 \
|
||||
--micro_batch_size 4 --gradient_checkpointing \
|
||||
--ctx_size 2048
|
||||
|
||||
@@ -11,6 +11,9 @@ import os
|
||||
import json
|
||||
import webcolors
|
||||
|
||||
import voluptuous as vol
|
||||
from collections.abc import Iterable
|
||||
|
||||
import homeassistant.components.conversation as ha_conversation
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
@@ -169,6 +172,24 @@ def closest_color(requested_color):
|
||||
min_colors[(rd + gd + bd)] = name
|
||||
return min_colors[min(min_colors.keys())]
|
||||
|
||||
def flatten_schema(schema):
|
||||
flattened = []
|
||||
def _flatten(current_schema, prefix=''):
|
||||
if isinstance(current_schema, vol.Schema):
|
||||
if isinstance(current_schema.schema, vol.validators._WithSubValidators):
|
||||
for subval in current_schema.schema.validators:
|
||||
_flatten(subval, prefix)
|
||||
else:
|
||||
for key, val in current_schema.schema.items():
|
||||
_flatten(val, prefix + str(key) + '/')
|
||||
elif isinstance(current_schema, vol.validators._WithSubValidators):
|
||||
for subval in current_schema.validators:
|
||||
_flatten(subval, prefix)
|
||||
elif callable(current_schema):
|
||||
flattened.append(prefix[:-1] if prefix else prefix)
|
||||
_flatten(schema)
|
||||
return flattened
|
||||
|
||||
class LLaMAAgent(AbstractConversationAgent):
|
||||
"""Local LLaMA conversation agent."""
|
||||
|
||||
@@ -216,6 +237,8 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
refresh_system_prompt = self.entry.options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
|
||||
service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX)
|
||||
extra_attributes_to_expose = self.entry.options \
|
||||
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
|
||||
|
||||
try:
|
||||
service_call_pattern = re.compile(service_call_regex)
|
||||
@@ -298,24 +321,37 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
service = json_output["service"]
|
||||
entity = json_output["target_device"]
|
||||
domain, service = tuple(service.split("."))
|
||||
extra_arguments = { k: v for k, v in json_output.items() if k not in [ "service", "target_device" ] }
|
||||
except Exception:
|
||||
try:
|
||||
service = line.split("(")[0]
|
||||
entity = line.split("(")[1][:-1]
|
||||
domain, service = tuple(service.split("."))
|
||||
extra_arguments = {}
|
||||
except Exception:
|
||||
to_say += f" Failed to parse call from '{line}'!"
|
||||
continue
|
||||
|
||||
# fix certain arguments
|
||||
# make sure brightness is 0-255 and not a percentage
|
||||
if "brightness" in extra_arguments and 0.0 < extra_arguments["brightness"] < 1.0:
|
||||
extra_arguments["brightness"] = int(extra_arguments["brightness"] * 255)
|
||||
|
||||
# only acknowledge requests to exposed entities
|
||||
if entity not in exposed_entities:
|
||||
to_say += f" Can't find device '{entity}'!"
|
||||
else:
|
||||
# copy arguments to service call
|
||||
service_data = {ATTR_ENTITY_ID: entity}
|
||||
for attr in extra_attributes_to_expose:
|
||||
if attr in extra_arguments.keys():
|
||||
service_data[attr] = extra_arguments[attr]
|
||||
|
||||
try:
|
||||
await self.hass.services.async_call(
|
||||
domain,
|
||||
service,
|
||||
service_data={ATTR_ENTITY_ID: entity},
|
||||
service_data=service_data,
|
||||
blocking=True,
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -386,7 +422,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
value = attributes[attribute_name]
|
||||
if value is not None:
|
||||
if attribute_name == "current_temperature":
|
||||
if attribute_name == "temperature":
|
||||
value = int(value)
|
||||
if value > 50:
|
||||
value = f"{value}F"
|
||||
@@ -396,6 +432,10 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
value = F"{closest_color(value)} {value}"
|
||||
elif attribute_name == "volume_level":
|
||||
value = f"vol={int(value*100)}"
|
||||
elif attribute_name == "brightness":
|
||||
value = f"{int(value/255*100)}%"
|
||||
elif attribute_name == "humidity":
|
||||
value = f"{value}%"
|
||||
|
||||
result = result + ";" + str(value)
|
||||
return result
|
||||
@@ -407,10 +447,10 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
service_dict = self.hass.services.async_services()
|
||||
all_services = []
|
||||
for domain in domains:
|
||||
# all_services.extend(service_dict.get(domain, {}).keys())
|
||||
all_services.extend(
|
||||
[f"{domain}.{name}" for name in service_dict.get(domain, {}).keys()]
|
||||
)
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
args = flatten_schema(service.schema)
|
||||
args_to_expose = set(args).intersection(extra_attributes_to_expose)
|
||||
all_services.append(f"{domain}.{name}({','.join(args_to_expose)})")
|
||||
formatted_services = ", ".join(all_services)
|
||||
|
||||
return template.Template(prompt_template, self.hass).async_render(
|
||||
|
||||
@@ -136,7 +136,7 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
|
||||
}
|
||||
)
|
||||
|
||||
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset=None, webui_chat_mode=None):
|
||||
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset="", webui_chat_mode=""):
|
||||
|
||||
extra1, extra2 = ({}, {})
|
||||
default_port = DEFAULT_PORT
|
||||
@@ -144,7 +144,7 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ch
|
||||
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_PRESET, default=webui_preset)] = str
|
||||
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_CHAT_MODE, default=webui_chat_mode)] = SelectSelector(SelectSelectorConfig(
|
||||
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
|
||||
options=["", TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
|
||||
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
multiple=False,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
@@ -485,7 +485,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
) -> FlowResult:
|
||||
errors = {}
|
||||
backend_type = self.model_config[CONF_BACKEND_TYPE]
|
||||
schema = STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI)
|
||||
schema = STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type)
|
||||
|
||||
if user_input:
|
||||
try:
|
||||
@@ -499,7 +499,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
if error_reason:
|
||||
errors["base"] = error_reason
|
||||
schema = STEP_REMOTE_SETUP_DATA_SCHEMA(
|
||||
True,
|
||||
backend_type,
|
||||
host=user_input[CONF_HOST],
|
||||
port=user_input[CONF_PORT],
|
||||
chat_model=user_input[CONF_CHAT_MODEL],
|
||||
|
||||
@@ -28,14 +28,14 @@ BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER = "llama_cpp_python_server"
|
||||
BACKEND_TYPE_OLLAMA = "ollama"
|
||||
DEFAULT_BACKEND_TYPE = BACKEND_TYPE_LLAMA_HF
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization"
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = ["Q8_0", "Q5_K_M", "Q4_K_M", "Q3_K_M"]
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = ["F16", "Q8_0", "Q5_K_M", "Q4_K_M", "Q3_K_M"]
|
||||
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION = "Q5_K_M"
|
||||
CONF_DOWNLOADED_MODEL_FILE = "downloaded_model_file"
|
||||
DEFAULT_DOWNLOADED_MODEL_FILE = ""
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = "5000"
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "current_temperature", "fan_mode", "media_title", "volume_level" ]
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level"]
|
||||
GBNF_GRAMMAR_FILE = "output.gbnf"
|
||||
CONF_PROMPT_TEMPLATE = "prompt_template"
|
||||
PROMPT_TEMPLATE_CHATML = "chatml"
|
||||
|
||||
@@ -45,9 +45,16 @@ def closest_color(requested_color):
|
||||
class DeviceType:
|
||||
name: str
|
||||
possible_states: list[(str, float)]
|
||||
services: list[str]
|
||||
services: dict[str, list]
|
||||
|
||||
def get_random_state(self, **kwargs):
|
||||
def get_all_services(self, extra_exposed_attributes):
|
||||
result = []
|
||||
for service in self.services.keys():
|
||||
args = set(extra_exposed_attributes).intersection(self.services[service])
|
||||
result.append(f"{self.name}.{service}({','.join(args)})")
|
||||
return result
|
||||
|
||||
def get_random_state(self, extra_exposed_attributes=[]):
|
||||
states = [ x[0] for x in self.possible_states ]
|
||||
weights = [ x[1] for x in self.possible_states ]
|
||||
return random.choices(states, weights=weights, k=1)[0]
|
||||
@@ -59,48 +66,62 @@ class LightDeviceType(DeviceType):
|
||||
(STATE_ON, 0.5),
|
||||
(STATE_OFF, 0.5)
|
||||
],
|
||||
services=[
|
||||
"turn_on",
|
||||
"turn_off",
|
||||
"toggle"
|
||||
],
|
||||
services={
|
||||
"turn_on": [ "rgb_color", "brightness" ],
|
||||
"turn_off": [],
|
||||
"toggle": []
|
||||
},
|
||||
)
|
||||
|
||||
def get_random_state(self, extra_exposed_attributes=[]):
|
||||
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
|
||||
def get_random_state(self, force_rgb=False, force_brightness=False, **kwargs):
|
||||
state = super().get_random_state()
|
||||
|
||||
if random.random() < 0.05 or force_rgb:
|
||||
if random.random() < 0.5 and "rgb_color" in extra_exposed_attributes:
|
||||
random_rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||
state = state + ";" + closest_color(random_rgb) + " " + str(random_rgb)
|
||||
|
||||
if random.random() < 0.3 or force_brightness:
|
||||
if random.random() < 0.7 and "brightness" in extra_exposed_attributes:
|
||||
state = state + ";" + str(random.randint(0, 100)) + "%"
|
||||
|
||||
return state
|
||||
|
||||
class ClimateDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
super().__init__("climate", [], [
|
||||
"turn_on",
|
||||
"turn_off",
|
||||
"toggle",
|
||||
"set_temperature",
|
||||
"set_humidity",
|
||||
"set_fan_mode",
|
||||
"set_hvac_mode",
|
||||
"set_preset_mode"
|
||||
])
|
||||
super().__init__("climate", [], {
|
||||
"turn_on": [],
|
||||
"turn_off": [],
|
||||
"toggle": [],
|
||||
"set_temperature": ["temperature"],
|
||||
"set_humidity": ["humidity"],
|
||||
"set_fan_mode": ["fan_mode"],
|
||||
"set_hvac_mode": ["hvac_mode"],
|
||||
"set_preset_mode": ["preset_mode"]
|
||||
})
|
||||
|
||||
def get_random_state(self, **kwargs):
|
||||
hvac = random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"])
|
||||
fan = random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"])
|
||||
if random.random() > 0.5:
|
||||
temp = str(random.randint(60, 80)) + "F"
|
||||
else:
|
||||
temp = str(random.randint(15, 25)) + "C"
|
||||
return f"{hvac};{fan};{temp}"
|
||||
def get_random_state(self, extra_exposed_attributes=[]):
|
||||
"""state;fan_mode;temperature;humidity"""
|
||||
state = random.choice(["heat", "cool", "heat_cool", "off", "auto", "fan_only"])
|
||||
|
||||
if "fan_mode" in extra_exposed_attributes:
|
||||
state = state + ";" + random.choice(["On Low", "On High", "Auto Low", "Auto High", "Off"])
|
||||
if "temperature" in extra_exposed_attributes:
|
||||
if random.random() > 0.5:
|
||||
state = state + ";" + str(random.randint(60, 80)) + "F"
|
||||
else:
|
||||
state = state + ";" + str(random.randint(15, 25)) + "C"
|
||||
if "humidity" in extra_exposed_attributes:
|
||||
state = state + ";" + str(random.randint(10, 90)) + "%"
|
||||
|
||||
if "preset_mode" in extra_exposed_attributes:
|
||||
# if it is not "on a preset" then don't add the mode
|
||||
random_mode = random.choice(["home", "eco", "away", "auto", None, None, None])
|
||||
if random_mode:
|
||||
state = state + ";" + random_mode
|
||||
|
||||
return state
|
||||
|
||||
with open("piles/pile_of_media_names.csv") as f:
|
||||
pile_of_media_names = [ x.strip() for x in f.readlines() ]
|
||||
|
||||
class MediaPlayerDeviceType(DeviceType):
|
||||
def __init__(self):
|
||||
@@ -112,31 +133,29 @@ class MediaPlayerDeviceType(DeviceType):
|
||||
(STATE_PAUSED, 0.05),
|
||||
(STATE_STANDBY, 0.05),
|
||||
(STATE_BUFFERING, 0.01),
|
||||
], [
|
||||
"turn_on",
|
||||
"turn_off",
|
||||
"toggle",
|
||||
"volume_up",
|
||||
"volume_down",
|
||||
"volume_mute",
|
||||
"media_play_pause",
|
||||
"media_play",
|
||||
"media_pause",
|
||||
"media_stop",
|
||||
"media_next_track",
|
||||
"media_previous_track"
|
||||
])
|
||||
], {
|
||||
"turn_on": [],
|
||||
"turn_off": [],
|
||||
"toggle": [],
|
||||
"volume_up": [],
|
||||
"volume_down": [],
|
||||
"volume_mute": [],
|
||||
"media_play_pause": [],
|
||||
"media_play": [],
|
||||
"media_pause": [],
|
||||
"media_stop": [],
|
||||
"media_next_track": [],
|
||||
"media_previous_track": []
|
||||
})
|
||||
|
||||
|
||||
with open("piles/pile_of_media_names.csv") as f:
|
||||
self.media_names = [ x.strip() for x in f.readlines() ]
|
||||
def get_random_state(self, extra_exposed_attributes=[]):
|
||||
state = super().get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
def get_random_state(self, **kwargs):
|
||||
state = super().get_random_state()
|
||||
if "media_title" in extra_exposed_attributes and state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
|
||||
state = state + ";" + random.choice(pile_of_media_names)
|
||||
|
||||
if state in [STATE_PLAYING, STATE_PAUSED, STATE_BUFFERING, STATE_ON]:
|
||||
state = state + ";" + random.choice(self.media_names)
|
||||
|
||||
if state != STATE_OFF:
|
||||
if "volume_level" in extra_exposed_attributes and state != STATE_OFF:
|
||||
state = state + ";vol=" + str(round(random.random(), 2))
|
||||
return state
|
||||
|
||||
@@ -148,11 +167,11 @@ SUPPORTED_DEVICES = {
|
||||
(STATE_ON, 0.5),
|
||||
(STATE_OFF, 0.5)
|
||||
],
|
||||
services=[
|
||||
"turn_on",
|
||||
"turn_off",
|
||||
"toggle"
|
||||
],
|
||||
services={
|
||||
"turn_on": [],
|
||||
"turn_off": [],
|
||||
"toggle": []
|
||||
},
|
||||
),
|
||||
"fan": DeviceType(
|
||||
name="fan",
|
||||
@@ -160,13 +179,13 @@ SUPPORTED_DEVICES = {
|
||||
(STATE_ON, 0.5),
|
||||
(STATE_OFF, 0.5)
|
||||
],
|
||||
services=[
|
||||
"turn_on",
|
||||
"turn_off",
|
||||
"toggle",
|
||||
"increase_speed",
|
||||
"decrease_speed",
|
||||
],
|
||||
services={
|
||||
"turn_on": [],
|
||||
"turn_off": [],
|
||||
"toggle": [],
|
||||
"increase_speed": [],
|
||||
"decrease_speed": [],
|
||||
},
|
||||
),
|
||||
"garage_door": DeviceType(
|
||||
name="garage_door",
|
||||
@@ -176,12 +195,12 @@ SUPPORTED_DEVICES = {
|
||||
(STATE_OPENING, 0.01),
|
||||
(STATE_CLOSING, 0.01)
|
||||
],
|
||||
services=[
|
||||
"open_cover",
|
||||
"close_cover",
|
||||
"stop_cover",
|
||||
"toggle",
|
||||
],
|
||||
services={
|
||||
"open_cover": [],
|
||||
"close_cover": [],
|
||||
"stop_cover": [],
|
||||
"toggle": [],
|
||||
},
|
||||
),
|
||||
"blinds": DeviceType(
|
||||
name="blinds",
|
||||
@@ -191,12 +210,12 @@ SUPPORTED_DEVICES = {
|
||||
(STATE_OPENING, 0.01),
|
||||
(STATE_CLOSING, 0.01)
|
||||
],
|
||||
services=[
|
||||
"open_cover",
|
||||
"close_cover",
|
||||
"stop_cover",
|
||||
"toggle",
|
||||
],
|
||||
services={
|
||||
"open_cover": [],
|
||||
"close_cover": [],
|
||||
"stop_cover": [],
|
||||
"toggle": [],
|
||||
},
|
||||
),
|
||||
"lock": DeviceType(
|
||||
name="lock",
|
||||
@@ -204,10 +223,10 @@ SUPPORTED_DEVICES = {
|
||||
(STATE_LOCKED, 0.5),
|
||||
(STATE_UNLOCKED, 0.5),
|
||||
],
|
||||
services=[
|
||||
"lock",
|
||||
"unlock",
|
||||
],
|
||||
services={
|
||||
"lock": [],
|
||||
"unlock": [],
|
||||
},
|
||||
),
|
||||
"media_player": MediaPlayerDeviceType(),
|
||||
"climate": ClimateDeviceType()
|
||||
@@ -284,6 +303,9 @@ def random_device_list(max_devices: int, avoid_device_names: list[str]):
|
||||
device_types = set()
|
||||
device_list = []
|
||||
device_lines = []
|
||||
# TODO: randomly pick attributes for this list
|
||||
extra_exposed_attributes = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level"]
|
||||
|
||||
while len(device_list) < num_devices:
|
||||
choice = random.choice(possible_choices)
|
||||
if choice["device_name"] in device_list:
|
||||
@@ -298,7 +320,7 @@ def random_device_list(max_devices: int, avoid_device_names: list[str]):
|
||||
if avoid_climate and device_type == "climate":
|
||||
continue
|
||||
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state()
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||
device_lines.append(format_device_line(
|
||||
device_name=device_name,
|
||||
friendly_name=friendly_name,
|
||||
@@ -310,7 +332,7 @@ def random_device_list(max_devices: int, avoid_device_names: list[str]):
|
||||
print(f"bad device name: {choice}")
|
||||
print(repr(ex))
|
||||
|
||||
return device_lines, list(device_types)
|
||||
return device_lines, list(device_types), list(extra_exposed_attributes)
|
||||
|
||||
def generate_static_example(action: dict, max_devices: int = 32):
|
||||
question = action["english_phrase"]
|
||||
@@ -319,11 +341,12 @@ def generate_static_example(action: dict, max_devices: int = 32):
|
||||
service_name = f"{device_type}.{action['service_name']}"
|
||||
friendly_name = target_device.split(".")[1].replace("_", " ")
|
||||
|
||||
device_list, device_types = random_device_list(max_devices=max_devices, avoid_device_names=[target_device])
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||
max_devices=max_devices, avoid_device_names=[target_device])
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
index = random.randint(0, len(device_list))
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state()
|
||||
state = SUPPORTED_DEVICES[device_type].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||
|
||||
device_list.insert(index, format_device_line(
|
||||
device_name=target_device,
|
||||
@@ -334,7 +357,7 @@ def generate_static_example(action: dict, max_devices: int = 32):
|
||||
# gather a list of all available services
|
||||
available_services = []
|
||||
for x in set(device_types + [device_type]):
|
||||
available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ])
|
||||
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
@@ -357,17 +380,23 @@ def generate_templated_example(template: dict, max_devices: int = 32):
|
||||
device_dict["type"] = device_type
|
||||
chosen_devices.append(device_dict)
|
||||
|
||||
device_list, device_types = random_device_list(max_devices=max_devices, avoid_device_names=[d["device_name"] for d in chosen_devices])
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||
max_devices=max_devices, avoid_device_names=[d["device_name"] for d in chosen_devices])
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
for device_dict in chosen_devices:
|
||||
index = random.randint(0, len(device_list))
|
||||
state_kwargs = {}
|
||||
if "<brightness>" in question_template:
|
||||
state_kwargs["force_brightness"] = True
|
||||
if "<color>" in question_template:
|
||||
state_kwargs["force_rgb"] = True
|
||||
state = SUPPORTED_DEVICES[device_dict["type"]].get_random_state(**state_kwargs)
|
||||
if "<brightness>" in question_template and "brightness" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("brightness")
|
||||
if "<color>" in question_template and "rgb_color" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("rgb_color")
|
||||
if ("<temp_f>" in question_template or "<temp_c>" in question_template) \
|
||||
and "temperature" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("temperature")
|
||||
if "<humidity>" in question_template and "humidity" not in extra_exposed_attributes:
|
||||
extra_exposed_attributes.append("humidity")
|
||||
|
||||
state = SUPPORTED_DEVICES[device_dict["type"]].get_random_state(extra_exposed_attributes=extra_exposed_attributes)
|
||||
device_name = device_dict["device_name"]
|
||||
friendly_name = device_dict["description"]
|
||||
|
||||
@@ -377,10 +406,10 @@ def generate_templated_example(template: dict, max_devices: int = 32):
|
||||
state=state
|
||||
))
|
||||
|
||||
# gather a list of all available services
|
||||
# gather a list of all available services with arguments
|
||||
available_services = []
|
||||
for x in set(device_types + template_device_types):
|
||||
available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ])
|
||||
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
|
||||
|
||||
# generate the question
|
||||
if len(template_device_types) == 1:
|
||||
@@ -422,7 +451,7 @@ def generate_templated_example(template: dict, max_devices: int = 32):
|
||||
brightness = random.randint(0, 100)
|
||||
question = question.replace("<brightness>", str(brightness))
|
||||
answer = answer.replace("<brightness>", str(brightness))
|
||||
service_calls = [ { **call, "brightness_pct": brightness} for call in service_calls ]
|
||||
service_calls = [ { **call, "brightness": round(brightness / 100, 2) } for call in service_calls ]
|
||||
|
||||
if "<color>" in question:
|
||||
random_rgb = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||
@@ -452,7 +481,7 @@ def generate_status_request(template: dict, max_devices: int = 32):
|
||||
chosen_device = random.choice(stacks_of_device_names[device_type])
|
||||
|
||||
# build a random list of devices
|
||||
device_list, device_types = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(max_devices=max_devices, avoid_device_names=[ chosen_device["device_name"] ])
|
||||
|
||||
# insert our target device somewhere random in the list
|
||||
index = random.randint(0, len(device_list))
|
||||
@@ -489,16 +518,20 @@ def generate_status_request(template: dict, max_devices: int = 32):
|
||||
|
||||
if device_type == "media_player":
|
||||
volume = random.randint(0, 100)
|
||||
random_media = random.choice(pile_of_media_names)
|
||||
|
||||
answer = answer.replace("<volume>", str(volume) + "%")
|
||||
state_name = state_name.replace("<volume>", str(volume) + "%")
|
||||
|
||||
answer = answer.replace("<media>", random_media)
|
||||
state_name = state_name.replace("<media>", random_media)
|
||||
|
||||
device_list.insert(index, f"{chosen_device['device_name']} = {state_name}")
|
||||
|
||||
# gather a list of all available services
|
||||
available_services = []
|
||||
for x in set(device_types + [device_type]):
|
||||
available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ])
|
||||
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
|
||||
|
||||
return {
|
||||
"states": device_list,
|
||||
@@ -512,8 +545,6 @@ def format_example(example):
|
||||
sys_prompt = "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed or answer the following question with the information provided only."
|
||||
services_block = "Services: " + ", ".join(sorted(example["available_services"]))
|
||||
states_block = "Devices:\n" + "\n".join(example["states"])
|
||||
# question = "Request:\n" + example["question"]
|
||||
# answers = "Response:\n" + " ".join(example["answers"])
|
||||
question = example["question"]
|
||||
answers = " ".join(example["answers"])
|
||||
|
||||
@@ -569,16 +600,17 @@ def generate_example_file(filename: str, seed: int, *, static_factor: int, templ
|
||||
|
||||
def format_alpaca(example):
|
||||
question = example["instruction"]
|
||||
if example["input"]:
|
||||
if "input" in example and example["input"]:
|
||||
question = question = "\n" + example["input"]
|
||||
|
||||
answer = example["output"]
|
||||
|
||||
device_list, device_types = random_device_list(max_devices=32, avoid_device_names=[])
|
||||
device_list, device_types, extra_exposed_attributes = random_device_list(
|
||||
max_devices=32, avoid_device_names=[])
|
||||
|
||||
available_services = []
|
||||
for x in device_types:
|
||||
available_services.extend([ f"{x}.{y}" for y in SUPPORTED_DEVICES[x].services ])
|
||||
available_services.extend(SUPPORTED_DEVICES[x].get_all_services(extra_exposed_attributes))
|
||||
|
||||
text = format_example(example={
|
||||
"states": device_list,
|
||||
@@ -594,13 +626,13 @@ def format_alpaca(example):
|
||||
|
||||
return result
|
||||
|
||||
def merge_with_dataset(dataset_name, seed, outupt_name, format_function):
|
||||
def merge_with_dataset(dataset_name, seed, outupt_name, format_function, dataset_column_names):
|
||||
alpaca_dataset = load_dataset(dataset_name)["train"].train_test_split(test_size=0.1)
|
||||
home_assistant_dataset = load_dataset("json", data_files={ "train": "home_assistant_train.json", "test": "home_assistant_test.json" })
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
alpaca_dataset = alpaca_dataset.map(format_function).remove_columns(["input", "output", "instruction"])
|
||||
alpaca_dataset = alpaca_dataset.map(format_function).remove_columns(dataset_column_names)
|
||||
|
||||
combined_dataset_train = concatenate_datasets([home_assistant_dataset["train"], alpaca_dataset["train"]]).shuffle(seed=42)
|
||||
combined_dataset_test = concatenate_datasets([home_assistant_dataset["test"], alpaca_dataset["test"]]).shuffle(seed=42)
|
||||
@@ -618,20 +650,36 @@ def main():
|
||||
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset..")
|
||||
parser.add_argument("--train", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--merge-alpaca", action="store_true", help="Set this flag to merge the generated datasets with the alpaca-cleaned dataset.")
|
||||
parser.add_argument("--merge", help="Set this flag to merge the generated datasets with the specified dataset.")
|
||||
train_size_group = parser.add_mutually_exclusive_group()
|
||||
train_size_group.add_argument('--small', action='store_const', const='small', dest='size')
|
||||
train_size_group.add_argument('--medium', action='store_const', const='medium', dest='size')
|
||||
train_size_group.add_argument('--large', action='store_const', const='large', dest='size')
|
||||
train_size_group.add_argument('--xl', action='store_const', const='xl', dest='size')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sample:
|
||||
generate_example_file("sample", 42, static_factor=1, template_factor=1, status_request_factor=1)
|
||||
if args.train:
|
||||
# TODO: add small, medium, large cli clags
|
||||
generate_example_file("home_assistant_train", 42, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
# generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
# generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
if args.size == "small":
|
||||
generate_example_file("home_assistant_train", 42, static_factor=1, template_factor=10, status_request_factor=8)
|
||||
elif args.size == "medium":
|
||||
generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=15, status_request_factor=12)
|
||||
elif args.size == "large":
|
||||
generate_example_file("home_assistant_train", 42, static_factor=5, template_factor=20, status_request_factor=15)
|
||||
elif args.size == "xl":
|
||||
generate_example_file("home_assistant_train", 42, static_factor=7, template_factor=25, status_request_factor=18)
|
||||
else:
|
||||
raise Exception(f"Unrecognized dataset size: {args.size}")
|
||||
if args.test:
|
||||
generate_example_file("home_assistant_test", 12345, static_factor=0.25, template_factor=3, status_request_factor=2)
|
||||
if args.merge_alpaca:
|
||||
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca)
|
||||
|
||||
if args.merge == "alpaca":
|
||||
merge_with_dataset("yahma/alpaca-cleaned", 42, "alpaca", format_alpaca, ["input", "output", "instruction"])
|
||||
elif args.merge == "wizardlm70k":
|
||||
merge_with_dataset("WizardLM/WizardLM_evol_instruct_70k", 42, "wizardlm70k", format_alpaca, ["output", "instruction"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -200,24 +200,24 @@ fan,on,How's the <device_name>,The <device_name> is on
|
||||
fan,off,Status of <device_name>,The <device_name> is off
|
||||
light,on,<device_name> condition,The <device_name> is on
|
||||
light,off,What's up with <device_name>,The <device_name> is off
|
||||
media_player,on;<volume>,"What's the volume on <device_name>?","The volume on <device_name> is at <volume_level>."
|
||||
media_player,on;<volume>,"What's the volume on <device_name>?","The volume on <device_name> is at <volume>."
|
||||
media_player,off,"Can you check the volume level of <device_name>?","<device_name> is off, so I can't check the volume level currently."
|
||||
media_player,idle,"I need the current volume of <device_name>, please.","The current volume of <device_name> is <volume_level>."
|
||||
media_player,playing,"Is the volume too loud on <device_name>?","The volume on <device_name> is currently set to <volume_level>."
|
||||
media_player,paused,"What was the volume set to before <device_name> was paused?","Before being paused, <device_name> had a volume level of <volume_level>."
|
||||
media_player,playing,"Could you increase <device_name>'s volume a bit?","I've increased the volume on <device_name>. It's now at <volume_level>."
|
||||
media_player,playing,"Is the volume muted on <device_name>?","The volume on <device_name> is not muted. It's at <volume_level>."
|
||||
media_player,on,"What's playing on <device_name> right now?","<media_name> is currently playing on <device_name>."
|
||||
media_player,idle,"I need the current volume of <device_name>, please.","The current volume of <device_name> is <volume>."
|
||||
media_player,playing,"Is the volume too loud on <device_name>?","The volume on <device_name> is currently set to <volume>."
|
||||
media_player,paused,"What was the volume set to before <device_name> was paused?","Before being paused, <device_name> had a volume level of <volume>."
|
||||
media_player,playing,"Could you increase <device_name>'s volume a bit?","I've increased the volume on <device_name>. It's now at <volume>."
|
||||
media_player,playing,"Is the volume muted on <device_name>?","The volume on <device_name> is not muted. It's at <volume>."
|
||||
media_player,on,"What's playing on <device_name> right now?","<media> is currently playing on <device_name>."
|
||||
media_player,off,"Is there anything playing on <device_name>?","Nothing is playing on <device_name> as it is currently off."
|
||||
media_player,idle,"Can you tell me what was last played on <device_name>?","The last media played on <device_name> was <media_name>."
|
||||
media_player,playing,"What is playing on <device_name> at the moment?","The song <media_name> is playing on <device_name> right now."
|
||||
media_player,paused,"What media was paused on <device_name>?","<media_name> was paused on <device_name>."
|
||||
media_player,idle,"Can you tell me what was last played on <device_name>?","The last media played on <device_name> was <media>."
|
||||
media_player,playing,"What is playing on <device_name> at the moment?","The song <media> is playing on <device_name> right now."
|
||||
media_player,paused,"What media was paused on <device_name>?","<media> was paused on <device_name>."
|
||||
media_player,on,"I hear something from <device_name>, what is it?","<device_name> is not currently playing anything."
|
||||
media_player,playing,"I like this track on <device_name>, what is it?","You're listening to <media_name> on <device_name>."
|
||||
media_player,paused,"Can you name the movie that's on <device_name>?","The movie currently paused on <device_name> is <media_name>."
|
||||
media_player,standby,"What was the last show watched on <device_name>?","The last show watched on <device_name> was <media_name>."
|
||||
media_player,buffering,"Which episode is on <device_name> right now?","The episode on <device_name> is <media_name>."
|
||||
light,on;<color>;<brightness>,"What's the brightness of <device_name>?","The brightness of <device_name> is currently at <brightness>%."
|
||||
media_player,playing,"I like this track on <device_name>, what is it?","You're listening to <media> on <device_name>."
|
||||
media_player,paused,"Can you name the movie that's on <device_name>?","The movie currently paused on <device_name> is <media>."
|
||||
media_player,standby,"What was the last show watched on <device_name>?","The last show watched on <device_name> was <media>."
|
||||
media_player,buffering,"Which episode is on <device_name> right now?","The episode on <device_name> is <media>."
|
||||
light,on;<color>;<brightness>,"What's the brightness of <device_name>?","The brightness of <device_name> is currently at <brightness>."
|
||||
light,on;<color>;<brightness>,"Can you tell me the color of <device_name>?","<device_name> is currently showing <color>."
|
||||
light,on;<color>;<brightness>,"Is <device_name> set to a bright or dim setting?","<device_name> is set to a <brightness> brightness level."
|
||||
light,on;<color>;<brightness>,"What hue is <device_name> displaying?","<device_name> is displaying a <color> hue."
|
||||
@@ -225,7 +225,7 @@ light,on;<color>;<brightness>,"How bright is <device_name> right now?","<device_
|
||||
light,on;<color>;<brightness>,"I'd like to know the current color of <device_name>.","The current color of <device_name> is <color>."
|
||||
light,on;<color>;<brightness>,"Is <device_name> more on the dim side or the bright side?","<device_name> is more on the <brightness> side."
|
||||
light,on;<color>;<brightness>,"What color setting is <device_name> on?","<device_name> is set to <color> color setting."
|
||||
light,on;<color>;<brightness>,"Could you check the brightness level of <device_name> for me?","Sure, the brightness level of <device_name> is <brightness>%."
|
||||
light,on;<color>;<brightness>,"Could you check the brightness level of <device_name> for me?","Sure, the brightness level of <device_name> is <brightness>."
|
||||
light,on;<color>;<brightness>,"Tell me, what's the current hue of <device_name>?","The current hue of <device_name> is <color>."
|
||||
light,off,"What's the brightness of <device_name>?","The device <device_name> is turned off."
|
||||
light,off,"Can you tell me the color of <device_name>?","<device_name> is turned off."
|
||||
|
||||
|
@@ -193,21 +193,31 @@ climate,set_hvac_mode,"Switch the system to cooling mode.","Switching to cooling
|
||||
climate,set_hvac_mode,"Can we set the HVAC to heat?","Setting the HVAC to heat.",8
|
||||
climate,set_hvac_mode,"Change the HVAC to automatic.","Changing HVAC to automatic mode.",8
|
||||
light,turn_on,"Set the brightness of <device_name> to <brightness>%.","Setting the brightness of <device_name> to <brightness>%.",8
|
||||
light,turn_on,"Can you make <device_name> <color>?","Turning <device_name> <color>.",8
|
||||
light,turn_on,"Change the color of <device_name> to <color>.","Changing the color of <device_name> to <color>.",8
|
||||
light,turn_on,"Dim <device_name> to <brightness> percent brightness.","Dimming <device_name> to <brightness>% brightness.",8
|
||||
light,turn_on,"Please set <device_name> to a <color> color.","Setting <device_name> to a <color> color.",8
|
||||
light,turn_on,"Brighten <device_name> to <brightness>.","Brightening <device_name> to <brightness>%.",8
|
||||
light,turn_on,"Turn <device_name> <color>.","Turning <device_name> <color>.",8
|
||||
light,turn_on,"I want <device_name> at a <color> setting.","Setting <device_name> to a <color>.",8
|
||||
light,turn_on,"Set <device_name> to a <color> color.","Setting <device_name> to a <color> color.",8
|
||||
light,turn_on,"Adjust <device_name> brightness to <brightness>.","Adjusting <device_name> brightness to <brightness>%.",8
|
||||
light,turn_on,"Increase <device_name>'s brightness to <brightness>.","Increasing <device_name>'s brightness to <brightness>%.",8
|
||||
light,turn_on,"Lower the brightness of <device_name> to <brightness>.","Lowering the brightness of <device_name> to <brightness>%.",8
|
||||
light,turn_on,"Adjust <device_name> brightness to <brightness>.","Adjusting <device_name> brightness to <brightness>.",8
|
||||
light,turn_on,"Increase <device_name>'s brightness to <brightness>.","Increasing <device_name>'s brightness to <brightness>.",8
|
||||
light,turn_on,"Lower the brightness of <device_name> to <brightness>.","Lowering the brightness of <device_name> to <brightness>.",8
|
||||
light,turn_on,"Can you set <device_name>'s brightness level to <brightness> percent?","Setting <device_name>'s brightness level to <brightness>%.",8
|
||||
light,turn_on,"I'd like <device_name> at <brightness> percent brightness, please.","Setting <device_name> to <brightness>% brightness.",8
|
||||
light,turn_on,"Can you make <device_name> <color>?","Turning <device_name> <color>.",8
|
||||
light,turn_on,"Change the color of <device_name> to <color>.","Changing the color of <device_name> to <color>.",8
|
||||
light,turn_on,"Change <device_name> to a <color> hue.","Changing <device_name> to a <color> hue.",8
|
||||
light,turn_on,"Set <device_name> to be <color>.","Setting <device_name> to be <color>.",8
|
||||
light,turn_on,"I want <device_name> to be <color>, please.","Setting <device_name> to be <color>.",8
|
||||
light,turn_on,"Can you make <device_name> shine in <color>?","Making <device_name> shine in <color>.",8
|
||||
light,turn_on,"Turn <device_name> to a <color> shade.","Turning <device_name> to a <color> shade.",8
|
||||
light,turn_on,"Turn <device_name> to a <color> shade.","Turning <device_name> to a <color> shade.",8
|
||||
light,turn_on,"Turn <device_name> <color>.","Turning <device_name> <color>.",8
|
||||
light,turn_on,"I want <device_name> at a <color> setting.","Setting <device_name> to a <color>.",8
|
||||
light,turn_on,"Set <device_name> to a <color> color.","Setting <device_name> to a <color> color.",8
|
||||
light,turn_on,"Please set <device_name> to a <color> color.","Setting <device_name> to a <color> color.",8
|
||||
light,turn_on,"Make <device_name> glow <color>.","Making <device_name> glow <color>.",8
|
||||
light,turn_on,"Could you turn <device_name> to <color>?","Turning <device_name> to <color>.",8
|
||||
light,turn_on,"Please change <device_name> to <color>.","Changing <device_name> to <color>.",8
|
||||
light,turn_on,"Adjust <device_name> to <color> color.","Adjusting <device_name> to <color> color.",8
|
||||
light,turn_on,"Switch <device_name> color to <color>.","Switching <device_name> color to <color>.",8
|
||||
light,turn_on,"Can <device_name> be <color> now?","Setting <device_name> to be <color>.",8
|
||||
light,turn_on,"Let's have <device_name> in <color>.","Setting <device_name> in <color>.",8
|
||||
light,turn_on,"I'd like <device_name> to change to <color>.","Changing <device_name> to <color>.",8
|
||||
light,turn_on,"Can <device_name> display a <color> light?","Making <device_name> display a <color> light.",8
|
||||
light,turn_on,"Set <device_name> color to <color>, please.","Setting <device_name> color to <color>.",8
|
||||
|
||||
|
@@ -1,3 +1,4 @@
|
||||
# home-llm experiements (phi1.5)
|
||||
rev1 - original test
|
||||
- 1 epoch
|
||||
- train ctx 1900
|
||||
@@ -217,7 +218,8 @@ rev 9 - reduced dataset size
|
||||
|
||||
------
|
||||
|
||||
home-1b-rev1
|
||||
# Home 1B
|
||||
## home-1b-rev1
|
||||
- 1 epoch
|
||||
- 2048 train ctx
|
||||
- batch size 8
|
||||
@@ -231,4 +233,98 @@ home-1b-rev1
|
||||
+ it works OK with low temperatures
|
||||
+ seems to handle the alpaca dataset not so well
|
||||
|
||||
home-1b-rev2
|
||||
Eval results for existing models:
|
||||
Home-1b-v1: 0.767816091954023
|
||||
Home-3b-v2: 0.6908045977011494
|
||||
|
||||
## home-1b-rev5 series
|
||||
- 1 epoch
|
||||
- 2048 train ctx
|
||||
- batch size 8
|
||||
- learning rate 1e-5
|
||||
- weight decay 0.1
|
||||
- gradient clipping 1.0
|
||||
- save model every 200 steps
|
||||
|
||||
home-1b-rev5
|
||||
- dataset size: medium
|
||||
- evaluation results:
|
||||
- 200: 0.553448275862069
|
||||
- 400: 0.7482758620689656 (+.19)
|
||||
- 600: 0.8103448275862069 (+.06)
|
||||
- 800: 0.8316091954022988 (+.02)
|
||||
- 1000: 0.8396551724137931 (+.008)
|
||||
- 1200: 0.8488505747126437 (+.009)
|
||||
- Final (1467): 0.8494252873563218 (+.00005)
|
||||
|
||||
home-1b-rev5_1
|
||||
- dataset size: small
|
||||
- evaluation results:
|
||||
- 200: 0.6057471264367816
|
||||
- 400: 0.7494252873563219 (+.143)
|
||||
- 600: 0.7683908045977011 (+.018)
|
||||
- 800: 0.7729885057471264 (+.0046)
|
||||
- Final (869): bad
|
||||
|
||||
home-1b-rev5_2
|
||||
- dataset size: large
|
||||
- evaluation results:
|
||||
- 200: --
|
||||
- 400: --
|
||||
- 600: 0.8425287356321839
|
||||
- 800: 0.8666666666666667
|
||||
- 1000: 0.8770114942528736
|
||||
- 1200: 0.8844827586206897
|
||||
- 1400: 0.8879310344827587
|
||||
- 1600: 0.8844827586206897
|
||||
- Final (1848): 0.8833333333333333
|
||||
|
||||
home-3b-v3-rev1
|
||||
- dataset size: large
|
||||
- evaluation results: 0.9091954022988505
|
||||
|
||||
home-1b-rev6
|
||||
- dataset size: large (fixed templates + function calling arguments; brightness is broken)
|
||||
- evaluation results: 0.8254149971379507
|
||||
|
||||
home-1b-rev6_1
|
||||
- dataset size: xl (fixed templates + function calling arguments; 0-255 brightness is broken)
|
||||
- evaluation results:
|
||||
- 400: 0.7240984544934173
|
||||
- 800: 0.8311390955924441
|
||||
- 1200: 0.8471665712650257
|
||||
- 1600: 0.8597595878649112
|
||||
- 2000: 0.8551803091013166
|
||||
- Final (2322): 0.8586147681740126
|
||||
|
||||
home-1b-rev6_2 = Home-1B-v2-GGUF
|
||||
- dataset size: large (change brightness back to percentages; increase color references by ~2x)
|
||||
- evaluation results:
|
||||
- 400: 0.7856064418721691
|
||||
- 800: 0.864116759
|
||||
- 1200: 0.882234524
|
||||
- 1600: 0.885254152
|
||||
- 2000: 0.8852541519879215
|
||||
- Final (2048):
|
||||
|
||||
home-3b-v3-rev2 = Home-3B-v2-GGUF
|
||||
- dataset size: xl + alpaca
|
||||
- evaluation results: 0.8731756416708606
|
||||
|
||||
Home-3B-v2-GGUF:ha_only
|
||||
- dataset size: large
|
||||
- evaluation results: FAILED (again.....)
|
||||
|
||||
# Datasets
|
||||
|
||||
## SFT
|
||||
Alpaca: https://huggingface.co/datasets/yahma/alpaca-cleaned
|
||||
Alpaca (Translated): https://huggingface.co/datasets/saillab/taco-datasets
|
||||
WizardLM 200k: https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k
|
||||
WizardLM 70k: https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_70k
|
||||
Huggingface Ultrachat 200k: https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
|
||||
OpenOrca Slim Deduped (363k): https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
|
||||
|
||||
## DPO
|
||||
Intel Orca DPO Pairs: https://huggingface.co/datasets/Intel/orca_dpo_pairs
|
||||
Huggingface Ultrachat: https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized
|
||||
|
||||
129
evaluate.py
Normal file
129
evaluate.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse, os, re, json
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from tqdm import tqdm
|
||||
|
||||
CTX_SIZE = 2048
|
||||
|
||||
def tokenize(tokenizer, prompt):
|
||||
return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=CTX_SIZE)
|
||||
|
||||
def generate(model, tokenizer, prompt):
|
||||
inputs = tokenize(tokenizer, prompt)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(**inputs)
|
||||
text = tokenizer.batch_decode(outputs)
|
||||
return text
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate the function calling for a model")
|
||||
parser.add_argument("model")
|
||||
parser.add_argument("--dataset_file", default="./data/home_assistant_test.json")
|
||||
parser.add_argument("--split", default="<|im_start|>assistant")
|
||||
parser.add_argument("--batch-size", default=8)
|
||||
|
||||
args = parser.parse_args()
|
||||
model_folder = f"./models/{args.model}"
|
||||
split = args.split
|
||||
|
||||
dataset = load_dataset("json", data_files={ "train": args.dataset_file })["train"]
|
||||
|
||||
# filter out examples that are status requests
|
||||
dataset = dataset.filter(lambda example: "```homeassistant" in example["text"])
|
||||
|
||||
service_call_regex = re.compile(r"```homeassistant\n([\S \t\n]*?)```")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
print(f"Loading model from {model_folder}...")
|
||||
trained_model = AutoModelForCausalLM.from_pretrained(model_folder, trust_remote_code=True, torch_dtype=torch.bfloat16) #, code_revision="834565c23f9b28b96ccbeabe614dd906b6db551a")
|
||||
trained_tokenizer = AutoTokenizer.from_pretrained(model_folder, trust_remote_code=True, padding_side='left')
|
||||
|
||||
trained_model.generation_config = GenerationConfig(
|
||||
max_new_tokens=128,
|
||||
use_cache=True,
|
||||
do_sample=True,
|
||||
temperature=0.1,
|
||||
top_k=40,
|
||||
top_p=1.0,
|
||||
repetition_penalty=1.15,
|
||||
eos_token_id=trained_model.config.eos_token_id,
|
||||
pad_token_id=trained_model.config.pad_token_id,
|
||||
)
|
||||
|
||||
print("Evaluating...")
|
||||
batch_size = int(args.batch_size)
|
||||
correct_answers = 0
|
||||
total_answers = 0
|
||||
color_mismatches = 0
|
||||
|
||||
failed_examples = []
|
||||
with tqdm(total=len(dataset), desc="Accuracy") as pbar:
|
||||
for batch_start in range(0, len(dataset), batch_size):
|
||||
batch = dataset[batch_start:batch_start + batch_size]
|
||||
prompts = [ example.split(split)[0] + split for example in batch["text"] ]
|
||||
expected_responses = [ example.split(split)[1] for example in batch["text"] ]
|
||||
output = generate(trained_model, trained_tokenizer, prompts)
|
||||
|
||||
for model_output, expected_response in zip(output, expected_responses):
|
||||
response = model_output.replace(trained_tokenizer.pad_token, "").replace(trained_tokenizer.eos_token, "").split(split)[1]
|
||||
|
||||
expected_service_calls = []
|
||||
|
||||
for block in service_call_regex.findall(expected_response.strip()):
|
||||
for line in block.split("\n"):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
expected_service_calls.append(json.loads(line))
|
||||
total_answers = total_answers + 1
|
||||
|
||||
for block in service_call_regex.findall(response.strip()):
|
||||
for line in block.split("\n"):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
try:
|
||||
json_output = json.loads(line)
|
||||
except:
|
||||
failed_examples.append({ "expected": expected_response, "actual": response, "invalid_json": True })
|
||||
continue
|
||||
|
||||
if json_output in expected_service_calls:
|
||||
expected_service_calls.pop(expected_service_calls.index(json_output))
|
||||
correct_answers = correct_answers + 1
|
||||
elif "rgb_color" in json_output:
|
||||
for sc in expected_service_calls:
|
||||
sc = { **sc }
|
||||
json_output_copy = { **json_output }
|
||||
if not "rgb_color" in sc:
|
||||
continue
|
||||
del sc["rgb_color"]
|
||||
del json_output_copy["rgb_color"]
|
||||
if sc == json_output_copy:
|
||||
correct_answers = correct_answers + 1
|
||||
color_mismatches = color_mismatches + 1
|
||||
else:
|
||||
failed_examples.append({ "expected": expected_response, "actual": response })
|
||||
else:
|
||||
failed_examples.append({ "expected": expected_response, "actual": response })
|
||||
|
||||
pbar.update(batch_size)
|
||||
pbar.set_description(f"Accuracy: {correct_answers/total_answers*100:.2f}% ({correct_answers}/{total_answers})")
|
||||
|
||||
accuracy = correct_answers/total_answers
|
||||
print(f"Final Accuracy Rating: {accuracy*100:.2f}%")
|
||||
print(f"Color Mismatches: {color_mismatches}")
|
||||
|
||||
with open(os.path.join(model_folder, "eval_results.json"), "w") as f:
|
||||
json.dump({
|
||||
"possible_answers": total_answers,
|
||||
"correct_answers": correct_answers,
|
||||
"accuracy": accuracy,
|
||||
"color_mismatches": color_mismatches,
|
||||
"failed_examples": failed_examples,
|
||||
}, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
91
train.py
91
train.py
@@ -3,8 +3,10 @@
|
||||
import math
|
||||
import copy
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, \
|
||||
PreTrainedTokenizerFast, HfArgumentParser, GPTQConfig, AutoConfig
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from datasets import load_dataset
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Sequence
|
||||
@@ -15,13 +17,12 @@ Phi Modules: fc1,fc2,q_proj,v_proj,k_proj,dense,embed_tokens,lm_head
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name home-3b-v2-rev3 \
|
||||
--run_name Home-3B-v2_ha-GGUF \
|
||||
--base_model microsoft/phi-2 \
|
||||
--add_pad_token \
|
||||
--add_chatml_tokens \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_alpaca_merged_train.json \
|
||||
--test_dataset data/home_assistant_alpaca_merged_test.json \
|
||||
--train_dataset data/home_assistant_train.json \
|
||||
--learning_rate 1e-5 \
|
||||
--save_steps 1000 \
|
||||
--micro_batch_size 2 --gradient_checkpointing \
|
||||
@@ -31,7 +32,7 @@ python3 train.py \
|
||||
|
||||
"""
|
||||
python3 train.py \
|
||||
--run_name home-1b-rev4 \
|
||||
--run_name home-1b-rev6 \
|
||||
--base_model microsoft/phi-1_5 \
|
||||
--add_pad_token \
|
||||
--add_chatml_tokens \
|
||||
@@ -40,7 +41,7 @@ python3 train.py \
|
||||
--test_dataset data/home_assistant_test.json \
|
||||
--learning_rate 1e-5 \
|
||||
--micro_batch_size 4 --gradient_checkpointing \
|
||||
--ctx_size 2048
|
||||
--ctx_size 2048 --save_steps 200
|
||||
"""
|
||||
|
||||
"""
|
||||
@@ -56,21 +57,23 @@ python3 train.py \
|
||||
@dataclass
|
||||
class TrainingRunArguments:
|
||||
run_name: str = field(metadata={"help": "The folder to save the output model under"})
|
||||
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
|
||||
test_dataset: str = field(metadata={"help": "The JSON file containing the evaluation dataset"})
|
||||
base_model: str = field(metadata={"help": "The base model to load for fine-tuning"})
|
||||
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
|
||||
test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
|
||||
ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"})
|
||||
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp16"})
|
||||
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
|
||||
micro_batch_size: int = field(default=2, metadata={"help": "The actual batch size that will fit into VRAM on this machine"})
|
||||
eval_batch_size: int = field(default=1, metadata={"help": "The batch size for generation used while performing evaluation"})
|
||||
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
|
||||
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
|
||||
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
|
||||
weight_decay: float = field(default=0.1, metadata={"help": ""})
|
||||
gradient_clip: float = field(default=1.0, metadata={"help": ""})
|
||||
resume_from_checkpoint: str = field(default="", metadata={"help": "The name of the checkpoint to resume training from"})
|
||||
eval_steps: int = field(default=100, metadata={"help": "The number of steps in between evaluations of the model"})
|
||||
eval_steps: int = field(default=200, metadata={"help": "The number of steps in between evaluations of the model; set to -1 to evaluate every epoch"})
|
||||
save_steps: int = field(default=-1, metadata={"help": "The number of steps in between model checkpoints; set to -1 to save every epoch"})
|
||||
save_total_limit: int = field(default=1, metadata={"help": "The number of recent checkpoints of the model to save (not including the final model)"})
|
||||
group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding"})
|
||||
|
||||
# Quantization
|
||||
@@ -99,8 +102,6 @@ training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strin
|
||||
if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1:
|
||||
raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq")
|
||||
|
||||
# TODO: write a proper evaluation script
|
||||
|
||||
print(f"Loading model '{training_run_args.base_model}'...")
|
||||
|
||||
model_kwargs = {}
|
||||
@@ -118,6 +119,7 @@ else:
|
||||
model_kwargs["torch_dtype"] = torch.float16
|
||||
|
||||
# model_kwargs["resid_pdrop"] = 0.0
|
||||
# model_kwargs["revision"] = "accfee56d8988cae60915486310362db5831b1bd"
|
||||
|
||||
def find_max_vram(min_buffer_mib=800):
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
@@ -136,10 +138,11 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
max_memory=find_max_vram(),
|
||||
**model_kwargs
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, trust_remote_code=True, use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, trust_remote_code=True)
|
||||
|
||||
if training_run_args.add_pad_token:
|
||||
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if training_run_args.add_chatml_tokens:
|
||||
tokenizer.add_special_tokens({
|
||||
@@ -150,6 +153,8 @@ if training_run_args.add_chatml_tokens:
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
# TODO: add chatml template to tokenizer for auto-formatting
|
||||
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
@@ -183,31 +188,37 @@ if training_run_args.use_lora:
|
||||
base_dir = "loras" if training_run_args.use_lora else "models"
|
||||
model_dir = f"./{base_dir}/{training_run_args.run_name}"
|
||||
|
||||
# TODO: eval is broken (returning NaN for loss)
|
||||
training_kwargs = {}
|
||||
|
||||
if training_run_args.test_dataset:
|
||||
training_kwargs.update({
|
||||
"per_device_eval_batch_size": training_run_args.eval_batch_size,
|
||||
"evaluation_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
|
||||
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
|
||||
"bf16_full_eval": training_run_args.bf16,
|
||||
})
|
||||
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=training_run_args.micro_batch_size,
|
||||
# per_device_eval_batch_size=training_run_args.micro_batch_size,
|
||||
gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size,
|
||||
gradient_checkpointing=training_run_args.gradient_checkpointing,
|
||||
# weight_decay=training_run_args.weight_decay,
|
||||
# max_grad_norm=training_run_args.gradient_clip,
|
||||
# evaluation_strategy="steps",
|
||||
# eval_steps=training_run_args.eval_steps,
|
||||
weight_decay=training_run_args.weight_decay,
|
||||
max_grad_norm=training_run_args.gradient_clip,
|
||||
save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"),
|
||||
save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None),
|
||||
save_safetensors=True,
|
||||
logging_steps=5,
|
||||
output_dir=model_dir,
|
||||
num_train_epochs=training_run_args.epochs,
|
||||
save_total_limit=1,
|
||||
# dataloader_pin_memory=False,
|
||||
save_total_limit=training_run_args.save_total_limit,
|
||||
report_to="tensorboard",
|
||||
learning_rate=training_run_args.learning_rate,
|
||||
lr_scheduler_type=training_run_args.learning_rate_schedule,
|
||||
log_level="info",
|
||||
bf16=training_run_args.bf16,
|
||||
# bf16_full_eval=training_run_args.bf16,
|
||||
group_by_length=training_run_args.group_by_length
|
||||
group_by_length=training_run_args.group_by_length,
|
||||
**training_kwargs,
|
||||
# include_inputs_for_metrics=True,
|
||||
)
|
||||
|
||||
class DataCollatorForSupervisedFineTuning(object):
|
||||
@@ -315,7 +326,10 @@ class DataCollatorForSupervisedFineTuning(object):
|
||||
)
|
||||
|
||||
print("Loading dataset...")
|
||||
datasets = load_dataset("json", data_files={ "train": training_run_args.train_dataset, "test": training_run_args.test_dataset })
|
||||
data_files = { "train": training_run_args.train_dataset }
|
||||
if training_run_args.test_dataset:
|
||||
data_files["test"] = training_run_args.test_dataset
|
||||
datasets = load_dataset("json", data_files=data_files)
|
||||
|
||||
def tokenize(example):
|
||||
return tokenizer(
|
||||
@@ -326,18 +340,24 @@ def tokenize(example):
|
||||
)
|
||||
|
||||
print("Tokenizing datasets...")
|
||||
tokenized_test_dataset = None
|
||||
tokenized_train_dataset = datasets["train"].map(tokenize, batched=True).remove_columns(["text"])
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize, batched=True).remove_columns(["text"])
|
||||
if training_run_args.test_dataset:
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize, batched=True).remove_columns(["text"])
|
||||
|
||||
tokens_in_train_set = sum(len(example) for example in tokenized_train_dataset["input_ids"])
|
||||
print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens")
|
||||
|
||||
data_collator = DataCollatorForSupervisedFineTuning(tokenizer=tokenizer)
|
||||
|
||||
import random
|
||||
from torch.utils.data import SequentialSampler, Subset, RandomSampler
|
||||
class RandomEvalSubsetTrainer(Trainer):
|
||||
def __init__(self, random_eval_sample_pct=0.1, *args, **kwargs):
|
||||
def __init__(self, random_eval_sample_pct=0.1, learning_rate_overshoot=1.15, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.random_eval_sample_pct = random_eval_sample_pct
|
||||
self.evaluate_full_dataset = False
|
||||
self.learning_rate_overshoot = learning_rate_overshoot
|
||||
|
||||
def evaluate_all(self):
|
||||
self.evaluate_full_dataset = True
|
||||
@@ -359,13 +379,34 @@ class RandomEvalSubsetTrainer(Trainer):
|
||||
return super()._get_train_sampler()
|
||||
|
||||
return RandomSampler(self.train_dataset, generator=torch.Generator(device='cpu'))
|
||||
|
||||
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
||||
"""
|
||||
Saw this in the chinchilla paper. It says not to go over 25% overshoot
|
||||
Should speed up training by skipping the final fine tuning part that doesn't affect accuracy much
|
||||
"""
|
||||
return super().create_scheduler(int(num_training_steps * self.learning_rate_overshoot), optimizer=optimizer)
|
||||
|
||||
def compute_metrics(pred: EvalPrediction):
|
||||
inputs = pred.inputs
|
||||
label_ids = pred.label_ids
|
||||
logits = pred.predictions
|
||||
|
||||
return {"accuracy": 1.0 }
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
"""https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941/22"""
|
||||
pred_ids = torch.argmax(logits, dim=-1)
|
||||
return pred_ids, labels
|
||||
|
||||
trainer = RandomEvalSubsetTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_train_dataset,
|
||||
# eval_dataset=tokenized_test_dataset,
|
||||
eval_dataset=tokenized_test_dataset,
|
||||
data_collator=data_collator,
|
||||
# compute_metrics=compute_metrics,
|
||||
# preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
tensorboard_process = None
|
||||
|
||||
Reference in New Issue
Block a user