feat(mm): wip port of main models to new api

This commit is contained in:
psychedelicious
2025-09-26 00:11:42 +10:00
parent 24bc4f5047
commit 0192caa90f

View File

@@ -1126,12 +1126,29 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
)
base = fields.get("base") or cls._get_base_or_raise(mod)
if base in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(base, prediction_type)
else:
variant= None
prediction_type = None
upcast_attention = False
return cls(**fields, base=base)
if base is BaseModelType.StableDiffusion3:
submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod, base)
else:
submodels = None
return cls(**fields, base=base, variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, submodels=submodels,)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
# Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL).
unet_config_path = mod.path / "unet" / "config.json"
if unet_config_path.exists():
with open(unet_config_path) as file:
@@ -1172,7 +1189,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
raise ValueError(f"Attempted to get scheduler prediction type for non-UNet model base '{base}'")
raise ValueError(f"Attempted to get scheduler prediction_type for non-UNet model base '{base}'")
scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
@@ -1185,7 +1202,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
case "epsilon":
return SchedulerPredictionType.Epsilon
case _:
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType:
@@ -1266,6 +1283,21 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
return submodels
@classmethod
def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType) -> bool:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
raise ValueError(f"Attempted to get upcast_attention flag for non-UNet model base '{base}'")
if base is BaseModelType.StableDiffusion2 and prediction_type is SchedulerPredictionType.VPrediction:
# SD2 v-prediction models need upcast_attention to be True
return True
return False
class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)