mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
manually set roles to train on
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
@@ -96,9 +97,18 @@ async def _async_update_listener(hass: HomeAssistant, entry: LocalLLMConfigEntry
|
|||||||
await hass.config_entries.async_reload(entry.entry_id)
|
await hass.config_entries.async_reload(entry.entry_id)
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
|
||||||
"""Unload Ollama."""
|
"""Unload the integration."""
|
||||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if entry.data[CONF_BACKEND_TYPE] == BACKEND_TYPE_LLAMA_CPP:
|
||||||
|
# clean up any disk cache resources
|
||||||
|
def cleanup_cache_dir():
|
||||||
|
cache_dir = entry.data[CONF_CHAT_MODEL].strip().replace(" ", "_").lower()
|
||||||
|
full_path = os.path.join(hass.config.media_dirs.get("local", hass.config.path("media")), "kv_cache", cache_dir)
|
||||||
|
shutil.rmtree(full_path, ignore_errors=True)
|
||||||
|
|
||||||
|
await hass.async_add_executor_job(cleanup_cache_dir)
|
||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
hass.data[DOMAIN].pop(entry.entry_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -180,11 +180,13 @@ class LlamaCppClient(LocalLLMClient):
|
|||||||
_LOGGER.debug("Model loaded")
|
_LOGGER.debug("Model loaded")
|
||||||
|
|
||||||
# create disk cache if enabled
|
# create disk cache if enabled
|
||||||
|
# cache must be per-model to avoid conflicts with different hidden state sizes
|
||||||
cache_size = model_settings.get(CONF_LLAMACPP_CACHE_SIZE_MB, DEFAULT_LLAMACPP_CACHE_SIZE_MB)
|
cache_size = model_settings.get(CONF_LLAMACPP_CACHE_SIZE_MB, DEFAULT_LLAMACPP_CACHE_SIZE_MB)
|
||||||
|
cache_dir = model_name.strip().replace(" ", "_").lower()
|
||||||
if cache_size > 0:
|
if cache_size > 0:
|
||||||
llm.set_cache(LlamaDiskCache(
|
llm.set_cache(LlamaDiskCache(
|
||||||
capacity_bytes=int(cache_size * (1024 ** 3)),
|
capacity_bytes=int(cache_size * (1024 ** 2)), # MB to bytes
|
||||||
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache", cache_dir)
|
||||||
))
|
))
|
||||||
|
|
||||||
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
||||||
@@ -406,7 +408,7 @@ class LlamaCppClient(LocalLLMClient):
|
|||||||
max_tokens=1,
|
max_tokens=1,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
stream=False,
|
stream=False,
|
||||||
stop=["<end_of_turn>", "<end_function_call>"]
|
# stop=["<end_of_turn>", "<end_function_call>"]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.last_cache_prime = time.time()
|
self.last_cache_prime = time.time()
|
||||||
@@ -480,7 +482,7 @@ class LlamaCppClient(LocalLLMClient):
|
|||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
stream=True,
|
stream=True,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
|
# stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
|
||||||
)
|
)
|
||||||
|
|
||||||
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
||||||
|
|||||||
@@ -614,18 +614,21 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
|||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": [{"type": "text", "text": sys_prompt}]
|
"content": [{"type": "text", "text": sys_prompt}],
|
||||||
|
"train_on_turn": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [{ "type": "text", "text": question }]
|
"content": [{ "type": "text", "text": question }],
|
||||||
|
"train_on_turn": False,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }],
|
||||||
|
"train_on_turn": False,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -646,7 +649,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
|||||||
call_names.append(call_name)
|
call_names.append(call_name)
|
||||||
formatted_calls.append({
|
formatted_calls.append({
|
||||||
"name": call_name,
|
"name": call_name,
|
||||||
"arguments": json.dumps(tool_call["tool_args"])
|
"arguments": json.dumps(tool_call["tool_args"]),
|
||||||
})
|
})
|
||||||
|
|
||||||
if formatted_calls:
|
if formatted_calls:
|
||||||
@@ -679,12 +682,14 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
|||||||
if tool_response_format == "text":
|
if tool_response_format == "text":
|
||||||
conversation.append({
|
conversation.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
|
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results],
|
||||||
|
"train_on_turn": False,
|
||||||
})
|
})
|
||||||
elif tool_response_format == "functiongemma":
|
elif tool_response_format == "functiongemma":
|
||||||
conversation.append({
|
conversation.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
|
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results],
|
||||||
|
"train_on_turn": False,
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ datasets:
|
|||||||
- path: /workspace/data/datasets/sample.jsonl
|
- path: /workspace/data/datasets/sample.jsonl
|
||||||
ds_type: json
|
ds_type: json
|
||||||
type: chat_template
|
type: chat_template
|
||||||
# roles_to_train: [ "assistant" ]
|
roles_to_train: []
|
||||||
message_field_training: train_on_turn
|
message_field_training: train_on_turn
|
||||||
|
|
||||||
dataset_prepared_path: /workspace/data/datasets/prepared/
|
dataset_prepared_path: /workspace/data/datasets/prepared/
|
||||||
|
|||||||
Reference in New Issue
Block a user