fix: only add prediction type if it exists

This commit is contained in:
dunkeroni
2024-12-31 03:05:24 -05:00
committed by Kent Keirsey
parent 59926c320c
commit ebe1873712

View File

@@ -86,7 +86,7 @@ def get_scheduler(
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
unet_config: AnyModelConfig | None = None,
unet_config: AnyModelConfig,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
@@ -105,7 +105,7 @@ def get_scheduler(
"_backup": scheduler_config,
}
if unet_config is not None:
if hasattr(unet_config, "prediction_type"):
scheduler_config["prediction_type"] = unet_config.prediction_type
# make dpmpp_sde reproducable(seed can be passed only in initializer)