mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): add flag for updating models to allow class changes
This commit is contained in:
@@ -41,6 +41,7 @@ from invokeai.backend.model_manager.configs.factory import (
|
||||
AnyModelConfig,
|
||||
ModelConfigFactory,
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
@@ -608,7 +609,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
)
|
||||
|
||||
if result.config is None:
|
||||
raise InvalidModelConfigException(f"Could not identify model type for {model_path}")
|
||||
self._logger.error(f"Could not identify model for {model_path}, detailed results: {result.details}")
|
||||
raise InvalidModelConfigException(f"Could not identify model for {model_path}")
|
||||
elif isinstance(result.config, Unknown_Config):
|
||||
self._logger.error(f"Could not identify model for {model_path}, detailed results: {result.details}")
|
||||
|
||||
return result.config
|
||||
|
||||
|
||||
@@ -127,12 +127,14 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||
:param allow_class_change: If True, allows changes that would change the model config class. For example,
|
||||
changing a LoRA into a Main model. This does not disable validation, so the changes must still be valid.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -134,30 +134,36 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig:
|
||||
with self._db.transaction() as cursor:
|
||||
record = self.get_model(key)
|
||||
|
||||
# The changes may mean the model config class changes. So we need to:
|
||||
#
|
||||
# 1. convert the existing record to a dict
|
||||
# 2. apply the changes to the dict
|
||||
# 3. create a new model config from the updated dict
|
||||
#
|
||||
# This way we ensure that the update does not inadvertently create an invalid model config.
|
||||
if allow_class_change:
|
||||
# The changes may cause the model config class to change. To handle this, we need to construct the new
|
||||
# class from scratch rather than trying to modify the existing instance in place.
|
||||
#
|
||||
# 1. Convert the existing record to a dict
|
||||
# 2. Apply the changes to the dict
|
||||
# 3. Attempt to create a new model config from the updated dict
|
||||
|
||||
# 1. convert the existing record to a dict
|
||||
record_as_dict = record.model_dump()
|
||||
# 1. Convert the existing record to a dict
|
||||
record_as_dict = record.model_dump()
|
||||
|
||||
# 2. apply the changes to the dict
|
||||
for field_name in changes.model_fields_set:
|
||||
record_as_dict[field_name] = getattr(changes, field_name)
|
||||
# 2. Apply the changes to the dict
|
||||
for field_name in changes.model_fields_set:
|
||||
record_as_dict[field_name] = getattr(changes, field_name)
|
||||
|
||||
# 3. create a new model config from the updated dict
|
||||
record = ModelConfigFactory.from_dict(record_as_dict)
|
||||
# 3. Attempt to create a new model config from the updated dict
|
||||
record = ModelConfigFactory.from_dict(record_as_dict)
|
||||
|
||||
# If we get this far, the updated model config is valid, so we can save it to the database.
|
||||
json_serialized = record.model_dump_json()
|
||||
# If we get this far, the updated model config is valid, so we can save it to the database.
|
||||
json_serialized = record.model_dump_json()
|
||||
else:
|
||||
# We are not allowing the model config class to change, so we can just update the existing instance in
|
||||
# place. If the changes are invalid for the existing class, an exception will be raised by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
|
||||
Reference in New Issue
Block a user