type checking + comments

This commit is contained in:
Alex O'Connell
2024-04-07 12:28:17 -04:00
parent 27044a8fae
commit 670e6d8625

View File

@@ -93,14 +93,21 @@ from .const import (
PROMPT_TEMPLATE_DESCRIPTIONS,
)
# make type checking work for llama-cpp-python without importing it directly at runtime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from llama_cpp import Llama as LlamaType
else:
LlamaType = Any
_LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
class LLaMAAgent(AbstractConversationAgent):
"""Local LLaMA conversation agent."""
"""Base LLaMA conversation agent."""
hass: Any
hass: HomeAssistant
entry_id: str
history: dict[str, list[dict]]
in_context_examples: list[dict]
@@ -122,6 +129,7 @@ class LLaMAAgent(AbstractConversationAgent):
self._load_model(entry)
def _load_icl_examples(self, filename: str):
"""Load info used for generating in context learning examples"""
try:
icl_filename = os.path.join(os.path.dirname(__file__), filename)
@@ -155,12 +163,15 @@ class LLaMAAgent(AbstractConversationAgent):
return MATCH_ALL
def _load_model(self, entry: ConfigEntry) -> None:
"""Load the model on the backend. Implemented by sub-classes"""
raise NotImplementedError()
def _generate(self, conversation: dict) -> str:
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
raise NotImplementedError()
async def _async_generate(self, conversation: dict) -> str:
"""Async wrapper for _generate()"""
return await self.hass.async_add_executor_job(
self._generate, conversation
)
@@ -226,6 +237,7 @@ class LLaMAAgent(AbstractConversationAgent):
conversation.append({"role": "user", "message": user_input.text})
# generate a response
try:
_LOGGER.debug(conversation)
response = await self._async_generate(conversation)
@@ -250,6 +262,7 @@ class LLaMAAgent(AbstractConversationAgent):
conversation.pop(1)
self.history[conversation_id] = conversation
# parse response
exposed_entities = list(self._async_get_exposed_entities()[0].keys())
to_say = service_call_pattern.sub("", response).strip()
@@ -315,6 +328,7 @@ class LLaMAAgent(AbstractConversationAgent):
if template_desc["assistant"]["suffix"]:
to_say = to_say.replace(template_desc["assistant"]["suffix"], "") # remove the eos token if it is returned (some backends + the old model does this)
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(to_say)
return ConversationResult(
@@ -347,6 +361,7 @@ class LLaMAAgent(AbstractConversationAgent):
def _format_prompt(
self, prompt: list[dict], include_generation_prompt: bool = True
) -> str:
"""Format a conversation into a raw text completion using the model's prompt template"""
formatted_prompt = ""
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
@@ -372,7 +387,7 @@ class LLaMAAgent(AbstractConversationAgent):
return formatted_prompt
def _generate_system_prompt(self, prompt_template: str) -> str:
"""Generate a prompt for the user."""
"""Generate the system prompt with current entity states"""
entities_to_expose, domains = self._async_get_exposed_entities()
extra_attributes_to_expose = self.entry.options \
@@ -470,7 +485,7 @@ class LLaMAAgent(AbstractConversationAgent):
class LocalLLaMAAgent(LLaMAAgent):
model_path: str
llm: Any
llm: LlamaType
grammar: Any
llama_cpp_module: Any
remove_prompt_caching_listener: Callable
@@ -604,7 +619,7 @@ class LocalLLaMAAgent(LLaMAAgent):
self._set_prompt_caching(enabled=False)
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
"""Takes the super class function results and sorts the entities by most recently updated at the end"""
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
entities, domains = LLaMAAgent._async_get_exposed_entities(self)
# ignore sorting if prompt caching is disabled
@@ -939,7 +954,7 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
def _load_model(self, entry: ConfigEntry):
super()._load_model(entry)
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
with open(os.path.join(os.path.dirname(__file__), DEFAULT_GBNF_GRAMMAR_FILE)) as f:
self.grammar = "".join(f.readlines())
def _chat_completion_params(self, conversation: dict) -> (str, dict):