Merge pull request #265 from acon96/release/v0.3.8

Release v0.3.8
This commit is contained in:
Alex O'Connell
2025-04-13 22:21:32 +00:00
committed by GitHub
20 changed files with 690 additions and 456 deletions

View File

@@ -20,33 +20,33 @@ jobs:
matrix:
include:
# ARM variants
- home_assistant_version: "2024.12.3"
- home_assistant_version: "2025.4.1"
arch: "aarch64"
- home_assistant_version: "2024.12.3"
- home_assistant_version: "2025.4.1"
arch: "armhf"
# Base x86
- home_assistant_version: "2024.12.3"
- home_assistant_version: "2025.4.1"
suffix: "-noavx"
arch: "amd64"
extra_defines: "-DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DGGML_F16C=OFF"
- home_assistant_version: "2024.12.3"
- home_assistant_version: "2025.4.1"
arch: "i386"
suffix: "-noavx"
extra_defines: "-DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DGGML_F16C=OFF"
# AVX2 and AVX512
- home_assistant_version: "2024.12.3"
- home_assistant_version: "2025.4.1"
arch: "amd64"
extra_defines: "-DGGML_AVX=ON -DGGML_AVX2=ON -DGGML_FMA=ON -DGGML_F16C=ON"
- home_assistant_version: "2024.12.3"
- home_assistant_version: "2025.4.1"
arch: "amd64"
suffix: "-avx512"
extra_defines: "-DGGML_AVX512=ON -DGGML_FMA=ON -DGGML_F16C=ON"
- home_assistant_version: "2024.12.3"
- home_assistant_version: "2025.4.1"
arch: "i386"
extra_defines: "-DGGML_AVX=ON -DGGML_AVX2=ON -DGGML_FMA=ON -DGGML_F16C=ON"
- home_assistant_version: "2024.12.3"
- home_assistant_version: "2025.4.1"
arch: "i386"
suffix: "-avx512"
extra_defines: "-DGGML_AVX512=ON -DGGML_FMA=ON -DGGML_F16C=ON"

View File

@@ -5,7 +5,7 @@ This project provides the required "glue" components to control your Home Assist
Please see the [Setup Guide](./docs/Setup.md) for more information on installation.
## Local LLM Conversation Integration
**The latest version of this integration requires Home Assistant 2024.12.3 or newer**
**The latest version of this integration requires Home Assistant 2025.4.1 or newer**
In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent".
@@ -150,6 +150,7 @@ In order to facilitate running the project entirely on the system where Home Ass
## Version History
| Version | Description |
|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| v0.3.8 | Update llama.cpp, remove think blocks from "thinking" models, fix wheel detection for some Intel CPUs, Fixes for compatibility with latest Home Assistant version (2025.4), other small bug fixes |
| v0.3.7 | Update llama.cpp version to support newer models, Update minimum Home Assistant version to 2024.12.3, Add German In-Context Learning examples, Fix multi-turn use, Fix an issue with webcolors |
| v0.3.6 | Small llama.cpp backend fixes |
| v0.3.5 | Fix for llama.cpp backend installation, Fix for Home LLM v1-3 API parameters, add Polish ICL examples |

View File

