diff --git a/ldm/invoke/ckpt_to_diffuser.py b/ldm/invoke/ckpt_to_diffuser.py index a1c99bc5b1..5050f53556 100644 --- a/ldm/invoke/ckpt_to_diffuser.py +++ b/ldm/invoke/ckpt_to_diffuser.py @@ -803,6 +803,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( extract_ema:bool=True, upcast_attn:bool=False, vae:AutoencoderKL=None, + precision:torch.dtype=torch.float32, return_generator_pipeline:bool=False, )->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]: ''' @@ -828,6 +829,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning. + :param precision: precision to use - torch.float16, torch.float32 or torch.autocast :param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when running stable diffusion 2.1. ''' @@ -837,7 +839,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( verbosity = dlogging.get_verbosity() dlogging.set_verbosity_error() - checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path) + checkpoint = load_file(checkpoint_path,device='cpu') if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path,device='cpu') cache_dir = global_cache_dir('hub') pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline @@ -988,12 +990,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt( safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker',cache_dir=global_cache_dir("hub")) feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir) pipe = pipeline_class( - vae=vae, - text_encoder=text_model, + vae=vae.to(precision), + text_encoder=text_model.to(precision), tokenizer=tokenizer, - unet=unet, + unet=unet.to(precision), scheduler=scheduler, - safety_checker=safety_checker, + safety_checker=safety_checker.to(precision), feature_extractor=feature_extractor, ) else: diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 0738140171..097b8809d7 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -336,7 +336,6 @@ class Generator: if self.caution_img: return self.caution_img path = Path(web_assets.__path__[0]) / CAUTION_IMG - print(f'DEBUG: path to caution = {path}') caution = Image.open(path) self.caution_img = caution.resize((caution.width // 2, caution.height //2)) return self.caution_img diff --git a/ldm/invoke/generator/omnibus.py b/ldm/invoke/generator/omnibus.py index 4d5d0d13fc..a6fae3e567 100644 --- a/ldm/invoke/generator/omnibus.py +++ b/ldm/invoke/generator/omnibus.py @@ -40,8 +40,6 @@ class Omnibus(Img2Img,Txt2Img): self.perlin = perlin num_samples = 1 - print('DEBUG: IN OMNIBUS') - sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 1c3b49f755..6d43e368a8 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -385,19 +385,25 @@ class ModelManager(object): from ldm.invoke.ckpt_to_diffuser import ( load_pipeline_from_original_stable_diffusion_ckpt, ) - + self.offload_model(self.current_model) if vae_config := self._choose_diffusers_vae(model_name): vae = self._load_vae(vae_config) + if self._has_cuda(): + torch.cuda.empty_cache() pipeline = load_pipeline_from_original_stable_diffusion_ckpt( checkpoint_path=weights, original_config_file=config, vae=vae, return_generator_pipeline=True, + precision=torch.float16 if self.precision=='float16' else torch.float32, ) + if self.sequential_offload: + pipeline.enable_offload_submodels(self.device) + else: + pipeline.to(self.device) + return ( - pipeline.to(self.device).to( - torch.float16 if self.precision == "float16" else torch.float32 - ), + pipeline, width, height, "NOHASH", diff --git a/ldm/invoke/training/textual_inversion.py b/ldm/invoke/training/textual_inversion.py index 7ea7970ecf..4c76270e8a 100755 --- a/ldm/invoke/training/textual_inversion.py +++ b/ldm/invoke/training/textual_inversion.py @@ -421,7 +421,6 @@ def do_front_end(args: Namespace): save_args(args) try: - print(f"DEBUG: args = {args}") do_textual_inversion_training(**args) copy_to_embeddings_folder(args) except Exception as e: