fix auth, add llama-cpp-python server backend, better in-app docs

This commit is contained in:
Alex O'Connell
2024-01-21 21:38:23 -05:00
parent 7c30bb57cf
commit 1594844962
4 changed files with 79 additions and 18 deletions

View File

@@ -73,6 +73,8 @@ from .const import (
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
@@ -103,12 +105,18 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
entry.options[item] = value
def create_agent(backend_type):
agent_cls = None
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
return LocalLLaMAAgent(hass, entry)
agent_cls = LocalLLaMAAgent
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
return GenericOpenAIAPIAgent(hass, entry)
agent_cls = GenericOpenAIAPIAgent
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
return TextGenerationWebuiAgent(hass, entry)
agent_cls = TextGenerationWebuiAgent
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
agent_cls = LlamaCppPythonAPIAgent
return agent_cls(hass, entry)
# load the model in an executor job because it takes a while and locks up the UI otherwise
agent = await hass.async_add_executor_job(create_agent, entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE))
@@ -333,7 +341,7 @@ class LLaMAAgent(AbstractConversationAgent):
entity_states[state.entity_id] = attributes
domains.add(state.domain)
# _LOGGER.debug(f"Exposed entities: {entity_states}")
_LOGGER.debug(f"Exposed entities: {entity_states}")
return entity_states, list(domains)
@@ -356,6 +364,7 @@ class LLaMAAgent(AbstractConversationAgent):
if include_generation_prompt:
formatted_prompt = formatted_prompt + template_desc["generation_prompt"]
# _LOGGER.debug(formatted_prompt)
return formatted_prompt
def _generate_system_prompt(self, prompt_template: str) -> str:
@@ -497,7 +506,7 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
# TODO: https
self.api_host = f"http://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL, self.model_path)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
def _chat_completion_params(self, conversation: dict) -> (str, dict):
@@ -550,7 +559,7 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
headers = {}
if self.api_key:
headers["Authorization"] = f"Basic {self.api_key}"
headers["Authorization"] = f"Bearer {self.api_key}"
result = requests.post(
f"{self.api_host}{endpoint}",
@@ -591,7 +600,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
headers = {}
if self.admin_key:
headers["Authorization"] = f"Basic {self.admin_key}"
headers["Authorization"] = f"Bearer {self.admin_key}"
load_result = requests.post(
f"{self.api_host}/v1/internal/model/load",
@@ -604,7 +613,8 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
load_result.raise_for_status()
except Exception as ex:
_LOGGER.error("Connection error was: %s", repr(ex))
_LOGGER.debug("Connection error was: %s", repr(ex))
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
def _chat_completion_params(self, conversation: dict) -> (str, dict):
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
@@ -654,7 +664,7 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
self.grammar = "".join(f.readlines())
def _chat_completion_params(self, conversation: dict) -> (str, dict):
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
endpoint, request_params = super()._chat_completion_params(conversation)
request_params["top_k"] = top_k
@@ -665,7 +675,7 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
return endpoint, request_params
def _completion_params(self, conversation: dict) -> (str, dict):
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
endpoint, request_params = super()._completion_params(conversation)
request_params["top_k"] = top_k

View File

@@ -82,6 +82,8 @@ from .const import (
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
PROMPT_TEMPLATE_CHATML,
PROMPT_TEMPLATE_ALPACA,
PROMPT_TEMPLATE_VICUNA,
@@ -105,10 +107,16 @@ def STEP_INIT_DATA_SCHEMA(backend_type=None):
CONF_BACKEND_TYPE,
default=backend_type if backend_type else DEFAULT_BACKEND_TYPE
): SelectSelector(SelectSelectorConfig(
options=[ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING, BACKEND_TYPE_TEXT_GEN_WEBUI, BACKEND_TYPE_GENERIC_OPENAI ],
options=[
BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
# BACKEND_TYPE_OLLAMA
],
translation_key=CONF_BACKEND_TYPE,
multiple=False,
mode=SelectSelectorMode.LIST,
mode=SelectSelectorMode.DROPDOWN,
))
}
)
@@ -128,10 +136,12 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
}
)
def STEP_REMOTE_SETUP_DATA_SCHEMA(text_gen_webui: bool, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset=None, webui_chat_mode=None):
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset=None, webui_chat_mode=None):
extra1, extra2 = ({}, {})
if text_gen_webui:
default_port = DEFAULT_PORT
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_PRESET, default=webui_preset)] = str
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_CHAT_MODE, default=webui_chat_mode)] = SelectSelector(SelectSelectorConfig(
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
@@ -141,10 +151,13 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(text_gen_webui: bool, *, host=None, port=None,
))
extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
default_port = "8000"
return vol.Schema(
{
vol.Required(CONF_HOST, default=host if host else DEFAULT_HOST): str,
vol.Required(CONF_PORT, default=port if port else DEFAULT_PORT): str,
vol.Required(CONF_PORT, default=port if port else default_port): str,
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): str,
vol.Required(CONF_REMOTE_USE_CHAT_ENDPOINT, default=use_chat_endpoint if use_chat_endpoint else DEFAULT_REMOTE_USE_CHAT_ENDPOINT): bool,
**extra1,
@@ -447,7 +460,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
headers = {}
api_key = user_input.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, user_input.get(CONF_OPENAI_API_KEY))
if api_key:
headers["Authorization"] = f"Basic {api_key}"
headers["Authorization"] = f"Bearer {api_key}"
models_result = requests.get(
f"http://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
@@ -682,5 +695,38 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
})
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
result = insert_after_key(result, CONF_MAX_TOKENS, {
vol.Required(
CONF_REQUEST_TIMEOUT,
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
default=DEFAULT_REQUEST_TIMEOUT,
): int,
vol.Required(
CONF_REMOTE_USE_CHAT_ENDPOINT,
description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
): bool,
vol.Required(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K,
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
vol.Required(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Required(
CONF_USE_GBNF_GRAMMAR,
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
default=DEFAULT_USE_GBNF_GRAMMAR,
): bool
})
return result

View File

@@ -24,6 +24,8 @@ BACKEND_TYPE_LLAMA_HF = "llama_cpp_hf"
BACKEND_TYPE_LLAMA_EXISTING = "llama_cpp_existing"
BACKEND_TYPE_TEXT_GEN_WEBUI = "text-generation-webui_api"
BACKEND_TYPE_GENERIC_OPENAI = "generic_openai"
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER = "llama_cpp_python_server"
BACKEND_TYPE_OLLAMA = "ollama"
DEFAULT_BACKEND_TYPE = BACKEND_TYPE_LLAMA_HF
CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization"
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = ["Q8_0", "Q5_K_M", "Q4_K_M", "Q3_K_M"]

View File

@@ -43,7 +43,7 @@
"download_model_from_hf": "Download model from HuggingFace",
"use_local_backend": "Use Llama.cpp"
},
"description": "Select the backend for running the model. Either Llama.cpp (locally) or text-generation-webui (remote).",
"description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n\nIf using Llama.cpp locally, make sure you copied the correct wheel file to the same directory as the integration.",
"title": "Select Backend"
}
}
@@ -87,7 +87,10 @@
"llama_cpp_hf": "Llama.cpp (HuggingFace)",
"llama_cpp_existing": "Llama.cpp (existing model)",
"text-generation-webui_api": "text-generation-webui API",
"generic_openai": "Generic OpenAI Compatible API"
"generic_openai": "Generic OpenAI Compatible API",
"llama_cpp_python_server": "llama-cpp-python Server",
"ollama": "Ollama"
}
},
"text_generation_webui_chat_mode": {