Merge branch 'release/v0.2.3'

This commit is contained in:
Alex O'Connell
2024-01-21 21:46:42 -05:00
7 changed files with 543 additions and 260 deletions

View File

@@ -103,8 +103,9 @@ When setting up the component, there are 4 different "backend" options to choose
1. Llama.cpp with a model from HuggingFace
2. Llama.cpp with a locally provided model
3. A remote instance of text-generation-webui
4. A generic OpenAI API compatible interface
- *should* be compatible with: LocalAI, LM Studio, and all other OpenAI compatible backends
4. A generic OpenAI API compatible interface; *should* be compatible with LocalAI, LM Studio, and all other OpenAI compatible backends
See (docs/Backend Configuration.md)[/docs/Backend%20Configuration.md] for more info.
**Installing llama-cpp-python for local model usage**:
In order to run a model directly as part of your Home Assistant installation, you will need to install one of the pre-build wheels because there are no existing musllinux wheels for the package. Compatible wheels for x86_x64 and arm64 are provided in the [dist](./dist) folder. Copy the `*.whl` files to the `custom_components/llama_conversation/` folder. They will be installed while setting up the component.
@@ -170,6 +171,7 @@ It is highly recommend to set up text-generation-webui on a separate machine tha
## Version History
| Version | Description |
| ------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| v0.2.3 | Fix API key auth, Support chat completion endpoint, and refactor to make it easier to add more remote backends |
| v0.2.2 | Fix options window after upgrade, fix training script for new Phi model format, and release new models |
| v0.2.1 | Properly expose generation parameters for each backend, handle config entry updates without reloading, support remote backends with an API key |
| v0.2 | Bug fixes, support more backends, support for climate + switch devices, JSON style function calling with parameters, GBNF grammars |

43
TODO.md
View File

