feat(mm): add flag for updating models to allow class changes

This commit is contained in:
psychedelicious
2025-10-10 12:02:59 +11:00
parent d81a55401a
commit adc332b9e3
3 changed files with 31 additions and 19 deletions

View File

@@ -41,6 +41,7 @@ from invokeai.backend.model_manager.configs.factory import (
AnyModelConfig,
ModelConfigFactory,
)
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
@@ -608,7 +609,10 @@ class ModelInstallService(ModelInstallServiceBase):
)
if result.config is None:
raise InvalidModelConfigException(f"Could not identify model type for {model_path}")
self._logger.error(f"Could not identify model for {model_path}, detailed results: {result.details}")
raise InvalidModelConfigException(f"Could not identify model for {model_path}")
elif isinstance(result.config, Unknown_Config):
self._logger.error(f"Could not identify model for {model_path}, detailed results: {result.details}")
return result.config

View File

@@ -127,12 +127,14 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated.
:param changes: A set of changes to apply to this model. Changes are validated before being written.
:param allow_class_change: If True, allows changes that would change the model config class. For example,
changing a LoRA into a Main model. This does not disable validation, so the changes must still be valid.
"""
pass

View File

@@ -134,30 +134,36 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig:
with self._db.transaction() as cursor:
record = self.get_model(key)
# The changes may mean the model config class changes. So we need to:
#
# 1. convert the existing record to a dict
# 2. apply the changes to the dict
# 3. create a new model config from the updated dict
#
# This way we ensure that the update does not inadvertently create an invalid model config.
if allow_class_change:
# The changes may cause the model config class to change. To handle this, we need to construct the new
# class from scratch rather than trying to modify the existing instance in place.
#
# 1. Convert the existing record to a dict
# 2. Apply the changes to the dict
# 3. Attempt to create a new model config from the updated dict
# 1. convert the existing record to a dict
record_as_dict = record.model_dump()
# 1. Convert the existing record to a dict
record_as_dict = record.model_dump()
# 2. apply the changes to the dict
for field_name in changes.model_fields_set:
record_as_dict[field_name] = getattr(changes, field_name)
# 2. Apply the changes to the dict
for field_name in changes.model_fields_set:
record_as_dict[field_name] = getattr(changes, field_name)
# 3. create a new model config from the updated dict
record = ModelConfigFactory.from_dict(record_as_dict)
# 3. Attempt to create a new model config from the updated dict
record = ModelConfigFactory.from_dict(record_as_dict)
# If we get this far, the updated model config is valid, so we can save it to the database.
json_serialized = record.model_dump_json()
# If we get this far, the updated model config is valid, so we can save it to the database.
json_serialized = record.model_dump_json()
else:
# We are not allowing the model config class to change, so we can just update the existing instance in
# place. If the changes are invalid for the existing class, an exception will be raised by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
cursor.execute(
"""--sql