feat(mm): add model config schema migration logic

This commit is contained in:
psychedelicious
2025-10-08 16:02:11 +11:00
parent a2f9e007ac
commit 9e03a39c3c

View File

@@ -8,7 +8,8 @@ from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, AnyModelConfigValidator
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelType, SchedulerPredictionType
class NormalizeResult(NamedTuple):
@@ -29,9 +30,8 @@ class Migration22Callback:
for model_id, config_json in rows:
try:
migrated_config_dict = self._migrate_config(config_json)
# Get the model config as a pydantic object
config = AnyModelConfigValidator.validate_python(migrated_config_dict)
# Migrate the config JSON to the latest schema
config = self._parse_and_migrate_config(config_json)
except ValidationError:
# This could happen if the config schema changed in a way that makes old configs invalid. Unlikely
# for users, more likely for devs testing out migration paths.
@@ -71,31 +71,76 @@ class Migration22Callback:
cursor.execute("ROLLBACK TO SAVEPOINT migrate_model")
cursor.execute("RELEASE SAVEPOINT migrate_model")
self._rollback_file_ops(rollback_ops)
continue
raise
cursor.execute("RELEASE SAVEPOINT migrate_model")
self._prune_empty_directories()
def _migrate_config(self, config_json: Any) -> str | None:
config_dict = json.loads(config_json)
def _parse_and_migrate_config(self, config_json: Any) -> AnyModelConfig:
config_dict: dict[str, Any] = json.loads(config_json)
# TODO: migrate fields, review changes to ensure we hit all cases for v6.7.0 to v6.8.0 upgrade.
# In v6.8.0 we made some improvements to the model taxonomy and the model config schemas. There are a changes
# we need to make to old configs to bring them up to date.
# Prior to v6.8.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX
# variants.
#
# `config_path` was set to one of:
# - flux-dev
# - flux-dev-fill
# - flux-schnell
#
# `variant` was set to ModelVariantType.Inpaint for FLUX Fill models and ModelVariantType.Normal for all other FLUX
# models.
#
# We now use the `variant` field to directly represent the FLUX variant type, and `config_path` is no longer used.
base = config_dict.get("base")
type = config_dict.get("type")
if base == BaseModelType.Flux.value and type == ModelType.Main.value:
# Prior to v6.8.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX
# variants.
#
# `config_path` was set to one of:
# - flux-dev
# - flux-dev-fill
# - flux-schnell
#
# `variant` was set to ModelVariantType.Inpaint for FLUX Fill models and ModelVariantType.Normal for all other FLUX
# models.
#
# We now use the `variant` field to directly represent the FLUX variant type, and `config_path` is no longer used.
return config_dict
# Extract and remove `config_path` if present.
config_path = config_dict.pop("config_path", None)
match config_path:
case "flux-dev":
config_dict["variant"] = FluxVariantType.Dev.value
case "flux-dev-fill":
config_dict["variant"] = FluxVariantType.DevFill.value
case "flux-schnell":
config_dict["variant"] = FluxVariantType.Schnell.value
case _:
# Unknown config_path - default to Dev variant
config_dict["variant"] = FluxVariantType.Dev.value
if (
base
in {
BaseModelType.StableDiffusion1.value,
BaseModelType.StableDiffusion2.value,
BaseModelType.StableDiffusionXL.value,
BaseModelType.StableDiffusionXLRefiner.value,
}
and type == "main"
):
# Prior to v6.8.0, the prediction_type field was optional and would default to Epsilon if not present.
# We now make it explicit and always present. Use the existing value if present, otherwise default to
# Epsilon, matching the probe logic.
#
# It's only on SD1.x, SD2.x, and SDXL main models.
config_dict["prediction_type"] = config_dict.get("prediction_type", SchedulerPredictionType.Epsilon.value)
if type == ModelType.CLIPVision.value:
# Prior to v6.8.0, some CLIP Vision models were associated with a specific base model architecture:
# - CLIP-ViT-bigG-14-laion2B-39B-b160k is the image encoder for SDXL IP Adapter and was associated with SDXL
# - CLIP-ViT-H-14-laion2B-s32B-b79K is the image encoder for SD1.5 IP Adapter and was associated with SD1.5
#
# While this made some sense at the time, it is more correct and flexible to treat CLIP Vision models
# as independent of any specific base model architecture.
config_dict["base"] = BaseModelType.Any.value
migrated_config = AnyModelConfigValidator.validate_python(config_dict)
return migrated_config
def _normalize_model_storage(self, key: str, path_value: str) -> NormalizeResult:
models_dir = self._models_dir