diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 50816ec403..922532cfad 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -133,7 +133,8 @@ class T2I: full_precision=False, strength=0.75, # default in scripts/img2img.py embedding_path=None, - latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt + # just to keep track of this parameter when regenerating prompt + latent_diffusion_weights=False, device='cuda', gfpgan=None, ): @@ -175,7 +176,8 @@ class T2I: outdir, prompt, kwargs.get('batch_size', self.batch_size) ) for r in results: - metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG + # gets written into the PNG + metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' pngwriter.write_image(r[0], r[1]) return pngwriter.files_written @@ -210,6 +212,7 @@ class T2I: strength=None, gfpgan_strength=None, variants=None, + user_sampler=None, **args, ): # eat up additional cruft """ @@ -269,6 +272,10 @@ class T2I: scope = autocast if self.precision == 'autocast' else nullcontext + if user_sampler and (user_sampler != self.sampler_name): + self.sampler_name = user_sampler + self._set_sampler() + tic = time.time() results = list() @@ -305,12 +312,15 @@ class T2I: iter_images = next(images_iterator) for image in iter_images: try: - # if gfpgan strength is none or less than or equal to 0.0 then + # if gfpgan strength is none or less than or equal to 0.0 then # don't even attempt to use GFPGAN. - # if the user specified a value of -G that satisifies the condition and + # if the user specified a value of -G that satisifies the condition and # --gfpgan wasn't specified, at startup then # the net result is a message gets printed - nothing else happens. - if gfpgan_strength is not None and gfpgan_strength > 0.0: + if ( + gfpgan_strength is not None + and gfpgan_strength > 0.0 + ): image = self._run_gfpgan( image, gfpgan_strength ) @@ -499,39 +509,38 @@ class T2I: except AttributeError: raise SystemExit - msg = f'setting sampler to {self.sampler_name}' - if self.sampler_name == 'plms': - self.sampler = PLMSSampler(self.model, device=self.device) - elif self.sampler_name == 'ddim': - self.sampler = DDIMSampler(self.model, device=self.device) - elif self.sampler_name == 'k_dpm_2_a': - self.sampler = KSampler( - self.model, 'dpm_2_ancestral', device=self.device - ) - elif self.sampler_name == 'k_dpm_2': - self.sampler = KSampler( - self.model, 'dpm_2', device=self.device - ) - elif self.sampler_name == 'k_euler_a': - self.sampler = KSampler( - self.model, 'euler_ancestral', device=self.device - ) - elif self.sampler_name == 'k_euler': - self.sampler = KSampler( - self.model, 'euler', device=self.device - ) - elif self.sampler_name == 'k_heun': - self.sampler = KSampler(self.model, 'heun', device=self.device) - elif self.sampler_name == 'k_lms': - self.sampler = KSampler(self.model, 'lms', device=self.device) - else: - msg = f'unsupported sampler {self.sampler_name}, defaulting to plms' - self.sampler = PLMSSampler(self.model, device=self.device) - - print(msg) + self._set_sampler() return self.model + def _set_sampler(self): + msg = f'>> Setting Sampler to {self.sampler_name}' + if self.sampler_name == 'plms': + self.sampler = PLMSSampler(self.model, device=self.device) + elif self.sampler_name == 'ddim': + self.sampler = DDIMSampler(self.model, device=self.device) + elif self.sampler_name == 'k_dpm_2_a': + self.sampler = KSampler( + self.model, 'dpm_2_ancestral', device=self.device + ) + elif self.sampler_name == 'k_dpm_2': + self.sampler = KSampler(self.model, 'dpm_2', device=self.device) + elif self.sampler_name == 'k_euler_a': + self.sampler = KSampler( + self.model, 'euler_ancestral', device=self.device + ) + elif self.sampler_name == 'k_euler': + self.sampler = KSampler(self.model, 'euler', device=self.device) + elif self.sampler_name == 'k_heun': + self.sampler = KSampler(self.model, 'heun', device=self.device) + elif self.sampler_name == 'k_lms': + self.sampler = KSampler(self.model, 'lms', device=self.device) + else: + msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' + self.sampler = PLMSSampler(self.model, device=self.device) + + print(msg) + def _load_model_from_config(self, config, ckpt): print(f'Loading model from {ckpt}') pl_sd = torch.load(ckpt, map_location='cpu') diff --git a/scripts/dream.py b/scripts/dream.py index 5d5e8db4a5..be72d25655 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -52,7 +52,8 @@ def main(): weights=weights, full_precision=opt.full_precision, config=config, - latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt + # this is solely for recreating the prompt + latent_diffusion_weights=opt.laion400m, embedding_path=opt.embedding_path, device=opt.device, ) @@ -508,6 +509,23 @@ def create_cmd_parser(): action='store_true', help='skip subprompt weight normalization', ) + parser.add_argument( + '-m', + '--user_sampler', + default=None, + type=str, + choices=[ + 'ddim', + 'k_dpm_2_a', + 'k_dpm_2', + 'k_euler_a', + 'k_euler', + 'k_heun', + 'k_lms', + 'plms', + ], + help='Change to another supported sampler using this command', + ) return parser