mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
mostly working gemma implementation
This commit is contained in:
@@ -67,6 +67,7 @@ if TYPE_CHECKING:
|
||||
from llama_cpp import (
|
||||
Llama as LlamaType,
|
||||
LlamaGrammar as LlamaGrammarType,
|
||||
LlamaDiskCache as LlamaDiskCacheType,
|
||||
ChatCompletionRequestResponseFormat
|
||||
)
|
||||
else:
|
||||
@@ -156,6 +157,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
||||
|
||||
Llama: type[LlamaType] = getattr(self.llama_cpp_module, "Llama")
|
||||
LlamaDiskCache: type[LlamaDiskCacheType] = getattr(self.llama_cpp_module, "LlamaDiskCache")
|
||||
|
||||
_LOGGER.debug(f"Loading model '{model_path}'...")
|
||||
model_settings = snapshot_settings(entity_options)
|
||||
@@ -170,11 +172,11 @@ class LlamaCppClient(LocalLLMClient):
|
||||
)
|
||||
_LOGGER.debug("Model loaded")
|
||||
|
||||
# TODO: check about disk caching
|
||||
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
|
||||
# capacity_bytes=(512 * 10e8),
|
||||
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
# ))
|
||||
# FIXME: make cache size configurable (0 means disabled)
|
||||
llm.set_cache(LlamaDiskCache(
|
||||
capacity_bytes=int(512 * 10e8),
|
||||
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
))
|
||||
|
||||
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
||||
@callback
|
||||
@@ -393,6 +395,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
max_tokens=1,
|
||||
grammar=grammar,
|
||||
stream=False,
|
||||
stop=["<end_of_turn>", "<end_function_call>"]
|
||||
)
|
||||
|
||||
self.last_cache_prime = time.time()
|
||||
@@ -464,6 +467,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
grammar=grammar,
|
||||
stream=True,
|
||||
response_format=response_format,
|
||||
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
|
||||
)
|
||||
|
||||
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
||||
|
||||
@@ -261,12 +261,15 @@ class LocalLLMClient:
|
||||
tool_content += content
|
||||
|
||||
if think_prefix in potential_block and not in_thinking:
|
||||
_LOGGER.debug("Entering thinking block")
|
||||
in_thinking = True
|
||||
last_5_tokens.clear()
|
||||
elif think_suffix in potential_block and in_thinking:
|
||||
_LOGGER.debug("Exiting thinking block")
|
||||
in_thinking = False
|
||||
content = content.replace(think_suffix, "").strip()
|
||||
elif tool_prefix in potential_block and not in_tool_call:
|
||||
_LOGGER.debug("Entering tool call block")
|
||||
in_tool_call = True
|
||||
last_5_tokens.clear()
|
||||
elif tool_suffix in potential_block and in_tool_call:
|
||||
@@ -307,6 +310,20 @@ class LocalLLMClient:
|
||||
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
|
||||
yield result
|
||||
|
||||
if in_tool_call and tool_content:
|
||||
# flush any unclosed tool calls because using the tool_suffix as a stop token can
|
||||
# cause the tool_suffix to be omitted when the model streams output
|
||||
tool_block = tool_content.strip().removeprefix(tool_prefix)
|
||||
_LOGGER.debug("Raw tool block extracted at end: %s", tool_block)
|
||||
tool_call, to_say = parse_raw_tool_call(tool_block, agent_id)
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed at end: %s", tool_call)
|
||||
yield TextGenerationResult(
|
||||
response=to_say,
|
||||
response_streamed=True,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
|
||||
async def _async_parse_completion(
|
||||
self,
|
||||
llm_api: llm.APIInstance | None,
|
||||
|
||||
@@ -356,7 +356,12 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content], use
|
||||
elif message.role == "tool_result":
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(message.tool_result),
|
||||
# FIXME: what is the correct format for content here? gemma expects name and result
|
||||
# "content": json.dumps(message.tool_result),
|
||||
"content": {
|
||||
"name": message.tool_name,
|
||||
"response": { "result": message.tool_result },
|
||||
},
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
|
||||
@@ -404,7 +409,26 @@ def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolI
|
||||
if isinstance(raw_block, dict):
|
||||
parsed_tool_call = raw_block
|
||||
else:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
try:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
except json.JSONDecodeError:
|
||||
# handle the "gemma" tool calling format
|
||||
# call:HassTurnOn{name:<escape>light.living_room_rgbww<escape>}
|
||||
gemma_match = re.finditer(r"call:(?P<name>\w+){(?P<args>.+)}", raw_block)
|
||||
for match in gemma_match:
|
||||
tool_name = match.group("name")
|
||||
raw_args = match.group("args")
|
||||
args_dict = {}
|
||||
for arg_match in re.finditer(r"(?P<key>\w+):<escape>(?P<value>.+?)<escape>", raw_args):
|
||||
args_dict[arg_match.group("key")] = arg_match.group("value")
|
||||
|
||||
parsed_tool_call = {
|
||||
"name": tool_name,
|
||||
"arguments": args_dict
|
||||
}
|
||||
break # TODO: how do we properly handle multiple tool calls in one response?
|
||||
else:
|
||||
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted JSON")
|
||||
|
||||
# try to validate either format
|
||||
is_services_tool_call = False
|
||||
|
||||
Reference in New Issue
Block a user