working version of in context examples

This commit is contained in:
Alex O'Connell
2024-03-20 23:03:31 -04:00
parent 0f1c773bff
commit fa31682c51
5 changed files with 82 additions and 13 deletions

View File

@@ -89,10 +89,11 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def update_listener(hass, entry):
"""Handle options update."""
hass.data[DOMAIN][entry.entry_id] = entry
# call update handler
agent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
agent._update_options()
hass.data[DOMAIN][entry.entry_id] = entry
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
@@ -182,12 +183,22 @@ class LLaMAAgent(AbstractConversationAgent):
if entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
self._load_icl_examples()
else:
self.in_context_examples = None
self._load_model(entry)
def _load_icl_examples(self):
try:
self.in_context_examples = list(csv.DictReader(os.path.join(os.path.dirname(__file__), IN_CONTEXT_EXAMPLES_FILE)))
icl_filename = os.path.join(os.path.dirname(__file__), IN_CONTEXT_EXAMPLES_FILE)
with open(icl_filename) as f:
self.in_context_examples = list(csv.DictReader(f))
if set(self.in_context_examples[0].keys()) != set(["service", "response" ]):
raise Exception("ICL csv file did not have 2 columns: service & response")
_LOGGER.debug(f"Loaded {len(self.in_context_examples)} examples for ICL")
except Exception:
_LOGGER.exception("Failed to load in context learning examples!")
self.in_context_examples = None
@@ -312,13 +323,15 @@ class LLaMAAgent(AbstractConversationAgent):
if len(line) == 0:
break
# TODO: handle json only format for things like mistral
# parse old format or JSON format
try:
json_output = json.loads(line)
service = json_output["service"]
entity = json_output["target_device"]
domain, service = tuple(service.split("."))
if "to_say" in json_output:
to_say = to_say + json_output.pop("to_say")
extra_arguments = { k: v for k, v in json_output.items() if k not in [ "service", "target_device" ] }
except Exception:
try:
@@ -400,6 +413,11 @@ class LLaMAAgent(AbstractConversationAgent):
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
template_desc = PROMPT_TEMPLATE_DESCRIPTIONS[prompt_template]
# handle models without a system prompt
if prompt[0]["role"] == "system" and "system" not in template_desc:
system_prompt = prompt.pop(0)
prompt[0]["message"] = system_prompt["message"] + prompt[0]["message"]
for message in prompt:
role = message["role"]
message = message["message"]
@@ -411,7 +429,7 @@ class LLaMAAgent(AbstractConversationAgent):
if include_generation_prompt:
formatted_prompt = formatted_prompt + template_desc["generation_prompt"]
# _LOGGER.debug(formatted_prompt)
_LOGGER.debug(formatted_prompt)
return formatted_prompt
def _generate_system_prompt(self, prompt_template: str) -> str:
@@ -428,13 +446,19 @@ class LLaMAAgent(AbstractConversationAgent):
entity_names = entity_names[:]
# filter out examples for disabled services
in_context_examples = [ x for x in self.in_context_examples if x["service"] in service_names and x["service"].split(".")[0] in entity_domains ]
# in_context_examples = [ x for x in self.in_context_examples if x["service"] in service_names and x["service"].split(".")[0] in entity_domains ]
selected_in_context_examples = []
_LOGGER.debug(service_names)
for x in self.in_context_examples:
_LOGGER.debug(str(x))
if x["service"] in service_names and x["service"].split(".")[0] in entity_domains:
selected_in_context_examples.append(x)
random.shuffle(in_context_examples)
random.shuffle(selected_in_context_examples)
random.shuffle(entity_names)
for x in range(num_examples):
chosen_example = in_context_examples.pop()
chosen_example = selected_in_context_examples.pop()
chosen_service = chosen_example["service"]
device = [ x for x in entity_names if x.split(".")[0] == chosen_service.split(".")[0] ][0]
example = {
@@ -484,11 +508,13 @@ class LLaMAAgent(AbstractConversationAgent):
service_dict = self.hass.services.async_services()
all_services = []
all_service_names = []
for domain in domains:
for name, service in service_dict.get(domain, {}).items():
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(allowed_service_call_arguments)
all_services.append(f"{domain}.{name}({','.join(args_to_expose)})")
all_service_names.append(f"{domain}.{name}")
formatted_services = ", ".join(all_services)
render_variables = {
@@ -497,7 +523,8 @@ class LLaMAAgent(AbstractConversationAgent):
}
if self.in_context_examples:
render_variables["response_examples"] = icl_example_generator(10, list(entities_to_expose.keys()), all_services)
# TODO: make number of examples configurable
render_variables["response_examples"] = "\n".join(icl_example_generator(4, list(entities_to_expose.keys()), all_service_names))
return template.Template(prompt_template, self.hass).async_render(
render_variables,
@@ -560,6 +587,7 @@ class LocalLLaMAAgent(LLaMAAgent):
self.grammar = None
def _update_options(self):
LLaMAAgent._update_options()
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
self._load_grammar()
else:

View File

@@ -53,6 +53,7 @@ from .const import (
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_SERVICE_CALL_REGEX,
@@ -77,6 +78,7 @@ from .const import (
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_OPTIONS,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
@@ -616,6 +618,11 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
CONF_REMEMBER_NUM_INTERACTIONS,
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS)},
): int,
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
): bool,
}
if is_local_backend(backend_type):

