mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
Merge branch 'release/v0.2.3'
This commit is contained in:
@@ -103,8 +103,9 @@ When setting up the component, there are 4 different "backend" options to choose
|
||||
1. Llama.cpp with a model from HuggingFace
|
||||
2. Llama.cpp with a locally provided model
|
||||
3. A remote instance of text-generation-webui
|
||||
4. A generic OpenAI API compatible interface
|
||||
- *should* be compatible with: LocalAI, LM Studio, and all other OpenAI compatible backends
|
||||
4. A generic OpenAI API compatible interface; *should* be compatible with LocalAI, LM Studio, and all other OpenAI compatible backends
|
||||
|
||||
See (docs/Backend Configuration.md)[/docs/Backend%20Configuration.md] for more info.
|
||||
|
||||
**Installing llama-cpp-python for local model usage**:
|
||||
In order to run a model directly as part of your Home Assistant installation, you will need to install one of the pre-build wheels because there are no existing musllinux wheels for the package. Compatible wheels for x86_x64 and arm64 are provided in the [dist](./dist) folder. Copy the `*.whl` files to the `custom_components/llama_conversation/` folder. They will be installed while setting up the component.
|
||||
@@ -170,6 +171,7 @@ It is highly recommend to set up text-generation-webui on a separate machine tha
|
||||
## Version History
|
||||
| Version | Description |
|
||||
| ------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| v0.2.3 | Fix API key auth, Support chat completion endpoint, and refactor to make it easier to add more remote backends |
|
||||
| v0.2.2 | Fix options window after upgrade, fix training script for new Phi model format, and release new models |
|
||||
| v0.2.1 | Properly expose generation parameters for each backend, handle config entry updates without reloading, support remote backends with an API key |
|
||||
| v0.2 | Bug fixes, support more backends, support for climate + switch devices, JSON style function calling with parameters, GBNF grammars |
|
||||
|
||||
43
TODO.md
43
TODO.md
@@ -1,31 +1,32 @@
|
||||
# TODO
|
||||
[x] ChatML format (actually need to add special tokens)
|
||||
[x] Vicuna dataset merge (yahma/alpaca-cleaned)
|
||||
[x] Phi-2 fine tuning
|
||||
[x] Quantize /w llama.cpp
|
||||
[x] Make custom component use llama.cpp + ChatML
|
||||
[x] Continued synthetic dataset improvements (there are a bunch of TODOs in there)
|
||||
[x] Licenses + Attributions
|
||||
[x] Finish Readme/docs for initial release
|
||||
[x] Function calling as JSON
|
||||
[ ] multi-turn prompts; better instruct dataset like dolphin/wizardlm?
|
||||
[x] Fine tune Phi-1.5 version
|
||||
[ ] "context requests"
|
||||
- [x] ChatML format (actually need to add special tokens)
|
||||
- [x] Vicuna dataset merge (yahma/alpaca-cleaned)
|
||||
- [x] Phi-2 fine tuning
|
||||
- [x] Quantize /w llama.cpp
|
||||
- [x] Make custom component use llama.cpp + ChatML
|
||||
- [x] Continued synthetic dataset improvements (there are a bunch of TODOs in there)
|
||||
- [x] Licenses + Attributions
|
||||
- [x] Finish Readme/docs for initial release
|
||||
- [x] Function calling as JSON
|
||||
- [ ] multi-turn prompts; better instruct dataset like dolphin/wizardlm?
|
||||
- [x] Fine tune Phi-1.5 version
|
||||
- [ ] "context requests"
|
||||
- basically just let the model decide what RAG/extra context it wants
|
||||
- the model predicts special tokens as the first few tokens of its output
|
||||
- the requested content is added to the context after the request tokens and then generation continues
|
||||
- needs more complicated training b/c multi-turn + there will be some weird masking going on for training the responses properly
|
||||
[ ] RAG for getting info for setting up new devices
|
||||
- [ ] RAG for getting info for setting up new devices
|
||||
- set up vectordb
|
||||
- ingest home assistant docs
|
||||
- "context request" from above to initiate a RAG search
|
||||
[x] make llama-cpp-python wheels for "llama-cpp-python>=0.2.24"
|
||||
[ ] prime kv cache with current "state" so that requests are faster
|
||||
[ ] make a proper evaluation framework to run. not just loss. should test accuracy on the function calling
|
||||
[ ] add more remote backends
|
||||
- [x] make llama-cpp-python wheels for "llama-cpp-python>=0.2.24"
|
||||
- [ ] prime kv cache with current "state" so that requests are faster
|
||||
- [ ] make a proper evaluation framework to run. not just loss. should test accuracy on the function calling
|
||||
- [ ] add more remote backends
|
||||
- LocalAI (openai compatible)
|
||||
- Ollama
|
||||
[x] more config options for prompt template (allow other than chatml)
|
||||
[ ] publish snapshot of dataset on HF
|
||||
[ ] figure out DPO for refusals + fixing incorrect entity id
|
||||
[ ] mixtral + prompting (no fine tuning)
|
||||
- support chat completions API (might fix Ollama + adds support for text-gen-ui characters)
|
||||
- [x] more config options for prompt template (allow other than chatml)
|
||||
- [ ] publish snapshot of dataset on HF
|
||||
- [ ] figure out DPO for refusals + fixing incorrect entity id
|
||||
- [ ] mixtral + prompting (no fine tuning)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
from typing import Literal
|
||||
from typing import Literal, Any
|
||||
|
||||
import requests
|
||||
import re
|
||||
@@ -52,6 +52,8 @@ from .const import (
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
@@ -64,9 +66,18 @@ from .const import (
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_OPTIONS,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
||||
DOMAIN,
|
||||
GBNF_GRAMMAR_FILE,
|
||||
PROMPT_TEMPLATE_DESCRIPTIONS,
|
||||
@@ -76,9 +87,6 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
def is_local_backend(backend):
|
||||
return backend not in [BACKEND_TYPE_TEXT_GEN_WEBUI, BACKEND_TYPE_GENERIC_OPENAI]
|
||||
|
||||
async def update_listener(hass, entry):
|
||||
"""Handle options update."""
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
@@ -87,20 +95,31 @@ async def update_listener(hass, entry):
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up Local LLaMA Conversation from a config entry."""
|
||||
|
||||
use_local_backend = is_local_backend(
|
||||
entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
)
|
||||
# TODO: figure out how to make this happen as part of the config flow. when I tried it errored out passing options in
|
||||
if len(entry.options) == 0:
|
||||
entry.options = { **DEFAULT_OPTIONS }
|
||||
copy_to_options = [ CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET ]
|
||||
for item in copy_to_options:
|
||||
value = entry.data.get(item)
|
||||
if value:
|
||||
entry.options[item] = value
|
||||
|
||||
if use_local_backend:
|
||||
_LOGGER.info(
|
||||
"Using model file '%s'", entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
|
||||
)
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
def create_agent():
|
||||
return LLaMAAgent(hass, entry)
|
||||
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
|
||||
agent_cls = LocalLLaMAAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
agent_cls = TextGenerationWebuiAgent
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
agent_cls = LlamaCppPythonAPIAgent
|
||||
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# load the model in an executor job because it takes a while and locks up the UI otherwise
|
||||
agent = await hass.async_add_executor_job(create_agent)
|
||||
agent = await hass.async_add_executor_job(create_agent, entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE))
|
||||
|
||||
# handle updates to the options
|
||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
@@ -151,35 +170,21 @@ def closest_color(requested_color):
|
||||
class LLaMAAgent(AbstractConversationAgent):
|
||||
"""Local LLaMA conversation agent."""
|
||||
|
||||
hass: Any
|
||||
entry_id: str
|
||||
history: dict[str, list[dict]]
|
||||
|
||||
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.hass = hass
|
||||
self.entry_id = entry.entry_id
|
||||
self.history: dict[str, list[dict]] = {}
|
||||
self.history = {}
|
||||
|
||||
self.backend_type = entry.data.get(
|
||||
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
|
||||
)
|
||||
self.use_local_backend = is_local_backend(self.backend_type)
|
||||
|
||||
self.api_host = None
|
||||
self.llm = None
|
||||
self.grammar = None
|
||||
|
||||
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL, self.model_path)
|
||||
|
||||
if self.use_local_backend:
|
||||
self._load_local_model()
|
||||
else:
|
||||
host = entry.data[CONF_HOST]
|
||||
port = entry.data[CONF_PORT]
|
||||
self.api_host = f"http://{host}:{port}"
|
||||
|
||||
# only load model if using text-generation-webui
|
||||
if self.backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
api_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, entry.data.get(CONF_OPENAI_API_KEY))
|
||||
self._load_remote_model(api_key)
|
||||
self._load_model(entry)
|
||||
|
||||
@property
|
||||
def entry(self):
|
||||
@@ -189,6 +194,17 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
"""Return a list of supported languages."""
|
||||
return MATCH_ALL
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _generate(self, conversation: dict) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._generate, conversation
|
||||
)
|
||||
|
||||
async def async_process(
|
||||
self, user_input: ConversationInput
|
||||
@@ -222,7 +238,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
if len(conversation) == 0 or refresh_system_prompt:
|
||||
try:
|
||||
message = self._async_generate_system_prompt(raw_prompt)
|
||||
message = self._generate_system_prompt(raw_prompt)
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
@@ -243,13 +259,9 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
conversation.append({"role": "user", "message": user_input.text})
|
||||
|
||||
# _LOGGER.debug("Prompt: %s", prompt)
|
||||
|
||||
try:
|
||||
prompt = await self._async_format_prompt(conversation)
|
||||
|
||||
_LOGGER.debug(prompt)
|
||||
response = await self._async_generate(prompt)
|
||||
_LOGGER.debug(conversation)
|
||||
response = await self._async_generate(conversation)
|
||||
_LOGGER.debug(response)
|
||||
|
||||
except Exception as err:
|
||||
@@ -316,156 +328,6 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def _load_remote_model(self, admin_key: str | None):
|
||||
try:
|
||||
currently_loaded_result = requests.get(f"{self.api_host}/v1/internal/model/info")
|
||||
currently_loaded_result.raise_for_status()
|
||||
|
||||
loaded_model = currently_loaded_result.json()["model_name"]
|
||||
if loaded_model == self.model_name:
|
||||
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
|
||||
else:
|
||||
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
|
||||
|
||||
headers = {}
|
||||
if admin_key:
|
||||
headers["Authorization"] = f"Basic {admin_key}"
|
||||
|
||||
load_result = requests.post(
|
||||
f"{self.api_host}/v1/internal/model/load",
|
||||
json={
|
||||
"model_name": self.model_name,
|
||||
# TODO: expose arguments to the user in home assistant UI
|
||||
# "args": {},
|
||||
}
|
||||
)
|
||||
load_result.raise_for_status()
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.error("Connection error was: %s", repr(ex))
|
||||
|
||||
def _generate_remote(self, prompt: str) -> str:
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
|
||||
request_params = {
|
||||
"prompt": prompt,
|
||||
"model": self.model_name,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
headers = {}
|
||||
api_key = self.entry.data.get(CONF_OPENAI_API_KEY)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Basic {api_key}"
|
||||
|
||||
if self.backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
|
||||
if preset:
|
||||
request_params["preset"] = preset
|
||||
|
||||
result = requests.post(
|
||||
f"{self.api_host}/v1/completions",
|
||||
json=request_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
result.raise_for_status()
|
||||
except requests.RequestException as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {result.text}")
|
||||
return f"Failed to communicate with the API! {err}"
|
||||
|
||||
choices = result.json()["choices"]
|
||||
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
return choices[0]["text"]
|
||||
|
||||
def _load_local_model(self):
|
||||
if not self.model_path:
|
||||
raise Exception(f"Model was not found at '{self.model_path}'!")
|
||||
|
||||
# don't import it until now because the wheel is installed by config_flow.py
|
||||
module = importlib.import_module("llama_cpp")
|
||||
Llama = getattr(module, "Llama")
|
||||
LlamaGrammar = getattr(module, "LlamaGrammar")
|
||||
|
||||
_LOGGER.debug("Loading model...")
|
||||
self.llm = Llama(
|
||||
model_path=self.model_path,
|
||||
n_ctx=2048,
|
||||
n_batch=2048,
|
||||
# TODO: expose arguments to the user in home assistant UI
|
||||
# n_threads=16,
|
||||
# n_threads_batch=4,
|
||||
)
|
||||
|
||||
_LOGGER.debug("Loading grammar...")
|
||||
try:
|
||||
# TODO: make grammar configurable
|
||||
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
|
||||
grammar_str = "".join(f.readlines())
|
||||
self.grammar = LlamaGrammar.from_string(grammar_str)
|
||||
_LOGGER.debug("Loaded grammar")
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to load grammar!")
|
||||
self.grammar = None
|
||||
|
||||
|
||||
def _generate_local(self, prompt: str) -> str:
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
|
||||
|
||||
_LOGGER.debug(f"Options: {self.entry.options}")
|
||||
|
||||
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
output_tokens = self.llm.generate(
|
||||
input_tokens,
|
||||
temp=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
grammar=grammar
|
||||
)
|
||||
|
||||
result_tokens = []
|
||||
for token in output_tokens:
|
||||
if token == self.llm.token_eos():
|
||||
break
|
||||
|
||||
result_tokens.append(token)
|
||||
|
||||
if len(result_tokens) >= max_tokens:
|
||||
break
|
||||
|
||||
result = self.llm.detokenize(result_tokens).decode()
|
||||
|
||||
return result
|
||||
|
||||
async def _async_generate(self, prompt: str) -> str:
|
||||
if self.use_local_backend:
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._generate_local, prompt
|
||||
)
|
||||
else:
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._generate_remote, prompt
|
||||
)
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
"""Gather exposed entity states"""
|
||||
entity_states = {}
|
||||
@@ -479,11 +341,11 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
entity_states[state.entity_id] = attributes
|
||||
domains.add(state.domain)
|
||||
|
||||
# _LOGGER.debug(f"Exposed entities: {entity_states}")
|
||||
_LOGGER.debug(f"Exposed entities: {entity_states}")
|
||||
|
||||
return entity_states, list(domains)
|
||||
|
||||
async def _async_format_prompt(
|
||||
def _format_prompt(
|
||||
self, prompt: list[dict], include_generation_prompt: bool = True
|
||||
) -> str:
|
||||
formatted_prompt = ""
|
||||
@@ -502,9 +364,10 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
if include_generation_prompt:
|
||||
formatted_prompt = formatted_prompt + template_desc["generation_prompt"]
|
||||
|
||||
# _LOGGER.debug(formatted_prompt)
|
||||
return formatted_prompt
|
||||
|
||||
def _async_generate_system_prompt(self, prompt_template: str) -> str:
|
||||
def _generate_system_prompt(self, prompt_template: str) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
entities_to_expose, domains = self._async_get_exposed_entities()
|
||||
|
||||
@@ -555,3 +418,269 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
class LocalLLaMAAgent(LLaMAAgent):
|
||||
model_path: str
|
||||
llm: Any
|
||||
grammar: Any
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
|
||||
|
||||
_LOGGER.info(
|
||||
"Using model file '%s'", self.model_path
|
||||
)
|
||||
|
||||
if not self.model_path:
|
||||
raise Exception(f"Model was not found at '{self.model_path}'!")
|
||||
|
||||
# don't import it until now because the wheel is installed by config_flow.py
|
||||
module = importlib.import_module("llama_cpp")
|
||||
Llama = getattr(module, "Llama")
|
||||
LlamaGrammar = getattr(module, "LlamaGrammar")
|
||||
|
||||
_LOGGER.debug("Loading model...")
|
||||
self.llm = Llama(
|
||||
model_path=self.model_path,
|
||||
n_ctx=2048,
|
||||
n_batch=2048,
|
||||
# TODO: expose arguments to the user in home assistant UI
|
||||
# n_threads=16,
|
||||
# n_threads_batch=4,
|
||||
)
|
||||
|
||||
_LOGGER.debug("Loading grammar...")
|
||||
try:
|
||||
# TODO: make grammar configurable
|
||||
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
|
||||
grammar_str = "".join(f.readlines())
|
||||
self.grammar = LlamaGrammar.from_string(grammar_str)
|
||||
_LOGGER.debug("Loaded grammar")
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to load grammar!")
|
||||
self.grammar = None
|
||||
|
||||
def _generate(self, conversation: dict) -> str:
|
||||
prompt = self._format_prompt(conversation)
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
|
||||
|
||||
_LOGGER.debug(f"Options: {self.entry.options}")
|
||||
|
||||
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
output_tokens = self.llm.generate(
|
||||
input_tokens,
|
||||
temp=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
grammar=grammar
|
||||
)
|
||||
|
||||
result_tokens = []
|
||||
for token in output_tokens:
|
||||
if token == self.llm.token_eos():
|
||||
break
|
||||
|
||||
result_tokens.append(token)
|
||||
|
||||
if len(result_tokens) >= max_tokens:
|
||||
break
|
||||
|
||||
result = self.llm.detokenize(result_tokens).decode()
|
||||
|
||||
return result
|
||||
|
||||
class GenericOpenAIAPIAgent(LLaMAAgent):
|
||||
api_host: str
|
||||
api_key: str
|
||||
model_name: str
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
# TODO: https
|
||||
self.api_host = f"http://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
||||
|
||||
|
||||
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
|
||||
endpoint = "/v1/chat/completions"
|
||||
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
|
||||
endpoint = "/v1/completions"
|
||||
request_params["prompt"] = self._format_prompt(conversation)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
if response_json["object"] == "chat.completion":
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
def _generate(self, conversation: dict) -> str:
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
result = requests.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
result.raise_for_status()
|
||||
except requests.RequestException as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {result.text}")
|
||||
return f"Failed to communicate with the API! {err}"
|
||||
|
||||
_LOGGER.debug(result.json())
|
||||
|
||||
return self._extract_response(result.json())
|
||||
|
||||
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
admin_key: str
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
super()._load_model(entry)
|
||||
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
|
||||
|
||||
try:
|
||||
currently_loaded_result = requests.get(f"{self.api_host}/v1/internal/model/info")
|
||||
currently_loaded_result.raise_for_status()
|
||||
|
||||
loaded_model = currently_loaded_result.json()["model_name"]
|
||||
if loaded_model == self.model_name:
|
||||
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
|
||||
return
|
||||
else:
|
||||
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
|
||||
|
||||
headers = {}
|
||||
if self.admin_key:
|
||||
headers["Authorization"] = f"Bearer {self.admin_key}"
|
||||
|
||||
load_result = requests.post(
|
||||
f"{self.api_host}/v1/internal/model/load",
|
||||
json={
|
||||
"model_name": self.model_name,
|
||||
# TODO: expose arguments to the user in home assistant UI
|
||||
# "args": {},
|
||||
}
|
||||
)
|
||||
load_result.raise_for_status()
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.debug("Connection error was: %s", repr(ex))
|
||||
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
|
||||
|
||||
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
||||
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
|
||||
chat_mode = self.entry.options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE, DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE)
|
||||
|
||||
endpoint, request_params = super()._chat_completion_params(conversation)
|
||||
|
||||
request_params["mode"] = chat_mode
|
||||
if chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT or chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT:
|
||||
if preset:
|
||||
request_params["character"] = preset
|
||||
elif chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT:
|
||||
# TODO: handle uppercase properly?
|
||||
request_params["instruction_template"] = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
|
||||
|
||||
endpoint, request_params = super()._completion_params(conversation)
|
||||
|
||||
if preset:
|
||||
request_params["preset"] = preset
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
|
||||
if response_json["object"] == "chat.completions":
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
|
||||
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
|
||||
grammar: str
|
||||
|
||||
def _load_model(self, entry: ConfigEntry):
|
||||
super()._load_model(entry)
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
|
||||
self.grammar = "".join(f.readlines())
|
||||
|
||||
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
endpoint, request_params = super()._chat_completion_params(conversation)
|
||||
|
||||
request_params["top_k"] = top_k
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
request_params["grammar"] = self.grammar
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
endpoint, request_params = super()._completion_params(conversation)
|
||||
|
||||
request_params["top_k"] = top_k
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
request_params["grammar"] = self.grammar
|
||||
|
||||
return endpoint, request_params
|
||||
@@ -57,6 +57,8 @@ from .const import (
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_HOST,
|
||||
DEFAULT_PORT,
|
||||
@@ -74,22 +76,29 @@ from .const import (
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_OPTIONS,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
PROMPT_TEMPLATE_CHATML,
|
||||
PROMPT_TEMPLATE_ALPACA,
|
||||
PROMPT_TEMPLATE_VICUNA,
|
||||
PROMPT_TEMPLATE_MISTRAL,
|
||||
PROMPT_TEMPLATE_NONE,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
def is_local_backend(backend):
|
||||
return backend not in [BACKEND_TYPE_TEXT_GEN_WEBUI, BACKEND_TYPE_GENERIC_OPENAI]
|
||||
return backend in [BACKEND_TYPE_LLAMA_EXISTING, BACKEND_TYPE_LLAMA_HF]
|
||||
|
||||
def STEP_INIT_DATA_SCHEMA(backend_type=None):
|
||||
return vol.Schema(
|
||||
@@ -98,10 +107,16 @@ def STEP_INIT_DATA_SCHEMA(backend_type=None):
|
||||
CONF_BACKEND_TYPE,
|
||||
default=backend_type if backend_type else DEFAULT_BACKEND_TYPE
|
||||
): SelectSelector(SelectSelectorConfig(
|
||||
options=[ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING, BACKEND_TYPE_TEXT_GEN_WEBUI, BACKEND_TYPE_GENERIC_OPENAI ],
|
||||
options=[
|
||||
BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
# BACKEND_TYPE_OLLAMA
|
||||
],
|
||||
translation_key=CONF_BACKEND_TYPE,
|
||||
multiple=False,
|
||||
mode=SelectSelectorMode.LIST,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
))
|
||||
}
|
||||
)
|
||||
@@ -121,19 +136,33 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
|
||||
}
|
||||
)
|
||||
|
||||
def STEP_REMOTE_SETUP_DATA_SCHEMA(include_admin_key: bool, *, host=None, port=None, chat_model=None):
|
||||
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset=None, webui_chat_mode=None):
|
||||
|
||||
extra = {}
|
||||
if include_admin_key:
|
||||
extra[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
|
||||
extra1, extra2 = ({}, {})
|
||||
default_port = DEFAULT_PORT
|
||||
|
||||
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_PRESET, default=webui_preset)] = str
|
||||
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_CHAT_MODE, default=webui_chat_mode)] = SelectSelector(SelectSelectorConfig(
|
||||
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
|
||||
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
multiple=False,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
))
|
||||
extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
|
||||
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
default_port = "8000"
|
||||
|
||||
return vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_HOST, default=host if host else DEFAULT_HOST): str,
|
||||
vol.Required(CONF_PORT, default=port if port else DEFAULT_PORT): str,
|
||||
vol.Required(CONF_PORT, default=port if port else default_port): str,
|
||||
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): str,
|
||||
vol.Required(CONF_REMOTE_USE_CHAT_ENDPOINT, default=use_chat_endpoint if use_chat_endpoint else DEFAULT_REMOTE_USE_CHAT_ENDPOINT): bool,
|
||||
**extra1,
|
||||
vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type="password")),
|
||||
**extra
|
||||
**extra2
|
||||
}
|
||||
)
|
||||
|
||||
@@ -242,7 +271,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
install_wheel_error = None
|
||||
download_task = None
|
||||
download_error = None
|
||||
model_options: dict[str, Any] = {}
|
||||
model_config: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def flow_manager(self) -> config_entries.ConfigEntriesFlowManager:
|
||||
@@ -281,7 +310,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
if user_input:
|
||||
try:
|
||||
local_backend = is_local_backend(user_input[CONF_BACKEND_TYPE])
|
||||
self.model_options.update(user_input)
|
||||
self.model_config.update(user_input)
|
||||
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
@@ -353,7 +382,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
) -> FlowResult:
|
||||
errors = {}
|
||||
|
||||
backend_type = self.model_options[CONF_BACKEND_TYPE]
|
||||
backend_type = self.model_config[CONF_BACKEND_TYPE]
|
||||
if backend_type == BACKEND_TYPE_LLAMA_HF:
|
||||
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA()
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_EXISTING:
|
||||
@@ -364,13 +393,13 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
if self.download_error:
|
||||
errors["base"] = "download_failed"
|
||||
schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
|
||||
chat_model=self.model_options[CONF_CHAT_MODEL],
|
||||
downloaded_model_quantization=self.model_options[CONF_DOWNLOADED_MODEL_QUANTIZATION]
|
||||
chat_model=self.model_config[CONF_CHAT_MODEL],
|
||||
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
|
||||
)
|
||||
|
||||
if user_input:
|
||||
try:
|
||||
self.model_options.update(user_input)
|
||||
self.model_config.update(user_input)
|
||||
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
@@ -379,7 +408,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
if backend_type == BACKEND_TYPE_LLAMA_HF:
|
||||
return await self.async_step_download()
|
||||
else:
|
||||
model_file = self.model_options[CONF_DOWNLOADED_MODEL_FILE]
|
||||
model_file = self.model_config[CONF_DOWNLOADED_MODEL_FILE]
|
||||
if os.path.exists(model_file):
|
||||
return await self.async_step_finish()
|
||||
else:
|
||||
@@ -401,8 +430,8 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
progress_action="download",
|
||||
)
|
||||
|
||||
model_name = self.model_options[CONF_CHAT_MODEL]
|
||||
quantization_type = self.model_options[CONF_DOWNLOADED_MODEL_QUANTIZATION]
|
||||
model_name = self.model_config[CONF_CHAT_MODEL]
|
||||
quantization_type = self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
|
||||
|
||||
storage_folder = os.path.join(self.hass.config.media_dirs["local"], "models")
|
||||
self.download_task = self.hass.async_add_executor_job(
|
||||
@@ -422,7 +451,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
self.download_error = download_result
|
||||
return self.async_show_progress_done(next_step_id="local_model")
|
||||
else:
|
||||
self.model_options[CONF_DOWNLOADED_MODEL_FILE] = download_result
|
||||
self.model_config[CONF_DOWNLOADED_MODEL_FILE] = download_result
|
||||
return self.async_show_progress_done(next_step_id="finish")
|
||||
|
||||
|
||||
@@ -431,10 +460,10 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
headers = {}
|
||||
api_key = user_input.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, user_input.get(CONF_OPENAI_API_KEY))
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Basic {api_key}"
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
models_result = requests.get(
|
||||
f"http://{self.model_options[CONF_HOST]}:{self.model_options[CONF_PORT]}/v1/internal/model/list",
|
||||
f"http://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
|
||||
headers=headers
|
||||
)
|
||||
models_result.raise_for_status()
|
||||
@@ -442,7 +471,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
models = models_result.json()
|
||||
|
||||
for model in models["model_names"]:
|
||||
if model == self.model_options[CONF_CHAT_MODEL].replace("/", "_"):
|
||||
if model == self.model_config[CONF_CHAT_MODEL].replace("/", "_"):
|
||||
return ""
|
||||
|
||||
return "missing_model_api"
|
||||
@@ -455,12 +484,12 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
errors = {}
|
||||
backend_type = self.model_options[CONF_BACKEND_TYPE]
|
||||
backend_type = self.model_config[CONF_BACKEND_TYPE]
|
||||
schema = STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI)
|
||||
|
||||
if user_input:
|
||||
try:
|
||||
self.model_options.update(user_input)
|
||||
self.model_config.update(user_input)
|
||||
|
||||
# only validate and load when using text-generation-webui
|
||||
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
@@ -474,6 +503,9 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
host=user_input[CONF_HOST],
|
||||
port=user_input[CONF_PORT],
|
||||
chat_model=user_input[CONF_CHAT_MODEL],
|
||||
use_chat_endpoint=user_input[CONF_REMOTE_USE_CHAT_ENDPOINT],
|
||||
webui_preset=user_input.get(CONF_TEXT_GEN_WEBUI_PRESET),
|
||||
webui_chat_mode=user_input.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE),
|
||||
)
|
||||
else:
|
||||
return await self.async_step_finish()
|
||||
@@ -492,16 +524,16 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
|
||||
model_name = self.model_options.get(CONF_CHAT_MODEL)
|
||||
backend = self.model_options[CONF_BACKEND_TYPE]
|
||||
model_name = self.model_config.get(CONF_CHAT_MODEL)
|
||||
backend = self.model_config[CONF_BACKEND_TYPE]
|
||||
if backend == BACKEND_TYPE_LLAMA_EXISTING:
|
||||
model_name = os.path.basename(self.model_options.get(CONF_DOWNLOADED_MODEL_FILE))
|
||||
model_name = os.path.basename(self.model_config.get(CONF_DOWNLOADED_MODEL_FILE))
|
||||
location = "llama.cpp" if is_local_backend(backend) else "remote"
|
||||
|
||||
return self.async_create_entry(
|
||||
title=f"LLM Model '{model_name}' ({location})",
|
||||
description="A Large Language Model Chat Agent",
|
||||
data=self.model_options,
|
||||
data=self.model_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -583,6 +615,11 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
description={"suggested_value": options.get(CONF_SERVICE_CALL_REGEX)},
|
||||
default=DEFAULT_SERVICE_CALL_REGEX,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT)},
|
||||
default=DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
): bool,
|
||||
}
|
||||
|
||||
if is_local_backend(backend_type):
|
||||
@@ -606,12 +643,7 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
|
||||
default=DEFAULT_USE_GBNF_GRAMMAR,
|
||||
): bool,
|
||||
vol.Required(
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT)},
|
||||
default=DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
): bool,
|
||||
): bool
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
result = insert_after_key(result, CONF_MAX_TOKENS, {
|
||||
@@ -620,13 +652,66 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
||||
default=DEFAULT_REQUEST_TIMEOUT,
|
||||
): int,
|
||||
vol.Required(
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
|
||||
default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
): bool,
|
||||
vol.Optional(
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_PRESET)},
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE)},
|
||||
default=DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
): SelectSelector(SelectSelectorConfig(
|
||||
options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
|
||||
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
multiple=False,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)),
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
result = insert_after_key(result, CONF_MAX_TOKENS, {
|
||||
vol.Required(
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
||||
default=DEFAULT_REQUEST_TIMEOUT,
|
||||
): int,
|
||||
vol.Required(
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
|
||||
default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
): bool,
|
||||
vol.Required(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=DEFAULT_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Required(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||
default=DEFAULT_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
result = insert_after_key(result, CONF_MAX_TOKENS, {
|
||||
vol.Required(
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
||||
default=DEFAULT_REQUEST_TIMEOUT,
|
||||
): int,
|
||||
vol.Required(
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
|
||||
default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
): bool,
|
||||
vol.Required(
|
||||
CONF_TOP_K,
|
||||
description={"suggested_value": options.get(CONF_TOP_K)},
|
||||
default=DEFAULT_TOP_K,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
|
||||
vol.Required(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
@@ -638,10 +723,10 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
default=DEFAULT_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Required(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||
default=DEFAULT_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
|
||||
default=DEFAULT_USE_GBNF_GRAMMAR,
|
||||
): bool
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@@ -24,6 +24,8 @@ BACKEND_TYPE_LLAMA_HF = "llama_cpp_hf"
|
||||
BACKEND_TYPE_LLAMA_EXISTING = "llama_cpp_existing"
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI = "text-generation-webui_api"
|
||||
BACKEND_TYPE_GENERIC_OPENAI = "generic_openai"
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER = "llama_cpp_python_server"
|
||||
BACKEND_TYPE_OLLAMA = "ollama"
|
||||
DEFAULT_BACKEND_TYPE = BACKEND_TYPE_LLAMA_HF
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization"
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = ["Q8_0", "Q5_K_M", "Q4_K_M", "Q3_K_M"]
|
||||
@@ -84,6 +86,13 @@ CONF_REFRESH_SYSTEM_PROMPT = "refresh_prompt_per_tern"
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT = True
|
||||
CONF_SERVICE_CALL_REGEX = "service_call_regex"
|
||||
DEFAULT_SERVICE_CALL_REGEX = r"```homeassistant\n([\S \t\n]*?)```"
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT = "remote_use_chat_endpoint"
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT = False
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE = "text_generation_webui_chat_mode"
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT = "chat"
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT = "instruct"
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT = "chat-instruct"
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE = TEXT_GEN_WEBUI_CHAT_MODE_CHAT
|
||||
|
||||
DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
{
|
||||
@@ -97,6 +106,8 @@ DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX
|
||||
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
}
|
||||
)
|
||||
@@ -30,7 +30,10 @@
|
||||
"huggingface_model": "Model Name",
|
||||
"port": "API Port",
|
||||
"openai_api_key": "API Key",
|
||||
"text_generation_webui_admin_key": "Admin Key"
|
||||
"text_generation_webui_admin_key": "Admin Key",
|
||||
"text_generation_webui_preset": "Generation Preset/Character Name",
|
||||
"remote_use_chat_endpoint": "Use chat completions endpoint",
|
||||
"text_generation_webui_chat_mode": "Chat Mode"
|
||||
},
|
||||
"description": "Provide the connection details for an instance of text-generation-webui that is hosting the model.",
|
||||
"title": "Configure connection to remote API"
|
||||
@@ -40,7 +43,7 @@
|
||||
"download_model_from_hf": "Download model from HuggingFace",
|
||||
"use_local_backend": "Use Llama.cpp"
|
||||
},
|
||||
"description": "Select the backend for running the model. Either Llama.cpp (locally) or text-generation-webui (remote).",
|
||||
"description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n\nIf using Llama.cpp locally, make sure you copied the correct wheel file to the same directory as the integration.",
|
||||
"title": "Select Backend"
|
||||
}
|
||||
}
|
||||
@@ -62,7 +65,9 @@
|
||||
"text_generation_webui_admin_key": "Admin Key",
|
||||
"service_call_regex": "Service Call Regex",
|
||||
"refresh_prompt_per_tern": "Refresh System Prompt Every Turn",
|
||||
"text_generation_webui_preset": "Generation Preset Name"
|
||||
"text_generation_webui_preset": "Generation Preset/Character Name",
|
||||
"remote_use_chat_endpoint": "Use chat completions endpoint",
|
||||
"text_generation_webui_chat_mode": "Chat Mode"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,6 +78,7 @@
|
||||
"chatml": "ChatML",
|
||||
"vicuna": "Vicuna",
|
||||
"alpaca": "Alpaca",
|
||||
"mistral": "Mistral",
|
||||
"no_prompt_template": "None"
|
||||
}
|
||||
},
|
||||
@@ -81,7 +87,17 @@
|
||||
"llama_cpp_hf": "Llama.cpp (HuggingFace)",
|
||||
"llama_cpp_existing": "Llama.cpp (existing model)",
|
||||
"text-generation-webui_api": "text-generation-webui API",
|
||||
"generic_openai": "Generic OpenAI Compatible API"
|
||||
"generic_openai": "Generic OpenAI Compatible API",
|
||||
"llama_cpp_python_server": "llama-cpp-python Server",
|
||||
"ollama": "Ollama"
|
||||
|
||||
}
|
||||
},
|
||||
"text_generation_webui_chat_mode": {
|
||||
"options": {
|
||||
"chat": "Chat",
|
||||
"instruct": "Instruct",
|
||||
"chat-instruct": "Chat-Instruct"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
39
docs/Backend Configuration.md
Normal file
39
docs/Backend Configuration.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Backend Configuration
|
||||
|
||||
There are multiple backends to choose for running the model that the Home Assistant integration uses. Here is a description of all the options for each backend
|
||||
|
||||
# Common Options
|
||||
| Option Name | Description | Suggested Value |
|
||||
| ------------ | --------- | ------------ |
|
||||
| System Prompt | [see here](./Model%20Prompting.md) | |
|
||||
| Prompt Format | The format for the context of the model | |
|
||||
| Maximum tokens to return in response | Limits the number of tokens that can be produced by each model response | 512 |
|
||||
| Additional attribute to expose in the context | Extra attributes that will be exposed to the model via the `{{ devices }}` template variable | |
|
||||
| Service Call Regex | The regular expression used to extract service calls from the model response; should contain 1 repeated capture group | |
|
||||
| Refresh System Prompt Every Turn | Flag to update the system prompt with updated device states on every chat turn. Disabling can significantly improve agent response times when using a backend that supports prefix caching (Llama.cpp) | Enabled |
|
||||
|
||||
# Llama.cpp
|
||||
For details about the sampling parameters, see here: https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#parameters-description
|
||||
| Option Name | Description | Suggested Value |
|
||||
| ------------ | --------- | ------------ |
|
||||
| Top K | Sampling parameter; see above link | 40 |
|
||||
| Top P | Sampling parameter; see above link | 1.0 |
|
||||
| Temperature | Sampling parameter; see above link | 0.1 |
|
||||
| Enable GBNF Grammar | Restricts the output of the model to follow a pre-defined syntax; eliminates function calling syntax errors on quantized models | Enabled |
|
||||
|
||||
# text-generation-webui
|
||||
| Option Name | Description | Suggested Value |
|
||||
| ------------ | --------- | ------------ |
|
||||
| Request Timeout | The maximum time in seconds that the integration will wait for a response from the remote server | 90 (higher if running on low resource hardware) |
|
||||
| Use chat completions endpoint | Flag to use `/v1/chat/completions` as the remote endpoint instead of `/v1/completions` | |
|
||||
| Generation Preset/Character Name | The preset or character name to pass to the backend. If none is provided then the settings that are currently selected in the UI will be applied | |
|
||||
| Chat Mode | [see here](https://github.com/oobabooga/text-generation-webui/wiki/01-%E2%80%90-Chat-Tab#mode) | Instruct |
|
||||
|
||||
# Generic OpenAI API Compatible
|
||||
For details about the sampling parameters, see here: https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#parameters-description
|
||||
| Option Name | Description | Suggested Value |
|
||||
| ------------ | --------- | ------------ |
|
||||
| Request Timeout | The maximum time in seconds that the integration will wait for a response from the remote server | 90 (higher if running on low resource hardware) |
|
||||
| Use chat completions endpoint | Flag to use `/v1/chat/completions` as the remote endpoint instead of `/v1/completions` | Backend Dependent |
|
||||
| Top P | Sampling parameter; see above link | 1.0 |
|
||||
| Temperature | Sampling parameter; see above link | 0.1 |
|
||||
Reference in New Issue
Block a user