manually set roles to train on

This commit is contained in:
Alex O'Connell
2025-12-21 22:09:18 -05:00
parent cf01fd29ae
commit 1811a907f7
4 changed files with 29 additions and 12 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
import os import os
import shutil
from typing import Final from typing import Final
from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
@@ -96,9 +97,18 @@ async def _async_update_listener(hass: HomeAssistant, entry: LocalLLMConfigEntry
await hass.config_entries.async_reload(entry.entry_id) await hass.config_entries.async_reload(entry.entry_id)
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
"""Unload Ollama.""" """Unload the integration."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False return False
if entry.data[CONF_BACKEND_TYPE] == BACKEND_TYPE_LLAMA_CPP:
# clean up any disk cache resources
def cleanup_cache_dir():
cache_dir = entry.data[CONF_CHAT_MODEL].strip().replace(" ", "_").lower()
full_path = os.path.join(hass.config.media_dirs.get("local", hass.config.path("media")), "kv_cache", cache_dir)
shutil.rmtree(full_path, ignore_errors=True)
await hass.async_add_executor_job(cleanup_cache_dir)
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
return True return True

View File

@@ -180,11 +180,13 @@ class LlamaCppClient(LocalLLMClient):
_LOGGER.debug("Model loaded") _LOGGER.debug("Model loaded")
# create disk cache if enabled # create disk cache if enabled
# cache must be per-model to avoid conflicts with different hidden state sizes
cache_size = model_settings.get(CONF_LLAMACPP_CACHE_SIZE_MB, DEFAULT_LLAMACPP_CACHE_SIZE_MB) cache_size = model_settings.get(CONF_LLAMACPP_CACHE_SIZE_MB, DEFAULT_LLAMACPP_CACHE_SIZE_MB)
cache_dir = model_name.strip().replace(" ", "_").lower()
if cache_size > 0: if cache_size > 0:
llm.set_cache(LlamaDiskCache( llm.set_cache(LlamaDiskCache(
capacity_bytes=int(cache_size * (1024 ** 3)), capacity_bytes=int(cache_size * (1024 ** 2)), # MB to bytes
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache") cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache", cache_dir)
)) ))
if model_settings[CONF_PROMPT_CACHING_ENABLED]: if model_settings[CONF_PROMPT_CACHING_ENABLED]:
@@ -406,7 +408,7 @@ class LlamaCppClient(LocalLLMClient):
max_tokens=1, max_tokens=1,
grammar=grammar, grammar=grammar,
stream=False, stream=False,
stop=["<end_of_turn>", "<end_function_call>"] # stop=["<end_of_turn>", "<end_function_call>"]
) )
self.last_cache_prime = time.time() self.last_cache_prime = time.time()
@@ -480,7 +482,7 @@ class LlamaCppClient(LocalLLMClient):
grammar=grammar, grammar=grammar,
stream=True, stream=True,
response_format=response_format, response_format=response_format,
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?) # stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
) )
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]: def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:

View File

@@ -614,18 +614,21 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
conversation = [ conversation = [
{ {
"role": "system", "role": "system",
"content": [{"type": "text", "text": sys_prompt}] "content": [{"type": "text", "text": sys_prompt}],
"train_on_turn": False,
}, },
{ {
"role": "user", "role": "user",
"content": [{ "type": "text", "text": question }] "content": [{ "type": "text", "text": question }],
"train_on_turn": False,
} }
] ]
else: else:
conversation = [ conversation = [
{ {
"role": "user", "role": "user",
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }] "content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }],
"train_on_turn": False,
} }
] ]
@@ -646,7 +649,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
call_names.append(call_name) call_names.append(call_name)
formatted_calls.append({ formatted_calls.append({
"name": call_name, "name": call_name,
"arguments": json.dumps(tool_call["tool_args"]) "arguments": json.dumps(tool_call["tool_args"]),
}) })
if formatted_calls: if formatted_calls:
@@ -679,12 +682,14 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
if tool_response_format == "text": if tool_response_format == "text":
conversation.append({ conversation.append({
"role": "tool", "role": "tool",
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results] "content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results],
"train_on_turn": False,
}) })
elif tool_response_format == "functiongemma": elif tool_response_format == "functiongemma":
conversation.append({ conversation.append({
"role": "tool", "role": "tool",
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results] "content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results],
"train_on_turn": False,
}) })
return { return {

View File

@@ -308,7 +308,7 @@ datasets:
- path: /workspace/data/datasets/sample.jsonl - path: /workspace/data/datasets/sample.jsonl
ds_type: json ds_type: json
type: chat_template type: chat_template
# roles_to_train: [ "assistant" ] roles_to_train: []
message_field_training: train_on_turn message_field_training: train_on_turn
dataset_prepared_path: /workspace/data/datasets/prepared/ dataset_prepared_path: /workspace/data/datasets/prepared/