View File

@@ -76,8 +76,7 @@ PROMPT_TEMPLATE_DESCRIPTIONS = {
"generation_prompt": ""
},
PROMPT_TEMPLATE_MISTRAL: {
"system": { "prefix": "<s>", "suffix": "" },
"user": { "prefix": "[INST]", "suffix": "[/INST]" },
"user": { "prefix": "<s>[INST] ", "suffix": " [/INST] " },
"assistant": { "prefix": "", "suffix": "</s>" },
"generation_prompt": ""
},

View File

@@ -70,6 +70,7 @@
"refresh_prompt_per_tern": "Refresh System Prompt Every Turn",
"remember_conversation": "Remember conversation",
"remember_num_interactions": "Number of past interactions to remember",
"in_context_examples": "Enable in context learning (ICL) examples",
"text_generation_webui_preset": "Generation Preset/Character Name",
"remote_use_chat_endpoint": "Use chat completions endpoint",
"text_generation_webui_chat_mode": "Chat Mode"

View File

@@ -15,8 +15,8 @@ The `services` and `devices` variables are special variables that are provided b
- `services` expands into a comma separated list of the services that correlate with the devices that have been exposed to the Voice Assistant.
- `devices` expands into a multi-line block where each line is the format `<entity_id> '<friendly_name> = <state>;<extra_attributes_to_expose>`
### Model "Persona"
The model is trained with a few different personas. They can be activated by using their system prompt found below:
### Home Model "Persona"
The Home model is trained with a few different personas. They can be activated by using their system prompt found below:
Al the Assistant - Responds politely and concisely
```
@@ -42,3 +42,37 @@ Currently supported prompt formats are:
3. Alpaca
4. Mistral
5. None (useful for foundation models)
## Prompting other models with In Context Learning
It is possible to use models that are not fine-tuned with the dataset via the usage of In Context Learning (ICL) examples. These examples condition the model to output the correct JSON schema without any fine-tuning of the model.
Here is an example configuration of using Mixtral-7B-Instruct-v0.2.
First, download and set up the model on the desired backend.
Then, navigate to the conversation agent's configuration page and set the following options:
System Prompt:
```
You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task ask instructed with the information provided only.
Services: {{ services }}
Devices:
{{ devices }}
Respond to the following user instruction by responding in the same format as the following examples:
{{ response_examples }}
User instruction:
```
Prompt Format: `Mistral`
Service Call Regex: `({[\S \t]*?})`
Enable in context learning (ICL) examples: Checked
### Explanation
Enabling in context learning examples exposes the additional `{{ response_examples }}` variable for the system prompt. This variable is expanded to include various examples in the following format:
```
{"to_say": "Switching off the fan as requested.", "service": "fan.turn_off", "target_device": "fan.ceiling_fan"}
{"to_say": "the todo has been added to your todo list.", "service": "todo.add_item", "target_device": "todo.shopping_list"}
{"to_say": "Starting media playback.", "service": "media_player.media_play", "target_device": "media_player.bedroom"}
```
These examples are loaded from the `in_context_examples.csv` file in the `/custom_components/llama_conversation/` folder.