mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 06:07:58 -05:00
split up even more + add llama-cpp-python server
This commit is contained in:
@@ -176,7 +176,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
|
||||
)
|
||||
|
||||
self._load_model(entry)
|
||||
self._load_model(entry)
|
||||
|
||||
@property
|
||||
def entry(self):
|
||||
@@ -516,6 +516,16 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
if response_json["object"] == "chat.completion":
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
def _generate(self, conversation: dict) -> str:
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
@@ -536,7 +546,7 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
request_params.update(additional_params)
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
@@ -557,25 +567,16 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
|
||||
_LOGGER.debug(f"Result was: {result.text}")
|
||||
return f"Failed to communicate with the API! {err}"
|
||||
|
||||
choices = result.json()["choices"]
|
||||
|
||||
_LOGGER.debug(result.json())
|
||||
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
|
||||
if result.json()["object"] in ["chat.completion", "chat.completions" ]:
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
return self._extract_response(result.json())
|
||||
|
||||
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
admin_key: str
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
super()._load_model(entry)
|
||||
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, entry.data.get(CONF_OPENAI_API_KEY))
|
||||
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
|
||||
|
||||
try:
|
||||
currently_loaded_result = requests.get(f"{self.api_host}/v1/internal/model/info")
|
||||
@@ -629,4 +630,47 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
if preset:
|
||||
request_params["preset"] = preset
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warn("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
|
||||
if response_json["object"] == "chat.completions":
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
|
||||
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
|
||||
grammar: str
|
||||
|
||||
def _load_model(self, entry: ConfigEntry):
|
||||
super()._load_model(entry)
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), GBNF_GRAMMAR_FILE)) as f:
|
||||
self.grammar = "".join(f.readlines())
|
||||
|
||||
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
||||
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
||||
endpoint, request_params = super()._chat_completion_params(conversation)
|
||||
|
||||
request_params["top_k"] = top_k
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
request_params["grammar"] = self.grammar
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
||||
endpoint, request_params = super()._completion_params(conversation)
|
||||
|
||||
request_params["top_k"] = top_k
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
request_params["grammar"] = self.grammar
|
||||
|
||||
return endpoint, request_params
|
||||
Reference in New Issue
Block a user