mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): wip port of main models to new api
This commit is contained in:
@@ -1164,7 +1164,9 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
raise NotAMatch(cls, "unable to determine base type")
|
||||
|
||||
@classmethod
|
||||
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> SchedulerPredictionType:
|
||||
def _get_scheduler_prediction_type_or_raise(
|
||||
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
|
||||
) -> SchedulerPredictionType:
|
||||
if base not in {
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
@@ -1186,7 +1188,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVariantType:
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType:
|
||||
if base not in {
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
@@ -1197,20 +1199,29 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
|
||||
in_channels = unet_config.get("in_channels")
|
||||
|
||||
match in_channels:
|
||||
case 4:
|
||||
return ModelVariantType.Normal
|
||||
case 5:
|
||||
if base is not BaseModelType.StableDiffusion2:
|
||||
raise NotAMatch(cls, "in_channels=5 is only valid for Stable Diffusion 2 models")
|
||||
return ModelVariantType.Depth
|
||||
case 9:
|
||||
return ModelVariantType.Inpaint
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels}")
|
||||
if base is BaseModelType.StableDiffusion2:
|
||||
match in_channels:
|
||||
case 4:
|
||||
return ModelVariantType.Normal
|
||||
case 9:
|
||||
return ModelVariantType.Inpaint
|
||||
case 5:
|
||||
return ModelVariantType.Depth
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
|
||||
else:
|
||||
match in_channels:
|
||||
case 4:
|
||||
return ModelVariantType.Normal
|
||||
case 9:
|
||||
return ModelVariantType.Inpaint
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
|
||||
|
||||
@classmethod
|
||||
def _get_submodels_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> dict[SubModelType, SubmodelDefinition]:
|
||||
def _get_submodels_or_raise(
|
||||
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
|
||||
) -> dict[SubModelType, SubmodelDefinition]:
|
||||
if base is not BaseModelType.StableDiffusion3:
|
||||
raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user