From 1811a907f7c7ab3810088882fb982e42aea1b88f Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Sun, 21 Dec 2025 22:09:18 -0500 Subject: [PATCH] manually set roles to train on --- .../llama_conversation/__init__.py | 12 +++++++++++- .../llama_conversation/backends/llamacpp.py | 10 ++++++---- data/generate_data.py | 17 +++++++++++------ train/configs/functiongemma-270m.yml | 2 +- 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index 7b9bc1d..ae7b7ff 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import os +import shutil from typing import Final from homeassistant.config_entries import ConfigEntry, ConfigSubentry @@ -96,9 +97,18 @@ async def _async_update_listener(hass: HomeAssistant, entry: LocalLLMConfigEntry await hass.config_entries.async_reload(entry.entry_id) async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool: - """Unload Ollama.""" + """Unload the integration.""" if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): return False + + if entry.data[CONF_BACKEND_TYPE] == BACKEND_TYPE_LLAMA_CPP: + # clean up any disk cache resources + def cleanup_cache_dir(): + cache_dir = entry.data[CONF_CHAT_MODEL].strip().replace(" ", "_").lower() + full_path = os.path.join(hass.config.media_dirs.get("local", hass.config.path("media")), "kv_cache", cache_dir) + shutil.rmtree(full_path, ignore_errors=True) + + await hass.async_add_executor_job(cleanup_cache_dir) hass.data[DOMAIN].pop(entry.entry_id) return True diff --git a/custom_components/llama_conversation/backends/llamacpp.py b/custom_components/llama_conversation/backends/llamacpp.py index e0edbef..c9acdc5 100644 --- a/custom_components/llama_conversation/backends/llamacpp.py +++ b/custom_components/llama_conversation/backends/llamacpp.py @@ -180,11 +180,13 @@ class LlamaCppClient(LocalLLMClient): _LOGGER.debug("Model loaded") # create disk cache if enabled + # cache must be per-model to avoid conflicts with different hidden state sizes cache_size = model_settings.get(CONF_LLAMACPP_CACHE_SIZE_MB, DEFAULT_LLAMACPP_CACHE_SIZE_MB) + cache_dir = model_name.strip().replace(" ", "_").lower() if cache_size > 0: llm.set_cache(LlamaDiskCache( - capacity_bytes=int(cache_size * (1024 ** 3)), - cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache") + capacity_bytes=int(cache_size * (1024 ** 2)), # MB to bytes + cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache", cache_dir) )) if model_settings[CONF_PROMPT_CACHING_ENABLED]: @@ -406,7 +408,7 @@ class LlamaCppClient(LocalLLMClient): max_tokens=1, grammar=grammar, stream=False, - stop=["", ""] + # stop=["", ""] ) self.last_cache_prime = time.time() @@ -480,7 +482,7 @@ class LlamaCppClient(LocalLLMClient): grammar=grammar, stream=True, response_format=response_format, - stop=["", ""] # FIXME: make configurable (pull from tool end token?) + # stop=["", ""] # FIXME: make configurable (pull from tool end token?) ) def next_token() -> Generator[tuple[Optional[str], Optional[List]]]: diff --git a/data/generate_data.py b/data/generate_data.py index c5c9169..eba4f0b 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -614,18 +614,21 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_ conversation = [ { "role": "system", - "content": [{"type": "text", "text": sys_prompt}] + "content": [{"type": "text", "text": sys_prompt}], + "train_on_turn": False, }, { "role": "user", - "content": [{ "type": "text", "text": question }] + "content": [{ "type": "text", "text": question }], + "train_on_turn": False, } ] else: conversation = [ { "role": "user", - "content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }] + "content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }], + "train_on_turn": False, } ] @@ -646,7 +649,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_ call_names.append(call_name) formatted_calls.append({ "name": call_name, - "arguments": json.dumps(tool_call["tool_args"]) + "arguments": json.dumps(tool_call["tool_args"]), }) if formatted_calls: @@ -679,12 +682,14 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_ if tool_response_format == "text": conversation.append({ "role": "tool", - "content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results] + "content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results], + "train_on_turn": False, }) elif tool_response_format == "functiongemma": conversation.append({ "role": "tool", - "content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results] + "content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results], + "train_on_turn": False, }) return { diff --git a/train/configs/functiongemma-270m.yml b/train/configs/functiongemma-270m.yml index adb11a5..7b0f48f 100644 --- a/train/configs/functiongemma-270m.yml +++ b/train/configs/functiongemma-270m.yml @@ -308,7 +308,7 @@ datasets: - path: /workspace/data/datasets/sample.jsonl ds_type: json type: chat_template - # roles_to_train: [ "assistant" ] + roles_to_train: [] message_field_training: train_on_turn dataset_prepared_path: /workspace/data/datasets/prepared/