split up even more + add llama-cpp-python server

This commit is contained in:
Alex O'Connell
2024-01-21 16:25:50 -05:00
parent cc3dd4884a
commit 7c30bb57cf

View File

@@ -176,7 +176,7 @@ class LLaMAAgent(AbstractConversationAgent):
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
)
self._load_model(entry)
self._load_model(entry)
@property
def entry(self):
@@ -516,6 +516,16 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
return endpoint, request_params
def _extract_response(self, response_json: dict) -> str:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
if response_json["object"] == "chat.completion":
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
def _generate(self, conversation: dict) -> str:
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
@@ -536,7 +546,7 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
else:
endpoint, additional_params = self._completion_params(conversation)
request_params.update(additional_params)
request_params.update(additional_params)
headers = {}
if self.api_key:
@@ -557,25 +567,16 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
_LOGGER.debug(f"Result was: {result.text}")
return f"Failed to communicate with the API! {err}"
choices = result.json()["choices"]
_LOGGER.debug(result.json())
if choices[0]["finish_reason"] != "stop":
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
if result.json()["object"] in ["chat.completion", "chat.completions" ]:
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
return self._extract_response(result.json())
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
admin_key: str
def _load_model(self, entry: ConfigEntry) -> None:
super()._load_model(entry)
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, entry.data.get(CONF_OPENAI_API_KEY))
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
try:
currently_loaded_result = requests.get(f"{self.api_host}/v1/internal/model/info")
@@ -629,4 +630,47 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
if preset:
request_params["preset"] = preset
return endpoint, request_params
def _extract_response(self, response_json: dict) -> str:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
if response_json["object"] == "chat.completions":
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
grammar: str
def _load_model(self, entry: ConfigEntry):
super()._load_model(entry)
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
self.grammar = "".join(f.readlines())
def _chat_completion_params(self, conversation: dict) -> (str, dict):
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
endpoint, request_params = super()._chat_completion_params(conversation)
request_params["top_k"] = top_k
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
request_params["grammar"] = self.grammar
return endpoint, request_params
def _completion_params(self, conversation: dict) -> (str, dict):
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
endpoint, request_params = super()._completion_params(conversation)
request_params["top_k"] = top_k
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
request_params["grammar"] = self.grammar
return endpoint, request_params