feat(mm): wip port of main models to new api

This commit is contained in:
psychedelicious
2025-09-25 23:08:40 +10:00
parent 6f5720904a
commit 24bc4f5047

View File

@@ -1164,7 +1164,9 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
raise NotAMatch(cls, "unable to determine base type")
@classmethod
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> SchedulerPredictionType:
def _get_scheduler_prediction_type_or_raise(
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
) -> SchedulerPredictionType:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
@@ -1186,7 +1188,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVariantType:
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
@@ -1197,20 +1199,29 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
in_channels = unet_config.get("in_channels")
match in_channels:
case 4:
return ModelVariantType.Normal
case 5:
if base is not BaseModelType.StableDiffusion2:
raise NotAMatch(cls, "in_channels=5 is only valid for Stable Diffusion 2 models")
return ModelVariantType.Depth
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels}")
if base is BaseModelType.StableDiffusion2:
match in_channels:
case 4:
return ModelVariantType.Normal
case 9:
return ModelVariantType.Inpaint
case 5:
return ModelVariantType.Depth
case _:
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
else:
match in_channels:
case 4:
return ModelVariantType.Normal
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
@classmethod
def _get_submodels_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> dict[SubModelType, SubmodelDefinition]:
def _get_submodels_or_raise(
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
) -> dict[SubModelType, SubmodelDefinition]:
if base is not BaseModelType.StableDiffusion3:
raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")