@@ -8,19 +8,31 @@ import homeassistant.components.conversation as ha_conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, llm
from homeassistant.util.json import JsonObjectType
import voluptuous as vol
from .const import (
ALLOWED_SERVICE_CALL_ARGUMENTS,
DOMAIN,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
CONF_BACKEND_TYPE,
DEFAULT_BACKEND_TYPE,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
)
from .conversation import LlamaCppAgent, GenericOpenAIAPIAgent, TextGenerationWebuiAgent, \
LlamaCppPythonAPIAgent, OllamaAPIAgent, LocalLLMAgent
type LocalLLMConfigEntry = ConfigEntry[LocalLLMAgent]
_LOGGER = logging.getLogger(__name__)
@@ -29,7 +41,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
# make sure the API is registered
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]):
@@ -37,18 +49,43 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
def create_agent(backend_type):
agent_cls = None
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
agent_cls = LlamaCppAgent
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
agent_cls = GenericOpenAIAPIAgent
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
agent_cls = TextGenerationWebuiAgent
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
agent_cls = LlamaCppPythonAPIAgent
elif backend_type == BACKEND_TYPE_OLLAMA:
agent_cls = OllamaAPIAgent
return agent_cls(hass, entry)
# create the agent in an executor job because the constructor calls `open()`
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
entry.runtime_data = await hass.async_add_executor_job(create_agent, backend_type)
# call load model
await entry.runtime_data._async_load_model(entry)
# forward setup to platform to register the entity
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> bool:
"""Unload Ollama."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
hass.data[DOMAIN].pop(entry.entry_id)
return True
async def async_migrate_entry(hass, config_entry: ConfigEntry):
async def async_migrate_entry(hass: HomeAssistant, config_entry: LocalLLMConfigEntry):
"""Migrate old entry."""
_LOGGER.debug("Migrating from version %s", config_entry.version)
@@ -82,13 +119,8 @@ class HassServiceTool(llm.Tool):
vol.Optional('item'): str,
})
ALLOWED_SERVICES: Final[list[str]] = [
"turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover",
"lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item"
]
ALLOWED_DOMAINS: Final[list[str]] = [
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
]
ALLOWED_SERVICES: Final[list[str]] = SERVICE_TOOL_ALLOWED_SERVICES
ALLOWED_DOMAINS: Final[list[str]] = SERVICE_TOOL_ALLOWED_DOMAINS
async def async_call(
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext

View File

@@ -687,6 +687,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
for key in OPTIONS_OVERRIDES.keys():
if key in model_name:
selected_default_options.update(OPTIONS_OVERRIDES[key])
break
persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en"))
current_date = CURRENT_DATE_PROMPT.get(self.selected_language, CURRENT_DATE_PROMPT.get("en"))
@@ -765,15 +766,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
config_entry: config_entries.ConfigEntry,
) -> config_entries.OptionsFlow:
"""Create the options flow."""
return OptionsFlow(config_entry)
return OptionsFlow()
class OptionsFlow(config_entries.OptionsFlow):
"""Local LLM config flow options handler."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
@property
def config_entry(self):
return self.hass.config_entries.async_get_entry(self.handler)
async def async_step_init(
self, user_input: dict[str, Any] | None = None

View File

@@ -4,6 +4,8 @@ import types, os
DOMAIN = "llama_conversation"
HOME_LLM_API_ID = "home-llm-service-api"
SERVICE_TOOL_NAME = "HassCallService"
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
CONF_PROMPT = "prompt"
PERSONA_PROMPTS = {
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",
@@ -13,7 +15,7 @@ PERSONA_PROMPTS = {
"pl": "Jeste\u015b 'Al', pomocnym asystentem AI, kt\u00f3ry kontroluje urz\u0105dzenia w domu. Wykonaj poni\u017csze zadanie zgodnie z instrukcj\u0105 lub odpowiedz na poni\u017csze pytanie, korzystaj\u0105c wy\u0142\u0105cznie z podanych informacji."
}
CURRENT_DATE_PROMPT = {
"en": """The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}""",
"en": """The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", True, "")) }}""",
"de": """{% set day_name = ["Montag", "Dienstag", "Mittwoch", "Donnerstag", "Freitag", "Samstag", "Sonntag"] %}{% set month_name = ["Januar", "Februar", "März", "April", "Mai", "Juni", "Juli", "August", "September", "Oktober", "November", "Dezember"] %}Die aktuelle Uhrzeit und das aktuelle Datum sind {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""",
"fr": """{% set day_name = ["lundi", "mardi", "mercredi", "jeudi", "vendredi", "samedi", "dimanche"] %}{% set month_name = ["janvier", "février", "mars", "avril", "mai", "juin", "juillet", "août", "septembre", "octobre", "novembre", "décembre"] %} L'heure et la date actuelles sont {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} {{ month_name[now().month -1]}} {{ now().year }}.""",
"es": """{% set day_name = ["lunes", "martes", "miércoles", "jueves", "viernes", "sábado", "domingo"] %}{% set month_name = ["enero", "febrero", "marzo", "abril", "mayo", "junio", "julio", "agosto", "septiembre", "octubre", "noviembre", "diciembre"] %}La hora y fecha actuales son {{ (as_timestamp(now()) | timestamp_custom("%H:%M", local=True)) }} {{ day_name[now().weekday()] }}, {{ now().day }} de {{ month_name[now().month -1]}} de {{ now().year }}.""",
@@ -146,59 +148,69 @@ PROMPT_TEMPLATE_DESCRIPTIONS = {
"user": { "prefix": "<|im_start|>user\n", "suffix": "<|im_end|>" },
"assistant": { "prefix": "<|im_start|>assistant\n", "suffix": "<|im_end|>" },
"tool": { "prefix": "<|im_start|>tool", "suffix": "<|im_end|>" },
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": "<|im_start|>assistant"
},
PROMPT_TEMPLATE_COMMAND_R: {
"system": { "prefix": "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", "suffix": "<|END_OF_TURN_TOKEN|>" },
"user": { "prefix": "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "suffix": "<|END_OF_TURN_TOKEN|>" },
"assistant": { "prefix": "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", "suffix": "<|END_OF_TURN_TOKEN|>" },
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
},
PROMPT_TEMPLATE_ALPACA: {
"system": { "prefix": "", "suffix": "\n" },
"user": { "prefix": "### Instruction:\n", "suffix": "\n" },
"assistant": { "prefix": "### Response:\n", "suffix": "\n" },
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": "### Response:"
},
PROMPT_TEMPLATE_VICUNA: {
"system": { "prefix": "", "suffix": "\n" },
"user": { "prefix": "USER: ", "suffix": "" },
"assistant": { "prefix": "ASSISTANT: ", "suffix": "</s>" },
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": "ASSISTANT:"
},
PROMPT_TEMPLATE_NONE: {
"system": { "prefix": "", "suffix": "" },
"user": { "prefix": "", "suffix": "" },
"assistant": { "prefix": "", "suffix": "" },
"chain_of_thought": { "prefix": "", "suffix": ""},
"generation_prompt": ""
},
PROMPT_TEMPLATE_MISTRAL: {
"user": { "prefix": "<s>[INST] ", "suffix": " [/INST] " },
"assistant": { "prefix": "", "suffix": "</s>" },
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": ""
},
PROMPT_TEMPLATE_ZEPHYR: {
"system": { "prefix": "<|system|>\n", "suffix": "<|endoftext|>" },
"user": { "prefix": "<|user|>\n", "suffix": "<|endoftext|>" },
"assistant": { "prefix": "<|assistant|>\n", "suffix": "<|endoftext|>" },
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": "<|assistant|>\n"
},
PROMPT_TEMPLATE_ZEPHYR2: {
"system": { "prefix": "<|system|>\n", "suffix": "</s>" },
"user": { "prefix": "<|user|>\n", "suffix": "</s>" },
"assistant": { "prefix": "<|assistant|>\n", "suffix": "</s>" },
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": "<|assistant|>\n"
},
PROMPT_TEMPLATE_ZEPHYR3: {
"system": { "prefix": "<|system|>\n", "suffix": "<|end|>" },
"user": { "prefix": "<|user|>\n", "suffix": "<|end|>" },
"assistant": { "prefix": "<|assistant|>\n", "suffix": "<|end|>" },
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": "<|assistant|>\n"
},
PROMPT_TEMPLATE_LLAMA3: {
"system": { "prefix": "<|start_header_id|>system<|end_header_id|>\n\n", "suffix": "<|eot_id|>"},
"user": { "prefix": "<|start_header_id|>user<|end_header_id|>\n\n", "suffix": "<|eot_id|>"},
"assistant": { "prefix": "<|start_header_id|>assistant<|end_header_id|>\n\n", "suffix": "<|eot_id|>"},
"chain_of_thought": { "prefix": "<think>", "suffix": "</think>"},
"generation_prompt": "<|start_header_id|>assistant<|end_header_id|>\n\n"
}
}
@@ -297,6 +309,14 @@ DEFAULT_OPTIONS = types.MappingProxyType(
)
OPTIONS_OVERRIDES = {
"home-llama-3.2": {
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_LLAMA3,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX,
CONF_TOOL_FORMAT: TOOL_FORMAT_MINIMAL,
CONF_CONTEXT_LENGTH: 131072,
},
"home-3b-v3": {
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR,
@@ -383,5 +403,5 @@ OPTIONS_OVERRIDES = {
},
}
INTEGRATION_VERSION = "0.3.7"
INTEGRATION_VERSION = "0.3.8"
EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.3.5"

View File

@@ -16,7 +16,7 @@ import voluptuous as vol
from typing import Literal, Any, Callable
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent, ConversationEntity
from homeassistant.components import assist_pipeline, conversation as ha_conversation
from homeassistant.components import assist_pipeline, conversation as conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
@@ -24,7 +24,7 @@ from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL,
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, \
area_registry as ar, device_registry as dr
area_registry as ar, device_registry as dr, chat_session
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.event import async_track_state_change, async_call_later
@@ -121,14 +121,10 @@ from .const import (
TOOL_FORMAT_REDUCED,
TOOL_FORMAT_MINIMAL,
ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
CONF_BACKEND_TYPE,
DEFAULT_BACKEND_TYPE,
BACKEND_TYPE_LLAMA_HF,
BACKEND_TYPE_LLAMA_EXISTING,
BACKEND_TYPE_TEXT_GEN_WEBUI,
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
)
# make type checking work for llama-cpp-python without importing it directly at runtime
@@ -147,7 +143,7 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
hass.data[DOMAIN][entry.entry_id] = entry
# call update handler
agent: LocalLLMAgent = ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
agent: LocalLLMAgent = entry.runtime_data
await hass.async_add_executor_job(agent._update_options)
return True
@@ -155,42 +151,50 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback) -> bool:
"""Set up Local LLM Conversation from a config entry."""
def create_agent(backend_type):
agent_cls = None
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
agent_cls = LlamaCppAgent
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
agent_cls = GenericOpenAIAPIAgent
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
agent_cls = TextGenerationWebuiAgent
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
agent_cls = LlamaCppPythonAPIAgent
elif backend_type == BACKEND_TYPE_OLLAMA:
agent_cls = OllamaAPIAgent
return agent_cls(hass, entry)
# create the agent in an executor job because the constructor calls `open()`
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
agent = await hass.async_add_executor_job(create_agent, backend_type)
# call load model
await agent._async_load_model(entry)
# handle updates to the options
entry.async_on_unload(entry.add_update_listener(update_listener))
async_add_entities([agent])
# register the agent entity
async_add_entities([entry.runtime_data])
return True
def _convert_content(
chat_content: conversation.Content
) -> dict[str, str]:
"""Create tool response content."""
role_name = None
if isinstance(chat_content, conversation.ToolResultContent):
role_name = "tool"
elif isinstance(chat_content, conversation.AssistantContent):
role_name = "assistant"
elif isinstance(chat_content, conversation.UserContent):
role_name = "user"
elif isinstance(chat_content, conversation.SystemContent):
role_name = "system"
else:
raise ValueError(f"Unexpected content type: {type(chat_content)}")
return { "role": role_name, "message": chat_content.content }
def _convert_content_back(
agent_id: str,
message_history_entry: dict[str, str]
) -> conversation.Content:
if message_history_entry["role"] == "tool":
return conversation.ToolResultContent(content=message_history_entry["message"])
if message_history_entry["role"] == "assistant":
return conversation.AssistantContent(agent_id=agent_id, content=message_history_entry["message"])
if message_history_entry["role"] == "user":
return conversation.UserContent(content=message_history_entry["message"])
if message_history_entry["role"] == "system":
return conversation.SystemContent(content=message_history_entry["message"])
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
"""Base Local LLM conversation agent."""
hass: HomeAssistant
entry_id: str
history: dict[str, list[dict]]
in_context_examples: list[dict]
_attr_has_entity_name = True
@@ -202,7 +206,6 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
self.hass = hass
self.entry_id = entry.entry_id
self.history = {}
self.backend_type = entry.data.get(
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
@@ -210,7 +213,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
ha_conversation.ConversationEntityFeature.CONTROL
conversation.ConversationEntityFeature.CONTROL
)
self.in_context_examples = None
@@ -223,11 +226,11 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
assist_pipeline.async_migrate_engine(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
ha_conversation.async_set_agent(self.hass, self.entry, self)
conversation.async_set_agent(self.hass, self.entry, self)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
ha_conversation.async_unset_agent(self.hass, self.entry)
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
def _load_icl_examples(self, filename: str):
@@ -253,7 +256,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
def _update_options(self):
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
ha_conversation.ConversationEntityFeature.CONTROL
conversation.ConversationEntityFeature.CONTROL
)
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
@@ -304,6 +307,19 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
self, user_input: ConversationInput
) -> ConversationResult:
"""Process a sentence."""
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
@@ -323,8 +339,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, there was a problem compiling the service call regex: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
llm_api: llm.APIInstance | None = None
@@ -338,7 +355,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=ha_conversation.DOMAIN,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
)
@@ -353,14 +370,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
response=intent_response, conversation_id=user_input.conversation_id
)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
conversation = self.history[conversation_id] if remember_conversation else [self.history[conversation_id][0]]
else:
conversation_id = ulid.ulid()
conversation = []
message_history = [ _convert_content(content) for content in chat_log.content ]
if len(conversation) == 0 or refresh_system_prompt:
# re-generate prompt if necessary
if len(message_history) == 0 or refresh_system_prompt:
try:
message = self._generate_system_prompt(raw_prompt, llm_api)
except TemplateError as err:
@@ -371,24 +384,20 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"Sorry, I had a problem with my template: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
system_prompt = { "role": "system", "message": message }
if len(conversation) == 0:
conversation.append(system_prompt)
if not remember_conversation:
self.history[conversation_id] = conversation
if len(message_history) == 0:
message_history.append(system_prompt)
else:
conversation[0] = system_prompt
conversation.append({"role": "user", "message": user_input.text})
message_history[0] = system_prompt
# generate a response
try:
_LOGGER.debug(conversation)
response = await self._async_generate(conversation)
_LOGGER.debug(message_history)
response = await self._async_generate(message_history)
_LOGGER.debug(response)
except Exception as err:
@@ -400,25 +409,28 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"Sorry, there was a problem talking to the backend: {repr(err)}",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
# remove end of text token if it was returned
response = response.replace(template_desc["assistant"]["suffix"], "")
conversation.append({"role": "assistant", "message": response})
# remove think blocks
response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL)
message_history.append({"role": "assistant", "message": response})
if remember_conversation:
if remember_num_interactions and len(conversation) > (remember_num_interactions * 2) + 1:
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
for i in range(0,2):
conversation.pop(1)
self.history[conversation_id] = conversation
message_history.pop(1)
# chat_log.content = [_convert_content_back(user_input.agent_id, message_history_entry) for message_history_entry in message_history ]
if llm_api is None:
# return the output without messing with it if there is no API exposed to the model
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.strip())
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
tool_response = None
@@ -459,7 +471,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
_LOGGER.info(f"calling tool: {block}")
@@ -503,20 +515,20 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"I'm sorry! I encountered an error calling the tool. See the logs for more info.",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
# handle models that generate a function call and wait for the result before providing a response
if self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT) and tool_response is not None:
try:
conversation.append({"role": "tool", "message": json.dumps(tool_response)})
message_history.append({"role": "tool", "message": json.dumps(tool_response)})
except:
conversation.append({"role": "tool", "message": "No tools were used in this response."})
message_history.append({"role": "tool", "message": "No tools were used in this response."})
# generate a response based on the tool result
try:
_LOGGER.debug(conversation)
to_say = await self._async_generate(conversation)
_LOGGER.debug(message_history)
to_say = await self._async_generate(message_history)
_LOGGER.debug(to_say)
except Exception as err:
@@ -528,17 +540,17 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"Sorry, there was a problem talking to the backend: {repr(err)}",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
conversation.append({"role": "assistant", "message": response})
conversation.append({"role": "assistant", "message": to_say})
message_history.append({"role": "assistant", "message": response})
message_history.append({"role": "assistant", "message": to_say})
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(to_say.strip())
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
@@ -813,6 +825,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
all_services = []
scripts_added = False
for domain in domains:
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
continue
# scripts show up as individual services
if domain == "script" and not scripts_added:
all_services.extend([
@@ -825,6 +840,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
continue
for name, service in service_dict.get(domain, {}).items():
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
continue
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
service_schema = vol.Schema({
@@ -1225,7 +1243,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
def _chat_completion_params(self, conversation: dict) -> (str, dict):
request_params = {}
api_base_path = self.entry.options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/chat/completions"
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
@@ -1234,7 +1252,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
def _completion_params(self, conversation: dict) -> (str, dict):
request_params = {}
api_base_path = self.entry.options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/completions"
request_params["prompt"] = self._format_prompt(conversation)

View File

@@ -1,7 +1,7 @@
{
"domain": "llama_conversation",
"name": "Local LLM Conversation",
"version": "0.3.7",
"version": "0.3.8",
"codeowners": ["@acon96"],
"config_flow": true,
"dependencies": ["conversation"],

View File

@@ -166,13 +166,16 @@ def install_llama_cpp_python(config_dir: str):
return True
platform_suffix = platform.machine()
# remap other names for architectures to the names we use
if platform_suffix == "arm64":
platform_suffix = "aarch64"
if platform_suffix == "i386" or platform_suffix == "x86_64":
platform_suffix = "amd64"
runtime_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
instruction_extensions_suffix = ""
if platform_suffix == "amd64" or platform_suffix == "i386":
if platform_suffix == "amd64":
instruction_extensions_suffix = "-noavx"
try:

View File

@@ -0,0 +1,9 @@
# types from Home Assistant
homeassistant>=2024.6.1
hassil
home-assistant-intents
# testing requirements
pytest
pytest-asyncio
pytest-homeassistant-custom-component

View File

@@ -0,0 +1,2 @@
huggingface-hub>=0.23.0
webcolors>=24.8.0

View File

@@ -569,7 +569,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
)
question = question_template.replace("<device_name>", chosen_devices[0]["description"])
answer = answer_template.replace("<device_name>", chosen_devices[0]["description"])
answer = [ answer_template.replace("<device_name>", chosen_devices[0]["description"]) ]
else:
question = question_template
answers = []
@@ -583,7 +583,7 @@ def generate_templated_example(template: dict, persona: str, max_devices: int =
)
answers.append(answer.replace(f"<device_name>", chosen_devices[i]["description"]))
answer = []
answer: list[str] = []
for word in and_words:
answer.append(f" {word} ".join(answers))
@@ -1151,7 +1151,7 @@ def load_dataset_piles(language):
# TODO: answer questions about more than one thing in the state list at once
# TODO: add examples for rooms/groups of devices. i.e. "turn off all the lights in the kitchen"
# TODO: add time, weather, and calendar/reminders (next 3 events?)
def main():
def main(args=None):
parser = argparse.ArgumentParser(description="Generate the full dataset from the CSV piles")
parser.add_argument("--sample", action="store_true", help="Set this flag to enable generation of the train dataset.")
parser.add_argument("--test", action="store_true", help="Set this flag to enable generation of the train dataset.")
@@ -1171,7 +1171,7 @@ def main():
dataset_format_group.add_argument('--raw_corpus', action='store_const', const='raw', dest='format')
dataset_format_group.add_argument('--sharegpt', action='store_const', const='sharegpt', dest='format')
args = parser.parse_args()
args = parser.parse_args(args=args)
if not args.sample and not args.train and not args.test and not args.merge and not args.dpo:
parser.print_usage()

6
data/requirements.txt Normal file
View File

@@ -0,0 +1,6 @@
datasets>=3.2.0
webcolors>=1.13
pandas>=2.2.3
deep-translator>=1.11.4
langcodes>=3.5.0
babel==2.15.0

View File

@@ -6,7 +6,7 @@ This integration allows for full customization of the system prompt using Home A
The default system prompt for non-fine tuned models is:
```
You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", True, "")) }}
Tools: {{ tools | to_json }}
Devices:
{% for device in devices | selectattr('area_id', 'none'): %}

View File

@@ -230,6 +230,22 @@ python3 train.py \
--save_steps 50 --save_total_limit 10 --eval_steps 100 --logging_steps 2
```
#### Llama 3.2 3B Instruct
```
python3 generate_home_assistant_data.py --train --test --large --sharegpt --language english german french spanish
python3 train.py \
--run_name Home-Llama-3.2-3B-rev1 \
--base_model meta-llama/Llama-3.2-3B-Instruct \
--bf16 \
--train_dataset data/home_assistant_train.jsonl \
--test_dataset data/home_assistant_test.jsonl \
--learning_rate 1e-5 --learning_rate_warmup 0.03 --batch_size 64 --epochs 1 \
--micro_batch_size 2 \
--ctx_size 2048 \
--save_steps 200 --save_total_limit 3 --eval_steps 100 --logging_steps 2
```
### Problems
Training a model is not an easy thing. Therefore, we are not able to cover all the problems encountered during training. Here we will try to add known problems and solutions on how to deal with them.

View File

@@ -83,7 +83,8 @@ def generate(model, tokenizer, prompts):
return text
def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_size, use_icl):
split = trained_tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")[0].replace(trained_tokenizer.bos_token, "")
# split = trained_tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")[0].replace(trained_tokenizer.bos_token, "")
split = "<|start_header_id|>assistant<|end_header_id|>"
print("Evaluating...")
correct_answers = 0
@@ -138,7 +139,7 @@ def evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_siz
output = generate(trained_model, trained_tokenizer, prompts)
for model_output, expected_response in zip(output, expected_responses):
response = model_output.replace(trained_tokenizer.pad_token, "").replace(trained_tokenizer.eos_token, "").split(split)[1]
response = model_output.replace(trained_tokenizer.pad_token, "").replace(trained_tokenizer.eos_token, "").split(split)[1].strip()
expected_service_calls = []
@@ -279,9 +280,19 @@ def load_model(model_name, is_lora, is_hf, load_in_8bit, checkpoint_name):
padding_side='left',
)
eos_token_id_to_use = trained_model.config.eos_token_id
if len(eos_token_id_to_use) > 0:
eos_token_id_to_use = trained_model.config.eos_token_id[0]
pad_token_id_to_use = trained_model.config.pad_token_id
if not trained_tokenizer.pad_token:
trained_tokenizer.pad_token = trained_tokenizer.eos_token
if len(trained_model.config.eos_token_id) > 0:
pad_token_id_to_use = trained_model.config.eos_token_id[0]
else:
pad_token_id_to_use = trained_model.config.eos_token_id
trained_model.generation_config = GenerationConfig(
max_new_tokens=128,
use_cache=True,
@@ -290,9 +301,9 @@ def load_model(model_name, is_lora, is_hf, load_in_8bit, checkpoint_name):
top_k=40,
top_p=1.0,
repetition_penalty=1.15,
# eos_token_id=trained_model.config.eos_token_id,
eos_token_id=128009,
pad_token_id=trained_model.config.pad_token_id if trained_model.config.pad_token_id else trained_model.config.eos_token_id,
eos_token_id=trained_model.config.eos_token_id,
# eos_token_id=128009,
pad_token_id=pad_token_id_to_use,
)
return trained_model, trained_tokenizer
@@ -350,7 +361,7 @@ def main():
print(f"Evaluation already exists for {output_folder}. Skipping...")
continue
trained_model, trained_tokenizer = load_model(args.model, args.lora, ckpt, False)
trained_model, trained_tokenizer = load_model(args.model, args.lora, False, False, ckpt)
evaluate(output_folder, trained_model, trained_tokenizer, dataset, batch_size, False)

View File

@@ -1,6 +1,6 @@
{
"name": "Local LLM Conversation",
"homeassistant": "2024.12.3",
"homeassistant": "2025.4.1",
"content_in_root": false,
"render_readme": true
}

View File

@@ -1,28 +1,13 @@
# training + dataset requirements
transformers
tensorboard
datasets
peft
bitsandbytes
trl
webcolors
pandas
transformers>=4.48.3
tensorboard>=2.18.0
datasets>=3.2.0
peft>=0.14.0
bitsandbytes>=0.45.2
trl>=0.14.0
webcolors>=1.13
pandas>=2.2.3
# flash-attn
sentencepiece
deep-translator
langcodes
babel==2.15.0
# integration requirements
huggingface-hub>=0.23.0
webcolors>=24.8.0
# types from Home Assistant
homeassistant>=2024.6.1
hassil
home-assistant-intents
# testing requirements
pytest
pytest-asyncio
pytest-homeassistant-custom-component
sentencepiece>=0.2.0
deep-translator>=1.11.4
langcodes>=3.5.0
babel>=2.15.0

5
scripts/convert_and_quantize.sh Normal file → Executable file
View File

@@ -11,8 +11,7 @@ fi
echo "Converting to GGUF..."
if [ ! -f "./models/$MODEL_NAME/$MODEL_NAME.f16.gguf" ]; then
$LLAMA_CPP/convert.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
# $LLAMA_CPP/convert-hf-to-gguf.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
$LLAMA_CPP/convert_hf_to_gguf.py --outfile ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf --outtype f16 ./models/$MODEL_NAME/
else
echo "Converted model for already exists. Skipping..."
fi
@@ -23,7 +22,7 @@ for QUANT in "${DESIRED_QUANTS[@]}"
do
QUANT_LOWER=$(echo "$QUANT" | awk '{print tolower($0)}')
if [ ! -f "./models/$MODEL_NAME/$MODEL_NAME.$QUANT_LOWER.gguf" ]; then
$LLAMA_CPP/build/bin/quantize ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf ./models/$MODEL_NAME/$MODEL_NAME.$QUANT_LOWER.gguf $QUANT
$LLAMA_CPP/build/bin/llama-quantize ./models/$MODEL_NAME/$MODEL_NAME.f16.gguf ./models/$MODEL_NAME/$MODEL_NAME.$QUANT_LOWER.gguf $QUANT
else
echo "Quantized model for '$QUANT' already exists. Skipping..."
fi

111
train.ipynb Normal file
View File

@@ -0,0 +1,111 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "469a9a97-0f6b-475f-8aef-a796c1c5244f",
"metadata": {},
"outputs": [],
"source": [
"%pip install -r requirements.txt\n",
"\n",
"import os, re\n",
"from train import TrainingRunArguments, do_training_run\n",
"\n",
"def get_next_run_name(model):\n",
" pattern = re.compile(model + r\"-rev(\\d+)$\")\n",
" max_rev = 0\n",
"\n",
" for folder in os.listdir(\"models/\"):\n",
" match = pattern.search(folder)\n",
" if match:\n",
" max_rev = max(max_rev, int(match.group(1)))\n",
"\n",
" return f\"{model}-rev{max_rev + 1}\"\n",
"\n",
"os.environ[\"HF_HOME\"] = \"/workspace/\""
]
},
{
"cell_type": "markdown",
"id": "ed0807bf",
"metadata": {},
"source": [
"## Generate Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aaafce74",
"metadata": {},
"outputs": [],
"source": [
"%pip install -r data/requirements.txt\n",
"from data.generate_home_assistant_data import main as generate_data\n",
"\n",
"generate_data([\"--train\", \"--test\", \"--large\", \"--sharegpt\", \"--language\", \"english\", \"german\", \"french\", \"spanish\"])"
]
},
{
"cell_type": "markdown",
"id": "ff011772",
"metadata": {},
"source": [
"## Llama 3.2 1B"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48839ce2-1939-4d7f-817c-97b047bafd42",
"metadata": {},
"outputs": [],
"source": [
"# python3 train.py \\\n",
"# --run_name Home-Llama-3.2-1B-rev1 \\\n",
"# --base_model meta-llama/Llama-3.2-1B-Instruct \\\n",
"# --bf16 \\\n",
"# --train_dataset data/home_assistant_train.jsonl \\\n",
"# --test_dataset data/home_assistant_test.jsonl \\\n",
"# --learning_rate 2e-5 --learning_rate_warmup 0.03 --batch_size 64 --epochs 1 \\\n",
"# --micro_batch_size 2 \\\n",
"# --ctx_size 2048 \\\n",
"# --save_steps 200 --save_total_limit 1 --eval_steps 200 --logging_steps 2\n",
"\n",
"do_training_run(TrainingRunArguments(\n",
" run_name=get_next_run_name(\"Home-Llama-3.2-1B\"),\n",
" base_model=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
" bf16=True,\n",
" train_dataset=\"data/home_assistant_train.jsonl\",\n",
" test_dataset=\"data/home_assistant_test.jsonl\",\n",
" learning_rate=2e-5, learning_rate_warmup=0.03, \n",
" batch_size=64, micro_batch_size=2, epochs=1,\n",
" ctx_size=2048,\n",
" save_steps=200, save_total_limit=1, eval_steps=200, logging_steps=2,\n",
"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

638
train.py
View File

@@ -6,17 +6,17 @@ import torch
import os
import random
import time
import shutil
import traceback
from torch.utils.data import SequentialSampler, Subset, RandomSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, \
PreTrainedTokenizerFast, HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback, BitsAndBytesConfig
from transformers.trainer_utils import EvalPrediction
HfArgumentParser, GPTQConfig, AutoConfig, TrainerCallback, BitsAndBytesConfig
from transformers.integrations.integration_utils import TensorBoardCallback
from datasets import load_dataset, Dataset
from datasets import load_dataset
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, Sized, Iterator
from typing import Dict, Optional, Sequence
IS_DDP_ENABLED = "LOCAL_RANK" in os.environ
MULTI_GPU_WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1"))
MULTI_GPU_RANK = int(os.environ.get("RANK", "0"))
IS_MULTI_GPU = os.environ.get("RANK") != None
@@ -28,22 +28,25 @@ class TrainingRunArguments:
base_model: str = field(metadata={"help": "The base model to load for fine-tuning"})
train_dataset: str = field(metadata={"help": "The JSON file containing the training dataset"})
test_dataset: str = field(default=None, metadata={"help": "The JSON file containing the evaluation dataset"})
dataset_processing_threads: int = field(default=None, metadata={"help": "The number of threads to use to tokenize the dataset"})
ctx_size: int = field(default=2048, metadata={"help": "The number of tokens to pad & truncate the input examples to"})
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp16"})
bf16: bool = field(default=False, metadata={"help": "If set, the model will the loaded and trained in bf16 instead of fp32"})
batch_size: int = field(default=8, metadata={"help": "The simulated 'batch size' that we will train on. will tweak gradient accumulations steps"})
micro_batch_size: int = field(default=2, metadata={"help": "The actual batch size that will fit into VRAM on this machine"})
epochs: int = field(default=1, metadata={"help": "The number of times to train the model on each example"})
learning_rate: float = field(default=1e-5, metadata={"help": "The starting learning rate (speed at which the model trains)"})
learning_rate_schedule: str = field(default="cosine", metadata={"help": "How fast the learning rate is reduced during training"})
learning_rate_warmup: float = field(default=0.0, metadata={"help": "The starting learning rate (speed at which the model trains)"})
weight_decay: float = field(default=0.1, metadata={"help": ""})
gradient_clip: float = field(default=1.0, metadata={"help": ""})
weight_decay: float = field(default=0.1, metadata={"help": "Weight Decay rate for regularization. Rate to reduce all neuron weights towards zero."})
# dropout: float = field(default=0.01, metadata={"help": "Dropout percent for regularization. Determines the fraction of neurons randomly deactivated during training."})
gradient_clip: float = field(default=1.0, metadata={"help": "Maximum gradient norm for clipping to prevent exploding gradients during training."})
resume_from_checkpoint: str = field(default="", metadata={"help": "The name of the checkpoint to resume training from"})
eval_steps: int = field(default=200, metadata={"help": "The number of steps in between evaluations of the model; set to -1 to evaluate every epoch"})
save_steps: int = field(default=-1, metadata={"help": "The number of steps in between model checkpoints; set to -1 to save every epoch"})
save_total_limit: int = field(default=1, metadata={"help": "The number of recent checkpoints of the model to save (not including the final model)"})
logging_steps: int = field(default=5, metadata={"help": "Sets the number of steps in between log output for the training run"})
group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding"})
group_by_length: bool = field(default=False, metadata={"help": "If enabled, the training data will be grouped by length to optimize use of padding. Runs from longest to shortest examples."})
gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing to saves VRAM at the cost of re-computing activations during the backwards pass"})
pre_allocate_cuda_buffers: bool = field(default=True, metadata={"help": "If enabled, runs a forward and backward pass on the model before training to force pytorch to allocate the correct size CUDA buffers up front"})
# Quantization
@@ -60,21 +63,23 @@ class TrainingRunArguments:
lora_modules_to_save: str = field(default=None, metadata={"help": "Additional modules to save"})
lora_merge: bool = field(default=False, metadata={"help": "If set, the Lora will be merged back into the base model an saved"})
# dpo config
dpo: bool = field(default=False, metadata={"help": "If set, performs Direct Preference Optimization instead of Supervised Fine Tuning"})
beta: float = field(default=0.1, metadata={"help": "The implicit reward value used during DPO training"})
dpo_loss: str = field(default="sigmoid", metadata={"help": "The loss type to use during DPO training"})
# token options
add_pad_token: bool = field(default=False, metadata={"help": "If set, a pad token will be added to the tokenizer's vocabulary"})
add_chatml_tokens: bool = field(default=False, metadata={"help": "If set, tokens for the ChatML format will be added specifically"})
add_chatml_prompt_template: bool = field(default=False, metadata={"help": "If set, the ChatML prompt template will be set as the model's Jinja2 template"})
gradient_checkpointing: bool = field(default=False, metadata={"help": "Enables gradient checkpointing which saves quite a lot of VRAM"})
prefix_ids: str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
suffix_ids: str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognize response."})
# custom trainer tweaks
sync_to_bucket: str = field(default=None, metadata={"help": "If set, checkpoints will be synced to the s3 bucket specified by this argument"})
bucket_save_limit: int = field(default=None, metadata={"help": "The number of recent checkpoints of the model to save in S3 (not including the final model)"})
flops_baseline: str = field(default=None, metadata={"help": "The baseline flops for the GPUs used for the training run. Outputs MFU"})
prefix_ids:str = field(default=None, metadata={"help": "Determine the prefix tokens that surround the response from the assistant for SFT if model can not correctly recognise response."})
suffix_ids:str = field(default=None, metadata={"help": "Determine the suffix tokens that surround the response from the assistant for SFT if model can not correctly recognise response."})
class UploadToS3Callback(TrainerCallback):
def __init__(self, s3_bucket, s3_prefix, save_total_limit=None):
@@ -85,10 +90,11 @@ class UploadToS3Callback(TrainerCallback):
self.save_total_limit = save_total_limit
def on_save(self, args, state, control, **kwargs):
output_dir = kwargs['output_dir']
checkpoint = os.path.basename(output_dir)
# Upload current checkpoint
checkpoint = f"checkpoint-{state.global_step}"
output_dir = f"{args.output_dir}/{checkpoint}"
for root, dirs, files in os.walk(output_dir):
for file in files:
local_path = os.path.join(root, file)
@@ -96,7 +102,7 @@ class UploadToS3Callback(TrainerCallback):
self.s3_client.upload_file(local_path, self.s3_bucket, s3_path)
print(f"Uploaded {local_path} to s3://{self.s3_bucket}/{s3_path}")
# Manage checkpoints in S3
# Delete prior checkpoints from S3
if self.save_total_limit:
s3_checkpoints = self.list_s3_checkpoints()
if len(s3_checkpoints) > self.save_total_limit:
@@ -105,18 +111,9 @@ class UploadToS3Callback(TrainerCallback):
for checkpoint in to_delete:
self.delete_checkpoint_from_s3(checkpoint)
# Clean local checkpoints, keeping only the most recent
all_checkpoints = [os.path.join(args.output_dir, d) for d in os.listdir(args.output_dir) if os.path.isdir(os.path.join(args.output_dir, d))]
if all_checkpoints:
latest_checkpoint = max(all_checkpoints, key=os.path.getmtime)
for checkpoint_dir in all_checkpoints:
if checkpoint_dir != latest_checkpoint:
shutil.rmtree(checkpoint_dir)
print(f"Deleted local checkpoint {checkpoint_dir}")
def list_s3_checkpoints(self):
paginator = self.s3_client.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=self.s3_bucket, Prefix=self.s3_prefix, Delimiter='/')
page_iterator = paginator.paginate(Bucket=self.s3_bucket, Prefix=self.s3_prefix + '/', Delimiter='/')
return [prefix.get('Prefix').rstrip('/').split('/')[-1] for page in page_iterator for prefix in page.get('CommonPrefixes', [])]
def delete_checkpoint_from_s3(self, checkpoint_name):
@@ -145,153 +142,25 @@ class MFUCallback(TrainerCallback):
self.start_time = current_time
self.last_total_flos = state.total_flos
parser = HfArgumentParser([TrainingRunArguments])
training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1:
raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq")
if IS_MASTER_PROCESS:
print(f"Loading model '{training_run_args.base_model}'...")
model_kwargs = {}
if training_run_args.load_in_8bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif training_run_args.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
elif training_run_args.load_as_gptq:
model_kwargs["quantization_config"] = GPTQConfig(bits=4, disable_exllama=True)
if training_run_args.bf16:
model_kwargs["torch_dtype"] = torch.bfloat16
else:
model_kwargs["torch_dtype"] = torch.float16
# model_kwargs["resid_pdrop"] = 0.0
# model_kwargs["revision"] = "accfee56d8988cae60915486310362db5831b1bd"
model_kwargs["use_cache"] = False
def ddp_print(*args, **kwargs):
if not IS_DDP_ENABLED or IS_MASTER_PROCESS:
print(*args, **kwargs)
def find_max_vram(min_buffer_mib=800):
max_memory = {}
for i in range(torch.cuda.device_count()):
total_mem = (torch.cuda.get_device_properties(i).total_memory / (1024 * 1024))
suggestion = round((total_mem - 1000) / 1000) * 1000
suggestion = min(suggestion, total_mem - min_buffer_mib)
gpu_properties = torch.cuda.get_device_properties(i)
total_memory_mib = (gpu_properties.total_memory / (1000 * 1000))
suggestion = max(total_memory_mib - 1000, min_buffer_mib)
if IS_MASTER_PROCESS:
print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
ddp_print(f"GPU {i}: {gpu_properties.name}, Total Memory: {gpu_properties.total_memory / (1024**3):.2f} GB")
ddp_print(f"Model will target using {suggestion}MiB of VRAM on GPU {i}")
max_memory[i] = f'{suggestion}MiB'
return max_memory
if "LOCAL_RANK" not in os.environ:
model_kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(
training_run_args.base_model,
trust_remote_code=True,
max_memory=find_max_vram(),
**model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, trust_remote_code=True)
if training_run_args.add_pad_token:
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
model.config.pad_token_id = tokenizer.pad_token_id
if training_run_args.add_chatml_tokens:
tokenizer.add_special_tokens({
'bos_token': '<|im_start|>',
'eos_token': '<|im_end|>'
})
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
if training_run_args.add_chatml_prompt_template:
tokenizer.chat_template = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
if model.get_input_embeddings().num_embeddings < embeddings_len:
model.resize_token_embeddings(embeddings_len)
else:
model.tie_weights()
# model.tie_weights()
original_model = model
peft_config = None
if training_run_args.use_lora:
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
if IS_MASTER_PROCESS:
print("Creating LoRA for model...")
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=training_run_args.lora_rank,
lora_alpha=training_run_args.lora_alpha,
lora_dropout=training_run_args.lora_dropout,
target_modules=target_modules,
modules_to_save=modules_to_save,
)
if training_run_args.load_in_8bit or training_run_args.load_in_4bit or training_run_args.load_as_gptq:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_run_args.gradient_checkpointing
)
model = get_peft_model(model, peft_config)
model.enable_input_require_grads()
model.print_trainable_parameters()
base_dir = "loras" if training_run_args.use_lora else "models"
model_dir = f"./{base_dir}/{training_run_args.run_name}"
training_kwargs = {}
if training_run_args.test_dataset:
training_kwargs.update({
"per_device_eval_batch_size": training_run_args.micro_batch_size,
"eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
"bf16_full_eval": training_run_args.bf16,
})
training_args = TrainingArguments(
per_device_train_batch_size=training_run_args.micro_batch_size,
gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size,
gradient_checkpointing=training_run_args.gradient_checkpointing,
weight_decay=training_run_args.weight_decay,
max_grad_norm=training_run_args.gradient_clip,
save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"),
save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None),
save_safetensors=True,
logging_steps=training_run_args.logging_steps,
output_dir=model_dir,
num_train_epochs=training_run_args.epochs,
save_total_limit=training_run_args.save_total_limit,
report_to='none',
learning_rate=training_run_args.learning_rate,
lr_scheduler_type=training_run_args.learning_rate_schedule,
warmup_ratio=training_run_args.learning_rate_warmup,
log_level="info",
bf16=training_run_args.bf16,
group_by_length=training_run_args.group_by_length,
# include_num_input_tokens_seen=True,
**training_kwargs,
)
class DataCollatorForSupervisedFineTuning(object):
"""Collate examples for supervised fine-tuning."""
@@ -413,14 +282,8 @@ class DataCollatorForSupervisedFineTuning(object):
attention_mask=input_ids.ne(self.tokenizer.pad_token_id or self.tokenizer.eos_token_id),
)
if IS_MASTER_PROCESS:
print("Loading dataset...")
data_files = { "train": training_run_args.train_dataset }
if training_run_args.test_dataset:
data_files["test"] = training_run_args.test_dataset
datasets = load_dataset("json", data_files=data_files)
def tokenize_raw_example(batch):
def tokenize_raw_example(batch, tokenizer=None, training_run_args=None):
return tokenizer(
text=batch["text"],
max_length=training_run_args.ctx_size,
@@ -428,7 +291,7 @@ def tokenize_raw_example(batch):
add_special_tokens=False,
)
def tokenize_sharegpt_example(batch):
def tokenize_sharegpt_example(batch, tokenizer=None, training_run_args=None):
# TODO: figure out how to properly batch this
result = []
for example in batch["conversations"]:
@@ -443,7 +306,7 @@ def tokenize_sharegpt_example(batch):
return {"input_ids": result}
def template_dpo_example(batch):
def template_dpo_example(batch, tokenizer=None, training_run_args=None):
# TODO: figure out how to properly batch this
result = []
for example in zip(batch["system"], batch["question"]):
@@ -463,24 +326,6 @@ def template_dpo_example(batch):
return {"prompt": result}
training_callbacks = []
if training_run_args.sync_to_bucket:
training_callbacks.append(UploadToS3Callback(
s3_bucket=training_run_args.sync_to_bucket,
s3_prefix=training_run_args.run_name,
save_total_limit=training_run_args.save_total_limit
))
if training_run_args.flops_baseline:
# A100 GPU bfloat16 peak flops is 312 TFLOPS (312e12)
# 4090 GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11)
# 3090 GPU bfloat16 peak flops is 71 TFLOPS (71e12)
training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline)))
# log to tensorboard (but after MFU)
training_callbacks.append(TensorBoardCallback())
class CustomSFTTrainer(Trainer):
"""Implement different training tweaks"""
@@ -514,7 +359,7 @@ class CustomSFTTrainer(Trainer):
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
"""
Saw this in the chinchilla paper. It says not to go over 25% overshoot
Should speed up training by skipping the final fine tuning part that doesn't affect accuracy much
Should improve training efficiency by skipping the final fine tuning part that doesn't affect accuracy much
"""
return super().create_scheduler(int(num_training_steps * self.learning_rate_overshoot), optimizer=optimizer)
@@ -523,8 +368,8 @@ class CustomSFTTrainer(Trainer):
examples_length = len(inputs["input_ids"][0])
batch_size = len(inputs["input_ids"])
# mfu is approximated using thoughtput and param count
# the number of paramters is approximately the number of multiply-accumulates (MAC) in the network
# mfu is approximated using throughput and param count
# the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
# each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
# there are 3 passes of a NN (fwd, bwd, delta) - we multiply by 3 ie 2 * 3 * n_param
# this gets us FLOPs / token
@@ -538,128 +383,303 @@ class CustomSFTTrainer(Trainer):
result = (3 * flops_per_seq + 3 * attn_flops_per_seq) * batch_size
return result
if not training_run_args.dpo:
if IS_MASTER_PROCESS:
print("Tokenizing datasets...")
if "text" in datasets["train"].column_names:
tokenize_function = tokenize_raw_example
columns_to_remove = ["text"]
elif "conversations" in datasets["train"].column_names:
tokenize_function = tokenize_sharegpt_example
columns_to_remove = ["conversations"]
def do_training_run(training_run_args: TrainingRunArguments):
# validate args + build model kwargs
if sum([training_run_args.load_in_8bit, training_run_args.load_in_4bit, training_run_args.load_as_gptq]) > 1:
raise Exception("Please select exactly one of 'load_in_8bit', 'load_in_4bit', or 'load_as_gptq")
model_kwargs = {}
if training_run_args.load_in_8bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif training_run_args.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
elif training_run_args.load_as_gptq:
model_kwargs["quantization_config"] = GPTQConfig(bits=4, disable_exllama=True)
if training_run_args.bf16:
model_kwargs["torch_dtype"] = torch.bfloat16
elif training_run_args.use_lora and "quantization_config" not in model_kwargs:
model_kwargs["torch_dtype"] = torch.float16
else:
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)")
# auto detect 'best' format with fallback to fp32
model_kwargs["torch_dtype"] = "auto"
# model_kwargs["resid_pdrop"] = training_run_args.dropout
model_kwargs["use_cache"] = False
if not IS_DDP_ENABLED:
model_kwargs["device_map"] = "auto"
# load the model
ddp_print(f"Loading model '{training_run_args.base_model}'...")
model = AutoModelForCausalLM.from_pretrained(
training_run_args.base_model,
max_memory=find_max_vram(),
token=os.environ.get("HF_TOKEN"),
**model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(training_run_args.base_model, token=os.environ.get("HF_TOKEN"))
# mess with tokens + prompt template
if training_run_args.add_pad_token:
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
model.config.pad_token_id = tokenizer.pad_token_id
if training_run_args.add_chatml_tokens:
tokenizer.add_special_tokens({
'bos_token': '<|im_start|>',
'eos_token': '<|im_end|>'
})
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
if training_run_args.add_chatml_prompt_template:
tokenizer.chat_template = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
# resize embeddings if added tokens require it
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
if model.get_input_embeddings().num_embeddings < embeddings_len:
model.resize_token_embeddings(embeddings_len)
else:
model.tie_weights()
# create LoRA model if config says so
original_model = model
peft_config = None
if training_run_args.use_lora:
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
ddp_print("Creating LoRA for model...")
target_modules = training_run_args.lora_modules.split(",") if training_run_args.lora_modules else None
modules_to_save = training_run_args.lora_modules_to_save.split(",") if training_run_args.lora_modules_to_save else None
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=training_run_args.lora_rank,
lora_alpha=training_run_args.lora_alpha,
lora_dropout=training_run_args.lora_dropout,
target_modules=target_modules,
modules_to_save=modules_to_save,
)
if training_run_args.load_in_8bit or training_run_args.load_in_4bit or training_run_args.load_as_gptq:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_run_args.gradient_checkpointing
)
model = get_peft_model(model, peft_config)
model.enable_input_require_grads()
model.print_trainable_parameters()
base_dir = "loras" if training_run_args.use_lora else "models"
model_dir = f"./{base_dir}/{training_run_args.run_name}"
# set up HuggingFace Trainer args
training_kwargs = {}
tokenized_test_dataset = None
num_proc = os.cpu_count() // MULTI_GPU_WORLD_SIZE
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove)
if training_run_args.test_dataset:
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc).remove_columns(columns_to_remove)
training_kwargs.update({
"per_device_eval_batch_size": training_run_args.micro_batch_size,
"eval_strategy": ("steps" if training_run_args.eval_steps != -1 else "epoch"),
"eval_steps": (training_run_args.eval_steps if training_run_args.eval_steps != -1 else None),
"bf16_full_eval": training_run_args.bf16,
})
training_args = TrainingArguments(
per_device_train_batch_size=training_run_args.micro_batch_size,
gradient_accumulation_steps=training_run_args.batch_size//training_run_args.micro_batch_size,
gradient_checkpointing=training_run_args.gradient_checkpointing,
weight_decay=training_run_args.weight_decay,
max_grad_norm=training_run_args.gradient_clip,
save_strategy=("steps" if training_run_args.save_steps != -1 else "epoch"),
save_steps=(training_run_args.save_steps if training_run_args.save_steps != -1 else None),
save_safetensors=True,
logging_steps=training_run_args.logging_steps,
output_dir=model_dir,
num_train_epochs=training_run_args.epochs,
save_total_limit=training_run_args.save_total_limit,
report_to='none',
learning_rate=training_run_args.learning_rate,
lr_scheduler_type=training_run_args.learning_rate_schedule,
warmup_ratio=training_run_args.learning_rate_warmup,
log_level="info",
bf16=training_run_args.bf16,
group_by_length=training_run_args.group_by_length,
# include_num_input_tokens_seen=True,
**training_kwargs,
)
# set up trainer callbacks
training_callbacks = []
if training_run_args.sync_to_bucket:
training_callbacks.append(UploadToS3Callback(
s3_bucket=training_run_args.sync_to_bucket,
s3_prefix=training_run_args.run_name,
save_total_limit=training_run_args.bucket_save_limit if training_run_args.bucket_save_limit else training_run_args.save_total_limit
))
if training_run_args.flops_baseline:
# A100 40/80GB GPU bfloat16 peak flops is 312 TFLOPS (312e12)
# 4090 24GB GPU bfloat16 peak flops is 165.2 TFLOPS (1652e11)
# A40 48GB GPU bfloat16 peak flops is 149.7 TFLOPS (149.7e11)
# 3090 24GB GPU bfloat16 peak flops is 71 TFLOPS (71e12)
training_callbacks.append(MFUCallback(peak_flops=float(training_run_args.flops_baseline)))
# log to tensorboard (but after MFU)
training_callbacks.append(TensorBoardCallback())
if not training_run_args.dpo:
ddp_print("Loading dataset...")
data_files = { "train": training_run_args.train_dataset }
if training_run_args.test_dataset:
data_files["test"] = training_run_args.test_dataset
datasets = load_dataset("json", data_files=data_files)
# prepare the dataset
ddp_print("Tokenizing datasets...")
if "text" in datasets["train"].column_names:
tokenize_function = tokenize_raw_example
columns_to_remove = ["text"]
elif "conversations" in datasets["train"].column_names:
tokenize_function = tokenize_sharegpt_example
columns_to_remove = ["conversations"]
else:
raise Exception("Unknown dataset input format (not raw corpus or sharegpt)")
tokenized_test_dataset = None
num_proc = None
if training_run_args.dataset_processing_threads:
num_proc = training_run_args.dataset_processing_threads // MULTI_GPU_WORLD_SIZE
tokenized_train_dataset = datasets["train"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
if training_run_args.test_dataset:
tokenized_test_dataset = datasets["test"].map(tokenize_function, batched=True, num_proc=num_proc, fn_kwargs={"tokenizer": tokenizer, "training_run_args": training_run_args}).remove_columns(columns_to_remove)
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)
ddp_print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens")
provided_prefix_ids = None
provided_suffix_ids = None
try:
if training_run_args.prefix_ids:
provided_prefix_ids = [ int(x) for x in training_run_args.prefix_ids.split(",") ]
if training_run_args.suffix_ids:
provided_suffix_ids = [ int(x) for x in training_run_args.suffix_ids.split(",") ]
except ValueError as ex:
print(f"Error parsing prefix_ids or suffix_ids: '{ex}'")
exit(-1)
trainer = CustomSFTTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_test_dataset,
data_collator=DataCollatorForSupervisedFineTuning(
tokenizer=tokenizer,
prefix_ids=provided_prefix_ids,
suffix_ids=provided_suffix_ids,
),
callbacks=training_callbacks,
)
else:
raise NotImplementedError("DPO Trainer doesn't work yet!")
# from trl import DPOTrainer
# max_prompt_length = 0
# train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) })
# test_dataset = None
# if training_run_args.test_dataset:
# test_dataset = datasets["test"]
# max_prompt_length = max(train_dataset["prompt_len"])
# print("Templating DPO Examples...")
# templated_test_dataset = None
# templated_train_dataset = train_dataset.map(template_dpo_example, batched=True).remove_columns(["system", "question"])
# if training_run_args.test_dataset:
# templated_test_dataset = datasets["test"].map(template_dpo_example, batched=True).remove_columns(["system", "question"])
# # tokenizer.model_input_names = [ "chosen_input_ids" ]
# # group_by_length doesn't work here
# # templated_train_dataset = templated_train_dataset.sort("prompt_len", reverse=True)
# training_args.length_column_name = "prompt_len"
# model.enable_input_require_grads()
# trainer = DPOTrainer(
# model,
# ref_model=None,
# # ref_model=original_model,
# peft_config=peft_config,
# args=training_args,
# beta=training_run_args.beta,
# loss_type=training_run_args.dpo_loss,
# train_dataset=templated_train_dataset,
# eval_dataset=templated_test_dataset,
# tokenizer=tokenizer,
# max_length=training_run_args.ctx_size,
# max_prompt_length=max_prompt_length,
# truncation_mode="keep_start",
# callbacks=training_callbacks,
# )
example_lengths = [ len(example) for example in tokenized_train_dataset["input_ids"] ]
tokens_in_train_set, longest_example = sum(example_lengths), max(example_lengths)
if IS_MASTER_PROCESS:
print(f"Train dataset has {int(tokens_in_train_set / 1000000)}M tokens. Longest Example: {longest_example} tokens")
provided_prefix_ids = None
provided_suffix_ids = None
try:
if training_run_args.prefix_ids:
provided_prefix_ids = [ int(x) for x in training_run_args.prefix_ids.split(",") ]
if training_run_args.suffix_ids:
provided_suffix_ids = [ int(x) for x in training_run_args.suffix_ids.split(",") ]
except ValueError as ex:
print(f"Error parsing prefix_ids or suffix_ids: '{ex}'")
trainer.train(resume_from_checkpoint=training_run_args.resume_from_checkpoint if training_run_args.resume_from_checkpoint else None)
if training_run_args.test_dataset:
trainer.evaluate_all()
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
if training_run_args.use_lora and training_run_args.lora_merge:
trainer.save_model() # save lora
merged_model = model.merge_and_unload(progressbar=True)
merged_model_dir = f"./models/{training_run_args.run_name}"
merged_model.save_pretrained(merged_model_dir, safe_serialization=True, max_shard_size="2GB")
tokenizer.save_pretrained(merged_model_dir)
else:
trainer.save_model()
tokenizer.save_pretrained(model_dir)
if training_run_args.sync_to_bucket:
import boto3
s3_client = boto3.client('s3')
for root, dirs, files in os.walk(model_dir):
for file in files:
local_path = os.path.join(root, file)
s3_path = os.path.join(training_run_args.run_name, os.path.relpath(local_path, start="."))
s3_client.upload_file(local_path, training_run_args.sync_to_bucket, s3_path)
print(f"Uploaded {local_path} to s3://{training_run_args.sync_to_bucket}/{s3_path}")
except Exception as ex:
if trainer.is_fsdp_enabled:
raise ex # this doesn't play nice with FSDP so don't even try
traceback.print_exc()
if input("Something bad happened! Try and save it? (Y/n) ").lower().startswith("y"):
trainer._save_checkpoint(model, None)
print("Saved Checkpoint!")
exit(-1)
data_collator = DataCollatorForSupervisedFineTuning(
tokenizer=tokenizer,
prefix_ids=provided_prefix_ids,
suffix_ids=provided_suffix_ids,
)
if __name__ == "__main__":
parser = HfArgumentParser([TrainingRunArguments])
training_run_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
trainer = CustomSFTTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_test_dataset,
data_collator=data_collator,
callbacks=training_callbacks,
)
else:
from trl import DPOTrainer
max_prompt_length = 0
train_dataset = datasets["train"].map(lambda x: { "prompt_len": len(x["system"]) })
test_dataset = None
if training_run_args.test_dataset:
test_dataset = datasets["test"]
max_prompt_length = max(train_dataset["prompt_len"])
print("Templating DPO Examples...")
templated_test_dataset = None
templated_train_dataset = train_dataset.map(template_dpo_example, batched=True).remove_columns(["system", "question"])
if training_run_args.test_dataset:
templated_test_dataset = datasets["test"].map(template_dpo_example, batched=True).remove_columns(["system", "question"])
# tokenizer.model_input_names = [ "chosen_input_ids" ]
# group_by_length doesn't work here
# templated_train_dataset = templated_train_dataset.sort("prompt_len", reverse=True)
training_args.length_column_name = "prompt_len"
model.enable_input_require_grads()
trainer = DPOTrainer(
model,
ref_model=None,
# ref_model=original_model,
peft_config=peft_config,
args=training_args,
beta=training_run_args.beta,
loss_type=training_run_args.dpo_loss,
train_dataset=templated_train_dataset,
eval_dataset=templated_test_dataset,
tokenizer=tokenizer,
max_length=training_run_args.ctx_size,
max_prompt_length=max_prompt_length,
truncation_mode="keep_start",
callbacks=training_callbacks,
)
try:
checkpoint = training_run_args.resume_from_checkpoint
if checkpoint:
trainer.train(checkpoint)
else:
trainer.train()
if training_run_args.test_dataset:
trainer.evaluate_all()
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
if training_run_args.use_lora and training_run_args.lora_merge:
trainer.save_model() # save lora
merged_model = model.merge_and_unload(progressbar=True)
merged_model_dir = f"./models/{training_run_args.run_name}"
merged_model.save_pretrained(merged_model_dir, safe_serialization=True, max_shard_size="2GB")
tokenizer.save_pretrained(merged_model_dir)
else:
trainer.save_model()
tokenizer.save_pretrained(model_dir)
except Exception as ex:
if trainer.is_fsdp_enabled:
raise ex # this doesn't play nice with FSDP so don't even try
print("Something bad happened! Try and save it?")
import code, traceback
traceback.print_exc()
code.interact(local=locals())
do_training_run(training_run_args)