mostly working gemma implementation

This commit is contained in:
Alex O'Connell
2025-12-20 20:29:09 -05:00
parent 672a9de65c
commit 29d839eea8
8 changed files with 694 additions and 38 deletions

View File

@@ -67,6 +67,7 @@ if TYPE_CHECKING:
from llama_cpp import (
Llama as LlamaType,
LlamaGrammar as LlamaGrammarType,
LlamaDiskCache as LlamaDiskCacheType,
ChatCompletionRequestResponseFormat
)
else:
@@ -156,6 +157,7 @@ class LlamaCppClient(LocalLLMClient):
self.llama_cpp_module = importlib.import_module("llama_cpp")
Llama: type[LlamaType] = getattr(self.llama_cpp_module, "Llama")
LlamaDiskCache: type[LlamaDiskCacheType] = getattr(self.llama_cpp_module, "LlamaDiskCache")
_LOGGER.debug(f"Loading model '{model_path}'...")
model_settings = snapshot_settings(entity_options)
@@ -170,11 +172,11 @@ class LlamaCppClient(LocalLLMClient):
)
_LOGGER.debug("Model loaded")
# TODO: check about disk caching
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
# capacity_bytes=(512 * 10e8),
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
# ))
# FIXME: make cache size configurable (0 means disabled)
llm.set_cache(LlamaDiskCache(
capacity_bytes=int(512 * 10e8),
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
))
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
@callback
@@ -393,6 +395,7 @@ class LlamaCppClient(LocalLLMClient):
max_tokens=1,
grammar=grammar,
stream=False,
stop=["<end_of_turn>", "<end_function_call>"]
)
self.last_cache_prime = time.time()
@@ -464,6 +467,7 @@ class LlamaCppClient(LocalLLMClient):
grammar=grammar,
stream=True,
response_format=response_format,
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
)
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:

View File

@@ -261,12 +261,15 @@ class LocalLLMClient:
tool_content += content
if think_prefix in potential_block and not in_thinking:
_LOGGER.debug("Entering thinking block")
in_thinking = True
last_5_tokens.clear()
elif think_suffix in potential_block and in_thinking:
_LOGGER.debug("Exiting thinking block")
in_thinking = False
content = content.replace(think_suffix, "").strip()
elif tool_prefix in potential_block and not in_tool_call:
_LOGGER.debug("Entering tool call block")
in_tool_call = True
last_5_tokens.clear()
elif tool_suffix in potential_block and in_tool_call:
@@ -307,6 +310,20 @@ class LocalLLMClient:
if not in_thinking and not in_tool_call and (cur_match_length == 0 or result.tool_calls):
yield result
if in_tool_call and tool_content:
# flush any unclosed tool calls because using the tool_suffix as a stop token can
# cause the tool_suffix to be omitted when the model streams output
tool_block = tool_content.strip().removeprefix(tool_prefix)
_LOGGER.debug("Raw tool block extracted at end: %s", tool_block)
tool_call, to_say = parse_raw_tool_call(tool_block, agent_id)
if tool_call:
_LOGGER.debug("Tool call parsed at end: %s", tool_call)
yield TextGenerationResult(
response=to_say,
response_streamed=True,
tool_calls=[tool_call]
)
async def _async_parse_completion(
self,
llm_api: llm.APIInstance | None,

View File

@@ -356,7 +356,12 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content], use
elif message.role == "tool_result":
messages.append({
"role": "tool",
"content": json.dumps(message.tool_result),
# FIXME: what is the correct format for content here? gemma expects name and result
# "content": json.dumps(message.tool_result),
"content": {
"name": message.tool_name,
"response": { "result": message.tool_result },
},
"tool_call_id": message.tool_call_id
})
@@ -404,7 +409,26 @@ def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolI
if isinstance(raw_block, dict):
parsed_tool_call = raw_block
else:
parsed_tool_call: dict = json.loads(raw_block)
try:
parsed_tool_call: dict = json.loads(raw_block)
except json.JSONDecodeError:
# handle the "gemma" tool calling format
# call:HassTurnOn{name:<escape>light.living_room_rgbww<escape>}
gemma_match = re.finditer(r"call:(?P<name>\w+){(?P<args>.+)}", raw_block)
for match in gemma_match:
tool_name = match.group("name")
raw_args = match.group("args")
args_dict = {}
for arg_match in re.finditer(r"(?P<key>\w+):<escape>(?P<value>.+?)<escape>", raw_args):
args_dict[arg_match.group("key")] = arg_match.group("value")
parsed_tool_call = {
"name": tool_name,
"arguments": args_dict
}
break # TODO: how do we properly handle multiple tool calls in one response?
else:
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted JSON")
# try to validate either format
is_services_tool_call = False