mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-07 21:04:08 -05:00
manually set roles to train on
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Final
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
@@ -96,9 +97,18 @@ async def _async_update_listener(hass: HomeAssistant, entry: LocalLLMConfigEntry
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
|
||||
"""Unload Ollama."""
|
||||
"""Unload the integration."""
|
||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
return False
|
||||
|
||||
if entry.data[CONF_BACKEND_TYPE] == BACKEND_TYPE_LLAMA_CPP:
|
||||
# clean up any disk cache resources
|
||||
def cleanup_cache_dir():
|
||||
cache_dir = entry.data[CONF_CHAT_MODEL].strip().replace(" ", "_").lower()
|
||||
full_path = os.path.join(hass.config.media_dirs.get("local", hass.config.path("media")), "kv_cache", cache_dir)
|
||||
shutil.rmtree(full_path, ignore_errors=True)
|
||||
|
||||
await hass.async_add_executor_job(cleanup_cache_dir)
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
return True
|
||||
|
||||
|
||||
@@ -180,11 +180,13 @@ class LlamaCppClient(LocalLLMClient):
|
||||
_LOGGER.debug("Model loaded")
|
||||
|
||||
# create disk cache if enabled
|
||||
# cache must be per-model to avoid conflicts with different hidden state sizes
|
||||
cache_size = model_settings.get(CONF_LLAMACPP_CACHE_SIZE_MB, DEFAULT_LLAMACPP_CACHE_SIZE_MB)
|
||||
cache_dir = model_name.strip().replace(" ", "_").lower()
|
||||
if cache_size > 0:
|
||||
llm.set_cache(LlamaDiskCache(
|
||||
capacity_bytes=int(cache_size * (1024 ** 3)),
|
||||
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
capacity_bytes=int(cache_size * (1024 ** 2)), # MB to bytes
|
||||
cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache", cache_dir)
|
||||
))
|
||||
|
||||
if model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
||||
@@ -406,7 +408,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
max_tokens=1,
|
||||
grammar=grammar,
|
||||
stream=False,
|
||||
stop=["<end_of_turn>", "<end_function_call>"]
|
||||
# stop=["<end_of_turn>", "<end_function_call>"]
|
||||
)
|
||||
|
||||
self.last_cache_prime = time.time()
|
||||
@@ -480,7 +482,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
grammar=grammar,
|
||||
stream=True,
|
||||
response_format=response_format,
|
||||
stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
|
||||
# stop=["<end_of_turn>", "<end_function_call>"] # FIXME: make configurable (pull from tool end token?)
|
||||
)
|
||||
|
||||
def next_token() -> Generator[tuple[Optional[str], Optional[List]]]:
|
||||
|
||||
@@ -614,18 +614,21 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": sys_prompt}]
|
||||
"content": [{"type": "text", "text": sys_prompt}],
|
||||
"train_on_turn": False,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{ "type": "text", "text": question }]
|
||||
"content": [{ "type": "text", "text": question }],
|
||||
"train_on_turn": False,
|
||||
}
|
||||
]
|
||||
else:
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }]
|
||||
"content": [{ "type": "text", "text": "\n".join([ sys_prompt, question ]) }],
|
||||
"train_on_turn": False,
|
||||
}
|
||||
]
|
||||
|
||||
@@ -646,7 +649,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
||||
call_names.append(call_name)
|
||||
formatted_calls.append({
|
||||
"name": call_name,
|
||||
"arguments": json.dumps(tool_call["tool_args"])
|
||||
"arguments": json.dumps(tool_call["tool_args"]),
|
||||
})
|
||||
|
||||
if formatted_calls:
|
||||
@@ -679,12 +682,14 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_
|
||||
if tool_response_format == "text":
|
||||
conversation.append({
|
||||
"role": "tool",
|
||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results]
|
||||
"content": [{ "type": "text", "text": json.dumps(result) } for result in step_tool_results],
|
||||
"train_on_turn": False,
|
||||
})
|
||||
elif tool_response_format == "functiongemma":
|
||||
conversation.append({
|
||||
"role": "tool",
|
||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results]
|
||||
"content": [{ "name": result["tool_name"], "response": {"result": result["tool_result"]} } for result in step_tool_results],
|
||||
"train_on_turn": False,
|
||||
})
|
||||
|
||||
return {
|
||||
|
||||
@@ -308,7 +308,7 @@ datasets:
|
||||
- path: /workspace/data/datasets/sample.jsonl
|
||||
ds_type: json
|
||||
type: chat_template
|
||||
# roles_to_train: [ "assistant" ]
|
||||
roles_to_train: []
|
||||
message_field_training: train_on_turn
|
||||
|
||||
dataset_prepared_path: /workspace/data/datasets/prepared/
|
||||
|
||||
Reference in New Issue
Block a user