mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
16
.github/workflows/create-release.yml
vendored
16
.github/workflows/create-release.yml
vendored
@@ -20,33 +20,33 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
# ARM variants
|
||||
- home_assistant_version: "2024.12.3"
|
||||
- home_assistant_version: "2025.4.1"
|
||||
arch: "aarch64"
|
||||
- home_assistant_version: "2024.12.3"
|
||||
- home_assistant_version: "2025.4.1"
|
||||
arch: "armhf"
|
||||
|
||||
# Base x86
|
||||
- home_assistant_version: "2024.12.3"
|
||||
- home_assistant_version: "2025.4.1"
|
||||
suffix: "-noavx"
|
||||
arch: "amd64"
|
||||
extra_defines: "-DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DGGML_F16C=OFF"
|
||||
- home_assistant_version: "2024.12.3"
|
||||
- home_assistant_version: "2025.4.1"
|
||||
arch: "i386"
|
||||
suffix: "-noavx"
|
||||
extra_defines: "-DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DGGML_F16C=OFF"
|
||||
|
||||
# AVX2 and AVX512
|
||||
- home_assistant_version: "2024.12.3"
|
||||
- home_assistant_version: "2025.4.1"
|
||||
arch: "amd64"
|
||||
extra_defines: "-DGGML_AVX=ON -DGGML_AVX2=ON -DGGML_FMA=ON -DGGML_F16C=ON"
|
||||
- home_assistant_version: "2024.12.3"
|
||||
- home_assistant_version: "2025.4.1"
|
||||
arch: "amd64"
|
||||
suffix: "-avx512"
|
||||
extra_defines: "-DGGML_AVX512=ON -DGGML_FMA=ON -DGGML_F16C=ON"
|
||||
- home_assistant_version: "2024.12.3"
|
||||
- home_assistant_version: "2025.4.1"
|
||||
arch: "i386"
|
||||
extra_defines: "-DGGML_AVX=ON -DGGML_AVX2=ON -DGGML_FMA=ON -DGGML_F16C=ON"
|
||||
- home_assistant_version: "2024.12.3"
|
||||
- home_assistant_version: "2025.4.1"
|
||||
arch: "i386"
|
||||
suffix: "-avx512"
|
||||
extra_defines: "-DGGML_AVX512=ON -DGGML_FMA=ON -DGGML_F16C=ON"
|
||||
|
||||
@@ -5,7 +5,7 @@ This project provides the required "glue" components to control your Home Assist
|
||||
Please see the [Setup Guide](./docs/Setup.md) for more information on installation.
|
||||
|
||||
## Local LLM Conversation Integration
|
||||
**The latest version of this integration requires Home Assistant 2024.12.3 or newer**
|
||||
**The latest version of this integration requires Home Assistant 2025.4.1 or newer**
|
||||
|
||||
In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent".
|
||||
|
||||
@@ -150,6 +150,7 @@ In order to facilitate running the project entirely on the system where Home Ass
|
||||
## Version History
|
||||
| Version | Description |
|
||||
|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| v0.3.8 | Update llama.cpp, remove think blocks from "thinking" models, fix wheel detection for some Intel CPUs, Fixes for compatibility with latest Home Assistant version (2025.4), other small bug fixes |
|
||||
| v0.3.7 | Update llama.cpp version to support newer models, Update minimum Home Assistant version to 2024.12.3, Add German In-Context Learning examples, Fix multi-turn use, Fix an issue with webcolors |
|
||||
| v0.3.6 | Small llama.cpp backend fixes |
|
||||
| v0.3.5 | Fix for llama.cpp backend installation, Fix for Home LLM v1-3 API parameters, add Polish ICL examples |
|
||||
|
||||
@@ -8,19 +8,31 @@ import homeassistant.components.conversation as ha_conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, llm
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
||||
from .const import (
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DOMAIN,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
CONF_BACKEND_TYPE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
)
|
||||
from .conversation import LlamaCppAgent, GenericOpenAIAPIAgent, TextGenerationWebuiAgent, \
|
||||
LlamaCppPythonAPIAgent, OllamaAPIAgent, LocalLLMAgent
|
||||
|
||||
type LocalLLMConfigEntry = ConfigEntry[LocalLLMAgent]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +41,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
|
||||
|
||||
# make sure the API is registered
|
||||
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]):
|
||||
@@ -37,18 +49,43 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
|
||||
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
|
||||
agent_cls = LlamaCppAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
agent_cls = TextGenerationWebuiAgent
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
agent_cls = LlamaCppPythonAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
agent_cls = OllamaAPIAgent
|
||||
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# create the agent in an executor job because the constructor calls `open()`
|
||||
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
entry.runtime_data = await hass.async_add_executor_job(create_agent, backend_type)
|
||||
|
||||
# call load model
|
||||
await entry.runtime_data._async_load_model(entry)
|
||||
|
||||
# forward setup to platform to register the entity
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
|
||||
"""Unload Ollama."""
|
||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
return False
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
return True
|
||||
|
||||
async def async_migrate_entry(hass, config_entry: ConfigEntry):
|
||||
async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigEntry):
|
||||
"""Migrate old entry."""
|
||||
_LOGGER.debug("Migrating from version %s", config_entry.version)
|
||||
|
||||
@@ -82,13 +119,8 @@ class HassServiceTool(llm.Tool):
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
|
||||
ALLOWED_SERVICES: Final[list[str]] = [
|
||||
"turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover",
|
||||
"lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item"
|
||||
]
|
||||
ALLOWED_DOMAINS: Final[list[str]] = [
|
||||
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
|
||||
]
|
||||
ALLOWED_SERVICES: Final[list[str]] = SERVICE_TOOL_ALLOWED_SERVICES
|
||||
ALLOWED_DOMAINS: Final[list[str]] = SERVICE_TOOL_ALLOWED_DOMAINS
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
|
||||
|
||||
@@ -687,6 +687,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
for key in OPTIONS_OVERRIDES.keys():
|
||||
if key in model_name:
|
||||
selected_default_options.update(OPTIONS_OVERRIDES[key])
|
||||
break
|
||||
|
||||
persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en"))
|
||||
current_date = CURRENT_DATE_PROMPT.get(self.selected_language, CURRENT_DATE_PROMPT.get("en"))
|
||||
@@ -765,15 +766,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
config_entry: config_entries.ConfigEntry,
|
||||
) -> config_entries.OptionsFlow:
|
||||
"""Create the options flow."""
|
||||
return OptionsFlow(config_entry)
|
||||
return OptionsFlow()
|
||||
|
||||
|
||||
class OptionsFlow(config_entries.OptionsFlow):
|
||||
"""Local LLM config flow options handler."""
|
||||
|
||||
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
self.config_entry = config_entry
|
||||
@property
|
||||
def config_entry(self):
|
||||
return self.hass.config_entries.async_get_entry(self.handler)
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
|
||||
@@ -4,6 +4,8 @@ import types, os
|
||||
DOMAIN = "llama_conversation"
|
||||
HOME_LLM_API_ID = "home-llm-service-api"
|
||||
SERVICE_TOOL_NAME = "HassCallService"
|
||||
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
|
||||
CONF_PROMPT = "prompt"
|
||||
PERSONA_PROMPTS = {
|
||||
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",
|
||||
@@ -13,7 +15,7 @@ PERSONA_PROMPTS = {
|
||||
"pl": "Jeste\u015b 'Al', pomocnym asystentem AI, kt\u00f3ry kontroluje urz\u0105dzenia w domu. Wykonaj poni\u017csze zadanie zgodnie z instrukcj\u0105 lub odpowiedz na poni\u017csze pytanie, korzystaj\u0105c wy\u0142\u0105cznie z podanych informacji."
|
||||
}
|
||||
CURRENT_DATE_PROMPT = {
|
||||
"en": """The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}""",
|
||||
"en": """The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", True, "")) }}""",
|
||||
"de": """{% set day_name = ["Montag", "Dienstag", "Mittwoch", "Donnerstag", "Freitag", "Samstag", "Sonntag"] %}{% set month_name = ["Januar", "Februar", "März", "April", "Mai", "Juni", "Juli", "August", "September", "Oktober", "November", "Dezember"] %}Die aktuelle Uhrzeit und das aktuelle Datum sind {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""",
|
||||
"fr": """{% set day_name = ["lundi", "mardi", "mercredi", "jeudi", "vendredi", "samedi", "dimanche"] %}{% set month_name = ["janvier", "février", "mars", "avril", "mai", "juin", "juillet", "août", "septembre", "octobre", "novembre", "décembre"] %} L'heure et la date actuelles sont {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""",
|
||||
"es": """{% set day_name = ["lunes", "martes", "miércoles", "jueves", "viernes", "sábado", "domingo"] %}{% set month_name = ["enero", "febrero", "marzo", "abril", "mayo", "junio", "julio", "agosto", "septiembre", "octubre", "noviembre", "diciembre"] %}La hora y fecha actuales son {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} de {{ month_name[now().month -1]}} de {{ now().year }}.""",
|
||||
@@ -146,59 +148,69 @@ PROMPT_TEMPLATE_DESCRIPTIONS = {
|
||||
"user": { "prefix": "<|im_start|>user\n", "suffix": "<|im_end|>" },
|
||||
"assistant": { "prefix": "<|im_start|>assistant\n", "suffix": "<|im_end|>" },
|
||||
"tool": { "prefix": "<|im_start|>tool", "suffix": "<|im_end|>" },
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": "<|im_start|>assistant"
|
||||
},
|
||||
PROMPT_TEMPLATE_COMMAND_R: {
|
||||
"system": { "prefix": "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", "suffix": "<|END_OF_TURN_TOKEN|>" },
|
||||
"user": { "prefix": "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "suffix": "<|END_OF_TURN_TOKEN|>" },
|
||||
"assistant": { "prefix": "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", "suffix": "<|END_OF_TURN_TOKEN|>" },
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
},
|
||||
PROMPT_TEMPLATE_ALPACA: {
|
||||
"system": { "prefix": "", "suffix": "\n" },
|
||||
"user": { "prefix": "### Instruction:\n", "suffix": "\n" },
|
||||
"assistant": { "prefix": "### Response:\n", "suffix": "\n" },
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": "### Response:"
|
||||
},
|
||||
PROMPT_TEMPLATE_VICUNA: {
|
||||
"system": { "prefix": "", "suffix": "\n" },
|
||||
"user": { "prefix": "USER: ", "suffix": "" },
|
||||
"assistant": { "prefix": "ASSISTANT: ", "suffix": "</s>" },
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": "ASSISTANT:"
|
||||
},
|
||||
PROMPT_TEMPLATE_NONE: {
|
||||
"system": { "prefix": "", "suffix": "" },
|
||||
"user": { "prefix": "", "suffix": "" },
|
||||
"assistant": { "prefix": "", "suffix": "" },
|
||||
"chain_of_thought": { "prefix": "", "suffix": ""},
|
||||
"generation_prompt": ""
|
||||
},
|
||||
PROMPT_TEMPLATE_MISTRAL: {
|
||||
"user": { "prefix": "<s>[INST] ", "suffix": " [/INST] " },
|
||||
"assistant": { "prefix": "", "suffix": "</s>" },
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": ""
|
||||
},
|
||||
PROMPT_TEMPLATE_ZEPHYR: {
|
||||
"system": { "prefix": "<|system|>\n", "suffix": "<|endoftext|>" },
|
||||
"user": { "prefix": "<|user|>\n", "suffix": "<|endoftext|>" },
|
||||
"assistant": { "prefix": "<|assistant|>\n", "suffix": "<|endoftext|>" },
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": "<|assistant|>\n"
|
||||
},
|
||||
PROMPT_TEMPLATE_ZEPHYR2: {
|
||||
"system": { "prefix": "<|system|>\n", "suffix": "</s>" },
|
||||
"user": { "prefix": "<|user|>\n", "suffix": "</s>" },
|
||||
"assistant": { "prefix": "<|assistant|>\n", "suffix": "</s>" },
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": "<|assistant|>\n"
|
||||
},
|
||||
PROMPT_TEMPLATE_ZEPHYR3: {
|
||||
"system": { "prefix": "<|system|>\n", "suffix": "<|end|>" },
|
||||
"user": { "prefix": "<|user|>\n", "suffix": "<|end|>" },
|
||||
"assistant": { "prefix": "<|assistant|>\n", "suffix": "<|end|>" },
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": "<|assistant|>\n"
|
||||
},
|
||||
PROMPT_TEMPLATE_LLAMA3: {
|
||||
"system": { "prefix": "<|start_header_id|>system<|end_header_id|>\n\n", "suffix": "<|eot_id|>"},
|
||||
"user": { "prefix": "<|start_header_id|>user<|end_header_id|>\n\n", "suffix": "<|eot_id|>"},
|
||||
"assistant": { "prefix": "<|start_header_id|>assistant<|end_header_id|>\n\n", "suffix": "<|eot_id|>"},
|
||||
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
|
||||
"generation_prompt": "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
}
|
||||
}
|
||||
@@ -297,6 +309,14 @@ DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
)
|
||||
|
||||
OPTIONS_OVERRIDES = {
|
||||
"home-llama-3.2": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
|
||||
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_LLAMA3,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
|
||||
CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX,
|
||||
CONF_TOOL_FORMAT: TOOL_FORMAT_MINIMAL,
|
||||
CONF_CONTEXT_LENGTH: 131072,
|
||||
},
|
||||
"home-3b-v3": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
|
||||
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR,
|
||||
@@ -383,5 +403,5 @@ OPTIONS_OVERRIDES = {
|
||||
},
|
||||
}
|
||||
|
||||
INTEGRATION_VERSION = "0.3.7"
|
||||
INTEGRATION_VERSION = "0.3.8"
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.3.5"
|
||||
@@ -16,7 +16,7 @@ import voluptuous as vol
|
||||
from typing import Literal, Any, Callable
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent, ConversationEntity
|
||||
from homeassistant.components import assist_pipeline, conversation as ha_conversation
|
||||
from homeassistant.components import assist_pipeline, conversation as conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -24,7 +24,7 @@ from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL,
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, \
|
||||
area_registry as ar, device_registry as dr
|
||||
area_registry as ar, device_registry as dr, chat_session
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
@@ -121,14 +121,10 @@ from .const import (
|
||||
TOOL_FORMAT_REDUCED,
|
||||
TOOL_FORMAT_MINIMAL,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
CONF_BACKEND_TYPE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
)
|
||||
|
||||
# make type checking work for llama-cpp-python without importing it directly at runtime
|
||||
@@ -147,7 +143,7 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
# call update handler
|
||||
agent: LocalLLMAgent = ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
agent: LocalLLMAgent = entry.runtime_data
|
||||
await hass.async_add_executor_job(agent._update_options)
|
||||
|
||||
return True
|
||||
@@ -155,42 +151,50 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback) -> bool:
|
||||
"""Set up Local LLM Conversation from a config entry."""
|
||||
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
|
||||
agent_cls = LlamaCppAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
agent_cls = TextGenerationWebuiAgent
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
agent_cls = LlamaCppPythonAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
agent_cls = OllamaAPIAgent
|
||||
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# create the agent in an executor job because the constructor calls `open()`
|
||||
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
agent = await hass.async_add_executor_job(create_agent, backend_type)
|
||||
|
||||
# call load model
|
||||
await agent._async_load_model(entry)
|
||||
|
||||
# handle updates to the options
|
||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
|
||||
async_add_entities([agent])
|
||||
# register the agent entity
|
||||
async_add_entities([entry.runtime_data])
|
||||
|
||||
return True
|
||||
|
||||
def _convert_content(
|
||||
chat_content: conversation.Content
|
||||
) -> dict[str, str]:
|
||||
"""Create tool response content."""
|
||||
role_name = None
|
||||
if isinstance(chat_content, conversation.ToolResultContent):
|
||||
role_name = "tool"
|
||||
elif isinstance(chat_content, conversation.AssistantContent):
|
||||
role_name = "assistant"
|
||||
elif isinstance(chat_content, conversation.UserContent):
|
||||
role_name = "user"
|
||||
elif isinstance(chat_content, conversation.SystemContent):
|
||||
role_name = "system"
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
||||
|
||||
return { "role": role_name, "message": chat_content.content }
|
||||
|
||||
def _convert_content_back(
|
||||
agent_id: str,
|
||||
message_history_entry: dict[str, str]
|
||||
) -> conversation.Content:
|
||||
if message_history_entry["role"] == "tool":
|
||||
return conversation.ToolResultContent(content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "assistant":
|
||||
return conversation.AssistantContent(agent_id=agent_id, content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "user":
|
||||
return conversation.UserContent(content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "system":
|
||||
return conversation.SystemContent(content=message_history_entry["message"])
|
||||
|
||||
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
"""Base Local LLM conversation agent."""
|
||||
|
||||
hass: HomeAssistant
|
||||
entry_id: str
|
||||
history: dict[str, list[dict]]
|
||||
in_context_examples: list[dict]
|
||||
|
||||
_attr_has_entity_name = True
|
||||
@@ -202,7 +206,6 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
self.hass = hass
|
||||
self.entry_id = entry.entry_id
|
||||
self.history = {}
|
||||
|
||||
self.backend_type = entry.data.get(
|
||||
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
|
||||
@@ -210,7 +213,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
ha_conversation.ConversationEntityFeature.CONTROL
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
self.in_context_examples = None
|
||||
@@ -223,11 +226,11 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
assist_pipeline.async_migrate_engine(
|
||||
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
||||
)
|
||||
ha_conversation.async_set_agent(self.hass, self.entry, self)
|
||||
conversation.async_set_agent(self.hass, self.entry, self)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""When entity will be removed from Home Assistant."""
|
||||
ha_conversation.async_unset_agent(self.hass, self.entry)
|
||||
conversation.async_unset_agent(self.hass, self.entry)
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
def _load_icl_examples(self, filename: str):
|
||||
@@ -253,7 +256,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
def _update_options(self):
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
ha_conversation.ConversationEntityFeature.CONTROL
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
@@ -304,6 +307,19 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
self, user_input: ConversationInput
|
||||
) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||
):
|
||||
return await self._async_handle_message(user_input, chat_log)
|
||||
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> conversation.ConversationResult:
|
||||
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
||||
@@ -323,8 +339,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, there was a problem compiling the service call regex: {err}",
|
||||
)
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
llm_api: llm.APIInstance | None = None
|
||||
@@ -338,7 +355,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=ha_conversation.DOMAIN,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
)
|
||||
@@ -353,14 +370,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
if user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
conversation = self.history[conversation_id] if remember_conversation else [self.history[conversation_id][0]]
|
||||
else:
|
||||
conversation_id = ulid.ulid()
|
||||
conversation = []
|
||||
message_history = [ _convert_content(content) for content in chat_log.content ]
|
||||
|
||||
if len(conversation) == 0 or refresh_system_prompt:
|
||||
# re-generate prompt if necessary
|
||||
if len(message_history) == 0 or refresh_system_prompt:
|
||||
try:
|
||||
message = self._generate_system_prompt(raw_prompt, llm_api)
|
||||
except TemplateError as err:
|
||||
@@ -371,24 +384,20 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"Sorry, I had a problem with my template: {err}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
system_prompt = { "role": "system", "message": message }
|
||||
|
||||
if len(conversation) == 0:
|
||||
conversation.append(system_prompt)
|
||||
if not remember_conversation:
|
||||
self.history[conversation_id] = conversation
|
||||
if len(message_history) == 0:
|
||||
message_history.append(system_prompt)
|
||||
else:
|
||||
conversation[0] = system_prompt
|
||||
|
||||
conversation.append({"role": "user", "message": user_input.text})
|
||||
message_history[0] = system_prompt
|
||||
|
||||
# generate a response
|
||||
try:
|
||||
_LOGGER.debug(conversation)
|
||||
response = await self._async_generate(conversation)
|
||||
_LOGGER.debug(message_history)
|
||||
response = await self._async_generate(message_history)
|
||||
_LOGGER.debug(response)
|
||||
|
||||
except Exception as err:
|
||||
@@ -400,25 +409,28 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# remove end of text token if it was returned
|
||||
response = response.replace(template_desc["assistant"]["suffix"], "")
|
||||
|
||||
conversation.append({"role": "assistant", "message": response})
|
||||
# remove think blocks
|
||||
response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL)
|
||||
|
||||
message_history.append({"role": "assistant", "message": response})
|
||||
if remember_conversation:
|
||||
if remember_num_interactions and len(conversation) > (remember_num_interactions * 2) + 1:
|
||||
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
|
||||
for i in range(0,2):
|
||||
conversation.pop(1)
|
||||
self.history[conversation_id] = conversation
|
||||
message_history.pop(1)
|
||||
# chat_log.content = [_convert_content_back(user_input.agent_id, message_history_entry) for message_history_entry in message_history ]
|
||||
|
||||
if llm_api is None:
|
||||
# return the output without messing with it if there is no API exposed to the model
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response.strip())
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
tool_response = None
|
||||
@@ -459,7 +471,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
_LOGGER.info(f"calling tool: {block}")
|
||||
@@ -503,20 +515,20 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"I'm sorry! I encountered an error calling the tool. See the logs for more info.",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# handle models that generate a function call and wait for the result before providing a response
|
||||
if self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT) and tool_response is not None:
|
||||
try:
|
||||
conversation.append({"role": "tool", "message": json.dumps(tool_response)})
|
||||
message_history.append({"role": "tool", "message": json.dumps(tool_response)})
|
||||
except:
|
||||
conversation.append({"role": "tool", "message": "No tools were used in this response."})
|
||||
message_history.append({"role": "tool", "message": "No tools were used in this response."})
|
||||
|
||||
# generate a response based on the tool result
|
||||
try:
|
||||
_LOGGER.debug(conversation)
|
||||
to_say = await self._async_generate(conversation)
|
||||
_LOGGER.debug(message_history)
|
||||
to_say = await self._async_generate(message_history)
|
||||
_LOGGER.debug(to_say)
|
||||
|
||||
except Exception as err:
|
||||
@@ -528,17 +540,17 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
conversation.append({"role": "assistant", "message": response})
|
||||
conversation.append({"role": "assistant", "message": to_say})
|
||||
message_history.append({"role": "assistant", "message": response})
|
||||
message_history.append({"role": "assistant", "message": to_say})
|
||||
|
||||
# generate intent response to Home Assistant
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(to_say.strip())
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
@@ -813,6 +825,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
all_services = []
|
||||
scripts_added = False
|
||||
for domain in domains:
|
||||
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
|
||||
continue
|
||||
|
||||
# scripts show up as individual services
|
||||
if domain == "script" and not scripts_added:
|
||||
all_services.extend([
|
||||
@@ -825,6 +840,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
continue
|
||||
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
|
||||
continue
|
||||
|
||||
args = flatten_vol_schema(service.schema)
|
||||
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
|
||||
service_schema = vol.Schema({
|
||||
@@ -1225,7 +1243,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
|
||||
|
||||
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
api_base_path = self.entry.options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/chat/completions"
|
||||
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
||||
@@ -1234,7 +1252,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
api_base_path = self.entry.options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/completions"
|
||||
request_params["prompt"] = self._format_prompt(conversation)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"domain": "llama_conversation",
|
||||
"name": "Local LLM Conversation",
|
||||
"version": "0.3.7",
|
||||
"version": "0.3.8",
|
||||
"codeowners": ["@acon96"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
|
||||
@@ -166,13 +166,16 @@ def install_llama_cpp_python(config_dir: str):
|
||||
return True
|
||||
|
||||
platform_suffix = platform.machine()
|
||||
# remap other names for architectures to the names we use
|
||||
if platform_suffix == "arm64":
|
||||
platform_suffix = "aarch64"
|
||||
if platform_suffix == "i386" or platform_suffix == "x86_64":
|
||||
platform_suffix = "amd64"
|
||||
|
||||
runtime_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||
|
||||
instruction_extensions_suffix = ""
|
||||
if platform_suffix == "amd64" or platform_suffix == "i386":
|
||||
if platform_suffix == "amd64":
|
||||
instruction_extensions_suffix = "-noavx"
|
||||
|
||||
try:
|
||||
|
||||
9
custom_components/requirements-dev.txt
Normal file
9
custom_components/requirements-dev.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# types from Home Assistant
|
||||
homeassistant>=2024.6.1
|
||||
hassil
|
||||
home-assistant-intents
|
||||
|
||||
# testing requirements
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-homeassistant-custom-component
|
||||
2
custom_components/requirements.txt
Normal file
2
custom_components/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
huggingface-hub>=0.23.0
|
||||
webcolors>=24.8.0
|
||||
@@ -569,7 +569,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
)
|
||||
|
||||
question = question_template.replace("<device_name>", chosen_devices[0]["description"])
|
||||
answer = answer_template.replace("<device_name>", chosen_devices[0]["description"])
|
||||
answer = [ answer_template.replace("<device_name>", chosen_devices[0]["description"]) ]
|
||||
else:
|
||||
question = question_template
|
||||
answers = []
|
||||
@@ -583,7 +583,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
|
||||
)
|
||||
answers.append(answer.replace(f"<device_name>", chosen_devices[i]["description"]))
|
||||
|
||||
answer = []
|
||||
answer: list[str] = []
|
||||
for word in and_words:
|
||||
answer.append(f" {word} ".join(answers))
|
||||
|
||||
@@ -1151,7 +1151,7 @@ def load_dataset_piles(language):
|
||||
# TODO: answer questions about more than one thing in the state list at once
|
||||
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
|
||||
# TODO: add time, weather, and calendar/reminders (next 3 events?)
|
||||
def main():
|
||||
def main(args=None):
|
||||
parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles")
|
||||
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
|
||||
@@ -1171,7 +1171,7 @@ def main():
|
||||
dataset_format_group.add_argument('--raw_corpus', action='store_const', const='raw', dest='format')
|
||||
dataset_format_group.add_argument('--sharegpt', action='store_const', const='sharegpt', dest='format')
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args(args=args)
|
||||
|
||||
if not args.sample and not args.train and not args.test and not args.merge and not args.dpo:
|
||||
parser.print_usage()
|
||||
|
||||
6
data/requirements.txt
Normal file
6
data/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
datasets>=3.2.0
|
||||
webcolors>=1.13
|
||||
pandas>=2.2.3
|
||||
deep-translator>=1.11.4
|
||||
langcodes>=3.5.0
|
||||
babel==2.15.0
|
||||
@@ -6,7 +6,7 @@ This integration allows for full customization of the system prompt using Home A
|
||||
The default system prompt for non-fine tuned models is:
|
||||
```
|
||||
You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.
|
||||
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}
|
||||
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", True, "")) }}
|
||||
Tools: {{ tools | to_json }}
|
||||
Devices:
|
||||
{% for device in devices | selectattr('area_id', 'none'): %}
|
||||
|
||||
@@ -230,6 +230,22 @@ python3 train.py \
|
||||
--save_steps 50 --save_total_limit 10 --eval_steps 100 --logging_steps 2
|
||||
```
|
||||
|
||||
#### Llama 3.2 3B Instruct
|
||||
```
|
||||
python3 generate_home_assistant_data.py --train --test --large --sharegpt --language english german french spanish
|
||||
|
||||
python3 train.py \
|
||||
--run_name Home-Llama-3.2-3B-rev1 \
|
||||
--base_model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--bf16 \
|
||||
--train_dataset data/home_assistant_train.jsonl \
|
||||
--test_dataset data/home_assistant_test.jsonl \
|
||||
--learning_rate 1e-5 --learning_rate_warmup 0.03 --batch_size 64 --epochs 1 \
|
||||
--micro_batch_size 2 \
|
||||
--ctx_size 2048 \
|
||||
--save_steps 200 --save_total_limit 3 --eval_steps 100 --logging_steps 2
|
||||
```
|
||||
|
||||
### Problems
|
||||
|
||||
Training a model is not an easy thing. Therefore, we are not able to cover all the problems encountered during training. Here we will try to add known problems and solutions on how to deal with them.
|
||||
|
||||
23
evaluate.py
23
evaluate.py
@@ -83,7 +83,8 @@ def generate(model, tokenizer, prompts):
|
||||
return text
|
||||
|
||||
def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_size, use_icl):
|
||||
split = trained_tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")[0].replace(trained_tokenizer.bos_token, "")
|
||||
# split = trained_tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")[0].replace(trained_tokenizer.bos_token, "")
|
||||
split = "<|start_header_id|>assistant<|end_header_id|>"
|
||||
|
||||
print("Evaluating...")
|
||||
correct_answers = 0
|
||||
@@ -138,7 +139,7 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
|
||||
output = generate(trained_model, trained_tokenizer, prompts)
|
||||
|
||||
for model_output, expected_response in zip(output, expected_responses):
|
||||
response = model_output.replace(trained_tokenizer.pad_token, "").replace(trained_tokenizer.eos_token, "").split(split)[1]
|
||||
response = model_output.replace(trained_tokenizer.pad_token, "").replace(trained_tokenizer.eos_token, "").split(split)[1].strip()
|
||||
|
||||
expected_service_calls = []
|
||||
|
||||
@@ -279,9 +280,19 @@ def load_model(model_name, is_lora, is_hf, load_in_8bit, checkpoint_name):
|
||||
padding_side='left',
|
||||
)
|
||||
|
||||
eos_token_id_to_use = trained_model.config.eos_token_id
|
||||
if len(eos_token_id_to_use) > 0:
|
||||
eos_token_id_to_use = trained_model.config.eos_token_id[0]
|
||||
|
||||
pad_token_id_to_use = trained_model.config.pad_token_id
|
||||
if not trained_tokenizer.pad_token:
|
||||
trained_tokenizer.pad_token = trained_tokenizer.eos_token
|
||||
|
||||
if len(trained_model.config.eos_token_id) > 0:
|
||||
pad_token_id_to_use = trained_model.config.eos_token_id[0]
|
||||
else:
|
||||
pad_token_id_to_use = trained_model.config.eos_token_id
|
||||
|
||||
trained_model.generation_config = GenerationConfig(
|
||||
max_new_tokens=128,
|
||||
use_cache=True,
|
||||
@@ -290,9 +301,9 @@ def load_model(model_name, is_lora, is_hf, load_in_8bit, checkpoint_name):
|
||||
top_k=40,
|
||||
top_p=1.0,
|
||||
repetition_penalty=1.15,
|
||||
# eos_token_id=trained_model.config.eos_token_id,
|
||||
eos_token_id=128009,
|
||||
pad_token_id=trained_model.config.pad_token_id if trained_model.config.pad_token_id else trained_model.config.eos_token_id,
|
||||
eos_token_id=trained_model.config.eos_token_id,
|
||||
# eos_token_id=128009,
|
||||
pad_token_id=pad_token_id_to_use,
|
||||
)
|
||||
|
||||
return trained_model, trained_tokenizer
|
||||
@@ -350,7 +361,7 @@ def main():
|
||||
print(f"Evaluation already exists for {output_folder}. Skipping...")
|
||||
continue
|
||||
|
||||
trained_model, trained_tokenizer = load_model(args.model, args.lora, ckpt, False)
|
||||
trained_model, trained_tokenizer = load_model(args.model, args.lora, False, False, ckpt)
|
||||
evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_size, False)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "Local LLM Conversation",
|
||||
"homeassistant": "2024.12.3",
|
||||
"homeassistant": "2025.4.1",
|
||||
"content_in_root": false,
|
||||
"render_readme": true
|
||||
}
|
||||
|
||||
@@ -1,28 +1,13 @@
|
||||
# training + dataset requirements
|
||||
transformers
|
||||
tensorboard
|
||||
datasets
|
||||
peft
|
||||
bitsandbytes
|
||||
trl
|
||||
webcolors
|
||||
pandas
|
||||
transformers>=4.48.3
|
||||
tensorboard>=2.18.0
|
||||
datasets>=3.2.0
|
||||
peft>=0.14.0
|
||||
bitsandbytes>=0.45.2
|
||||
trl>=0.14.0
|
||||
webcolors>=1.13
|
||||
pandas>=2.2.3
|
||||
# flash-attn
|
||||
sentencepiece
|
||||
deep-translator
|
||||
langcodes
|
||||
babel==2.15.0
|
||||
|
||||
# integration requirements
|
||||
huggingface-hub>=0.23.0
|
||||
webcolors>=24.8.0
|
||||
|
||||
# types from Home Assistant
|
||||
homeassistant>=2024.6.1
|
||||
hassil
|
||||
home-assistant-intents
|
||||
|
||||
# testing requirements
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-homeassistant-custom-component
|
||||
sentencepiece>=0.2.0
|
||||
deep-translator>=1.11.4
|
||||
langcodes>=3.5.0
|
||||
babel>=2.15.0
|
||||
|
||||
5
scripts/convert_and_quantize.sh
Normal file → Executable file
5
scripts/convert_and_quantize.sh
Normal file → Executable file
@@ -11,8 +11,7 @@ fi
|
||||
|
||||
echo "Converting to GGUF..."
|
||||
if [ ! -f "./models/$MODEL_NAME/$MODEL_NAME.f16.gguf" ]; then
|
||||
$LLAMA_CPP/convert.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
|
||||
# $LLAMA_CPP/convert-hf-to-gguf.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
|
||||
$LLAMA_CPP/convert_hf_to_gguf.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
|
||||
else
|
||||
echo "Converted model for already exists. Skipping..."
|
||||
fi
|
||||
@@ -23,7 +22,7 @@ for QUANT in "${DESIRED_QUANTS[@]}"
|
||||
do
|
||||
QUANT_LOWER=$(echo "$QUANT" | awk '{print tolower($0)}')
|
||||
if [ ! -f "./models/$MODEL_NAME/$MODEL_NAME.$QUANT_LOWER.gguf" ]; then
|
||||
$LLAMA_CPP/build/bin/quantize ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf ./models/$MODEL_NAME/$MODEL_NAME.$QUANT_LOWER.gguf $QUANT
|
||||
$LLAMA_CPP/build/bin/llama-quantize ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf ./models/$MODEL_NAME/$MODEL_NAME.$QUANT_LOWER.gguf $QUANT
|
||||
else
|
||||
echo "Quantized model for '$QUANT' already exists. Skipping..."
|
||||
fi
|
||||
|
||||
111
train.ipynb
Normal file
111
train.ipynb
Normal file
@@ -0,0 +1,111 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "469a9a97-0f6b-475f-8aef-a796c1c5244f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -r requirements.txt\n",
|
||||
"\n",
|
||||
"import os, re\n",
|
||||
"from train import TrainingRunArguments, do_training_run\n",
|
||||
"\n",
|
||||
"def get_next_run_name(model):\n",
|
||||
" pattern = re.compile(model + r\"-rev(\\d+)$\")\n",
|
||||
" max_rev = 0\n",
|
||||
"\n",
|
||||
" for folder in os.listdir(\"models/\"):\n",
|
||||
" match = pattern.search(folder)\n",
|
||||
" if match:\n",
|
||||
" max_rev = max(max_rev, int(match.group(1)))\n",
|
||||
"\n",
|
||||
" return f\"{model}-rev{max_rev + 1}\"\n",
|
||||
"\n",
|
||||
"os.environ[\"HF_HOME\"] = \"/workspace/\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed0807bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Generate Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aaafce74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -r data/requirements.txt\n",
|
||||
"from data.generate_home_assistant_data import main as generate_data\n",
|
||||
"\n",
|
||||
"generate_data([\"--train\", \"--test\", \"--large\", \"--sharegpt\", \"--language\", \"english\", \"german\", \"french\", \"spanish\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ff011772",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Llama 3.2 1B"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "48839ce2-1939-4d7f-817c-97b047bafd42",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# python3 train.py \\\n",
|
||||
"# --run_name Home-Llama-3.2-1B-rev1 \\\n",
|
||||
"# --base_model meta-llama/Llama-3.2-1B-Instruct \\\n",
|
||||
"# --bf16 \\\n",
|
||||
"# --train_dataset data/home_assistant_train.jsonl \\\n",
|
||||
"# --test_dataset data/home_assistant_test.jsonl \\\n",
|
||||
"# --learning_rate 2e-5 --learning_rate_warmup 0.03 --batch_size 64 --epochs 1 \\\n",
|
||||
"# --micro_batch_size 2 \\\n",
|
||||
"# --ctx_size 2048 \\\n",
|
||||
"# --save_steps 200 --save_total_limit 1 --eval_steps 200 --logging_steps 2\n",
|
||||
"\n",
|
||||
"do_training_run(TrainingRunArguments(\n",
|
||||
" run_name=get_next_run_name(\"Home-Llama-3.2-1B\"),\n",
|
||||
" base_model=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
|
||||
" bf16=True,\n",
|
||||
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
|
||||
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
|
||||
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
|
||||
" batch_size=64, micro_batch_size=2, epochs=1,\n",
|
||||
" ctx_size=2048,\n",
|
||||
" save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n",
|
||||
"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
638
train.py
638
train.py
@@ -6,17 +6,17 @@ import torch
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import shutil
|
||||
import traceback
|
||||
from torch.utils.data import SequentialSampler, Subset, RandomSampler
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, \
|
||||
PreTrainedTokenizerFast, HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback, BitsAndBytesConfig
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback, BitsAndBytesConfig
|
||||
from transformers.integrations.integration_utils import TensorBoardCallback
|
||||
from datasets import load_dataset, Dataset
|
||||
from datasets import load_dataset
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Sequence, Sized, Iterator
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
|
||||
IS_DDP_ENABLED = "LOCAL_RANK" in os.environ
|
||||
MULTI_GPU_WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
MULTI_GPU_RANK = int(os.environ.get("RANK", "0"))
|
||||
IS_MULTI_GPU = os.environ.get("RANK") != None
|
||||
@@ -28,22 +28,25 @@ class TrainingRunArguments:
|
||||
base_model: str = field(metadata={"help": "The base model to load for fine-tuning"})
|
||||
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
|
||||
test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
|
||||
dataset_processing_threads: int = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"})
|
||||
ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"})
|
||||
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp16"})
|
||||
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"})
|
||||
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
|
||||
micro_batch_size: int = field(default=2, metadata={"help": "The actual batch size that will fit into VRAM on this machine"})
|
||||
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
|
||||
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
|
||||
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
|
||||
learning_rate_warmup: float = field(default=0.0, metadata={"help": "The starting learning rate (speed at which the model trains)"})
|
||||
weight_decay: float = field(default=0.1, metadata={"help": ""})
|
||||
gradient_clip: float = field(default=1.0, metadata={"help": ""})
|
||||
weight_decay: float = field(default=0.1, metadata={"help": "Weight Decay rate for regularization. Rate to reduce all neuron weights towards zero."})
|
||||
# dropout: float = field(default=0.01, metadata={"help": "Dropout percent for regularization. Determines the fraction of neurons randomly deactivated during training."})
|
||||
gradient_clip: float = field(default=1.0, metadata={"help": "Maximum gradient norm for clipping to prevent exploding gradients during training."})
|
||||
resume_from_checkpoint: str = field(default="", metadata={"help": "The name of the checkpoint to resume training from"})
|
||||
eval_steps: int = field(default=200, metadata={"help": "The number of steps in between evaluations of the model; set to -1 to evaluate every epoch"})
|
||||
save_steps: int = field(default=-1, metadata={"help": "The number of steps in between model checkpoints; set to -1 to save every epoch"})
|
||||
save_total_limit: int = field(default=1, metadata={"help": "The number of recent checkpoints of the model to save (not including the final model)"})
|
||||
logging_steps: int = field(default=5, metadata={"help": "Sets the number of steps in between log output for the training run"})
|
||||
group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding"})
|
||||
group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding. Runs from longest to shortest examples."})
|
||||
gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing to saves VRAM at the cost of re-computing activations during the backwards pass"})
|
||||
pre_allocate_cuda_buffers: bool = field(default=True, metadata={"help": "If enabled, runs a forward and backward pass on the model before training to force pytorch to allocate the correct size CUDA buffers up front"})
|
||||
|
||||
# Quantization
|
||||
@@ -60,21 +63,23 @@ class TrainingRunArguments:
|
||||
lora_modules_to_save: str = field(default=None, metadata={"help": "Additional modules to save"})
|
||||
lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"})
|
||||
|
||||
# dpo config
|
||||
dpo: bool = field(default=False, metadata={"help": "If set, performs Direct Preference Optimization instead of Supervised Fine Tuning"})
|
||||
beta: float = field(default=0.1, metadata={"help": "The implicit reward value used during DPO training"})
|
||||
dpo_loss: str = field(default="sigmoid", metadata={"help": "The loss type to use during DPO training"})
|
||||
|
||||
# token options
|
||||
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
|
||||
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
|
||||
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
|
||||
gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing which saves quite a lot of VRAM"})
|
||||
prefix_ids: str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||
suffix_ids: str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
|
||||
|
||||
# custom trainer tweaks
|
||||
sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
|
||||
bucket_save_limit: int = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"})
|
||||
flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
|
||||
|
||||
prefix_ids:str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognise response."})
|
||||
suffix_ids:str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognise response."})
|
||||
|
||||
|
||||
class UploadToS3Callback(TrainerCallback):
|
||||
def __init__(self, s3_bucket, s3_prefix, save_total_limit=None):
|
||||
@@ -85,10 +90,11 @@ class UploadToS3Callback(TrainerCallback):
|
||||
self.save_total_limit = save_total_limit
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
output_dir = kwargs['output_dir']
|
||||
checkpoint = os.path.basename(output_dir)
|
||||
|
||||
|
||||
# Upload current checkpoint
|
||||
checkpoint = f"checkpoint-{state.global_step}"
|
||||
output_dir = f"{args.output_dir}/{checkpoint}"
|
||||
|
||||
for root, dirs, files in os.walk(output_dir):
|
||||
for file in files:
|
||||
local_path = os.path.join(root, file)
|
||||
@@ -96,7 +102,7 @@ class UploadToS3Callback(TrainerCallback):
|
||||
self.s3_client.upload_file(local_path, self.s3_bucket, s3_path)
|
||||
print(f"Uploaded {local_path} to s3://{self.s3_bucket}/{s3_path}")
|
||||
|
||||
# Manage checkpoints in S3
|
||||
# Delete prior checkpoints from S3
|
||||
if self.save_total_limit:
|
||||
s3_checkpoints = self.list_s3_checkpoints()
|
||||
if len(s3_checkpoints) > self.save_total_limit:
|
||||
@@ -105,18 +111,9 @@ class UploadToS3Callback(TrainerCallback):
|
||||
for checkpoint in to_delete:
|
||||
self.delete_checkpoint_from_s3(checkpoint)
|
||||
|
||||
# Clean local checkpoints, keeping only the most recent
|
||||
all_checkpoints = [os.path.join(args.output_dir, d) for d in os.listdir(args.output_dir) if os.path.isdir(os.path.join(args.output_dir, d))]
|
||||
if all_checkpoints:
|
||||
latest_checkpoint = max(all_checkpoints, key=os.path.getmtime)
|
||||
for checkpoint_dir in all_checkpoints:
|
||||
if checkpoint_dir != latest_checkpoint:
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
print(f"Deleted local checkpoint {checkpoint_dir}")
|
||||
|
||||
def list_s3_checkpoints(self):
|
||||
paginator = self.s3_client.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=self.s3_bucket, Prefix=self.s3_prefix, Delimiter='/')
|
||||
page_iterator = paginator.paginate(Bucket=self.s3_bucket, Prefix=self.s3_prefix + '/', Delimiter='/')
|
||||
return [prefix.get('Prefix').rstrip('/').split('/')[-1] for page in page_iterator for prefix in page.get('CommonPrefixes', [])]
|
||||
|
||||
def delete_checkpoint_from_s3(self, checkpoint_name):
|
||||
@@ -145,153 +142,25 @@ class MFUCallback(TrainerCallback):
|
||||
|
||||
self.start_time = current_time
|
||||
self.last_total_flos = state.total_flos
|
||||
|
||||
|
||||
parser = HfArgumentParser([TrainingRunArguments])
|
||||
training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1:
|
||||
raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq")
|
||||
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Loading model '{training_run_args.base_model}'...")
|
||||
|
||||
model_kwargs = {}
|
||||
if training_run_args.load_in_8bit:
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif training_run_args.load_in_4bit:
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
||||
elif training_run_args.load_as_gptq:
|
||||
model_kwargs["quantization_config"] = GPTQConfig(bits=4, disable_exllama=True)
|
||||
|
||||
|
||||
if training_run_args.bf16:
|
||||
model_kwargs["torch_dtype"] = torch.bfloat16
|
||||
else:
|
||||
model_kwargs["torch_dtype"] = torch.float16
|
||||
|
||||
# model_kwargs["resid_pdrop"] = 0.0
|
||||
# model_kwargs["revision"] = "accfee56d8988cae60915486310362db5831b1bd"
|
||||
model_kwargs["use_cache"] = False
|
||||
def ddp_print(*args, **kwargs):
|
||||
if not IS_DDP_ENABLED or IS_MASTER_PROCESS:
|
||||
print(*args, **kwargs)
|
||||
|
||||
def find_max_vram(min_buffer_mib=800):
|
||||
max_memory = {}
|
||||
for i in range(torch.cuda.device_count()):
|
||||
total_mem = (torch.cuda.get_device_properties(i).total_memory / (1024 * 1024))
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
suggestion = min(suggestion, total_mem - min_buffer_mib)
|
||||
gpu_properties = torch.cuda.get_device_properties(i)
|
||||
total_memory_mib = (gpu_properties.total_memory / (1000 * 1000))
|
||||
suggestion = max(total_memory_mib - 1000, min_buffer_mib)
|
||||
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
|
||||
ddp_print(f"GPU {i}: {gpu_properties.name}, Total Memory: {gpu_properties.total_memory / (1024**3):.2f} GB")
|
||||
ddp_print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
|
||||
max_memory[i] = f'{suggestion}MiB'
|
||||
|
||||
return max_memory
|
||||
|
||||
if "LOCAL_RANK" not in os.environ:
|
||||
model_kwargs["device_map"] = "auto"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
training_run_args.base_model,
|
||||
trust_remote_code=True,
|
||||
max_memory=find_max_vram(),
|
||||
**model_kwargs
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, trust_remote_code=True)
|
||||
|
||||
if training_run_args.add_pad_token:
|
||||
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if training_run_args.add_chatml_tokens:
|
||||
tokenizer.add_special_tokens({
|
||||
'bos_token': '<|im_start|>',
|
||||
'eos_token': '<|im_end|>'
|
||||
})
|
||||
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if training_run_args.add_chatml_prompt_template:
|
||||
tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '<|im_start|>assistant\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
else:
|
||||
model.tie_weights()
|
||||
|
||||
# model.tie_weights()
|
||||
|
||||
original_model = model
|
||||
peft_config = None
|
||||
if training_run_args.use_lora:
|
||||
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Creating LoRA for model...")
|
||||
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
|
||||
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=training_run_args.lora_rank,
|
||||
lora_alpha=training_run_args.lora_alpha,
|
||||
lora_dropout=training_run_args.lora_dropout,
|
||||
target_modules=target_modules,
|
||||
modules_to_save=modules_to_save,
|
||||
)
|
||||
if training_run_args.load_in_8bit or training_run_args.load_in_4bit or training_run_args.load_as_gptq:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=training_run_args.gradient_checkpointing
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.enable_input_require_grads()
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
|
||||
base_dir = "loras" if training_run_args.use_lora else "models"
|
||||
model_dir = f"./{base_dir}/{training_run_args.run_name}"
|
||||
|
||||
training_kwargs = {}
|
||||
|
||||
if training_run_args.test_dataset:
|
||||
training_kwargs.update({
|
||||
"per_device_eval_batch_size": training_run_args.micro_batch_size,
|
||||
"eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
|
||||
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
|
||||
"bf16_full_eval": training_run_args.bf16,
|
||||
})
|
||||
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=training_run_args.micro_batch_size,
|
||||
gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size,
|
||||
gradient_checkpointing=training_run_args.gradient_checkpointing,
|
||||
weight_decay=training_run_args.weight_decay,
|
||||
max_grad_norm=training_run_args.gradient_clip,
|
||||
save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"),
|
||||
save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None),
|
||||
save_safetensors=True,
|
||||
logging_steps=training_run_args.logging_steps,
|
||||
output_dir=model_dir,
|
||||
num_train_epochs=training_run_args.epochs,
|
||||
save_total_limit=training_run_args.save_total_limit,
|
||||
report_to='none',
|
||||
learning_rate=training_run_args.learning_rate,
|
||||
lr_scheduler_type=training_run_args.learning_rate_schedule,
|
||||
warmup_ratio=training_run_args.learning_rate_warmup,
|
||||
log_level="info",
|
||||
bf16=training_run_args.bf16,
|
||||
group_by_length=training_run_args.group_by_length,
|
||||
# include_num_input_tokens_seen=True,
|
||||
**training_kwargs,
|
||||
)
|
||||
|
||||
class DataCollatorForSupervisedFineTuning(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
@@ -413,14 +282,8 @@ class DataCollatorForSupervisedFineTuning(object):
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id or self.tokenizer.eos_token_id),
|
||||
)
|
||||
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Loading dataset...")
|
||||
data_files = { "train": training_run_args.train_dataset }
|
||||
if training_run_args.test_dataset:
|
||||
data_files["test"] = training_run_args.test_dataset
|
||||
datasets = load_dataset("json", data_files=data_files)
|
||||
|
||||
def tokenize_raw_example(batch):
|
||||
def tokenize_raw_example(batch, tokenizer=None, training_run_args=None):
|
||||
return tokenizer(
|
||||
text=batch["text"],
|
||||
max_length=training_run_args.ctx_size,
|
||||
@@ -428,7 +291,7 @@ def tokenize_raw_example(batch):
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
def tokenize_sharegpt_example(batch):
|
||||
def tokenize_sharegpt_example(batch, tokenizer=None, training_run_args=None):
|
||||
# TODO: figure out how to properly batch this
|
||||
result = []
|
||||
for example in batch["conversations"]:
|
||||
@@ -443,7 +306,7 @@ def tokenize_sharegpt_example(batch):
|
||||
|
||||
return {"input_ids": result}
|
||||
|
||||
def template_dpo_example(batch):
|
||||
def template_dpo_example(batch, tokenizer=None, training_run_args=None):
|
||||
# TODO: figure out how to properly batch this
|
||||
result = []
|
||||
for example in zip(batch["system"], batch["question"]):
|
||||
@@ -463,24 +326,6 @@ def template_dpo_example(batch):
|
||||
|
||||
return {"prompt": result}
|
||||
|
||||
training_callbacks = []
|
||||
if training_run_args.sync_to_bucket:
|
||||
training_callbacks.append(UploadToS3Callback(
|
||||
s3_bucket=training_run_args.sync_to_bucket,
|
||||
s3_prefix=training_run_args.run_name,
|
||||
save_total_limit=training_run_args.save_total_limit
|
||||
))
|
||||
|
||||
if training_run_args.flops_baseline:
|
||||
# A100 GPU bfloat16 peak flops is 312 TFLOPS (312e12)
|
||||
# 4090 GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11)
|
||||
# 3090 GPU bfloat16 peak flops is 71 TFLOPS (71e12)
|
||||
|
||||
training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline)))
|
||||
|
||||
|
||||
# log to tensorboard (but after MFU)
|
||||
training_callbacks.append(TensorBoardCallback())
|
||||
|
||||
class CustomSFTTrainer(Trainer):
|
||||
"""Implement different training tweaks"""
|
||||
@@ -514,7 +359,7 @@ class CustomSFTTrainer(Trainer):
|
||||
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
||||
"""
|
||||
Saw this in the chinchilla paper. It says not to go over 25% overshoot
|
||||
Should speed up training by skipping the final fine tuning part that doesn't affect accuracy much
|
||||
Should improve training efficiency by skipping the final fine tuning part that doesn't affect accuracy much
|
||||
"""
|
||||
return super().create_scheduler(int(num_training_steps * self.learning_rate_overshoot), optimizer=optimizer)
|
||||
|
||||
@@ -523,8 +368,8 @@ class CustomSFTTrainer(Trainer):
|
||||
examples_length = len(inputs["input_ids"][0])
|
||||
batch_size = len(inputs["input_ids"])
|
||||
|
||||
# mfu is approximated using thoughtput and param count
|
||||
# the number of paramters is approximately the number of multiply-accumulates (MAC) in the network
|
||||
# mfu is approximated using throughput and param count
|
||||
# the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
|
||||
# each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
|
||||
# there are 3 passes of a NN (fwd, bwd, delta) - we multiply by 3 ie 2 * 3 * n_param
|
||||
# this gets us FLOPs / token
|
||||
@@ -538,128 +383,303 @@ class CustomSFTTrainer(Trainer):
|
||||
result = (3 * flops_per_seq + 3 * attn_flops_per_seq) * batch_size
|
||||
return result
|
||||
|
||||
if not training_run_args.dpo:
|
||||
if IS_MASTER_PROCESS:
|
||||
print("Tokenizing datasets...")
|
||||
|
||||
if "text" in datasets["train"].column_names:
|
||||
tokenize_function = tokenize_raw_example
|
||||
columns_to_remove = ["text"]
|
||||
elif "conversations" in datasets["train"].column_names:
|
||||
tokenize_function = tokenize_sharegpt_example
|
||||
columns_to_remove = ["conversations"]
|
||||
def do_training_run(training_run_args: TrainingRunArguments):
|
||||
# validate args + build model kwargs
|
||||
if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1:
|
||||
raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq")
|
||||
|
||||
model_kwargs = {}
|
||||
if training_run_args.load_in_8bit:
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif training_run_args.load_in_4bit:
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
||||
elif training_run_args.load_as_gptq:
|
||||
model_kwargs["quantization_config"] = GPTQConfig(bits=4, disable_exllama=True)
|
||||
|
||||
if training_run_args.bf16:
|
||||
model_kwargs["torch_dtype"] = torch.bfloat16
|
||||
elif training_run_args.use_lora and "quantization_config" not in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = torch.float16
|
||||
else:
|
||||
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)")
|
||||
# auto detect 'best' format with fallback to fp32
|
||||
model_kwargs["torch_dtype"] = "auto"
|
||||
|
||||
# model_kwargs["resid_pdrop"] = training_run_args.dropout
|
||||
model_kwargs["use_cache"] = False
|
||||
|
||||
if not IS_DDP_ENABLED:
|
||||
model_kwargs["device_map"] = "auto"
|
||||
|
||||
# load the model
|
||||
ddp_print(f"Loading model '{training_run_args.base_model}'...")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
training_run_args.base_model,
|
||||
max_memory=find_max_vram(),
|
||||
token=os.environ.get("HF_TOKEN"),
|
||||
**model_kwargs
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, token=os.environ.get("HF_TOKEN"))
|
||||
|
||||
# mess with tokens + prompt template
|
||||
if training_run_args.add_pad_token:
|
||||
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if training_run_args.add_chatml_tokens:
|
||||
tokenizer.add_special_tokens({
|
||||
'bos_token': '<|im_start|>',
|
||||
'eos_token': '<|im_end|>'
|
||||
})
|
||||
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if training_run_args.add_chatml_prompt_template:
|
||||
tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '<|im_start|>assistant\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
# resize embeddings if added tokens require it
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
else:
|
||||
model.tie_weights()
|
||||
|
||||
# create LoRA model if config says so
|
||||
original_model = model
|
||||
peft_config = None
|
||||
if training_run_args.use_lora:
|
||||
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
||||
ddp_print("Creating LoRA for model...")
|
||||
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
|
||||
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=training_run_args.lora_rank,
|
||||
lora_alpha=training_run_args.lora_alpha,
|
||||
lora_dropout=training_run_args.lora_dropout,
|
||||
target_modules=target_modules,
|
||||
modules_to_save=modules_to_save,
|
||||
)
|
||||
if training_run_args.load_in_8bit or training_run_args.load_in_4bit or training_run_args.load_as_gptq:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=training_run_args.gradient_checkpointing
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.enable_input_require_grads()
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
base_dir = "loras" if training_run_args.use_lora else "models"
|
||||
model_dir = f"./{base_dir}/{training_run_args.run_name}"
|
||||
|
||||
# set up HuggingFace Trainer args
|
||||
training_kwargs = {}
|
||||
|
||||
tokenized_test_dataset = None
|
||||
num_proc = os.cpu_count() // MULTI_GPU_WORLD_SIZE
|
||||
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove)
|
||||
if training_run_args.test_dataset:
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove)
|
||||
training_kwargs.update({
|
||||
"per_device_eval_batch_size": training_run_args.micro_batch_size,
|
||||
"eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
|
||||
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
|
||||
"bf16_full_eval": training_run_args.bf16,
|
||||
})
|
||||
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=training_run_args.micro_batch_size,
|
||||
gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size,
|
||||
gradient_checkpointing=training_run_args.gradient_checkpointing,
|
||||
weight_decay=training_run_args.weight_decay,
|
||||
max_grad_norm=training_run_args.gradient_clip,
|
||||
save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"),
|
||||
save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None),
|
||||
save_safetensors=True,
|
||||
logging_steps=training_run_args.logging_steps,
|
||||
output_dir=model_dir,
|
||||
num_train_epochs=training_run_args.epochs,
|
||||
save_total_limit=training_run_args.save_total_limit,
|
||||
report_to='none',
|
||||
learning_rate=training_run_args.learning_rate,
|
||||
lr_scheduler_type=training_run_args.learning_rate_schedule,
|
||||
warmup_ratio=training_run_args.learning_rate_warmup,
|
||||
log_level="info",
|
||||
bf16=training_run_args.bf16,
|
||||
group_by_length=training_run_args.group_by_length,
|
||||
# include_num_input_tokens_seen=True,
|
||||
**training_kwargs,
|
||||
)
|
||||
|
||||
# set up trainer callbacks
|
||||
training_callbacks = []
|
||||
if training_run_args.sync_to_bucket:
|
||||
training_callbacks.append(UploadToS3Callback(
|
||||
s3_bucket=training_run_args.sync_to_bucket,
|
||||
s3_prefix=training_run_args.run_name,
|
||||
save_total_limit=training_run_args.bucket_save_limit if training_run_args.bucket_save_limit else training_run_args.save_total_limit
|
||||
))
|
||||
|
||||
if training_run_args.flops_baseline:
|
||||
# A100 40/80GB GPU bfloat16 peak flops is 312 TFLOPS (312e12)
|
||||
# 4090 24GB GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11)
|
||||
# A40 48GB GPU bfloat16 peak flops is 149.7 TFLOPS (149.7e11)
|
||||
# 3090 24GB GPU bfloat16 peak flops is 71 TFLOPS (71e12)
|
||||
training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline)))
|
||||
|
||||
# log to tensorboard (but after MFU)
|
||||
training_callbacks.append(TensorBoardCallback())
|
||||
|
||||
if not training_run_args.dpo:
|
||||
ddp_print("Loading dataset...")
|
||||
data_files = { "train": training_run_args.train_dataset }
|
||||
if training_run_args.test_dataset:
|
||||
data_files["test"] = training_run_args.test_dataset
|
||||
datasets = load_dataset("json", data_files=data_files)
|
||||
|
||||
# prepare the dataset
|
||||
ddp_print("Tokenizing datasets...")
|
||||
|
||||
if "text" in datasets["train"].column_names:
|
||||
tokenize_function = tokenize_raw_example
|
||||
columns_to_remove = ["text"]
|
||||
elif "conversations" in datasets["train"].column_names:
|
||||
tokenize_function = tokenize_sharegpt_example
|
||||
columns_to_remove = ["conversations"]
|
||||
else:
|
||||
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)")
|
||||
|
||||
tokenized_test_dataset = None
|
||||
num_proc = None
|
||||
if training_run_args.dataset_processing_threads:
|
||||
num_proc = training_run_args.dataset_processing_threads // MULTI_GPU_WORLD_SIZE
|
||||
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
|
||||
if training_run_args.test_dataset:
|
||||
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
|
||||
|
||||
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
|
||||
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)
|
||||
ddp_print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens")
|
||||
|
||||
provided_prefix_ids = None
|
||||
provided_suffix_ids = None
|
||||
try:
|
||||
if training_run_args.prefix_ids:
|
||||
provided_prefix_ids = [ int(x) for x in training_run_args.prefix_ids.split(",") ]
|
||||
if training_run_args.suffix_ids:
|
||||
provided_suffix_ids = [ int(x) for x in training_run_args.suffix_ids.split(",") ]
|
||||
except ValueError as ex:
|
||||
print(f"Error parsing prefix_ids or suffix_ids: '{ex}'")
|
||||
exit(-1)
|
||||
|
||||
trainer = CustomSFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_train_dataset,
|
||||
eval_dataset=tokenized_test_dataset,
|
||||
data_collator=DataCollatorForSupervisedFineTuning(
|
||||
tokenizer=tokenizer,
|
||||
prefix_ids=provided_prefix_ids,
|
||||
suffix_ids=provided_suffix_ids,
|
||||
),
|
||||
callbacks=training_callbacks,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("DPO Trainer doesn't work yet!")
|
||||
# from trl import DPOTrainer
|
||||
# max_prompt_length = 0
|
||||
|
||||
# train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) })
|
||||
|
||||
# test_dataset = None
|
||||
# if training_run_args.test_dataset:
|
||||
# test_dataset = datasets["test"]
|
||||
|
||||
# max_prompt_length = max(train_dataset["prompt_len"])
|
||||
|
||||
# print("Templating DPO Examples...")
|
||||
# templated_test_dataset = None
|
||||
# templated_train_dataset = train_dataset.map(template_dpo_example, batched=True).remove_columns(["system", "question"])
|
||||
# if training_run_args.test_dataset:
|
||||
# templated_test_dataset = datasets["test"].map(template_dpo_example, batched=True).remove_columns(["system", "question"])
|
||||
|
||||
# # tokenizer.model_input_names = [ "chosen_input_ids" ]
|
||||
|
||||
# # group_by_length doesn't work here
|
||||
# # templated_train_dataset = templated_train_dataset.sort("prompt_len", reverse=True)
|
||||
|
||||
# training_args.length_column_name = "prompt_len"
|
||||
# model.enable_input_require_grads()
|
||||
|
||||
# trainer = DPOTrainer(
|
||||
# model,
|
||||
# ref_model=None,
|
||||
# # ref_model=original_model,
|
||||
# peft_config=peft_config,
|
||||
# args=training_args,
|
||||
# beta=training_run_args.beta,
|
||||
# loss_type=training_run_args.dpo_loss,
|
||||
# train_dataset=templated_train_dataset,
|
||||
# eval_dataset=templated_test_dataset,
|
||||
# tokenizer=tokenizer,
|
||||
# max_length=training_run_args.ctx_size,
|
||||
# max_prompt_length=max_prompt_length,
|
||||
# truncation_mode="keep_start",
|
||||
# callbacks=training_callbacks,
|
||||
# )
|
||||
|
||||
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
|
||||
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)
|
||||
if IS_MASTER_PROCESS:
|
||||
print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens")
|
||||
|
||||
provided_prefix_ids = None
|
||||
provided_suffix_ids = None
|
||||
try:
|
||||
if training_run_args.prefix_ids:
|
||||
provided_prefix_ids = [ int(x) for x in training_run_args.prefix_ids.split(",") ]
|
||||
if training_run_args.suffix_ids:
|
||||
provided_suffix_ids = [ int(x) for x in training_run_args.suffix_ids.split(",") ]
|
||||
except ValueError as ex:
|
||||
print(f"Error parsing prefix_ids or suffix_ids: '{ex}'")
|
||||
trainer.train(resume_from_checkpoint=training_run_args.resume_from_checkpoint if training_run_args.resume_from_checkpoint else None)
|
||||
|
||||
if training_run_args.test_dataset:
|
||||
trainer.evaluate_all()
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||
|
||||
if training_run_args.use_lora and training_run_args.lora_merge:
|
||||
trainer.save_model() # save lora
|
||||
|
||||
merged_model = model.merge_and_unload(progressbar=True)
|
||||
merged_model_dir = f"./models/{training_run_args.run_name}"
|
||||
merged_model.save_pretrained(merged_model_dir, safe_serialization=True, max_shard_size="2GB")
|
||||
|
||||
tokenizer.save_pretrained(merged_model_dir)
|
||||
else:
|
||||
trainer.save_model()
|
||||
tokenizer.save_pretrained(model_dir)
|
||||
|
||||
if training_run_args.sync_to_bucket:
|
||||
import boto3
|
||||
s3_client = boto3.client('s3')
|
||||
|
||||
for root, dirs, files in os.walk(model_dir):
|
||||
for file in files:
|
||||
local_path = os.path.join(root, file)
|
||||
s3_path = os.path.join(training_run_args.run_name, os.path.relpath(local_path, start="."))
|
||||
s3_client.upload_file(local_path, training_run_args.sync_to_bucket, s3_path)
|
||||
print(f"Uploaded {local_path} to s3://{training_run_args.sync_to_bucket}/{s3_path}")
|
||||
|
||||
except Exception as ex:
|
||||
if trainer.is_fsdp_enabled:
|
||||
raise ex # this doesn't play nice with FSDP so don't even try
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
if input("Something bad happened! Try and save it? (Y/n) ").lower().startswith("y"):
|
||||
trainer._save_checkpoint(model, None)
|
||||
print("Saved Checkpoint!")
|
||||
|
||||
exit(-1)
|
||||
|
||||
data_collator = DataCollatorForSupervisedFineTuning(
|
||||
tokenizer=tokenizer,
|
||||
prefix_ids=provided_prefix_ids,
|
||||
suffix_ids=provided_suffix_ids,
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser([TrainingRunArguments])
|
||||
training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
trainer = CustomSFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_train_dataset,
|
||||
eval_dataset=tokenized_test_dataset,
|
||||
data_collator=data_collator,
|
||||
callbacks=training_callbacks,
|
||||
)
|
||||
else:
|
||||
from trl import DPOTrainer
|
||||
max_prompt_length = 0
|
||||
|
||||
train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) })
|
||||
|
||||
test_dataset = None
|
||||
if training_run_args.test_dataset:
|
||||
test_dataset = datasets["test"]
|
||||
|
||||
max_prompt_length = max(train_dataset["prompt_len"])
|
||||
|
||||
print("Templating DPO Examples...")
|
||||
templated_test_dataset = None
|
||||
templated_train_dataset = train_dataset.map(template_dpo_example, batched=True).remove_columns(["system", "question"])
|
||||
if training_run_args.test_dataset:
|
||||
templated_test_dataset = datasets["test"].map(template_dpo_example, batched=True).remove_columns(["system", "question"])
|
||||
|
||||
# tokenizer.model_input_names = [ "chosen_input_ids" ]
|
||||
|
||||
# group_by_length doesn't work here
|
||||
# templated_train_dataset = templated_train_dataset.sort("prompt_len", reverse=True)
|
||||
|
||||
training_args.length_column_name = "prompt_len"
|
||||
model.enable_input_require_grads()
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=None,
|
||||
# ref_model=original_model,
|
||||
peft_config=peft_config,
|
||||
args=training_args,
|
||||
beta=training_run_args.beta,
|
||||
loss_type=training_run_args.dpo_loss,
|
||||
train_dataset=templated_train_dataset,
|
||||
eval_dataset=templated_test_dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_run_args.ctx_size,
|
||||
max_prompt_length=max_prompt_length,
|
||||
truncation_mode="keep_start",
|
||||
callbacks=training_callbacks,
|
||||
)
|
||||
|
||||
try:
|
||||
checkpoint = training_run_args.resume_from_checkpoint
|
||||
if checkpoint:
|
||||
trainer.train(checkpoint)
|
||||
else:
|
||||
trainer.train()
|
||||
|
||||
if training_run_args.test_dataset:
|
||||
trainer.evaluate_all()
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||
|
||||
if training_run_args.use_lora and training_run_args.lora_merge:
|
||||
trainer.save_model() # save lora
|
||||
|
||||
merged_model = model.merge_and_unload(progressbar=True)
|
||||
merged_model_dir = f"./models/{training_run_args.run_name}"
|
||||
merged_model.save_pretrained(merged_model_dir, safe_serialization=True, max_shard_size="2GB")
|
||||
|
||||
tokenizer.save_pretrained(merged_model_dir)
|
||||
else:
|
||||
trainer.save_model()
|
||||
tokenizer.save_pretrained(model_dir)
|
||||
|
||||
except Exception as ex:
|
||||
if trainer.is_fsdp_enabled:
|
||||
raise ex # this doesn't play nice with FSDP so don't even try
|
||||
|
||||
print("Something bad happened! Try and save it?")
|
||||
import code, traceback
|
||||
traceback.print_exc()
|
||||
code.interact(local=locals())
|
||||
do_training_run(training_run_args)
|
||||
|
||||
Reference in New Issue
Block a user