@@ -1,31 +1,32 @@
# TODO
[x] ChatML format (actually need to add special tokens)
[x] Vicuna dataset merge (yahma/alpaca-cleaned)
[x] Phi-2 fine tuning
[x] Quantize /w llama.cpp
[x] Make custom component use llama.cpp + ChatML
[x] Continued synthetic dataset improvements (there are a bunch of TODOs in there)
[x] Licenses + Attributions
[x] Finish Readme/docs for initial release
[x] Function calling as JSON
[ ] multi-turn prompts; better instruct dataset like dolphin/wizardlm?
[x] Fine tune Phi-1.5 version
[ ] "context requests"
- [x] ChatML format (actually need to add special tokens)
- [x] Vicuna dataset merge (yahma/alpaca-cleaned)
- [x] Phi-2 fine tuning
- [x] Quantize /w llama.cpp
- [x] Make custom component use llama.cpp + ChatML
- [x] Continued synthetic dataset improvements (there are a bunch of TODOs in there)
- [x] Licenses + Attributions
- [x] Finish Readme/docs for initial release
- [x] Function calling as JSON
- [ ] multi-turn prompts; better instruct dataset like dolphin/wizardlm?
- [x] Fine tune Phi-1.5 version
- [ ] "context requests"
- basically just let the model decide what RAG/extra context it wants
- the model predicts special tokens as the first few tokens of its output
- the requested content is added to the context after the request tokens and then generation continues
- needs more complicated training b/c multi-turn + there will be some weird masking going on for training the responses properly
[ ] RAG for getting info for setting up new devices
- [ ] RAG for getting info for setting up new devices
- set up vectordb
- ingest home assistant docs
- "context request" from above to initiate a RAG search
[x] make llama-cpp-python wheels for "llama-cpp-python>=0.2.24"
[ ] prime kv cache with current "state" so that requests are faster
[ ] make a proper evaluation framework to run. not just loss. should test accuracy on the function calling
[ ] add more remote backends
- [x] make llama-cpp-python wheels for "llama-cpp-python>=0.2.24"
- [ ] prime kv cache with current "state" so that requests are faster
- [ ] make a proper evaluation framework to run. not just loss. should test accuracy on the function calling
- [ ] add more remote backends
- LocalAI (openai compatible)
- Ollama
[x] more config options for prompt template (allow other than chatml)
[ ] publish snapshot of dataset on HF
[ ] figure out DPO for refusals + fixing incorrect entity id
[ ] mixtral + prompting (no fine tuning)
- support chat completions API (might fix Ollama + adds support for text-gen-ui characters)
- [x] more config options for prompt template (allow other than chatml)
- [ ] publish snapshot of dataset on HF
- [ ] figure out DPO for refusals + fixing incorrect entity id
- [ ] mixtral + prompting (no fine tuning)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import importlib
from typing import Literal
from typing import Literal, Any
import requests
import re
@@ -52,6 +52,8 @@ from .const import (
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
@@ -64,9 +66,18 @@ from .const import (
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OPTIONS,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
DOMAIN,
GBNF_GRAMMAR_FILE,
PROMPT_TEMPLATE_DESCRIPTIONS,
@@ -76,9 +87,6 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
def is_local_backend(backend):
return backend not in [BACKEND_TYPE_TEXT_GEN_WEBUI, BACKEND_TYPE_GENERIC_OPENAI]
async def update_listener(hass, entry):
"""Handle options update."""
hass.data[DOMAIN][entry.entry_id] = entry
@@ -87,20 +95,31 @@ async def update_listener(hass, entry):
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Local LLaMA Conversation from a config entry."""
use_local_backend = is_local_backend(
entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
)
# TODO: figure out how to make this happen as part of the config flow. when I tried it errored out passing options in
if len(entry.options) == 0:
entry.options = { **DEFAULT_OPTIONS }
copy_to_options = [ CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET ]
for item in copy_to_options:
value = entry.data.get(item)
if value:
entry.options[item] = value
if use_local_backend:
_LOGGER.info(
"Using model file '%s'", entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
)
def create_agent(backend_type):
agent_cls = None
def create_agent():
return LLaMAAgent(hass, entry)
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
agent_cls = LocalLLaMAAgent
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
agent_cls = GenericOpenAIAPIAgent
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
agent_cls = TextGenerationWebuiAgent
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
agent_cls = LlamaCppPythonAPIAgent
return agent_cls(hass, entry)
# load the model in an executor job because it takes a while and locks up the UI otherwise
agent = await hass.async_add_executor_job(create_agent)
agent = await hass.async_add_executor_job(create_agent, entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE))
# handle updates to the options
entry.async_on_unload(entry.add_update_listener(update_listener))
@@ -151,35 +170,21 @@ def closest_color(requested_color):
class LLaMAAgent(AbstractConversationAgent):
"""Local LLaMA conversation agent."""
hass: Any
entry_id: str
history: dict[str, list[dict]]
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry_id = entry.entry_id
self.history: dict[str, list[dict]] = {}
self.history = {}
self.backend_type = entry.data.get(
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
)
self.use_local_backend = is_local_backend(self.backend_type)
self.api_host = None
self.llm = None
self.grammar = None
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
self.model_name = entry.data.get(CONF_CHAT_MODEL, self.model_path)
if self.use_local_backend:
self._load_local_model()
else:
host = entry.data[CONF_HOST]
port = entry.data[CONF_PORT]
self.api_host = f"http://{host}:{port}"
# only load model if using text-generation-webui
if self.backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
api_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, entry.data.get(CONF_OPENAI_API_KEY))
self._load_remote_model(api_key)
self._load_model(entry)
@property
def entry(self):
@@ -189,6 +194,17 @@ class LLaMAAgent(AbstractConversationAgent):
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
def _load_model(self, entry: ConfigEntry) -> None:
raise NotImplementedError()
def _generate(self, conversation: dict) -> str:
raise NotImplementedError()
async def _async_generate(self, conversation: dict) -> str:
return await self.hass.async_add_executor_job(
self._generate, conversation
)
async def async_process(
self, user_input: ConversationInput
@@ -222,7 +238,7 @@ class LLaMAAgent(AbstractConversationAgent):
if len(conversation) == 0 or refresh_system_prompt:
try:
message = self._async_generate_system_prompt(raw_prompt)
message = self._generate_system_prompt(raw_prompt)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
@@ -243,13 +259,9 @@ class LLaMAAgent(AbstractConversationAgent):
conversation.append({"role": "user", "message": user_input.text})
# _LOGGER.debug("Prompt: %s", prompt)
try:
prompt = await self._async_format_prompt(conversation)
_LOGGER.debug(prompt)
response = await self._async_generate(prompt)
_LOGGER.debug(conversation)
response = await self._async_generate(conversation)
_LOGGER.debug(response)
except Exception as err:
@@ -316,156 +328,6 @@ class LLaMAAgent(AbstractConversationAgent):
response=intent_response, conversation_id=conversation_id
)
def _load_remote_model(self, admin_key: str | None):
try:
currently_loaded_result = requests.get(f"{self.api_host}/v1/internal/model/info")
currently_loaded_result.raise_for_status()
loaded_model = currently_loaded_result.json()["model_name"]
if loaded_model == self.model_name:
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
else:
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
headers = {}
if admin_key:
headers["Authorization"] = f"Basic {admin_key}"
load_result = requests.post(
f"{self.api_host}/v1/internal/model/load",
json={
"model_name": self.model_name,
# TODO: expose arguments to the user in home assistant UI
# "args": {},
}
)
load_result.raise_for_status()
except Exception as ex:
_LOGGER.error("Connection error was: %s", repr(ex))
def _generate_remote(self, prompt: str) -> str:
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
request_params = {
"prompt": prompt,
"model": self.model_name,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
headers = {}
api_key = self.entry.data.get(CONF_OPENAI_API_KEY)
if api_key:
headers["Authorization"] = f"Basic {api_key}"
if self.backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
if preset:
request_params["preset"] = preset
result = requests.post(
f"{self.api_host}/v1/completions",
json=request_params,
timeout=timeout,
headers=headers,
)
try:
result.raise_for_status()
except requests.RequestException as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {result.text}")
return f"Failed to communicate with the API! {err}"
choices = result.json()["choices"]
if choices[0]["finish_reason"] != "stop":
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
return choices[0]["text"]
def _load_local_model(self):
if not self.model_path:
raise Exception(f"Model was not found at '{self.model_path}'!")
# don't import it until now because the wheel is installed by config_flow.py
module = importlib.import_module("llama_cpp")
Llama = getattr(module, "Llama")
LlamaGrammar = getattr(module, "LlamaGrammar")
_LOGGER.debug("Loading model...")
self.llm = Llama(
model_path=self.model_path,
n_ctx=2048,
n_batch=2048,
# TODO: expose arguments to the user in home assistant UI
# n_threads=16,
# n_threads_batch=4,
)
_LOGGER.debug("Loading grammar...")
try:
# TODO: make grammar configurable
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
grammar_str = "".join(f.readlines())
self.grammar = LlamaGrammar.from_string(grammar_str)
_LOGGER.debug("Loaded grammar")
except Exception:
_LOGGER.exception("Failed to load grammar!")
self.grammar = None
def _generate_local(self, prompt: str) -> str:
input_tokens = self.llm.tokenize(
prompt.encode(), add_bos=False
)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
_LOGGER.debug(f"Options: {self.entry.options}")
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
output_tokens = self.llm.generate(
input_tokens,
temp=temperature,
top_k=top_k,
top_p=top_p,
grammar=grammar
)
result_tokens = []
for token in output_tokens:
if token == self.llm.token_eos():
break
result_tokens.append(token)
if len(result_tokens) >= max_tokens:
break
result = self.llm.detokenize(result_tokens).decode()
return result
async def _async_generate(self, prompt: str) -> str:
if self.use_local_backend:
return await self.hass.async_add_executor_job(
self._generate_local, prompt
)
else:
return await self.hass.async_add_executor_job(
self._generate_remote, prompt
)
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
"""Gather exposed entity states"""
entity_states = {}
@@ -479,11 +341,11 @@ class LLaMAAgent(AbstractConversationAgent):
entity_states[state.entity_id] = attributes
domains.add(state.domain)
# _LOGGER.debug(f"Exposed entities: {entity_states}")
_LOGGER.debug(f"Exposed entities: {entity_states}")
return entity_states, list(domains)
async def _async_format_prompt(
def _format_prompt(
self, prompt: list[dict], include_generation_prompt: bool = True
) -> str:
formatted_prompt = ""
@@ -502,9 +364,10 @@ class LLaMAAgent(AbstractConversationAgent):
if include_generation_prompt:
formatted_prompt = formatted_prompt + template_desc["generation_prompt"]
# _LOGGER.debug(formatted_prompt)
return formatted_prompt
def _async_generate_system_prompt(self, prompt_template: str) -> str:
def _generate_system_prompt(self, prompt_template: str) -> str:
"""Generate a prompt for the user."""
entities_to_expose, domains = self._async_get_exposed_entities()
@@ -555,3 +418,269 @@ class LLaMAAgent(AbstractConversationAgent):
},
parse_result=False,
)
class LocalLLaMAAgent(LLaMAAgent):
model_path: str
llm: Any
grammar: Any
def _load_model(self, entry: ConfigEntry) -> None:
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
_LOGGER.info(
"Using model file '%s'", self.model_path
)
if not self.model_path:
raise Exception(f"Model was not found at '{self.model_path}'!")
# don't import it until now because the wheel is installed by config_flow.py
module = importlib.import_module("llama_cpp")
Llama = getattr(module, "Llama")
LlamaGrammar = getattr(module, "LlamaGrammar")
_LOGGER.debug("Loading model...")
self.llm = Llama(
model_path=self.model_path,
n_ctx=2048,
n_batch=2048,
# TODO: expose arguments to the user in home assistant UI
# n_threads=16,
# n_threads_batch=4,
)
_LOGGER.debug("Loading grammar...")
try:
# TODO: make grammar configurable
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
grammar_str = "".join(f.readlines())
self.grammar = LlamaGrammar.from_string(grammar_str)
_LOGGER.debug("Loaded grammar")
except Exception:
_LOGGER.exception("Failed to load grammar!")
self.grammar = None
def _generate(self, conversation: dict) -> str:
prompt = self._format_prompt(conversation)
input_tokens = self.llm.tokenize(
prompt.encode(), add_bos=False
)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
_LOGGER.debug(f"Options: {self.entry.options}")
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
output_tokens = self.llm.generate(
input_tokens,
temp=temperature,
top_k=top_k,
top_p=top_p,
grammar=grammar
)
result_tokens = []
for token in output_tokens:
if token == self.llm.token_eos():
break
result_tokens.append(token)
if len(result_tokens) >= max_tokens:
break
result = self.llm.detokenize(result_tokens).decode()
return result
class GenericOpenAIAPIAgent(LLaMAAgent):
api_host: str
api_key: str
model_name: str
def _load_model(self, entry: ConfigEntry) -> None:
# TODO: https
self.api_host = f"http://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
def _chat_completion_params(self, conversation: dict) -> (str, dict):
request_params = {}
endpoint = "/v1/chat/completions"
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
return endpoint, request_params
def _completion_params(self, conversation: dict) -> (str, dict):
request_params = {}
endpoint = "/v1/completions"
request_params["prompt"] = self._format_prompt(conversation)
return endpoint, request_params
def _extract_response(self, response_json: dict) -> str:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
if response_json["object"] == "chat.completion":
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
def _generate(self, conversation: dict) -> str:
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
request_params = {
"model": self.model_name,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
request_params.update(additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
result = requests.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers,
)
try:
result.raise_for_status()
except requests.RequestException as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {result.text}")
return f"Failed to communicate with the API! {err}"
_LOGGER.debug(result.json())
return self._extract_response(result.json())
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
admin_key: str
def _load_model(self, entry: ConfigEntry) -> None:
super()._load_model(entry)
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
try:
currently_loaded_result = requests.get(f"{self.api_host}/v1/internal/model/info")
currently_loaded_result.raise_for_status()
loaded_model = currently_loaded_result.json()["model_name"]
if loaded_model == self.model_name:
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
return
else:
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
headers = {}
if self.admin_key:
headers["Authorization"] = f"Bearer {self.admin_key}"
load_result = requests.post(
f"{self.api_host}/v1/internal/model/load",
json={
"model_name": self.model_name,
# TODO: expose arguments to the user in home assistant UI
# "args": {},
}
)
load_result.raise_for_status()
except Exception as ex:
_LOGGER.debug("Connection error was: %s", repr(ex))
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
def _chat_completion_params(self, conversation: dict) -> (str, dict):
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
chat_mode = self.entry.options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE, DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE)
endpoint, request_params = super()._chat_completion_params(conversation)
request_params["mode"] = chat_mode
if chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT or chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT:
if preset:
request_params["character"] = preset
elif chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT:
# TODO: handle uppercase properly?
request_params["instruction_template"] = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
return endpoint, request_params
def _completion_params(self, conversation: dict) -> (str, dict):
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
endpoint, request_params = super()._completion_params(conversation)
if preset:
request_params["preset"] = preset
return endpoint, request_params
def _extract_response(self, response_json: dict) -> str:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
if response_json["object"] == "chat.completions":
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
grammar: str
def _load_model(self, entry: ConfigEntry):
super()._load_model(entry)
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
self.grammar = "".join(f.readlines())
def _chat_completion_params(self, conversation: dict) -> (str, dict):
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
endpoint, request_params = super()._chat_completion_params(conversation)
request_params["top_k"] = top_k
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
request_params["grammar"] = self.grammar
return endpoint, request_params
def _completion_params(self, conversation: dict) -> (str, dict):
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
endpoint, request_params = super()._completion_params(conversation)
request_params["top_k"] = top_k
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
request_params["grammar"] = self.grammar
return endpoint, request_params

View File

@@ -57,6 +57,8 @@ from .const import (
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_CHAT_MODEL,
DEFAULT_HOST,
DEFAULT_PORT,
@@ -74,22 +76,29 @@ from .const import (
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_OPTIONS,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
PROMPT_TEMPLATE_CHATML,
PROMPT_TEMPLATE_ALPACA,
PROMPT_TEMPLATE_VICUNA,
PROMPT_TEMPLATE_MISTRAL,
PROMPT_TEMPLATE_NONE,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
def is_local_backend(backend):
return backend not in [BACKEND_TYPE_TEXT_GEN_WEBUI, BACKEND_TYPE_GENERIC_OPENAI]
return backend in [BACKEND_TYPE_LLAMA_EXISTING, BACKEND_TYPE_LLAMA_HF]
def STEP_INIT_DATA_SCHEMA(backend_type=None):
return vol.Schema(
@@ -98,10 +107,16 @@ def STEP_INIT_DATA_SCHEMA(backend_type=None):
CONF_BACKEND_TYPE,
default=backend_type if backend_type else DEFAULT_BACKEND_TYPE
): SelectSelector(SelectSelectorConfig(
options=[ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING, BACKEND_TYPE_TEXT_GEN_WEBUI, BACKEND_TYPE_GENERIC_OPENAI ],
options=[
BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
# BACKEND_TYPE_OLLAMA
],
translation_key=CONF_BACKEND_TYPE,
multiple=False,
mode=SelectSelectorMode.LIST,
mode=SelectSelectorMode.DROPDOWN,
))
}
)
@@ -121,19 +136,33 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
}
)
def STEP_REMOTE_SETUP_DATA_SCHEMA(include_admin_key: bool, *, host=None, port=None, chat_model=None):
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset=None, webui_chat_mode=None):
extra = {}
if include_admin_key:
extra[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
extra1, extra2 = ({}, {})
default_port = DEFAULT_PORT
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_PRESET, default=webui_preset)] = str
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_CHAT_MODE, default=webui_chat_mode)] = SelectSelector(SelectSelectorConfig(
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
))
extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
default_port = "8000"
return vol.Schema(
{
vol.Required(CONF_HOST, default=host if host else DEFAULT_HOST): str,
vol.Required(CONF_PORT, default=port if port else DEFAULT_PORT): str,
vol.Required(CONF_PORT, default=port if port else default_port): str,
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): str,
vol.Required(CONF_REMOTE_USE_CHAT_ENDPOINT, default=use_chat_endpoint if use_chat_endpoint else DEFAULT_REMOTE_USE_CHAT_ENDPOINT): bool,
**extra1,
vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type="password")),
**extra
**extra2
}
)
@@ -242,7 +271,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
install_wheel_error = None
download_task = None
download_error = None
model_options: dict[str, Any] = {}
model_config: dict[str, Any] = {}
@property
def flow_manager(self) -> config_entries.ConfigEntriesFlowManager:
@@ -281,7 +310,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
if user_input:
try:
local_backend = is_local_backend(user_input[CONF_BACKEND_TYPE])
self.model_options.update(user_input)
self.model_config.update(user_input)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
@@ -353,7 +382,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
) -> FlowResult:
errors = {}
backend_type = self.model_options[CONF_BACKEND_TYPE]
backend_type = self.model_config[CONF_BACKEND_TYPE]
if backend_type == BACKEND_TYPE_LLAMA_HF:
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA()
elif backend_type == BACKEND_TYPE_LLAMA_EXISTING:
@@ -364,13 +393,13 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
if self.download_error:
errors["base"] = "download_failed"
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
chat_model=self.model_options[CONF_CHAT_MODEL],
downloaded_model_quantization=self.model_options[CONF_DOWNLOADED_MODEL_QUANTIZATION]
chat_model=self.model_config[CONF_CHAT_MODEL],
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
)
if user_input:
try:
self.model_options.update(user_input)
self.model_config.update(user_input)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
@@ -379,7 +408,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
if backend_type == BACKEND_TYPE_LLAMA_HF:
return await self.async_step_download()
else:
model_file = self.model_options[CONF_DOWNLOADED_MODEL_FILE]
model_file = self.model_config[CONF_DOWNLOADED_MODEL_FILE]
if os.path.exists(model_file):
return await self.async_step_finish()
else:
@@ -401,8 +430,8 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
progress_action="download",
)
model_name = self.model_options[CONF_CHAT_MODEL]
quantization_type = self.model_options[CONF_DOWNLOADED_MODEL_QUANTIZATION]
model_name = self.model_config[CONF_CHAT_MODEL]
quantization_type = self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
storage_folder = os.path.join(self.hass.config.media_dirs["local"], "models")
self.download_task = self.hass.async_add_executor_job(
@@ -422,7 +451,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
self.download_error = download_result
return self.async_show_progress_done(next_step_id="local_model")
else:
self.model_options[CONF_DOWNLOADED_MODEL_FILE] = download_result
self.model_config[CONF_DOWNLOADED_MODEL_FILE] = download_result
return self.async_show_progress_done(next_step_id="finish")
@@ -431,10 +460,10 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
headers = {}
api_key = user_input.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, user_input.get(CONF_OPENAI_API_KEY))
if api_key:
headers["Authorization"] = f"Basic {api_key}"
headers["Authorization"] = f"Bearer {api_key}"
models_result = requests.get(
f"http://{self.model_options[CONF_HOST]}:{self.model_options[CONF_PORT]}/v1/internal/model/list",
f"http://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
headers=headers
)
models_result.raise_for_status()
@@ -442,7 +471,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
models = models_result.json()
for model in models["model_names"]:
if model == self.model_options[CONF_CHAT_MODEL].replace("/", "_"):
if model == self.model_config[CONF_CHAT_MODEL].replace("/", "_"):
return ""
return "missing_model_api"
@@ -455,12 +484,12 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
errors = {}
backend_type = self.model_options[CONF_BACKEND_TYPE]
backend_type = self.model_config[CONF_BACKEND_TYPE]
schema = STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI)
if user_input:
try:
self.model_options.update(user_input)
self.model_config.update(user_input)
# only validate and load when using text-generation-webui
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
@@ -474,6 +503,9 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
host=user_input[CONF_HOST],
port=user_input[CONF_PORT],
chat_model=user_input[CONF_CHAT_MODEL],
use_chat_endpoint=user_input[CONF_REMOTE_USE_CHAT_ENDPOINT],
webui_preset=user_input.get(CONF_TEXT_GEN_WEBUI_PRESET),
webui_chat_mode=user_input.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE),
)
else:
return await self.async_step_finish()
@@ -492,16 +524,16 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
model_name = self.model_options.get(CONF_CHAT_MODEL)
backend = self.model_options[CONF_BACKEND_TYPE]
model_name = self.model_config.get(CONF_CHAT_MODEL)
backend = self.model_config[CONF_BACKEND_TYPE]
if backend == BACKEND_TYPE_LLAMA_EXISTING:
model_name = os.path.basename(self.model_options.get(CONF_DOWNLOADED_MODEL_FILE))
model_name = os.path.basename(self.model_config.get(CONF_DOWNLOADED_MODEL_FILE))
location = "llama.cpp" if is_local_backend(backend) else "remote"
return self.async_create_entry(
title=f"LLM Model '{model_name}' ({location})",
description="A Large Language Model Chat Agent",
data=self.model_options,
data=self.model_config,
)
@staticmethod
@@ -583,6 +615,11 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
description={"suggested_value": options.get(CONF_SERVICE_CALL_REGEX)},
default=DEFAULT_SERVICE_CALL_REGEX,
): str,
vol.Required(
CONF_REFRESH_SYSTEM_PROMPT,
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT)},
default=DEFAULT_REFRESH_SYSTEM_PROMPT,
): bool,
}
if is_local_backend(backend_type):
@@ -606,12 +643,7 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
CONF_USE_GBNF_GRAMMAR,
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
default=DEFAULT_USE_GBNF_GRAMMAR,
): bool,
vol.Required(
CONF_REFRESH_SYSTEM_PROMPT,
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT)},
default=DEFAULT_REFRESH_SYSTEM_PROMPT,
): bool,
): bool
})
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
result = insert_after_key(result, CONF_MAX_TOKENS, {
@@ -620,13 +652,66 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): int,
vol.Required(
CONF_REMOTE_USE_CHAT_ENDPOINT,
description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
): bool,
vol.Optional(
CONF_TEXT_GEN_WEBUI_PRESET,
description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_PRESET)},
): str,
vol.Required(
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE)},
default=DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
): SelectSelector(SelectSelectorConfig(
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
)),
})
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
result = insert_after_key(result, CONF_MAX_TOKENS, {
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): int,
vol.Required(
CONF_REMOTE_USE_CHAT_ENDPOINT,
description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
): bool,
vol.Required(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
})
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
result = insert_after_key(result, CONF_MAX_TOKENS, {
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): int,
vol.Required(
CONF_REMOTE_USE_CHAT_ENDPOINT,
description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
): bool,
vol.Required(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K,
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
vol.Required(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
@@ -638,10 +723,10 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
CONF_USE_GBNF_GRAMMAR,
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
default=DEFAULT_USE_GBNF_GRAMMAR,
): bool
})
return result

View File

@@ -24,6 +24,8 @@ BACKEND_TYPE_LLAMA_HF = "llama_cpp_hf"
BACKEND_TYPE_LLAMA_EXISTING = "llama_cpp_existing"
BACKEND_TYPE_TEXT_GEN_WEBUI = "text-generation-webui_api"
BACKEND_TYPE_GENERIC_OPENAI = "generic_openai"
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER = "llama_cpp_python_server"
BACKEND_TYPE_OLLAMA = "ollama"
DEFAULT_BACKEND_TYPE = BACKEND_TYPE_LLAMA_HF
CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization"
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = ["Q8_0", "Q5_K_M", "Q4_K_M", "Q3_K_M"]
@@ -84,6 +86,13 @@ CONF_REFRESH_SYSTEM_PROMPT = "refresh_prompt_per_tern"
DEFAULT_REFRESH_SYSTEM_PROMPT = True
CONF_SERVICE_CALL_REGEX = "service_call_regex"
DEFAULT_SERVICE_CALL_REGEX = r"```homeassistant\n([\S \t\n]*?)```"
CONF_REMOTE_USE_CHAT_ENDPOINT = "remote_use_chat_endpoint"
DEFAULT_REMOTE_USE_CHAT_ENDPOINT = False
CONF_TEXT_GEN_WEBUI_CHAT_MODE = "text_generation_webui_chat_mode"
TEXT_GEN_WEBUI_CHAT_MODE_CHAT = "chat"
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT = "instruct"
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT = "chat-instruct"
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE = TEXT_GEN_WEBUI_CHAT_MODE_CHAT
DEFAULT_OPTIONS = types.MappingProxyType(
{
@@ -97,6 +106,8 @@ DEFAULT_OPTIONS = types.MappingProxyType(
CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
}
)

View File

@@ -30,7 +30,10 @@
"huggingface_model": "Model Name",
"port": "API Port",
"openai_api_key": "API Key",
"text_generation_webui_admin_key": "Admin Key"
"text_generation_webui_admin_key": "Admin Key",
"text_generation_webui_preset": "Generation Preset/Character Name",
"remote_use_chat_endpoint": "Use chat completions endpoint",
"text_generation_webui_chat_mode": "Chat Mode"
},
"description": "Provide the connection details for an instance of text-generation-webui that is hosting the model.",
"title": "Configure connection to remote API"
@@ -40,7 +43,7 @@
"download_model_from_hf": "Download model from HuggingFace",
"use_local_backend": "Use Llama.cpp"
},
"description": "Select the backend for running the model. Either Llama.cpp (locally) or text-generation-webui (remote).",
"description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n\nIf using Llama.cpp locally, make sure you copied the correct wheel file to the same directory as the integration.",
"title": "Select Backend"
}
}
@@ -62,7 +65,9 @@
"text_generation_webui_admin_key": "Admin Key",
"service_call_regex": "Service Call Regex",
"refresh_prompt_per_tern": "Refresh System Prompt Every Turn",
"text_generation_webui_preset": "Generation Preset Name"
"text_generation_webui_preset": "Generation Preset/Character Name",
"remote_use_chat_endpoint": "Use chat completions endpoint",
"text_generation_webui_chat_mode": "Chat Mode"
}
}
}
@@ -73,6 +78,7 @@
"chatml": "ChatML",
"vicuna": "Vicuna",
"alpaca": "Alpaca",
"mistral": "Mistral",
"no_prompt_template": "None"
}
},
@@ -81,7 +87,17 @@
"llama_cpp_hf": "Llama.cpp (HuggingFace)",
"llama_cpp_existing": "Llama.cpp (existing model)",
"text-generation-webui_api": "text-generation-webui API",
"generic_openai": "Generic OpenAI Compatible API"
"generic_openai": "Generic OpenAI Compatible API",
"llama_cpp_python_server": "llama-cpp-python Server",
"ollama": "Ollama"
}
},
"text_generation_webui_chat_mode": {
"options": {
"chat": "Chat",
"instruct": "Instruct",
"chat-instruct": "Chat-Instruct"
}
}
}

View File

@@ -0,0 +1,39 @@
# Backend Configuration
There are multiple backends to choose for running the model that the Home Assistant integration uses. Here is a description of all the options for each backend
# Common Options
| Option Name | Description | Suggested Value |
| ------------ | --------- | ------------ |
| System Prompt | [see here](./Model%20Prompting.md) | |
| Prompt Format | The format for the context of the model | |
| Maximum tokens to return in response | Limits the number of tokens that can be produced by each model response | 512 |
| Additional attribute to expose in the context | Extra attributes that will be exposed to the model via the `{{ devices }}` template variable | |
| Service Call Regex | The regular expression used to extract service calls from the model response; should contain 1 repeated capture group | |
| Refresh System Prompt Every Turn | Flag to update the system prompt with updated device states on every chat turn. Disabling can significantly improve agent response times when using a backend that supports prefix caching (Llama.cpp) | Enabled |
# Llama.cpp
For details about the sampling parameters, see here: https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#parameters-description
| Option Name | Description | Suggested Value |
| ------------ | --------- | ------------ |
| Top K | Sampling parameter; see above link | 40 |
| Top P | Sampling parameter; see above link | 1.0 |
| Temperature | Sampling parameter; see above link | 0.1 |
| Enable GBNF Grammar | Restricts the output of the model to follow a pre-defined syntax; eliminates function calling syntax errors on quantized models | Enabled |
# text-generation-webui
| Option Name | Description | Suggested Value |
| ------------ | --------- | ------------ |
| Request Timeout | The maximum time in seconds that the integration will wait for a response from the remote server | 90 (higher if running on low resource hardware) |
| Use chat completions endpoint | Flag to use `/v1/chat/completions` as the remote endpoint instead of `/v1/completions` | |
| Generation Preset/Character Name | The preset or character name to pass to the backend. If none is provided then the settings that are currently selected in the UI will be applied | |
| Chat Mode | [see here](https://github.com/oobabooga/text-generation-webui/wiki/01-%E2%80%90-Chat-Tab#mode) | Instruct |
# Generic OpenAI API Compatible
For details about the sampling parameters, see here: https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#parameters-description
| Option Name | Description | Suggested Value |
| ------------ | --------- | ------------ |
| Request Timeout | The maximum time in seconds that the integration will wait for a response from the remote server | 90 (higher if running on low resource hardware) |
| Use chat completions endpoint | Flag to use `/v1/chat/completions` as the remote endpoint instead of `/v1/completions` | Backend Dependent |
| Top P | Sampling parameter; see above link | 1.0 |
| Temperature | Sampling parameter; see above link | 0.1 |