mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 14:18:00 -05:00
type checking + comments
This commit is contained in:
@@ -93,14 +93,21 @@ from .const import (
|
||||
PROMPT_TEMPLATE_DESCRIPTIONS,
|
||||
)
|
||||
|
||||
# make type checking work for llama-cpp-python without importing it directly at runtime
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp import Llama as LlamaType
|
||||
else:
|
||||
LlamaType = Any
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
class LLaMAAgent(AbstractConversationAgent):
|
||||
"""Local LLaMA conversation agent."""
|
||||
"""Base LLaMA conversation agent."""
|
||||
|
||||
hass: Any
|
||||
hass: HomeAssistant
|
||||
entry_id: str
|
||||
history: dict[str, list[dict]]
|
||||
in_context_examples: list[dict]
|
||||
@@ -122,6 +129,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
self._load_model(entry)
|
||||
|
||||
def _load_icl_examples(self, filename: str):
|
||||
"""Load info used for generating in context learning examples"""
|
||||
try:
|
||||
icl_filename = os.path.join(os.path.dirname(__file__), filename)
|
||||
|
||||
@@ -155,12 +163,15 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
return MATCH_ALL
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
"""Load the model on the backend. Implemented by sub-classes"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _generate(self, conversation: dict) -> str:
|
||||
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
"""Async wrapper for _generate()"""
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._generate, conversation
|
||||
)
|
||||
@@ -226,6 +237,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
conversation.append({"role": "user", "message": user_input.text})
|
||||
|
||||
# generate a response
|
||||
try:
|
||||
_LOGGER.debug(conversation)
|
||||
response = await self._async_generate(conversation)
|
||||
@@ -250,6 +262,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
conversation.pop(1)
|
||||
self.history[conversation_id] = conversation
|
||||
|
||||
# parse response
|
||||
exposed_entities = list(self._async_get_exposed_entities()[0].keys())
|
||||
|
||||
to_say = service_call_pattern.sub("", response).strip()
|
||||
@@ -315,6 +328,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
if template_desc["assistant"]["suffix"]:
|
||||
to_say = to_say.replace(template_desc["assistant"]["suffix"], "") # remove the eos token if it is returned (some backends + the old model does this)
|
||||
|
||||
# generate intent response to Home Assistant
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(to_say)
|
||||
return ConversationResult(
|
||||
@@ -347,6 +361,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
def _format_prompt(
|
||||
self, prompt: list[dict], include_generation_prompt: bool = True
|
||||
) -> str:
|
||||
"""Format a conversation into a raw text completion using the model's prompt template"""
|
||||
formatted_prompt = ""
|
||||
|
||||
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
||||
@@ -372,7 +387,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
return formatted_prompt
|
||||
|
||||
def _generate_system_prompt(self, prompt_template: str) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
"""Generate the system prompt with current entity states"""
|
||||
entities_to_expose, domains = self._async_get_exposed_entities()
|
||||
|
||||
extra_attributes_to_expose = self.entry.options \
|
||||
@@ -470,7 +485,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
class LocalLLaMAAgent(LLaMAAgent):
|
||||
model_path: str
|
||||
llm: Any
|
||||
llm: LlamaType
|
||||
grammar: Any
|
||||
llama_cpp_module: Any
|
||||
remove_prompt_caching_listener: Callable
|
||||
@@ -604,7 +619,7 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
self._set_prompt_caching(enabled=False)
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
"""Takes the super class function results and sorts the entities by most recently updated at the end"""
|
||||
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
|
||||
entities, domains = LLaMAAgent._async_get_exposed_entities(self)
|
||||
|
||||
# ignore sorting if prompt caching is disabled
|
||||
@@ -939,7 +954,7 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
|
||||
def _load_model(self, entry: ConfigEntry):
|
||||
super()._load_model(entry)
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
|
||||
with open(os.path.join(os.path.dirname(__file__), DEFAULT_GBNF_GRAMMAR_FILE)) as f:
|
||||
self.grammar = "".join(f.readlines())
|
||||
|
||||
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
||||
|
||||
Reference in New Issue
Block a user