Fix for DEIS / DPM clash

This commit is contained in:
David Burnett
2024-12-12 11:48:20 +00:00
committed by Kent Keirsey
parent 607d19f4dd
commit d8da9b45cc

View File

@@ -11,6 +11,8 @@ from diffusers.configuration_utils import ConfigMixin
from diffusers.models.adapter import T2IAdapter
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from PIL import Image
@@ -89,6 +91,7 @@ def get_scheduler(
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -104,6 +107,10 @@ def get_scheduler(
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
if scheduler_class is DPMSolverMultistepScheduler or scheduler_class is DPMSolverSinglestepScheduler:
if scheduler_config['_class_name'] == 'DEISMultistepScheduler' and scheduler_config["algorithm_type"] == 'deis':
scheduler_config["algorithm_type"] = 'dpmsolver++'
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py