mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
working version of in context examples
This commit is contained in:
@@ -89,10 +89,11 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
async def update_listener(hass, entry):
|
||||
"""Handle options update."""
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
# call update handler
|
||||
agent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
agent._update_options()
|
||||
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
return True
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
@@ -182,12 +183,22 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
if entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
self._load_icl_examples()
|
||||
else:
|
||||
self.in_context_examples = None
|
||||
|
||||
self._load_model(entry)
|
||||
|
||||
def _load_icl_examples(self):
|
||||
try:
|
||||
self.in_context_examples = list(csv.DictReader(os.path.join(os.path.dirname(__file__), IN_CONTEXT_EXAMPLES_FILE)))
|
||||
icl_filename = os.path.join(os.path.dirname(__file__), IN_CONTEXT_EXAMPLES_FILE)
|
||||
|
||||
with open(icl_filename) as f:
|
||||
self.in_context_examples = list(csv.DictReader(f))
|
||||
|
||||
if set(self.in_context_examples[0].keys()) != set(["service", "response" ]):
|
||||
raise Exception("ICL csv file did not have 2 columns: service & response")
|
||||
|
||||
_LOGGER.debug(f"Loaded {len(self.in_context_examples)} examples for ICL")
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to load in context learning examples!")
|
||||
self.in_context_examples = None
|
||||
@@ -312,13 +323,15 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
if len(line) == 0:
|
||||
break
|
||||
|
||||
# TODO: handle json only format for things like mistral
|
||||
# parse old format or JSON format
|
||||
try:
|
||||
json_output = json.loads(line)
|
||||
service = json_output["service"]
|
||||
entity = json_output["target_device"]
|
||||
domain, service = tuple(service.split("."))
|
||||
if "to_say" in json_output:
|
||||
to_say = to_say + json_output.pop("to_say")
|
||||
|
||||
extra_arguments = { k: v for k, v in json_output.items() if k not in [ "service", "target_device" ] }
|
||||
except Exception:
|
||||
try:
|
||||
@@ -400,6 +413,11 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
||||
template_desc = PROMPT_TEMPLATE_DESCRIPTIONS[prompt_template]
|
||||
|
||||
# handle models without a system prompt
|
||||
if prompt[0]["role"] == "system" and "system" not in template_desc:
|
||||
system_prompt = prompt.pop(0)
|
||||
prompt[0]["message"] = system_prompt["message"] + prompt[0]["message"]
|
||||
|
||||
for message in prompt:
|
||||
role = message["role"]
|
||||
message = message["message"]
|
||||
@@ -411,7 +429,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
if include_generation_prompt:
|
||||
formatted_prompt = formatted_prompt + template_desc["generation_prompt"]
|
||||
|
||||
# _LOGGER.debug(formatted_prompt)
|
||||
_LOGGER.debug(formatted_prompt)
|
||||
return formatted_prompt
|
||||
|
||||
def _generate_system_prompt(self, prompt_template: str) -> str:
|
||||
@@ -428,13 +446,19 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
entity_names = entity_names[:]
|
||||
|
||||
# filter out examples for disabled services
|
||||
in_context_examples = [ x for x in self.in_context_examples if x["service"] in service_names and x["service"].split(".")[0] in entity_domains ]
|
||||
# in_context_examples = [ x for x in self.in_context_examples if x["service"] in service_names and x["service"].split(".")[0] in entity_domains ]
|
||||
selected_in_context_examples = []
|
||||
_LOGGER.debug(service_names)
|
||||
for x in self.in_context_examples:
|
||||
_LOGGER.debug(str(x))
|
||||
if x["service"] in service_names and x["service"].split(".")[0] in entity_domains:
|
||||
selected_in_context_examples.append(x)
|
||||
|
||||
random.shuffle(in_context_examples)
|
||||
random.shuffle(selected_in_context_examples)
|
||||
random.shuffle(entity_names)
|
||||
|
||||
for x in range(num_examples):
|
||||
chosen_example = in_context_examples.pop()
|
||||
chosen_example = selected_in_context_examples.pop()
|
||||
chosen_service = chosen_example["service"]
|
||||
device = [ x for x in entity_names if x.split(".")[0] == chosen_service.split(".")[0] ][0]
|
||||
example = {
|
||||
@@ -484,11 +508,13 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
service_dict = self.hass.services.async_services()
|
||||
all_services = []
|
||||
all_service_names = []
|
||||
for domain in domains:
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
args = flatten_vol_schema(service.schema)
|
||||
args_to_expose = set(args).intersection(allowed_service_call_arguments)
|
||||
all_services.append(f"{domain}.{name}({','.join(args_to_expose)})")
|
||||
all_service_names.append(f"{domain}.{name}")
|
||||
formatted_services = ", ".join(all_services)
|
||||
|
||||
render_variables = {
|
||||
@@ -497,7 +523,8 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
}
|
||||
|
||||
if self.in_context_examples:
|
||||
render_variables["response_examples"] = icl_example_generator(10, list(entities_to_expose.keys()), all_services)
|
||||
# TODO: make number of examples configurable
|
||||
render_variables["response_examples"] = "\n".join(icl_example_generator(4, list(entities_to_expose.keys()), all_service_names))
|
||||
|
||||
return template.Template(prompt_template, self.hass).async_render(
|
||||
render_variables,
|
||||
@@ -560,6 +587,7 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
self.grammar = None
|
||||
|
||||
def _update_options(self):
|
||||
LLaMAAgent._update_options()
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
self._load_grammar()
|
||||
else:
|
||||
|
||||
@@ -53,6 +53,7 @@ from .const import (
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
@@ -77,6 +78,7 @@ from .const import (
|
||||
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_OPTIONS,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
@@ -616,6 +618,11 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS)},
|
||||
): int,
|
||||
vol.Required(
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
||||
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
): bool,
|
||||
}
|
||||
|
||||
if is_local_backend(backend_type):
|
||||
|
||||
@@ -76,8 +76,7 @@ PROMPT_TEMPLATE_DESCRIPTIONS = {
|
||||
"generation_prompt": ""
|
||||
},
|
||||
PROMPT_TEMPLATE_MISTRAL: {
|
||||
"system": { "prefix": "<s>", "suffix": "" },
|
||||
"user": { "prefix": "[INST]", "suffix": "[/INST]" },
|
||||
"user": { "prefix": "<s>[INST] ", "suffix": " [/INST] " },
|
||||
"assistant": { "prefix": "", "suffix": "</s>" },
|
||||
"generation_prompt": ""
|
||||
},
|
||||
|
||||
@@ -70,6 +70,7 @@
|
||||
"refresh_prompt_per_tern": "Refresh System Prompt Every Turn",
|
||||
"remember_conversation": "Remember conversation",
|
||||
"remember_num_interactions": "Number of past interactions to remember",
|
||||
"in_context_examples": "Enable in context learning (ICL) examples",
|
||||
"text_generation_webui_preset": "Generation Preset/Character Name",
|
||||
"remote_use_chat_endpoint": "Use chat completions endpoint",
|
||||
"text_generation_webui_chat_mode": "Chat Mode"
|
||||
|
||||
@@ -15,8 +15,8 @@ The `services` and `devices` variables are special variables that are provided b
|
||||
- `services` expands into a comma separated list of the services that correlate with the devices that have been exposed to the Voice Assistant.
|
||||
- `devices` expands into a multi-line block where each line is the format `<entity_id> '<friendly_name> = <state>;<extra_attributes_to_expose>`
|
||||
|
||||
### Model "Persona"
|
||||
The model is trained with a few different personas. They can be activated by using their system prompt found below:
|
||||
### Home Model "Persona"
|
||||
The Home model is trained with a few different personas. They can be activated by using their system prompt found below:
|
||||
|
||||
Al the Assistant - Responds politely and concisely
|
||||
```
|
||||
@@ -42,3 +42,37 @@ Currently supported prompt formats are:
|
||||
3. Alpaca
|
||||
4. Mistral
|
||||
5. None (useful for foundation models)
|
||||
|
||||
## Prompting other models with In Context Learning
|
||||
It is possible to use models that are not fine-tuned with the dataset via the usage of In Context Learning (ICL) examples. These examples condition the model to output the correct JSON schema without any fine-tuning of the model.
|
||||
|
||||
Here is an example configuration of using Mixtral-7B-Instruct-v0.2.
|
||||
First, download and set up the model on the desired backend.
|
||||
|
||||
Then, navigate to the conversation agent's configuration page and set the following options:
|
||||
|
||||
System Prompt:
|
||||
```
|
||||
You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task ask instructed with the information provided only.
|
||||
Services: {{ services }}
|
||||
Devices:
|
||||
{{ devices }}
|
||||
|
||||
Respond to the following user instruction by responding in the same format as the following examples:
|
||||
{{ response_examples }}
|
||||
|
||||
User instruction:
|
||||
```
|
||||
Prompt Format: `Mistral`
|
||||
Service Call Regex: `({[\S \t]*?})`
|
||||
Enable in context learning (ICL) examples: Checked
|
||||
|
||||
### Explanation
|
||||
Enabling in context learning examples exposes the additional `{{ response_examples }}` variable for the system prompt. This variable is expanded to include various examples in the following format:
|
||||
```
|
||||
{"to_say": "Switching off the fan as requested.", "service": "fan.turn_off", "target_device": "fan.ceiling_fan"}
|
||||
{"to_say": "the todo has been added to your todo list.", "service": "todo.add_item", "target_device": "todo.shopping_list"}
|
||||
{"to_say": "Starting media playback.", "service": "media_player.media_play", "target_device": "media_player.bedroom"}
|
||||
```
|
||||
|
||||
These examples are loaded from the `in_context_examples.csv` file in the `/custom_components/llama_conversation/` folder.
|
||||
Reference in New Issue
Block a user