mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 06:07:58 -05:00
88
.github/workflows/create-release.yml
vendored
Normal file
88
.github/workflows/create-release.yml
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
name: Create Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_notes:
|
||||
description: "Release Notes"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build_wheels:
|
||||
name: Build wheels on ${{ matrix.arch }} (HA ${{ matrix.home_assistant_version }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
home_assistant_version: ["2023.12.4", "2024.2.1"]
|
||||
arch: ["aarch64", "armhf", "amd64", "i386"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Read llama-cpp-python version
|
||||
run: cat custom_components/llama_conversation/const.py | grep "EMBEDDED_LLAMA_CPP_PYTHON_VERSION" | tr -d ' ' | tr -d '"' >> $GITHUB_ENV
|
||||
|
||||
- name: Build artifact
|
||||
uses: uraimo/run-on-arch-action@v2
|
||||
id: build
|
||||
with:
|
||||
arch: none
|
||||
distro: none
|
||||
base_image: homeassistant/${{ matrix.arch }}-homeassistant:${{ matrix.home_assistant_version }}
|
||||
|
||||
# Create an artifacts directory
|
||||
setup: |
|
||||
mkdir -p "${PWD}/artifacts"
|
||||
|
||||
# Mount the artifacts directory as /artifacts in the container
|
||||
dockerRunArgs: |
|
||||
--volume "${PWD}/artifacts:/artifacts"
|
||||
|
||||
# The shell to run commands with in the container
|
||||
shell: /bin/bash
|
||||
|
||||
# Produce a binary artifact and place it in the mounted volume
|
||||
run: |
|
||||
apk update
|
||||
apk add build-base python3-dev cmake
|
||||
pip3 install build
|
||||
|
||||
cd /tmp
|
||||
git clone --quiet --recurse-submodules https://github.com/abetlen/llama-cpp-python --branch "v${{ env.EMBEDDED_LLAMA_CPP_PYTHON_VERSION }}"
|
||||
cd llama-cpp-python
|
||||
|
||||
export CMAKE_ARGS="-DLLAVA_BUILD=OFF"
|
||||
python3 -m build --wheel
|
||||
cp -f ./dist/*.whl /artifacts/
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: ./artifacts/*.whl
|
||||
name: artifact_${{ matrix.arch }}_${{ matrix.home_assistant_version }}
|
||||
|
||||
release:
|
||||
name: Create Release
|
||||
needs: [ build_wheels ]
|
||||
runs-on: ubuntu-latest
|
||||
if: "startsWith(github.event.ref, 'refs/tags/v')" # only create a release if this was run on a tag
|
||||
|
||||
steps:
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist
|
||||
merge-multiple: true
|
||||
|
||||
- name: Create GitHub release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: dist/*
|
||||
body: ${{ inputs.release_notes }}
|
||||
make_latest: true
|
||||
@@ -129,6 +129,7 @@ In order to facilitate running the project entirely on the system where Home Ass
|
||||
## Version History
|
||||
| Version | Description |
|
||||
| ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| v0.2.11 | Add prompt caching, expose llama.cpp runtime settings, build llama-cpp-python wheels using GitHub actions, and install wheels directly from GitHub |
|
||||
| v0.2.10 | Allow configuring the model parameters during initial setup, attempt to auto-detect defaults for recommended models, Fix to allow lights to be set to max brightness |
|
||||
| v0.2.9 | Fix HuggingFace Download, Fix llama.cpp wheel installation, Fix light color changing, Add in-context-learning support |
|
||||
| v0.2.8 | Fix ollama model names with colons |
|
||||
|
||||
10
TODO.md
10
TODO.md
@@ -1,15 +1,12 @@
|
||||
# TODO
|
||||
- [ ] setup github actions to build wheels that are optimized for RPIs??
|
||||
- [ ] setup github actions to publish docker images for text-gen-webui addon
|
||||
- [ ] detection/mitigation of too many entities being exposed & blowing out the context length
|
||||
- [ ] areas/room support
|
||||
- [ ] figure out DPO for refusals + fixing incorrect entity id
|
||||
- [ ] figure out DPO to improve response quality
|
||||
- [x] setup github actions to build wheels that are optimized for RPIs
|
||||
- [x] mixtral + prompting (no fine tuning)
|
||||
- add in context learning variables to sys prompt template
|
||||
- add new options to setup process for setting prompt style + picking fine-tuned/ICL
|
||||
- [ ] prime kv cache with current "state" so that requests are faster
|
||||
- [ ] support fine-tuning with RoPE for longer contexts
|
||||
- [ ] support config via yaml instead of configflow
|
||||
- [x] prime kv cache with current "state" so that requests are faster
|
||||
- [x] ChatML format (actually need to add special tokens)
|
||||
- [x] Vicuna dataset merge (yahma/alpaca-cleaned)
|
||||
- [x] Phi-2 fine tuning
|
||||
@@ -19,7 +16,6 @@
|
||||
- [x] Licenses + Attributions
|
||||
- [x] Finish Readme/docs for initial release
|
||||
- [x] Function calling as JSON
|
||||
- [ ] multi-turn prompts; better instruct dataset like dolphin/wizardlm?
|
||||
- [x] Fine tune Phi-1.5 version
|
||||
- [x] make llama-cpp-python wheels for "llama-cpp-python>=0.2.24"
|
||||
- [x] make a proper evaluation framework to run. not just loss. should test accuracy on the function calling
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# text-generation-webui - Home Assistant Addon
|
||||
NOTE: This is super experimental and may or may not work on a Raspberry Pi
|
||||
Installs text-generation-webui into a docker container using CPU only mode (llama.cpp)
|
||||
|
||||
Installs text-generation-webui into a docker container using CPU only mode (llama.cpp)
|
||||
NOTE: This addon is not the preferred way to run LLama.cpp as part of Home Assistant and will not be updated.
|
||||
@@ -2,8 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import importlib
|
||||
from typing import Literal, Any
|
||||
from typing import Literal, Any, Callable
|
||||
|
||||
import requests
|
||||
import re
|
||||
@@ -11,6 +12,7 @@ import os
|
||||
import json
|
||||
import csv
|
||||
import random
|
||||
import time
|
||||
|
||||
import homeassistant.components.conversation as ha_conversation
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent
|
||||
@@ -18,9 +20,10 @@ from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_D
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError
|
||||
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .utils import closest_color, flatten_vol_schema, install_llama_cpp_python
|
||||
@@ -38,17 +41,25 @@ from .const import (
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
@@ -60,14 +71,23 @@ from .const import (
|
||||
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_OPTIONS,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_THREAD_COUNT,
|
||||
DEFAULT_BATCH_THREAD_COUNT,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
@@ -78,8 +98,6 @@ from .const import (
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
||||
DOMAIN,
|
||||
GBNF_GRAMMAR_FILE,
|
||||
IN_CONTEXT_EXAMPLES_FILE,
|
||||
PROMPT_TEMPLATE_DESCRIPTIONS,
|
||||
)
|
||||
|
||||
@@ -87,19 +105,19 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
async def update_listener(hass, entry):
|
||||
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Handle options update."""
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
# call update handler
|
||||
agent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
agent: LLaMAAgent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
agent._update_options()
|
||||
|
||||
return True
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up Local LLaMA Conversation from a config entry."""
|
||||
|
||||
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
@@ -117,7 +135,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# load the model in an executor job because it takes a while and locks up the UI otherwise
|
||||
agent = await hass.async_add_executor_job(create_agent, entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE))
|
||||
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
agent = await hass.async_add_executor_job(create_agent, backend_type)
|
||||
|
||||
# handle updates to the options
|
||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
@@ -126,6 +145,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -175,13 +195,13 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
self.in_context_examples = None
|
||||
if entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
self._load_icl_examples()
|
||||
self._load_icl_examples(entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE))
|
||||
|
||||
self._load_model(entry)
|
||||
|
||||
def _load_icl_examples(self):
|
||||
def _load_icl_examples(self, filename: str):
|
||||
try:
|
||||
icl_filename = os.path.join(os.path.dirname(__file__), IN_CONTEXT_EXAMPLES_FILE)
|
||||
icl_filename = os.path.join(os.path.dirname(__file__), filename)
|
||||
|
||||
with open(icl_filename) as f:
|
||||
self.in_context_examples = list(csv.DictReader(f))
|
||||
@@ -196,13 +216,16 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
def _update_options(self):
|
||||
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
self._load_icl_examples()
|
||||
self._load_icl_examples(self.entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE))
|
||||
else:
|
||||
self.in_context_examples = None
|
||||
|
||||
@property
|
||||
def entry(self):
|
||||
return self.hass.data[DOMAIN][self.entry_id]
|
||||
def entry(self) -> ConfigEntry:
|
||||
try:
|
||||
return self.hass.data[DOMAIN][self.entry_id]
|
||||
except KeyError as ex:
|
||||
raise Exception("Attempted to use self.entry during startup.") from ex
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
@@ -230,7 +253,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
template_desc = PROMPT_TEMPLATE_DESCRIPTIONS[prompt_template]
|
||||
refresh_system_prompt = self.entry.options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
|
||||
remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
|
||||
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, False)
|
||||
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
|
||||
service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX)
|
||||
allowed_service_call_arguments = self.entry.options \
|
||||
.get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS)
|
||||
@@ -395,7 +418,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
entity_states[state.entity_id] = attributes
|
||||
domains.add(state.domain)
|
||||
|
||||
_LOGGER.debug(f"Exposed entities: {entity_states}")
|
||||
# _LOGGER.debug(f"Exposed entities: {entity_states}")
|
||||
|
||||
return entity_states, list(domains)
|
||||
|
||||
@@ -487,10 +510,11 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
result = result + ";" + str(value)
|
||||
return result
|
||||
|
||||
device_states = [f"{name} '{attributes.get('friendly_name')}' = {expose_attributes(attributes)}" for name, attributes in entities_to_expose.items()]
|
||||
device_states = []
|
||||
|
||||
# expose devices as their alias as well
|
||||
# expose devices and their alias as well
|
||||
for name, attributes in entities_to_expose.items():
|
||||
device_states.append(f"{name} '{attributes.get('friendly_name')}' = {expose_attributes(attributes)}")
|
||||
if "aliases" in attributes:
|
||||
for alias in attributes["aliases"]:
|
||||
device_states.append(f"{name} '{alias}' = {expose_attributes(attributes)}")
|
||||
@@ -527,6 +551,12 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
llm: Any
|
||||
grammar: Any
|
||||
llama_cpp_module: Any
|
||||
remove_prompt_caching_listener: Callable
|
||||
model_lock: threading.Lock
|
||||
last_cache_prime: float
|
||||
last_updated_entities: dict[str, float]
|
||||
cache_refresh_after_cooldown: bool
|
||||
loaded_model_settings: dict[str, Any]
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
|
||||
@@ -551,45 +581,241 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
|
||||
Llama = getattr(self.llama_cpp_module, "Llama")
|
||||
|
||||
_LOGGER.debug("Loading model...")
|
||||
_LOGGER.debug(f"Loading model '{self.model_path}'...")
|
||||
self.loaded_model_settings = {}
|
||||
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
self.loaded_model_settings[CONF_BATCH_SIZE] = entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
|
||||
self.loaded_model_settings[CONF_THREAD_COUNT] = entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
|
||||
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
|
||||
|
||||
self.llm = Llama(
|
||||
model_path=self.model_path,
|
||||
n_ctx=2048,
|
||||
n_batch=2048,
|
||||
# TODO: expose arguments to the user in home assistant UI
|
||||
# n_threads=16,
|
||||
# n_threads_batch=4,
|
||||
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
|
||||
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
|
||||
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
|
||||
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT])
|
||||
)
|
||||
_LOGGER.debug("Model loaded")
|
||||
|
||||
self.grammar = None
|
||||
if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
self._load_grammar()
|
||||
self._load_grammar(entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE))
|
||||
|
||||
|
||||
def _load_grammar(self):
|
||||
# TODO: check about disk caching
|
||||
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
|
||||
# capacity_bytes=(512 * 10e8),
|
||||
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
# ))
|
||||
|
||||
self.remove_prompt_caching_listener = None
|
||||
self.last_cache_prime = None
|
||||
self.last_updated_entities = {}
|
||||
self.cache_refresh_after_cooldown = False
|
||||
self.model_lock = threading.Lock()
|
||||
|
||||
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
|
||||
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
||||
@callback
|
||||
async def enable_caching_after_startup(_now) -> None:
|
||||
self._set_prompt_caching(enabled=True)
|
||||
await self._async_cache_prompt(None, None, None)
|
||||
async_call_later(self.hass, 5.0, enable_caching_after_startup)
|
||||
|
||||
def _load_grammar(self, filename: str):
|
||||
LlamaGrammar = getattr(self.llama_cpp_module, "LlamaGrammar")
|
||||
_LOGGER.debug("Loading grammar...")
|
||||
_LOGGER.debug(f"Loading grammar {filename}...")
|
||||
try:
|
||||
# TODO: make grammar configurable
|
||||
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
|
||||
with open(os.path.join(os.path.dirname(__file__), filename)) as f:
|
||||
grammar_str = "".join(f.readlines())
|
||||
self.grammar = LlamaGrammar.from_string(grammar_str)
|
||||
self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] = filename
|
||||
_LOGGER.debug("Loaded grammar")
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to load grammar!")
|
||||
self.grammar = None
|
||||
|
||||
def _update_options(self):
|
||||
LLaMAAgent._update_options()
|
||||
LLaMAAgent._update_options(self)
|
||||
|
||||
model_reloaded = False
|
||||
if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \
|
||||
self.loaded_model_settings[CONF_BATCH_SIZE] != self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) or \
|
||||
self.loaded_model_settings[CONF_THREAD_COUNT] != self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) or \
|
||||
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT):
|
||||
|
||||
_LOGGER.debug(f"Reloading model '{self.model_path}'...")
|
||||
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
self.loaded_model_settings[CONF_BATCH_SIZE] = self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
|
||||
self.loaded_model_settings[CONF_THREAD_COUNT] = self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
|
||||
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
|
||||
|
||||
Llama = getattr(self.llama_cpp_module, "Llama")
|
||||
self.llm = Llama(
|
||||
model_path=self.model_path,
|
||||
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
|
||||
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
|
||||
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
|
||||
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT])
|
||||
)
|
||||
_LOGGER.debug("Model loaded")
|
||||
model_reloaded = True
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
self._load_grammar()
|
||||
current_grammar = self.entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
if not self.grammar or self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] != current_grammar:
|
||||
self._load_grammar(current_grammar)
|
||||
else:
|
||||
self.grammar = None
|
||||
|
||||
if self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
||||
self._set_prompt_caching(enabled=True)
|
||||
|
||||
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] != self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED) or \
|
||||
model_reloaded:
|
||||
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
|
||||
|
||||
async def cache_current_prompt(_now):
|
||||
await self._async_cache_prompt(None, None, None)
|
||||
async_call_later(self.hass, 1.0, cache_current_prompt)
|
||||
else:
|
||||
self._set_prompt_caching(enabled=False)
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
"""Takes the super class function results and sorts the entities by most recently updated at the end"""
|
||||
entities, domains = LLaMAAgent._async_get_exposed_entities(self)
|
||||
|
||||
# ignore sorting if prompt caching is disabled
|
||||
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
||||
return entities, domains
|
||||
|
||||
entity_order = { name: None for name in entities.keys() }
|
||||
entity_order.update(self.last_updated_entities)
|
||||
|
||||
def sort_key(item):
|
||||
item_name, last_updated = item
|
||||
# Handle cases where last updated is None
|
||||
if last_updated is None:
|
||||
return (False, '', item_name)
|
||||
else:
|
||||
return (True, last_updated, '')
|
||||
|
||||
# Sort the items based on the sort_key function
|
||||
sorted_items = sorted(list(entity_order.items()), key=sort_key)
|
||||
|
||||
_LOGGER.debug(f"sorted_items: {sorted_items}")
|
||||
|
||||
sorted_entities = {}
|
||||
for item_name, _ in sorted_items:
|
||||
sorted_entities[item_name] = entities[item_name]
|
||||
|
||||
return sorted_entities, domains
|
||||
|
||||
def _set_prompt_caching(self, *, enabled=True):
|
||||
if enabled and not self.remove_prompt_caching_listener:
|
||||
_LOGGER.info("enabling prompt caching...")
|
||||
|
||||
entity_ids = [
|
||||
state.entity_id for state in self.hass.states.async_all() \
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
_LOGGER.debug(f"watching entities: {entity_ids}")
|
||||
|
||||
self.remove_prompt_caching_listener = async_track_state_change(self.hass, entity_ids, self._async_cache_prompt)
|
||||
|
||||
elif not enabled and self.remove_prompt_caching_listener:
|
||||
_LOGGER.info("disabling prompt caching...")
|
||||
self.remove_prompt_caching_listener()
|
||||
|
||||
@callback
|
||||
async def _async_cache_prompt(self, entity, old_state, new_state):
|
||||
refresh_start = time.time()
|
||||
|
||||
# track last update time so we can sort the context efficiently
|
||||
if entity:
|
||||
self.last_updated_entities[entity] = refresh_start
|
||||
|
||||
_LOGGER.debug(f"refreshing cached prompt because {entity} changed...")
|
||||
await self.hass.async_add_executor_job(self._cache_prompt)
|
||||
|
||||
refresh_end = time.time()
|
||||
_LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec")
|
||||
|
||||
def _cache_prompt(self) -> None:
|
||||
# if a refresh is already scheduled then exit
|
||||
if self.cache_refresh_after_cooldown:
|
||||
return
|
||||
|
||||
# if we are inside the cooldown period, request a refresh and exit
|
||||
current_time = time.time()
|
||||
fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval:
|
||||
self.cache_refresh_after_cooldown = True
|
||||
return
|
||||
|
||||
# try to acquire the lock, if we are still running for some reason, request a refresh and exit
|
||||
lock_acquired = self.model_lock.acquire(False)
|
||||
if not lock_acquired:
|
||||
self.cache_refresh_after_cooldown = True
|
||||
return
|
||||
|
||||
try:
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
prompt = self._format_prompt([
|
||||
{ "role": "system", "message": self._generate_system_prompt(raw_prompt)},
|
||||
{ "role": "user", "message": "" }
|
||||
], include_generation_prompt=False)
|
||||
|
||||
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
|
||||
|
||||
_LOGGER.debug(f"Options: {self.entry.options}")
|
||||
|
||||
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
|
||||
# grab just one token. should prime the kv cache with the system prompt
|
||||
next(self.llm.generate(
|
||||
input_tokens,
|
||||
temp=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
grammar=grammar
|
||||
))
|
||||
|
||||
self.last_cache_prime = time.time()
|
||||
finally:
|
||||
self.model_lock.release()
|
||||
|
||||
|
||||
# schedule a refresh using async_call_later
|
||||
# if the flag is set after the delay then we do another refresh
|
||||
|
||||
@callback
|
||||
async def refresh_if_requested(_now):
|
||||
if self.cache_refresh_after_cooldown:
|
||||
self.cache_refresh_after_cooldown = False
|
||||
|
||||
refresh_start = time.time()
|
||||
_LOGGER.debug(f"refreshing cached prompt after cooldown...")
|
||||
await self.hass.async_add_executor_job(self._cache_prompt)
|
||||
|
||||
refresh_end = time.time()
|
||||
_LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec")
|
||||
|
||||
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
|
||||
|
||||
|
||||
def _generate(self, conversation: dict) -> str:
|
||||
prompt = self._format_prompt(conversation)
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
@@ -598,26 +824,31 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
|
||||
_LOGGER.debug(f"Options: {self.entry.options}")
|
||||
|
||||
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
output_tokens = self.llm.generate(
|
||||
input_tokens,
|
||||
temp=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
grammar=self.grammar
|
||||
)
|
||||
with self.model_lock:
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
|
||||
result_tokens = []
|
||||
for token in output_tokens:
|
||||
if token == self.llm.token_eos():
|
||||
break
|
||||
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
output_tokens = self.llm.generate(
|
||||
input_tokens,
|
||||
temp=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
grammar=self.grammar
|
||||
)
|
||||
|
||||
result_tokens.append(token)
|
||||
result_tokens = []
|
||||
for token in output_tokens:
|
||||
if token == self.llm.token_eos():
|
||||
break
|
||||
|
||||
if len(result_tokens) >= max_tokens:
|
||||
break
|
||||
result_tokens.append(token)
|
||||
|
||||
result = self.llm.detokenize(result_tokens).decode()
|
||||
if len(result_tokens) >= max_tokens:
|
||||
break
|
||||
|
||||
result = self.llm.detokenize(result_tokens).decode()
|
||||
|
||||
return result
|
||||
|
||||
@@ -917,4 +1148,4 @@ class OllamaAPIAgent(LLaMAAgent):
|
||||
|
||||
_LOGGER.debug(result.json())
|
||||
|
||||
return self._extract_response(result.json())
|
||||
return self._extract_response(result.json())
|
||||
|
||||
@@ -47,19 +47,27 @@ from .const import (
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_PORT,
|
||||
DEFAULT_SSL,
|
||||
@@ -73,15 +81,24 @@ from .const import (
|
||||
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_THREAD_COUNT,
|
||||
DEFAULT_BATCH_THREAD_COUNT,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
@@ -301,7 +318,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
install_exception = self.install_wheel_task.exception()
|
||||
if install_exception:
|
||||
_LOGGER.warning("Failed to install wheel: %s", repr(install_exception))
|
||||
self.install_wheel_error = install_exception
|
||||
self.install_wheel_error = "pip_wheel_error"
|
||||
next_step = "pick_backend"
|
||||
else:
|
||||
wheel_install_result = self.install_wheel_task.result()
|
||||
@@ -346,7 +363,8 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
else:
|
||||
model_file = self.model_config[CONF_DOWNLOADED_MODEL_FILE]
|
||||
if os.path.exists(model_file):
|
||||
return await self.async_step_finish()
|
||||
self.model_config[CONF_CHAT_MODEL] = os.path.basename(model_file)
|
||||
return await self.async_step_model_parameters()
|
||||
else:
|
||||
errors["base"] = "missing_model_file"
|
||||
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file)
|
||||
@@ -570,9 +588,28 @@ class OptionsFlow(config_entries.OptionsFlow):
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Manage the options."""
|
||||
errors = {}
|
||||
description_placeholders = {}
|
||||
|
||||
if user_input is not None:
|
||||
# TODO: validate that files exist (GBNF + ICL examples)
|
||||
return self.async_create_entry(title="LLaMA Conversation", data=user_input)
|
||||
if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
|
||||
errors["base"] = "sys_refresh_caching_enabled"
|
||||
|
||||
if user_input.get(CONF_USE_GBNF_GRAMMAR):
|
||||
filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_gbnf_file"
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_icl_file"
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
if len(errors) == 0:
|
||||
return self.async_create_entry(title="LLaMA Conversation", data=user_input)
|
||||
|
||||
schema = local_llama_config_option_schema(
|
||||
self.config_entry.options,
|
||||
self.config_entry.data[CONF_BACKEND_TYPE],
|
||||
@@ -580,6 +617,8 @@ class OptionsFlow(config_entries.OptionsFlow):
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema),
|
||||
errors=errors,
|
||||
description_placeholders=description_placeholders,
|
||||
)
|
||||
|
||||
|
||||
@@ -616,6 +655,16 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
multiple=False,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)),
|
||||
vol.Required(
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
||||
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
): bool,
|
||||
vol.Required(
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
|
||||
default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
@@ -649,12 +698,8 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
vol.Optional(
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS)},
|
||||
default=DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
): int,
|
||||
vol.Required(
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
||||
default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
): bool,
|
||||
}
|
||||
|
||||
if is_local_backend(backend_type):
|
||||
@@ -674,11 +719,47 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=DEFAULT_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Required(
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
|
||||
default=DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
): bool,
|
||||
vol.Required(
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
|
||||
default=DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
|
||||
# TODO: add rope_scaling_type
|
||||
vol.Required(
|
||||
CONF_CONTEXT_LENGTH,
|
||||
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
|
||||
default=DEFAULT_CONTEXT_LENGTH,
|
||||
): NumberSelector(NumberSelectorConfig(min=512, max=32768, step=1)),
|
||||
vol.Required(
|
||||
CONF_BATCH_SIZE,
|
||||
description={"suggested_value": options.get(CONF_BATCH_SIZE)},
|
||||
default=DEFAULT_BATCH_SIZE,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
|
||||
vol.Required(
|
||||
CONF_THREAD_COUNT,
|
||||
description={"suggested_value": options.get(CONF_THREAD_COUNT)},
|
||||
default=DEFAULT_THREAD_COUNT,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=(os.cpu_count() * 2), step=1)),
|
||||
vol.Required(
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
description={"suggested_value": options.get(CONF_BATCH_THREAD_COUNT)},
|
||||
default=DEFAULT_BATCH_THREAD_COUNT,
|
||||
): NumberSelector(NumberSelectorConfig(min=1, max=(os.cpu_count() * 2), step=1)),
|
||||
vol.Required(
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
|
||||
default=DEFAULT_USE_GBNF_GRAMMAR,
|
||||
): bool
|
||||
): bool,
|
||||
vol.Required(
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
description={"suggested_value": options.get(CONF_GBNF_GRAMMAR_FILE)},
|
||||
default=DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
): str
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
result = insert_after_key(result, CONF_MAX_TOKENS, {
|
||||
@@ -761,7 +842,12 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
|
||||
default=DEFAULT_USE_GBNF_GRAMMAR,
|
||||
): bool
|
||||
): bool,
|
||||
vol.Required(
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
description={"suggested_value": options.get(CONF_GBNF_GRAMMAR_FILE)},
|
||||
default=DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
): str
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
result = insert_after_key(result, CONF_MAX_TOKENS, {
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
"""Constants for the LLaMa Conversation integration."""
|
||||
import types
|
||||
# import voluptuous as vol
|
||||
# import homeassistant.helpers.config_validation as cv
|
||||
# from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
import types, os
|
||||
|
||||
DOMAIN = "llama_conversation"
|
||||
CONF_PROMPT = "prompt"
|
||||
@@ -57,8 +54,6 @@ CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "item"]
|
||||
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS = "allowed_service_call_arguments"
|
||||
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration"]
|
||||
GBNF_GRAMMAR_FILE = "output.gbnf"
|
||||
IN_CONTEXT_EXAMPLES_FILE = "in_context_examples.csv"
|
||||
CONF_PROMPT_TEMPLATE = "prompt_template"
|
||||
PROMPT_TEMPLATE_CHATML = "chatml"
|
||||
PROMPT_TEMPLATE_ALPACA = "alpaca"
|
||||
@@ -107,16 +102,25 @@ PROMPT_TEMPLATE_DESCRIPTIONS = {
|
||||
}
|
||||
CONF_USE_GBNF_GRAMMAR = "gbnf_grammar"
|
||||
DEFAULT_USE_GBNF_GRAMMAR = False
|
||||
CONF_GBNF_GRAMMAR_FILE = "gbnf_grammar_file"
|
||||
DEFAULT_GBNF_GRAMMAR_FILE = "output.gbnf"
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES = "in_context_examples"
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES = True
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE = "in_context_examples_file"
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE = "in_context_examples.csv"
|
||||
CONF_TEXT_GEN_WEBUI_PRESET = "text_generation_webui_preset"
|
||||
CONF_OPENAI_API_KEY = "openai_api_key"
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY = "text_generation_webui_admin_key"
|
||||
CONF_REFRESH_SYSTEM_PROMPT = "refresh_prompt_per_tern"
|
||||
CONF_REMEMBER_CONVERSATION = "remember_conversation"
|
||||
CONF_REMEMBER_NUM_INTERACTIONS = "remember_num_interactions"
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT = True
|
||||
CONF_REMEMBER_CONVERSATION = "remember_conversation"
|
||||
DEFAULT_REMEMBER_CONVERSATION = True
|
||||
CONF_REMEMBER_NUM_INTERACTIONS = "remember_num_interactions"
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS = 5
|
||||
CONF_PROMPT_CACHING_ENABLED = "prompt_caching"
|
||||
DEFAULT_PROMPT_CACHING_ENABLED = False
|
||||
CONF_PROMPT_CACHING_INTERVAL = "prompt_caching_interval"
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL = 30
|
||||
CONF_SERVICE_CALL_REGEX = "service_call_regex"
|
||||
DEFAULT_SERVICE_CALL_REGEX = r"({[\S \t]*?})"
|
||||
FINE_TUNED_SERVICE_CALL_REGEX = r"```homeassistant\n([\S \t\n]*?)```"
|
||||
@@ -130,6 +134,15 @@ DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE = TEXT_GEN_WEBUI_CHAT_MODE_CHAT
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN = "ollama_keep_alive"
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN = 30
|
||||
|
||||
CONF_CONTEXT_LENGTH = "context_length"
|
||||
DEFAULT_CONTEXT_LENGTH = 2048
|
||||
CONF_BATCH_SIZE = "batch_size"
|
||||
DEFAULT_BATCH_SIZE = 512
|
||||
CONF_THREAD_COUNT = "n_threads"
|
||||
DEFAULT_THREAD_COUNT = os.cpu_count()
|
||||
CONF_BATCH_THREAD_COUNT = "n_batch_threads"
|
||||
DEFAULT_BATCH_THREAD_COUNT = os.cpu_count()
|
||||
|
||||
DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
{
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
@@ -147,6 +160,10 @@ DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_CONTEXT_LENGTH: DEFAULT_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE: DEFAULT_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT: DEFAULT_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT: DEFAULT_BATCH_THREAD_COUNT,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -195,41 +212,5 @@ OPTIONS_OVERRIDES = {
|
||||
}
|
||||
}
|
||||
|
||||
# TODO: need to rewrite the internal config_entry key names so they actually make sense before we expose this
|
||||
# method of configuring the component. doing so will require writing a config version upgrade migration
|
||||
# MODEL_CONFIG_SCHEMA = vol.Schema(
|
||||
# {
|
||||
# vol.Required(CONF_BACKEND_TYPE): vol.All(
|
||||
# vol.In([
|
||||
# BACKEND_TYPE_LLAMA_EXISTING,
|
||||
# BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
# BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
# BACKEND_TYPE_OLLAMA,
|
||||
# BACKEND_TYPE_GENERIC_OPENAI,
|
||||
# ])
|
||||
# ),
|
||||
# vol.Optional(CONF_HOST): cv.string,
|
||||
# vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
|
||||
# vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
|
||||
# vol.Optional("options"): vol.Schema(
|
||||
# {
|
||||
# vol.Optional(CONF_PROMPT): cv.string,
|
||||
# vol.Optional(CONF_PROMPT_TEMPLATE): vol.All(
|
||||
# vol.In([
|
||||
# PROMPT_TEMPLATE_ALPACA,
|
||||
# PROMPT_TEMPLATE_CHATML,
|
||||
# PROMPT_TEMPLATE_LLAMA2,
|
||||
# PROMPT_TEMPLATE_MISTRAL,
|
||||
# PROMPT_TEMPLATE_VICUNA,
|
||||
# PROMPT_TEMPLATE_ZEPHYR,
|
||||
# ])
|
||||
# ),
|
||||
# }
|
||||
# )
|
||||
# }
|
||||
# )
|
||||
|
||||
# CONFIG_SCHEMA = vol.Schema(
|
||||
# { DOMAIN: vol.All(cv.ensure_list, [MODEL_CONFIG_SCHEMA]) },
|
||||
# extra=vol.ALLOW_EXTRA,
|
||||
# )
|
||||
INTEGRATION_VERSION = "0.2.11"
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.2.60"
|
||||
25
custom_components/llama_conversation/json.gbnf
Normal file
25
custom_components/llama_conversation/json.gbnf
Normal file
@@ -0,0 +1,25 @@
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"domain": "llama_conversation",
|
||||
"name": "LLaMA Conversation",
|
||||
"version": "0.2.10",
|
||||
"version": "0.2.11",
|
||||
"codeowners": ["@acon96"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
|
||||
@@ -7,8 +7,7 @@
|
||||
"missing_model_file": "The provided file does not exist.",
|
||||
"other_existing_local": "Another model is already loaded locally. Please unload it or configure a remote model.",
|
||||
"unknown": "Unexpected error",
|
||||
"missing_wheels": "Llama.cpp is not installed and could not find any wheels to install! See the logs for more information.",
|
||||
"pip_wheel_error": "Pip returned an error while installing the wheel!"
|
||||
"pip_wheel_error": "Pip returned an error while installing the wheel! Please check the Home Assistant logs for more details."
|
||||
},
|
||||
"progress": {
|
||||
"download": "Please wait while the model is being downloaded from HuggingFace. This can take a few minutes.",
|
||||
@@ -60,6 +59,7 @@
|
||||
"extra_attributes_to_expose": "Additional attribute to expose in the context",
|
||||
"allowed_service_call_arguments": "Arguments allowed to be pass to service calls",
|
||||
"gbnf_grammar": "Enable GBNF Grammar",
|
||||
"gbnf_grammar_file": "GBNF Grammar Filename",
|
||||
"openai_api_key": "API Key",
|
||||
"text_generation_webui_admin_key": "Admin Key",
|
||||
"service_call_regex": "Service Call Regex",
|
||||
@@ -67,9 +67,16 @@
|
||||
"remember_conversation": "Remember conversation",
|
||||
"remember_num_interactions": "Number of past interactions to remember",
|
||||
"in_context_examples": "Enable in context learning (ICL) examples",
|
||||
"in_context_examples_file": "In context learning examples CSV filename",
|
||||
"text_generation_webui_preset": "Generation Preset/Character Name",
|
||||
"remote_use_chat_endpoint": "Use chat completions endpoint",
|
||||
"text_generation_webui_chat_mode": "Chat Mode"
|
||||
"text_generation_webui_chat_mode": "Chat Mode",
|
||||
"prompt_caching": "Enable Prompt Caching",
|
||||
"prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)",
|
||||
"context_length": "Context Length",
|
||||
"batch_size": "Batch Size",
|
||||
"n_threads": "Thread Count",
|
||||
"n_batch_threads": "Batch Thread Count"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
|
||||
@@ -98,6 +105,7 @@
|
||||
"extra_attributes_to_expose": "Additional attribute to expose in the context",
|
||||
"allowed_service_call_arguments": "Arguments allowed to be pass to service calls",
|
||||
"gbnf_grammar": "Enable GBNF Grammar",
|
||||
"gbnf_grammar_file": "GBNF Grammar Filename",
|
||||
"openai_api_key": "API Key",
|
||||
"text_generation_webui_admin_key": "Admin Key",
|
||||
"service_call_regex": "Service Call Regex",
|
||||
@@ -105,11 +113,23 @@
|
||||
"remember_conversation": "Remember conversation",
|
||||
"remember_num_interactions": "Number of past interactions to remember",
|
||||
"in_context_examples": "Enable in context learning (ICL) examples",
|
||||
"in_context_examples_file": "In context learning examples CSV filename",
|
||||
"text_generation_webui_preset": "Generation Preset/Character Name",
|
||||
"remote_use_chat_endpoint": "Use chat completions endpoint",
|
||||
"text_generation_webui_chat_mode": "Chat Mode"
|
||||
"text_generation_webui_chat_mode": "Chat Mode",
|
||||
"prompt_caching": "Enable Prompt Caching",
|
||||
"prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)",
|
||||
"context_length": "Context Length",
|
||||
"batch_size": "Batch Size",
|
||||
"n_threads": "Thread Count",
|
||||
"n_batch_threads": "Batch Thread Count"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"sys_refresh_caching_enabled": "System prompt refresh must be enabled for prompt caching to work!",
|
||||
"missing_gbnf_file": "The GBNF file was not found: '{filename}'",
|
||||
"missing_icl_file": "The in context learning example CSV file was not found: '{filename}'"
|
||||
}
|
||||
},
|
||||
"selector": {
|
||||
|
||||
@@ -11,6 +11,11 @@ from huggingface_hub import hf_hub_download, HfFileSystem
|
||||
from homeassistant.requirements import pip_kwargs
|
||||
from homeassistant.util.package import install_package, is_installed
|
||||
|
||||
from .const import (
|
||||
INTEGRATION_VERSION,
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
def closest_color(requested_color):
|
||||
@@ -60,38 +65,37 @@ def download_model_from_hf(model_name: str, quantization_type: str, storage_fold
|
||||
)
|
||||
|
||||
def install_llama_cpp_python(config_dir: str):
|
||||
|
||||
if is_installed("llama-cpp-python"):
|
||||
_LOGGER.info("llama-cpp-python is already installed")
|
||||
return True
|
||||
|
||||
platform_suffix = platform.machine()
|
||||
if platform_suffix == "arm64":
|
||||
platform_suffix = "aarch64"
|
||||
|
||||
runtime_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||
|
||||
github_release_url = f"https://github.com/acon96/home-llm/releases/download/v{INTEGRATION_VERSION}/llama_cpp_python-{EMBEDDED_LLAMA_CPP_PYTHON_VERSION}-{runtime_version}-{runtime_version}-musllinux_1_2_{platform_suffix}.whl"
|
||||
if install_package(github_release_url, pip_kwargs(config_dir)):
|
||||
_LOGGER.info("llama-cpp-python successfully installed from GitHub release")
|
||||
return True
|
||||
|
||||
folder = os.path.dirname(__file__)
|
||||
potential_wheels = sorted([ path for path in os.listdir(folder) if path.endswith(f"{platform_suffix}.whl") ], reverse=True)
|
||||
potential_wheels = [ wheel for wheel in potential_wheels if f"cp{sys.version_info.major}{sys.version_info.minor}" in wheel ]
|
||||
if len(potential_wheels) == 0:
|
||||
# someone who is better at async can figure out why this is necessary
|
||||
time.sleep(0.5)
|
||||
|
||||
if is_installed("llama-cpp-python"):
|
||||
_LOGGER.info("llama-cpp-python is already installed")
|
||||
return True
|
||||
|
||||
_LOGGER.error(
|
||||
"Error installing llama-cpp-python. Could not find any wheels that match the following filters. " + \
|
||||
f"platform: {platform_suffix}, python version: {sys.version_info.major}.{sys.version_info.minor}. " + \
|
||||
"If you recently updated Home Assistant, then you may need to use a different wheel than previously. " + \
|
||||
"Make sure that the correct .whl file is located in config/custom_components/llama_conversation/*"
|
||||
"Make sure that you download the correct .whl file from the GitHub releases page"
|
||||
)
|
||||
raise Exception("missing_wheels")
|
||||
return False
|
||||
|
||||
latest_wheel = potential_wheels[0]
|
||||
latest_version = latest_wheel.split("-")[1]
|
||||
|
||||
if not is_installed("llama-cpp-python") or version("llama-cpp-python") != latest_version:
|
||||
_LOGGER.info("Installing llama-cpp-python from wheel")
|
||||
_LOGGER.debug(f"Wheel location: {latest_wheel}")
|
||||
return install_package(os.path.join(folder, latest_wheel), pip_kwargs(config_dir))
|
||||
else:
|
||||
# someone who is better at async can figure out why this is necessary
|
||||
time.sleep(0.5)
|
||||
|
||||
_LOGGER.info("llama-cpp-python is already installed")
|
||||
return True
|
||||
_LOGGER.info("Installing llama-cpp-python from local wheel")
|
||||
_LOGGER.debug(f"Wheel location: {latest_wheel}")
|
||||
return install_package(os.path.join(folder, latest_wheel), pip_kwargs(config_dir))
|
||||
|
||||
@@ -25,6 +25,35 @@ For details about the sampling parameters, see here: https://github.com/oobaboog
|
||||
| Temperature | Sampling parameter; see above link | 0.1 |
|
||||
| Enable GBNF Grammar | Restricts the output of the model to follow a pre-defined syntax; eliminates function calling syntax errors on quantized models | Enabled |
|
||||
|
||||
## Wheels
|
||||
The wheels for `llama-cpp-python` can be built or downloaded manually for installation.
|
||||
|
||||
Take the appropriate wheel and copy it to the `custom_components/llama_conversation/` directory.
|
||||
|
||||
After the wheel file has been copied to the correct folder, attempt the wheel installation step of the integration setup. The local wheel file should be detected and installed.
|
||||
|
||||
## Pre-built
|
||||
Pre-built wheel files (`*.whl`) are located as part of the GitHub release for the integration.
|
||||
|
||||
To ensure compatibility with your Home Assistant and Python versions, select the correct `.whl` file for your hardware's architecture:
|
||||
- For Home Assistant `2024.1.4` and older, use the Python 3.11 wheels (`cp311`)
|
||||
- For Home Assistant `2024.2.0` and newer, use the Python 3.12 wheels (`cp312`)
|
||||
- **ARM devices** (e.g., Raspberry Pi 4/5):
|
||||
- Example filenames:
|
||||
- `llama_cpp_python-{version}-cp311-cp311-musllinux_1_2_aarch64.whl`
|
||||
- `llama_cpp_python-{version}-cp312-cp312-musllinux_1_2_aarch64.whl`
|
||||
- **x86_64 devices** (e.g., Intel/AMD desktops):
|
||||
- Example filenames:
|
||||
- `llama_cpp_python-{version}-cp311-cp311-musllinux_1_2_x86_64.whl`
|
||||
- `llama_cpp_python-{version}-cp312-cp312-musllinux_1_2_x86_64.whl`
|
||||
|
||||
## Build your own
|
||||
|
||||
1. Clone the repository on the target machine that will be running Home Assistant
|
||||
2. Run the `dist/run_docker.sh` script
|
||||
3. The wheel files will be placed in the `dist/` folder
|
||||
|
||||
|
||||
# text-generation-webui
|
||||
| Option Name | Description | Suggested Value |
|
||||
| ------------ | --------- | ------------ |
|
||||
|
||||
@@ -42,38 +42,20 @@ After installation, A "LLaMA Conversation" device should show up in the `Setting
|
||||
This setup path involves downloading a fine-tuned model from HuggingFace and integrating it with Home Assistant using the Llama.cpp backend. This option is for Home Assistant setups without a dedicated GPU, and the model is capable of running on most devices, and can even run on a Raspberry Pi (although slowly).
|
||||
|
||||
### Step 1: Wheel Installation for llama-cpp-python
|
||||
In order to run the Llama.cpp backend as part of Home Assistant, we need to install the binary "wheel" distribution that is pre-built for compatibility with Home Assistant.
|
||||
|
||||
The `*.whl` files are located in the [/dist](/dist) folder of this repository.
|
||||
|
||||
To ensure compatibility with your Home Assistant and Python versions, select the correct `.whl` file for your hardware's architecture:
|
||||
- For Home Assistant `2024.1.4` and older, use the Python 3.11 wheels (`cp311`)
|
||||
- For Home Assistant `2024.2.0` and newer, use the Python 3.12 wheels (`cp312`)
|
||||
- **ARM devices** (e.g., Raspberry Pi 4/5):
|
||||
- Example filenames:
|
||||
- `llama_cpp_python-{version}-cp311-cp311-musllinux_1_2_aarch64.whl`
|
||||
- `llama_cpp_python-{version}-cp312-cp312-musllinux_1_2_aarch64.whl`
|
||||
- **x86_64 devices** (e.g., Intel/AMD desktops):
|
||||
- Example filenames:
|
||||
- `llama_cpp_python-{version}-cp311-cp311-musllinux_1_2_x86_64.whl`
|
||||
- `llama_cpp_python-{version}-cp312-cp312-musllinux_1_2_x86_64.whl`
|
||||
Download the appropriate wheel and copy it to the `custom_components/llama_conversation/` directory.
|
||||
|
||||
After the wheel file has been copied to the correct folder.
|
||||
1. In Home Assistant: navigate to `Settings > Devices and Services`
|
||||
2. Select the `+ Add Integration` button in the bottom right corner
|
||||
3. Search for, and select `LLaMA Conversation`
|
||||
4. With the `Llama.cpp (HuggingFace)` backend selected, click `Submit`
|
||||
|
||||
This will trigger the installation of the wheel. If you ever need to update the version of Llama.cpp, you can copy a newer wheel file to the same folder, and re-create the integration; this will re-trigger the install process.
|
||||
This should download and install `llama-cpp-python` from GitHub. If the installation fails for any reason, follow the manual installation instructions [here](./Backend%20Configuration.md#wheels).
|
||||
|
||||
Once `llama-cpp-python` is installed, continue to the model selection.
|
||||
|
||||
### Step 2: Model Selection
|
||||
The next step is to specify which model will be used by the integration. You may select any repository on HuggingFace that has a model in GGUF format in it. We will use `acon96/Home-3B-v3-GGUF` for this example. If you have less than 4GB of RAM then use ``acon96/Home-1B-v2-GGUF`.
|
||||
The next step is to specify which model will be used by the integration. You may select any repository on HuggingFace that has a model in GGUF format in it. We will use `acon96/Home-3B-v3-GGUF` for this example. If you have less than 4GB of RAM then use `acon96/Home-1B-v2-GGUF`.
|
||||
|
||||
**Model Name**: Use either `acon96/Home-3B-v3-GGUF` or `acon96/Home-1B-v2-GGUF`
|
||||
**Quantization Level**: The model will be downloaded in the selected quantization level from the HuggingFace repository. If unsure which level to choose, select `Q4_K_M`.
|
||||
**Model Name**: Use either `acon96/Home-3B-v3-GGUF` or `acon96/Home-1B-v2-GGUF`
|
||||
**Quantization Level**: The model will be downloaded in the selected quantization level from the HuggingFace repository. If unsure which level to choose, select `Q4_K_M`.
|
||||
|
||||
Pressing `Submit` will download the model from HuggingFace.
|
||||
|
||||
@@ -86,7 +68,7 @@ The model will be loaded into memory and should now be available to select as a
|
||||
|
||||
## Path 2: Using Mistral-Instruct-7B with Ollama Backend
|
||||
### Overview
|
||||
For those who have access to a GPU, you can also use the Mistral-Instruct-7B model to power your conversation agent. This path requires a separate machine that has a GPU that has [Ollama](https://ollama.com/) already installed on it. This path utilizes in-context learning examples, to prompt the model to produce the output that we expect.
|
||||
For those who have access to a GPU, you can also use the Mistral-Instruct-7B model to power your conversation agent. This path requires a separate machine that has a GPU and has [Ollama](https://ollama.com/) already installed on it. This path utilizes in-context learning examples, to prompt the model to produce the output that we expect.
|
||||
|
||||
### Step 1: Downloading and serving the Model
|
||||
Mistral can be easily set up and downloaded on the serving machine using the `ollama pull mistral` command.
|
||||
@@ -113,6 +95,8 @@ This step allows you to configure how the model is "prompted". See [here](./Mode
|
||||
|
||||
For now, defaults for the model should have been populated and you can just scroll to the bottom and click `Submit`.
|
||||
|
||||
> NOTE: The key settings in this case are that our prompt references the `{{ response_examples }}` variable and the `Enable in context learning (ICL) examples` option is turned on.
|
||||
|
||||
## Configuring the Integration as a Conversation Agent
|
||||
Now that the integration is configured and providing the conversation agent, we need to configure Home Assistant to use our conversation agent instead of the built in intent recognition system.
|
||||
|
||||
|
||||
@@ -6,4 +6,6 @@ bitsandbytes
|
||||
webcolors
|
||||
pandas
|
||||
# flash-attn
|
||||
sentencepiece
|
||||
sentencepiece
|
||||
|
||||
homeassistant
|
||||
Reference in New Issue
Block a user