manually set roles to train on

This commit is contained in:
Alex O'Connell
2025-12-21 22:09:18 -05:00
parent cf01fd29ae
commit 1811a907f7
4 changed files with 29 additions and 12 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import os
import shutil
from typing import Final
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
@@ -96,9 +97,18 @@ async def _async_update_listener(hass: HomeAssistant, entry: LocalLLMConfigEntry
await hass.config_entries.async_reload(entry.entry_id)
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
"""Unload Ollama."""
"""Unload the integration."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
if entry.data[CONF_BACKEND_TYPE] == BACKEND_TYPE_LLAMA_CPP:
# clean up any disk cache resources
def cleanup_cache_dir():
cache_dir = entry.data[CONF_CHAT_MODEL].strip().replace(" ", "_").lower()
full_path = os.path.join(hass.config.media_dirs.get("local", hass.config.path("media")), "kv_cache", cache_dir)
shutil.rmtree(full_path, ignore_errors=True)
await hass.async_add_executor_job(cleanup_cache_dir)
hass.data[DOMAIN].pop(entry.entry_id)
return True

View File

@@ -180,11 +180,13 @@ class LlamaCppClient(LocalLLMClient):
_LOGGER.debug("Model loaded")
# create disk cache if enabled
# cache must be per-model to avoid conflicts with different hidden state sizes
cache_size = model_settings.get(CONF_LLAMACPP_CACHE_SIZE_MB, DEFAULT_LLAMACPP_CACHE_SIZE_MB)
cache_dir = model_name.strip().replace(" ", "_").lower()
if cache_size > 0:
llm.set_cache(LlamaDiskCache(
capacity_bytes=int(cache_size * (1024 ** 3)),
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
capacity_bytes=int(cache_size * (1024 ** 2)), # MB to bytes
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache", cache_dir)
))
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
@@ -406,7 +408,7 @@ class LlamaCppClient(LocalLLMClient):
max_tokens=1,
grammar=grammar,
stream=False,
stop=["<end_of_turn>", "<end_function_call>"]
# stop=["<end_of_turn>", "<end_function_call>"]
)
self.last_cache_prime = time.time()
@@ -480,7 +482,7 @@ class LlamaCppClient(LocalLLMClient):
grammar=grammar,
stream=True,
response_format=response_format,
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
# stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
)
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:

View File

@@ -614,18 +614,21 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
conversation = [
{
"role": "system",
"content": [{"type": "text", "text": sys_prompt}]
"content": [{"type": "text", "text": sys_prompt}],
"train_on_turn": False,
},
{
"role": "user",
"content": [{ "type": "text", "text": question }]
"content": [{ "type": "text", "text": question }],
"train_on_turn": False,
}
]
else:
conversation = [
{
"role": "user",
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }],
"train_on_turn": False,
}
]
@@ -646,7 +649,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
call_names.append(call_name)
formatted_calls.append({
"name": call_name,
"arguments": json.dumps(tool_call["tool_args"])
"arguments": json.dumps(tool_call["tool_args"]),
})
if formatted_calls:
@@ -679,12 +682,14 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
if tool_response_format == "text":
conversation.append({
"role": "tool",
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results],
"train_on_turn": False,
})
elif tool_response_format == "functiongemma":
conversation.append({
"role": "tool",
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results],
"train_on_turn": False,
})
return {

View File

@@ -308,7 +308,7 @@ datasets:
- path: /workspace/data/datasets/sample.jsonl
ds_type: json
type: chat_template
# roles_to_train: [ "assistant" ]
roles_to_train: []
message_field_training: train_on_turn
dataset_prepared_path: /workspace/data/datasets/prepared/