diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index a4f563acd7..82014807ba 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -3,3 +3,4 @@ Initialization file for invokeai.backend ''' from .invoke_ai_web_server import InvokeAIWebServer + diff --git a/invokeai/backend/invoke_ai_web_server.py b/invokeai/backend/invoke_ai_web_server.py index 15bf25d5db..c93e5e2a60 100644 --- a/invokeai/backend/invoke_ai_web_server.py +++ b/invokeai/backend/invoke_ai_web_server.py @@ -27,10 +27,10 @@ from invokeai.backend.modules.parameters import parameters_to_command from ldm.generate import Generate from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, get_tokenizer -from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState -from ldm.invoke.generator.inpaint import infill_methods -from ldm.invoke.globals import Globals, global_converted_ckpts_dir -from ldm.invoke.globals import global_models_dir +from ..generator import infill_methods, PipelineIntermediateState +from ldm.invoke.globals import ( Globals, global_converted_ckpts_dir, + global_models_dir + ) from ldm.invoke.merge_diffusers import merge_diffusion_models from ldm.invoke.pngwriter import PngWriter, retrieve_metadata diff --git a/invokeai/configs/stable-diffusion/v1-finetune.yaml b/invokeai/configs/stable-diffusion/v1-finetune.yaml index 783a7f10ec..9fea4ae01f 100644 --- a/invokeai/configs/stable-diffusion/v1-finetune.yaml +++ b/invokeai/configs/stable-diffusion/v1-finetune.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 5.0e-03 - target: ldm.models.diffusion.ddpm.LatentDiffusion + target: invokeai.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 @@ -45,7 +45,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + target: invokeai.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss diff --git a/invokeai/configs/stable-diffusion/v1-finetune_style.yaml b/invokeai/configs/stable-diffusion/v1-finetune_style.yaml index 1964d925e1..fdecca9b72 100644 --- a/invokeai/configs/stable-diffusion/v1-finetune_style.yaml +++ b/invokeai/configs/stable-diffusion/v1-finetune_style.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 5.0e-03 - target: ldm.models.diffusion.ddpm.LatentDiffusion + target: invokeai.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 @@ -44,7 +44,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + target: invokeai.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss diff --git a/invokeai/configs/stable-diffusion/v1-inference.yaml b/invokeai/configs/stable-diffusion/v1-inference.yaml index d872404f2c..913cbbf310 100644 --- a/invokeai/configs/stable-diffusion/v1-inference.yaml +++ b/invokeai/configs/stable-diffusion/v1-inference.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion + target: invokeai.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 @@ -53,7 +53,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + target: invokeai.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss diff --git a/invokeai/configs/stable-diffusion/v1-inpainting-inference.yaml b/invokeai/configs/stable-diffusion/v1-inpainting-inference.yaml index 2d25b8a4e6..78458a7e54 100644 --- a/invokeai/configs/stable-diffusion/v1-inpainting-inference.yaml +++ b/invokeai/configs/stable-diffusion/v1-inpainting-inference.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 7.5e-05 - target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + target: invokeai.models.diffusion.ddpm.LatentInpaintDiffusion params: linear_start: 0.00085 linear_end: 0.0120 @@ -53,7 +53,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + target: invokeai.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss diff --git a/invokeai/configs/stable-diffusion/v1-m1-finetune.yaml b/invokeai/configs/stable-diffusion/v1-m1-finetune.yaml index f2d5ddda02..e6db3ac067 100644 --- a/invokeai/configs/stable-diffusion/v1-m1-finetune.yaml +++ b/invokeai/configs/stable-diffusion/v1-m1-finetune.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 5.0e-03 - target: ldm.models.diffusion.ddpm.LatentDiffusion + target: invokeai.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 @@ -45,7 +45,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + target: invokeai.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss diff --git a/invokeai/configs/stable-diffusion/v2-inference-v.yaml b/invokeai/configs/stable-diffusion/v2-inference-v.yaml index 8ec8dfbfef..6b6828fbe7 100644 --- a/invokeai/configs/stable-diffusion/v2-inference-v.yaml +++ b/invokeai/configs/stable-diffusion/v2-inference-v.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion + target: invokeai.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -38,7 +38,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + target: invokeai.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss diff --git a/ldm/generate.py b/ldm/generate.py index 256f214b25..a639360491 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -26,21 +26,19 @@ from PIL import Image, ImageOps from pytorch_lightning import logging, seed_everything import ldm.invoke.conditioning + +from invokeai.models import ModelManager +from invokeai.generator import infill_methods +from invokeai.models import (DDIMSampler, KSampler, PLMSSampler ) from ldm.invoke.args import metadata_from_png from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.devices import choose_precision, choose_torch_device -from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.globals import Globals, global_cache_dir from ldm.invoke.image_util import InitImageResizer -from ldm.invoke.model_manager import ModelManager from ldm.invoke.pngwriter import PngWriter from ldm.invoke.seamless import configure_model_padding from ldm.invoke.txt2mask import Txt2Mask -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.ksampler import KSampler -from ldm.models.diffusion.plms import PLMSSampler - def fix_func(orig): if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): @@ -816,7 +814,6 @@ class Generate: hires_fix: bool = False, force_outpaint: bool = False, ): - inpainting_model_in_use = self.sampler.uses_inpainting_model() if hires_fix: return self._make_txt2img2img() @@ -824,9 +821,6 @@ class Generate: if embiggen is not None: return self._make_embiggen() - if inpainting_model_in_use: - return self._make_omnibus() - if ((init_image is not None) and (mask_image is not None)) or force_outpaint: return self._make_inpaint() @@ -903,16 +897,9 @@ class Generate: def _make_inpaint(self): return self._load_generator(".inpaint", "Inpaint") - def _make_omnibus(self): - return self._load_generator(".omnibus", "Omnibus") - def _load_generator(self, module, class_name): - if self.is_legacy_model(self.model_name): - mn = f"ldm.invoke.ckpt_generator{module}" - cn = f"Ckpt{class_name}" - else: - mn = f"ldm.invoke.generator{module}" - cn = class_name + mn = f"invokeai.generator{module}" + cn = class_name module = importlib.import_module(mn) constructor = getattr(module, cn) return constructor(self.model, self.precision) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index b755eafed4..05aa4482d0 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -21,11 +21,11 @@ import ldm.invoke from ..generate import Generate from .args import (Args, dream_cmd_from_png, metadata_dumps, metadata_from_png) -from .generator.diffusers_pipeline import PipelineIntermediateState +from invokeai.generator import PipelineIntermediateState from .globals import Globals from .image_util import make_grid from .log import write_log -from .model_manager import ModelManager +from invokeai.models import ModelManager from .pngwriter import PngWriter, retrieve_metadata, write_metadata from .readline import Completer, get_completer from ..util import url_attachment_name @@ -64,7 +64,7 @@ def main(): Globals.internet_available = args.internet_available and check_internet() Globals.disable_xformers = not args.xformers Globals.sequential_guidance = args.sequential_guidance - Globals.ckpt_convert = args.ckpt_convert + Globals.ckpt_convert = True # always true now print(f">> Internet connectivity is {Globals.internet_available}") diff --git a/ldm/invoke/_version.py b/ldm/invoke/_version.py index 259b4f09e5..041471f37e 100644 --- a/ldm/invoke/_version.py +++ b/ldm/invoke/_version.py @@ -1 +1 @@ -__version__='2.3.1' +__version__='3.0.0+a0' diff --git a/ldm/invoke/app/services/generate_initializer.py b/ldm/invoke/app/services/generate_initializer.py index 39c0fe491e..0cfc3f39bb 100644 --- a/ldm/invoke/app/services/generate_initializer.py +++ b/ldm/invoke/app/services/generate_initializer.py @@ -3,7 +3,7 @@ import os import sys import traceback -from ...model_manager import ModelManager +from invokeai.models import ModelManager from ...globals import Globals from ....generate import Generate diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index b23238cf09..1a5dbe334a 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -434,6 +434,14 @@ class Args(object): deprecated_group.add_argument('--laion400m') deprecated_group.add_argument('--weights') # deprecated + deprecated_group.add_argument( + '--ckpt_convert', + action=argparse.BooleanOptionalAction, + dest='ckpt_convert', + default=True, + help='Load legacy ckpt files as diffusers (deprecated; always true now).', + ) + general_group.add_argument( '--version','-V', action='store_true', @@ -518,13 +526,6 @@ class Args(object): help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', default='auto', ) - model_group.add_argument( - '--ckpt_convert', - action=argparse.BooleanOptionalAction, - dest='ckpt_convert', - default=False, - help='Load legacy ckpt files as diffusers. Pass --no-ckpt-convert to inhibit this behavior', - ) model_group.add_argument( '--internet', action=argparse.BooleanOptionalAction, diff --git a/ldm/invoke/ckpt_generator/__init__.py b/ldm/invoke/ckpt_generator/__init__.py deleted file mode 100644 index d25e192149..0000000000 --- a/ldm/invoke/ckpt_generator/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -''' -Initialization file for the ldm.invoke.generator package -''' -from .base import CkptGenerator diff --git a/ldm/invoke/ckpt_generator/base.py b/ldm/invoke/ckpt_generator/base.py deleted file mode 100644 index 520b35612d..0000000000 --- a/ldm/invoke/ckpt_generator/base.py +++ /dev/null @@ -1,335 +0,0 @@ -''' -Base class for ldm.invoke.ckpt_generator.* -including img2img, txt2img, and inpaint - -THESE MODULES ARE TRANSITIONAL AND WILL BE REMOVED AT A FUTURE DATE -WHEN LEGACY CKPT MODEL SUPPORT IS DISCONTINUED. -''' -import torch -import numpy as np -import random -import os -import os.path as osp -import traceback -from tqdm import tqdm, trange -from PIL import Image, ImageFilter, ImageChops -import cv2 as cv -from einops import rearrange, repeat -from pathlib import Path -from pytorch_lightning import seed_everything -import invokeai.assets.web as web_assets -from ldm.invoke.devices import choose_autocast -from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver -from ldm.util import rand_perlin_2d - -downsampling = 8 -CAUTION_IMG = 'caution.png' - -class CkptGenerator(): - def __init__(self, model, precision): - self.model = model - self.precision = precision - self.seed = None - self.latent_channels = model.channels - self.downsampling_factor = downsampling # BUG: should come from model or config - self.safety_checker = None - self.perlin = 0.0 - self.threshold = 0 - self.variation_amount = 0 - self.with_variations = [] - self.use_mps_noise = False - self.free_gpu_mem = None - self.caution_img = None - - # this is going to be overridden in img2img.py, txt2img.py and inpaint.py - def get_make_image(self,prompt,**kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it - """ - raise NotImplementedError("image_iterator() must be implemented in a descendent class") - - def set_variation(self, seed, variation_amount, with_variations): - self.seed = seed - self.variation_amount = variation_amount - self.with_variations = with_variations - - def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, - image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, - safety_checker:dict=None, - attention_maps_callback = None, - free_gpu_mem: bool=False, - **kwargs): - scope = choose_autocast(self.precision) - self.safety_checker = safety_checker - self.free_gpu_mem = free_gpu_mem - attention_maps_images = [] - attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) - make_image = self.get_make_image( - prompt, - sampler = sampler, - init_image = init_image, - width = width, - height = height, - step_callback = step_callback, - threshold = threshold, - perlin = perlin, - attention_maps_callback = attention_maps_callback, - **kwargs - ) - results = [] - seed = seed if seed is not None and seed >= 0 else self.new_seed() - first_seed = seed - seed, initial_noise = self.generate_initial_noise(seed, width, height) - - # There used to be an additional self.model.ema_scope() here, but it breaks - # the inpaint-1.5 model. Not sure what it did.... ? - with scope(self.model.device.type): - for n in trange(iterations, desc='Generating'): - x_T = None - if self.variation_amount > 0: - seed_everything(seed) - target_noise = self.get_noise(width,height) - x_T = self.slerp(self.variation_amount, initial_noise, target_noise) - elif initial_noise is not None: - # i.e. we specified particular variations - x_T = initial_noise - else: - seed_everything(seed) - try: - x_T = self.get_noise(width,height) - except: - print('** An error occurred while getting initial noise **') - print(traceback.format_exc()) - - image = make_image(x_T) - - if self.safety_checker is not None: - image = self.safety_check(image) - - results.append([image, seed]) - - if image_callback is not None: - attention_maps_image = None if len(attention_maps_images)==0 else attention_maps_images[-1] - image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_image) - - seed = self.new_seed() - - return results - - def sample_to_image(self,samples)->Image.Image: - """ - Given samples returned from a sampler, converts - it into a PIL Image - """ - x_samples = self.model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - if len(x_samples) != 1: - raise Exception( - f'>> expected to get a single image, but got {len(x_samples)}') - x_sample = 255.0 * rearrange( - x_samples[0].cpu().numpy(), 'c h w -> h w c' - ) - return Image.fromarray(x_sample.astype(np.uint8)) - - # write an approximate RGB image from latent samples for a single step to PNG - - def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image: - if init_image is None or init_mask is None: - return result - - # Get the original alpha channel of the mask if there is one. - # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') - pil_init_mask = init_mask.getchannel('A') if init_mask.mode == 'RGBA' else init_mask.convert('L') - pil_init_image = init_image.convert('RGBA') # Add an alpha channel if one doesn't exist - - # Build an image with only visible pixels from source to use as reference for color-matching. - init_rgb_pixels = np.asarray(init_image.convert('RGB'), dtype=np.uint8) - init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) - init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) - - # Get numpy version of result - np_image = np.asarray(result, dtype=np.uint8) - - # Mask and calculate mean and standard deviation - mask_pixels = init_a_pixels * init_mask_pixels > 0 - np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] - np_image_masked = np_image[mask_pixels, :] - - if np_init_rgb_pixels_masked.size > 0: - init_means = np_init_rgb_pixels_masked.mean(axis=0) - init_std = np_init_rgb_pixels_masked.std(axis=0) - gen_means = np_image_masked.mean(axis=0) - gen_std = np_image_masked.std(axis=0) - - # Color correct - np_matched_result = np_image.copy() - np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) - matched_result = Image.fromarray(np_matched_result, mode='RGB') - else: - matched_result = Image.fromarray(np_image, mode='RGB') - - # Blur the mask out (into init image) by specified amount - if mask_blur_radius > 0: - nm = np.asarray(pil_init_mask, dtype=np.uint8) - nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2)) - pmd = Image.fromarray(nmd, mode='L') - blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius)) - else: - blurred_init_mask = pil_init_mask - - multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1]) - - # Paste original on color-corrected generation (using blurred mask) - matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask) - return matched_result - - - - def sample_to_lowres_estimated_image(self,samples): - # origingally adapted from code by @erucipe and @keturn here: - # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 - - # these updated numbers for v1.5 are from @torridgristle - v1_5_latent_rgb_factors = torch.tensor([ - # R G B - [ 0.3444, 0.1385, 0.0670], # L1 - [ 0.1247, 0.4027, 0.1494], # L2 - [-0.3192, 0.2513, 0.2103], # L3 - [-0.1307, -0.1874, -0.7445] # L4 - ], dtype=samples.dtype, device=samples.device) - - latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors - latents_ubyte = (((latent_image + 1) / 2) - .clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - .byte()).cpu() - - return Image.fromarray(latents_ubyte.numpy()) - - def generate_initial_noise(self, seed, width, height): - initial_noise = None - if self.variation_amount > 0 or len(self.with_variations) > 0: - # use fixed initial noise plus random noise per iteration - seed_everything(seed) - initial_noise = self.get_noise(width,height) - for v_seed, v_weight in self.with_variations: - seed = v_seed - seed_everything(seed) - next_noise = self.get_noise(width,height) - initial_noise = self.slerp(v_weight, initial_noise, next_noise) - if self.variation_amount > 0: - random.seed() # reset RNG to an actually random state, so we can get a random seed for variations - seed = random.randrange(0,np.iinfo(np.uint32).max) - return (seed, initial_noise) - else: - return (seed, None) - - # returns a tensor filled with random numbers from a normal distribution - def get_noise(self,width,height): - """ - Returns a tensor filled with random numbers, either form a normal distribution - (txt2img) or from the latent image (img2img, inpaint) - """ - raise NotImplementedError("get_noise() must be implemented in a descendent class") - - def get_perlin_noise(self,width,height): - fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device - return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) - - def new_seed(self): - self.seed = random.randrange(0, np.iinfo(np.uint32).max) - return self.seed - - def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): - ''' - Spherical linear interpolation - Args: - t (float/np.ndarray): Float value between 0.0 and 1.0 - v0 (np.ndarray): Starting vector - v1 (np.ndarray): Final vector - DOT_THRESHOLD (float): Threshold for considering the two vectors as - colineal. Not recommended to alter this. - Returns: - v2 (np.ndarray): Interpolation vector between v0 and v1 - ''' - inputs_are_torch = False - if not isinstance(v0, np.ndarray): - inputs_are_torch = True - v0 = v0.detach().cpu().numpy() - if not isinstance(v1, np.ndarray): - inputs_are_torch = True - v1 = v1.detach().cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - if inputs_are_torch: - v2 = torch.from_numpy(v2).to(self.model.device) - - return v2 - - def safety_check(self,image:Image.Image): - ''' - If the CompViz safety checker flags an NSFW image, we - blur it out. - ''' - import diffusers - - checker = self.safety_checker['checker'] - extractor = self.safety_checker['extractor'] - features = extractor([image], return_tensors="pt") - features.to(self.model.device) - - # unfortunately checker requires the numpy version, so we have to convert back - x_image = np.array(image).astype(np.float32) / 255.0 - x_image = x_image[None].transpose(0, 3, 1, 2) - - diffusers.logging.set_verbosity_error() - checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values) - if has_nsfw_concept[0]: - print('** An image with potential non-safe content has been detected. A blurred image will be returned. **') - return self.blur(image) - else: - return image - - def blur(self,input): - blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) - try: - caution = self.get_caution_img() - if caution: - blurry.paste(caution,(0,0),caution) - except FileNotFoundError: - pass - return blurry - - def get_caution_img(self): - path = None - if self.caution_img: - return self.caution_img - path = Path(web_assets.__path__[0]) / CAUTION_IMG - caution = Image.open(path) - self.caution_img = caution.resize((caution.width // 2, caution.height //2)) - return self.caution_img - - # this is a handy routine for debugging use. Given a generated sample, - # convert it into a PNG image and store it at the indicated path - def save_sample(self, sample, filepath): - image = self.sample_to_image(sample) - dirname = os.path.dirname(filepath) or '.' - if not os.path.exists(dirname): - print(f'** creating directory {dirname}') - os.makedirs(dirname, exist_ok=True) - image.save(filepath,'PNG') - - def torch_dtype(self)->torch.dtype: - return torch.float16 if self.precision == 'float16' else torch.float32 diff --git a/ldm/invoke/ckpt_generator/embiggen.py b/ldm/invoke/ckpt_generator/embiggen.py deleted file mode 100644 index 0b43d3d19b..0000000000 --- a/ldm/invoke/ckpt_generator/embiggen.py +++ /dev/null @@ -1,501 +0,0 @@ -''' -ldm.invoke.ckpt_generator.embiggen descends from ldm.invoke.ckpt_generator -and generates with ldm.invoke.ckpt_generator.img2img -''' - -import torch -import numpy as np -from tqdm import trange -from PIL import Image -from ldm.invoke.ckpt_generator.base import CkptGenerator -from ldm.invoke.ckpt_generator.img2img import CkptImg2Img -from ldm.invoke.devices import choose_autocast -from ldm.models.diffusion.ddim import DDIMSampler - -class CkptEmbiggen(CkptGenerator): - def __init__(self, model, precision): - super().__init__(model, precision) - self.init_latent = None - - # Replace generate because Embiggen doesn't need/use most of what it does normallly - def generate(self,prompt,iterations=1,seed=None, - image_callback=None, step_callback=None, - **kwargs): - - scope = choose_autocast(self.precision) - make_image = self.get_make_image( - prompt, - step_callback = step_callback, - **kwargs - ) - results = [] - seed = seed if seed else self.new_seed() - - # Noise will be generated by the Img2Img generator when called - with scope(self.model.device.type), self.model.ema_scope(): - for n in trange(iterations, desc='Generating'): - # make_image will call Img2Img which will do the equivalent of get_noise itself - image = make_image() - results.append([image, seed]) - if image_callback is not None: - image_callback(image, seed, prompt_in=prompt) - seed = self.new_seed() - return results - - @torch.no_grad() - def get_make_image( - self, - prompt, - sampler, - steps, - cfg_scale, - ddim_eta, - conditioning, - init_img, - strength, - width, - height, - embiggen, - embiggen_tiles, - step_callback=None, - **kwargs - ): - """ - Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image - Return value depends on the seed at the time you call it - """ - assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models" - - # Construct embiggen arg array, and sanity check arguments - if embiggen == None: # embiggen can also be called with just embiggen_tiles - embiggen = [1.0] # If not specified, assume no scaling - elif embiggen[0] < 0: - embiggen[0] = 1.0 - print( - '>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !') - if len(embiggen) < 2: - embiggen.append(0.75) - elif embiggen[1] > 1.0 or embiggen[1] < 0: - embiggen[1] = 0.75 - print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !') - if len(embiggen) < 3: - embiggen.append(0.25) - elif embiggen[2] < 0: - embiggen[2] = 0.25 - print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !') - - # Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math - # and then sort them, because... people. - if embiggen_tiles: - embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles)) - embiggen_tiles.sort() - - if strength >= 0.5: - print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.') - - # Prep img2img generator, since we wrap over it - gen_img2img = CkptImg2Img(self.model,self.precision) - - # Open original init image (not a tensor) to manipulate - initsuperimage = Image.open(init_img) - - with Image.open(init_img) as img: - initsuperimage = img.convert('RGB') - - # Size of the target super init image in pixels - initsuperwidth, initsuperheight = initsuperimage.size - - # Increase by scaling factor if not already resized, using ESRGAN as able - if embiggen[0] != 1.0: - initsuperwidth = round(initsuperwidth*embiggen[0]) - initsuperheight = round(initsuperheight*embiggen[0]) - if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero - from ldm.invoke.restoration.realesrgan import ESRGAN - esrgan = ESRGAN() - print( - f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}') - if embiggen[0] > 2: - initsuperimage = esrgan.process( - initsuperimage, - embiggen[1], # upscale strength - self.seed, - 4, # upscale scale - ) - else: - initsuperimage = esrgan.process( - initsuperimage, - embiggen[1], # upscale strength - self.seed, - 2, # upscale scale - ) - # We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x - # but from personal experiance it doesn't greatly improve anything after 4x - # Resize to target scaling factor resolution - initsuperimage = initsuperimage.resize( - (initsuperwidth, initsuperheight), Image.Resampling.LANCZOS) - - # Use width and height as tile widths and height - # Determine buffer size in pixels - if embiggen[2] < 1: - if embiggen[2] < 0: - embiggen[2] = 0 - overlap_size_x = round(embiggen[2] * width) - overlap_size_y = round(embiggen[2] * height) - else: - overlap_size_x = round(embiggen[2]) - overlap_size_y = round(embiggen[2]) - - # With overall image width and height known, determine how many tiles we need - def ceildiv(a, b): - return -1 * (-a // b) - - # X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count) - # (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill - # (width - overlap_size_x) is how much new we can fill with a single tile - emb_tiles_x = 1 - emb_tiles_y = 1 - if (initsuperwidth - width) > 0: - emb_tiles_x = ceildiv(initsuperwidth - width, - width - overlap_size_x) + 1 - if (initsuperheight - height) > 0: - emb_tiles_y = ceildiv(initsuperheight - height, - height - overlap_size_y) + 1 - # Sanity - assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.' - - # Prep alpha layers -------------- - # https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil - # agradientL is Left-side transparent - agradientL = Image.linear_gradient('L').rotate( - 90).resize((overlap_size_x, height)) - # agradientT is Top-side transparent - agradientT = Image.linear_gradient('L').resize((width, overlap_size_y)) - # radial corner is the left-top corner, made full circle then cut to just the left-top quadrant - agradientC = Image.new('L', (256, 256)) - for y in range(256): - for x in range(256): - # Find distance to lower right corner (numpy takes arrays) - distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0] - # Clamp values to max 255 - if distanceToLR > 255: - distanceToLR = 255 - #Place the pixel as invert of distance - agradientC.putpixel((x, y), round(255 - distanceToLR)) - - # Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges - # Fits for a left-fading gradient on the bottom side and full opacity on the right side. - agradientAsymC = Image.new('L', (256, 256)) - for y in range(256): - for x in range(256): - value = round(max(0, x-(255-y)) * (255 / max(1,y))) - #Clamp values - value = max(0, value) - value = min(255, value) - agradientAsymC.putpixel((x, y), value) - - # Create alpha layers default fully white - alphaLayerL = Image.new("L", (width, height), 255) - alphaLayerT = Image.new("L", (width, height), 255) - alphaLayerLTC = Image.new("L", (width, height), 255) - # Paste gradients into alpha layers - alphaLayerL.paste(agradientL, (0, 0)) - alphaLayerT.paste(agradientT, (0, 0)) - alphaLayerLTC.paste(agradientL, (0, 0)) - alphaLayerLTC.paste(agradientT, (0, 0)) - alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) - # make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile - # to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space - alphaLayerTaC = alphaLayerT.copy() - alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) - alphaLayerLTaC = alphaLayerLTC.copy() - alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) - - if embiggen_tiles: - # Individual unconnected sides - alphaLayerR = Image.new("L", (width, height), 255) - alphaLayerR.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - alphaLayerB = Image.new("L", (width, height), 255) - alphaLayerB.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerTB = Image.new("L", (width, height), 255) - alphaLayerTB.paste(agradientT, (0, 0)) - alphaLayerTB.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerLR = Image.new("L", (width, height), 255) - alphaLayerLR.paste(agradientL, (0, 0)) - alphaLayerLR.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - - # Sides and corner Layers - alphaLayerRBC = Image.new("L", (width, height), 255) - alphaLayerRBC.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - alphaLayerRBC.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerRBC.paste(agradientC.rotate(180).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) - alphaLayerLBC = Image.new("L", (width, height), 255) - alphaLayerLBC.paste(agradientL, (0, 0)) - alphaLayerLBC.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerLBC.paste(agradientC.rotate(90).resize( - (overlap_size_x, overlap_size_y)), (0, height - overlap_size_y)) - alphaLayerRTC = Image.new("L", (width, height), 255) - alphaLayerRTC.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - alphaLayerRTC.paste(agradientT, (0, 0)) - alphaLayerRTC.paste(agradientC.rotate(270).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) - - # All but X layers - alphaLayerABT = Image.new("L", (width, height), 255) - alphaLayerABT.paste(alphaLayerLBC, (0, 0)) - alphaLayerABT.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - alphaLayerABT.paste(agradientC.rotate(180).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) - alphaLayerABL = Image.new("L", (width, height), 255) - alphaLayerABL.paste(alphaLayerRTC, (0, 0)) - alphaLayerABL.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerABL.paste(agradientC.rotate(180).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) - alphaLayerABR = Image.new("L", (width, height), 255) - alphaLayerABR.paste(alphaLayerLBC, (0, 0)) - alphaLayerABR.paste(agradientT, (0, 0)) - alphaLayerABR.paste(agradientC.resize( - (overlap_size_x, overlap_size_y)), (0, 0)) - alphaLayerABB = Image.new("L", (width, height), 255) - alphaLayerABB.paste(alphaLayerRTC, (0, 0)) - alphaLayerABB.paste(agradientL, (0, 0)) - alphaLayerABB.paste(agradientC.resize( - (overlap_size_x, overlap_size_y)), (0, 0)) - - # All-around layer - alphaLayerAA = Image.new("L", (width, height), 255) - alphaLayerAA.paste(alphaLayerABT, (0, 0)) - alphaLayerAA.paste(agradientT, (0, 0)) - alphaLayerAA.paste(agradientC.resize( - (overlap_size_x, overlap_size_y)), (0, 0)) - alphaLayerAA.paste(agradientC.rotate(270).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) - - # Clean up temporary gradients - del agradientL - del agradientT - del agradientC - - def make_image(): - # Make main tiles ------------------------------------------------- - if embiggen_tiles: - print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...') - else: - print( - f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...') - - emb_tile_store = [] - # Although we could use the same seed for every tile for determinism, at higher strengths this may - # produce duplicated structures for each tile and make the tiling effect more obvious - # instead track and iterate a local seed we pass to Img2Img - seed = self.seed - seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy - - for tile in range(emb_tiles_x * emb_tiles_y): - # Don't iterate on first tile - if tile != 0: - if seed < seedintlimit: - seed += 1 - else: - seed = 0 - - # Determine if this is a re-run and replace - if embiggen_tiles and not tile in embiggen_tiles: - continue - # Get row and column entries - emb_row_i = tile // emb_tiles_x - emb_column_i = tile % emb_tiles_x - # Determine bounds to cut up the init image - # Determine upper-left point - if emb_column_i + 1 == emb_tiles_x: - left = initsuperwidth - width - else: - left = round(emb_column_i * (width - overlap_size_x)) - if emb_row_i + 1 == emb_tiles_y: - top = initsuperheight - height - else: - top = round(emb_row_i * (height - overlap_size_y)) - right = left + width - bottom = top + height - - # Cropped image of above dimension (does not modify the original) - newinitimage = initsuperimage.crop((left, top, right, bottom)) - # DEBUG: - # newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png' - # newinitimage.save(newinitimagepath) - - if embiggen_tiles: - print( - f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)') - else: - print( - f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles') - - # create a torch tensor from an Image - newinitimage = np.array( - newinitimage).astype(np.float32) / 255.0 - newinitimage = newinitimage[None].transpose(0, 3, 1, 2) - newinitimage = torch.from_numpy(newinitimage) - newinitimage = 2.0 * newinitimage - 1.0 - newinitimage = newinitimage.to(self.model.device) - - tile_results = gen_img2img.generate( - prompt, - iterations = 1, - seed = seed, - sampler = DDIMSampler(self.model, device=self.model.device), - steps = steps, - cfg_scale = cfg_scale, - conditioning = conditioning, - ddim_eta = ddim_eta, - image_callback = None, # called only after the final image is generated - step_callback = step_callback, # called after each intermediate image is generated - width = width, - height = height, - init_image = newinitimage, # notice that init_image is different from init_img - mask_image = None, - strength = strength, - ) - - emb_tile_store.append(tile_results[0][0]) - # DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite - # emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png') - del newinitimage - - # Sanity check we have them all - if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)): - outputsuperimage = Image.new( - "RGBA", (initsuperwidth, initsuperheight)) - if embiggen_tiles: - outputsuperimage.alpha_composite( - initsuperimage.convert('RGBA'), (0, 0)) - for tile in range(emb_tiles_x * emb_tiles_y): - if embiggen_tiles: - if tile in embiggen_tiles: - intileimage = emb_tile_store.pop(0) - else: - continue - else: - intileimage = emb_tile_store[tile] - intileimage = intileimage.convert('RGBA') - # Get row and column entries - emb_row_i = tile // emb_tiles_x - emb_column_i = tile % emb_tiles_x - if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles: - left = 0 - top = 0 - else: - # Determine upper-left point - if emb_column_i + 1 == emb_tiles_x: - left = initsuperwidth - width - else: - left = round(emb_column_i * - (width - overlap_size_x)) - if emb_row_i + 1 == emb_tiles_y: - top = initsuperheight - height - else: - top = round(emb_row_i * (height - overlap_size_y)) - # Handle gradients for various conditions - # Handle emb_rerun case - if embiggen_tiles: - # top of image - if emb_row_i == 0: - if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerB) - # Otherwise do nothing on this tile - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only - intileimage.putalpha(alphaLayerR) - else: - intileimage.putalpha(alphaLayerRBC) - elif emb_column_i == emb_tiles_x - 1: - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerL) - else: - intileimage.putalpha(alphaLayerLBC) - else: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerL) - else: - intileimage.putalpha(alphaLayerLBC) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only - intileimage.putalpha(alphaLayerLR) - else: - intileimage.putalpha(alphaLayerABT) - # bottom of image - elif emb_row_i == emb_tiles_y - 1: - if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right - intileimage.putalpha(alphaLayerTaC) - else: - intileimage.putalpha(alphaLayerRTC) - elif emb_column_i == emb_tiles_x - 1: - # No tiles to look ahead to - intileimage.putalpha(alphaLayerLTC) - else: - if (tile+1) in embiggen_tiles: # Look-ahead right - intileimage.putalpha(alphaLayerLTaC) - else: - intileimage.putalpha(alphaLayerABB) - # vertical middle of image - else: - if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerTaC) - else: - intileimage.putalpha(alphaLayerTB) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only - intileimage.putalpha(alphaLayerRTC) - else: - intileimage.putalpha(alphaLayerABL) - elif emb_column_i == emb_tiles_x - 1: - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerLTC) - else: - intileimage.putalpha(alphaLayerABR) - else: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerLTaC) - else: - intileimage.putalpha(alphaLayerABR) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only - intileimage.putalpha(alphaLayerABB) - else: - intileimage.putalpha(alphaLayerAA) - # Handle normal tiling case (much simpler - since we tile left to right, top to bottom) - else: - if emb_row_i == 0 and emb_column_i >= 1: - intileimage.putalpha(alphaLayerL) - elif emb_row_i >= 1 and emb_column_i == 0: - if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right - intileimage.putalpha(alphaLayerT) - else: - intileimage.putalpha(alphaLayerTaC) - else: - if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right - intileimage.putalpha(alphaLayerLTC) - else: - intileimage.putalpha(alphaLayerLTaC) - # Layer tile onto final image - outputsuperimage.alpha_composite(intileimage, (left, top)) - else: - print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.') - - # after internal loops and patching up return Embiggen image - return outputsuperimage - # end of function declaration - return make_image diff --git a/ldm/invoke/ckpt_generator/img2img.py b/ldm/invoke/ckpt_generator/img2img.py deleted file mode 100644 index e1f12b542e..0000000000 --- a/ldm/invoke/ckpt_generator/img2img.py +++ /dev/null @@ -1,97 +0,0 @@ -''' -ldm.invoke.ckpt_generator.img2img descends from ldm.invoke.generator -''' - -import torch -import numpy as np -import PIL -from torch import Tensor -from PIL import Image -from ldm.invoke.devices import choose_autocast -from ldm.invoke.ckpt_generator.base import CkptGenerator -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent - -class CkptImg2Img(CkptGenerator): - def __init__(self, model, precision): - super().__init__(model, precision) - self.init_latent = None # by get_noise() - - def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it. - """ - self.perlin = perlin - - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - if isinstance(init_image, PIL.Image.Image): - init_image = self._image_to_tensor(init_image.convert('RGB')) - - scope = choose_autocast(self.precision) - with scope(self.model.device.type): - self.init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - - t_enc = int(strength * steps) - uc, c, extra_conditioning_info = conditioning - - def make_image(x_T): - # encode (scaled latent) - z_enc = sampler.stochastic_encode( - self.init_latent, - torch.tensor([t_enc - 1]).to(self.model.device), - noise=x_T - ) - - if self.free_gpu_mem and self.model.model.device != self.model.device: - self.model.model.to(self.model.device) - - # decode it - samples = sampler.decode( - z_enc, - c, - t_enc, - img_callback = step_callback, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - init_latent = self.init_latent, # changes how noising is performed in ksampler - extra_conditioning_info = extra_conditioning_info, - all_timesteps_count = steps - ) - - if self.free_gpu_mem: - self.model.model.to("cpu") - - return self.sample_to_image(samples) - - return make_image - - def get_noise(self,width,height): - device = self.model.device - init_latent = self.init_latent - assert init_latent is not None,'call to get_noise() when init_latent not set' - if device.type == 'mps': - x = torch.randn_like(init_latent, device='cpu').to(device) - else: - x = torch.randn_like(init_latent, device=device) - if self.perlin > 0.0: - shape = init_latent.shape - x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) - return x - - def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: - image = np.array(image).astype(np.float32) / 255.0 - if len(image.shape) == 2: # 'L' image, as in a mask - image = image[None,None] - else: # 'RGB' image - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - if normalize: - image = 2.0 * image - 1.0 - return image.to(self.model.device) diff --git a/ldm/invoke/ckpt_generator/inpaint.py b/ldm/invoke/ckpt_generator/inpaint.py deleted file mode 100644 index 3b965b0ee3..0000000000 --- a/ldm/invoke/ckpt_generator/inpaint.py +++ /dev/null @@ -1,358 +0,0 @@ -''' -ldm.invoke.ckpt_generator.inpaint descends from ldm.invoke.ckpt_generator -''' - -import math -import torch -import torchvision.transforms as T -import numpy as np -import cv2 as cv -import PIL -from PIL import Image, ImageFilter, ImageOps, ImageChops -from skimage.exposure.histogram_matching import match_histograms -from einops import rearrange, repeat -from ldm.invoke.devices import choose_autocast -from ldm.invoke.ckpt_generator.img2img import CkptImg2Img -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.ksampler import KSampler -from ldm.invoke.generator.base import downsampling -from ldm.util import debug_image -from ldm.invoke.patchmatch import PatchMatch -from ldm.invoke.globals import Globals - -def infill_methods()->list[str]: - methods = list() - if PatchMatch.patchmatch_available(): - methods.append('patchmatch') - methods.append('tile') - return methods - -class CkptInpaint(CkptImg2Img): - def __init__(self, model, precision): - self.init_latent = None - self.pil_image = None - self.pil_mask = None - self.mask_blur_radius = 0 - self.infill_method = None - super().__init__(model, precision) - - # Outpaint support code - def get_tile_images(self, image: np.ndarray, width=8, height=8): - _nrows, _ncols, depth = image.shape - _strides = image.strides - - nrows, _m = divmod(_nrows, height) - ncols, _n = divmod(_ncols, width) - if _m != 0 or _n != 0: - return None - - return np.lib.stride_tricks.as_strided( - np.ravel(image), - shape=(nrows, ncols, height, width, depth), - strides=(height * _strides[0], width * _strides[1], *_strides), - writeable=False - ) - - def infill_patchmatch(self, im: Image.Image) -> Image: - if im.mode != 'RGBA': - return im - - # Skip patchmatch if patchmatch isn't available - if not PatchMatch.patchmatch_available(): - return im - - # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) - im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3) - im_patched = Image.fromarray(im_patched_np, mode = 'RGB') - return im_patched - - def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: - # Only fill if there's an alpha layer - if im.mode != 'RGBA': - return im - - a = np.asarray(im, dtype=np.uint8) - - tile_size = (tile_size, tile_size) - - # Get the image as tiles of a specified size - tiles = self.get_tile_images(a,*tile_size).copy() - - # Get the mask as tiles - tiles_mask = tiles[:,:,:,:,3] - - # Find any mask tiles with any fully transparent pixels (we will be replacing these later) - tmask_shape = tiles_mask.shape - tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape)) - n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:]) - tiles_mask = (tiles_mask > 0) - tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1) - - # Get RGB tiles in single array and filter by the mask - tshape = tiles.shape - tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:])) - filtered_tiles = tiles_all[tiles_mask] - - if len(filtered_tiles) == 0: - return im - - # Find all invalid tiles and replace with a random valid tile - replace_count = (tiles_mask == False).sum() - rng = np.random.default_rng(seed = seed) - tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:] - - # Convert back to an image - tiles_all = tiles_all.reshape(tshape) - tiles_all = tiles_all.swapaxes(1,2) - st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4])) - si = Image.fromarray(st, mode='RGBA') - - return si - - - def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: - npimg = np.asarray(mask, dtype=np.uint8) - - # Detect any partially transparent regions - npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) - - # Detect hard edges - npedge = cv.Canny(npimg, threshold1=100, threshold2=200) - - # Combine - npmask = npgradient + npedge - - # Expand - npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) - - new_mask = Image.fromarray(npmask) - - if edge_blur > 0: - new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur)) - - return ImageOps.invert(new_mask) - - - def seam_paint(self, - im: Image.Image, - seam_size: int, - seam_blur: int, - prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,strength, - noise, - step_callback - ) -> Image.Image: - hard_mask = self.pil_image.split()[-1].copy() - mask = self.mask_edge(hard_mask, seam_size, seam_blur) - - make_image = self.get_make_image( - prompt, - sampler, - steps, - cfg_scale, - ddim_eta, - conditioning, - init_image = im.copy().convert('RGBA'), - mask_image = mask.convert('RGB'), # Code currently requires an RGB mask - strength = strength, - mask_blur_radius = 0, - seam_size = 0, - step_callback = step_callback, - inpaint_width = im.width, - inpaint_height = im.height - ) - - seam_noise = self.get_noise(im.width, im.height) - - result = make_image(seam_noise) - - return result - - - @torch.no_grad() - def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,init_image,mask_image,strength, - mask_blur_radius: int = 8, - # Seam settings - when 0, doesn't fill seam - seam_size: int = 0, - seam_blur: int = 0, - seam_strength: float = 0.7, - seam_steps: int = 10, - tile_size: int = 32, - step_callback=None, - inpaint_replace=False, enable_image_debugging=False, - infill_method = None, - inpaint_width=None, - inpaint_height=None, - **kwargs): - """ - Returns a function returning an image derived from the prompt and - the initial image + mask. Return value depends on the seed at - the time you call it. kwargs are 'init_latent' and 'strength' - """ - - self.enable_image_debugging = enable_image_debugging - self.infill_method = infill_method or infill_methods()[0], # The infill method to use - - self.inpaint_width = inpaint_width - self.inpaint_height = inpaint_height - - if isinstance(init_image, PIL.Image.Image): - self.pil_image = init_image.copy() - - # Do infill - if infill_method == 'patchmatch' and PatchMatch.patchmatch_available(): - init_filled = self.infill_patchmatch(self.pil_image.copy()) - else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch - init_filled = self.tile_fill_missing( - self.pil_image.copy(), - seed = self.seed, - tile_size = tile_size - ) - init_filled.paste(init_image, (0,0), init_image.split()[-1]) - - # Resize if requested for inpainting - if inpaint_width and inpaint_height: - init_filled = init_filled.resize((inpaint_width, inpaint_height)) - - debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging) - - # Create init tensor - init_image = self._image_to_tensor(init_filled.convert('RGB')) - - if isinstance(mask_image, PIL.Image.Image): - self.pil_mask = mask_image.copy() - debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging) - - mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB')) - self.pil_mask = mask_image - - # Resize if requested for inpainting - if inpaint_width and inpaint_height: - mask_image = mask_image.resize((inpaint_width, inpaint_height)) - - debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging) - mask_image = mask_image.resize( - ( - mask_image.width // downsampling, - mask_image.height // downsampling - ), - resample=Image.Resampling.NEAREST - ) - mask_image = self._image_to_tensor(mask_image,normalize=False) - - self.mask_blur_radius = mask_blur_radius - - # klms samplers not supported yet, so ignore previous sampler - if isinstance(sampler,KSampler): - print( - f">> Using recommended DDIM sampler for inpainting." - ) - sampler = DDIMSampler(self.model, device=self.model.device) - - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) - mask_image = repeat(mask_image, '1 ... -> b ...', b=1) - - scope = choose_autocast(self.precision) - with scope(self.model.device.type): - self.init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - - t_enc = int(strength * steps) - # todo: support cross-attention control - uc, c, _ = conditioning - - print(f">> target t_enc is {t_enc} steps") - - @torch.no_grad() - def make_image(x_T): - # encode (scaled latent) - z_enc = sampler.stochastic_encode( - self.init_latent, - torch.tensor([t_enc - 1]).to(self.model.device), - noise=x_T - ) - - # to replace masked area with latent noise, weighted by inpaint_replace strength - if inpaint_replace > 0.0: - print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}') - l_noise = self.get_noise(kwargs['width'],kwargs['height']) - inverted_mask = 1.0-mask_image # there will be 1s where the mask is - masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise - z_enc = z_enc * mask_image + masked_region - - if self.free_gpu_mem and self.model.model.device != self.model.device: - self.model.model.to(self.model.device) - - # decode it - samples = sampler.decode( - z_enc, - c, - t_enc, - img_callback = step_callback, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc, - mask = mask_image, - init_latent = self.init_latent - ) - - result = self.sample_to_image(samples) - - # Seam paint if this is our first pass (seam_size set to 0 during seam painting) - if seam_size > 0: - old_image = self.pil_image or init_image - old_mask = self.pil_mask or mask_image - - result = self.seam_paint( - result, - seam_size, - seam_blur, - prompt, - sampler, - seam_steps, - cfg_scale, - ddim_eta, - conditioning, - seam_strength, - x_T, - step_callback) - - # Restore original settings - self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning, - old_image, - old_mask, - strength, - mask_blur_radius, seam_size, seam_blur, seam_strength, - seam_steps, tile_size, step_callback, - inpaint_replace, enable_image_debugging, - inpaint_width = inpaint_width, - inpaint_height = inpaint_height, - infill_method = infill_method, - **kwargs) - - return result - - return make_image - - - def sample_to_image(self, samples)->Image.Image: - gen_result = super().sample_to_image(samples).convert('RGB') - debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging) - - # Resize if necessary - if self.inpaint_width and self.inpaint_height: - gen_result = gen_result.resize(self.pil_image.size) - - if self.pil_image is None or self.pil_mask is None: - return gen_result - - corrected_result = super().repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) - debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging) - - return corrected_result diff --git a/ldm/invoke/ckpt_generator/omnibus.py b/ldm/invoke/ckpt_generator/omnibus.py deleted file mode 100644 index a479ac85ec..0000000000 --- a/ldm/invoke/ckpt_generator/omnibus.py +++ /dev/null @@ -1,175 +0,0 @@ -"""omnibus module to be used with the runwayml 9-channel custom inpainting model""" - -import torch -import numpy as np -from einops import repeat -from PIL import Image, ImageOps, ImageChops -from ldm.invoke.devices import choose_autocast -from ldm.invoke.ckpt_generator.base import downsampling -from ldm.invoke.ckpt_generator.img2img import CkptImg2Img -from ldm.invoke.ckpt_generator.txt2img import CkptTxt2Img - -class CkptOmnibus(CkptImg2Img,CkptTxt2Img): - def __init__(self, model, precision): - super().__init__(model, precision) - self.pil_mask = None - self.pil_image = None - - def get_make_image( - self, - prompt, - sampler, - steps, - cfg_scale, - ddim_eta, - conditioning, - width, - height, - init_image = None, - mask_image = None, - strength = None, - step_callback=None, - threshold=0.0, - perlin=0.0, - mask_blur_radius: int = 8, - **kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it. - """ - self.perlin = perlin - num_samples = 1 - - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - if isinstance(init_image, Image.Image): - self.pil_image = init_image - if init_image.mode != 'RGB': - init_image = init_image.convert('RGB') - init_image = self._image_to_tensor(init_image) - - if isinstance(mask_image, Image.Image): - self.pil_mask = mask_image - - mask_image = ImageChops.multiply(mask_image.convert('L'), self.pil_image.split()[-1]) - mask_image = self._image_to_tensor(ImageOps.invert(mask_image), normalize=False) - - self.mask_blur_radius = mask_blur_radius - - t_enc = steps - - if init_image is not None and mask_image is not None: # inpainting - masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero - - elif init_image is not None: # img2img - scope = choose_autocast(self.precision) - - with scope(self.model.device.type): - self.init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - - # create a completely black mask (1s) - mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device) - # and the masked image is just a copy of the original - masked_image = init_image - - else: # txt2img - init_image = torch.zeros(1, 3, height, width, device=self.model.device) - mask_image = torch.ones(1, 1, height, width, device=self.model.device) - masked_image = init_image - - self.init_latent = init_image - height = init_image.shape[2] - width = init_image.shape[3] - model = self.model - - def make_image(x_T): - with torch.no_grad(): - scope = choose_autocast(self.precision) - with scope(self.model.device.type): - - batch = self.make_batch_sd( - init_image, - mask_image, - masked_image, - prompt=prompt, - device=model.device, - num_samples=num_samples, - ) - - c = model.cond_stage_model.encode(batch["txt"]) - c_cat = list() - for ck in model.concat_keys: - cc = batch[ck].float() - if ck != model.masked_image_key: - bchw = [num_samples, 4, height//8, width//8] - cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) - else: - cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) - c_cat.append(cc) - c_cat = torch.cat(c_cat, dim=1) - - # cond - cond={"c_concat": [c_cat], "c_crossattn": [c]} - - # uncond cond - uc_cross = model.get_unconditional_conditioning(num_samples, "") - uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} - shape = [model.channels, height//8, width//8] - - samples, _ = sampler.sample( - batch_size = 1, - S = steps, - x_T = x_T, - conditioning = cond, - shape = shape, - verbose = False, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc_full, - eta = 1.0, - img_callback = step_callback, - threshold = threshold, - ) - if self.free_gpu_mem: - self.model.model.to("cpu") - return self.sample_to_image(samples) - - return make_image - - def make_batch_sd( - self, - image, - mask, - masked_image, - prompt, - device, - num_samples=1): - batch = { - "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples), - "txt": num_samples * [prompt], - "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), - "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), - } - return batch - - def get_noise(self, width:int, height:int): - if self.init_latent is not None: - height = self.init_latent.shape[2] - width = self.init_latent.shape[3] - return CkptTxt2Img.get_noise(self,width,height) - - - def sample_to_image(self, samples)->Image.Image: - gen_result = super().sample_to_image(samples).convert('RGB') - - if self.pil_image is None or self.pil_mask is None: - return gen_result - if self.pil_image.size != self.pil_mask.size: - return gen_result - - corrected_result = super(CkptImg2Img, self).repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) - - return corrected_result diff --git a/ldm/invoke/ckpt_generator/txt2img.py b/ldm/invoke/ckpt_generator/txt2img.py deleted file mode 100644 index 825b8583b9..0000000000 --- a/ldm/invoke/ckpt_generator/txt2img.py +++ /dev/null @@ -1,90 +0,0 @@ -''' -ldm.invoke.ckpt_generator.txt2img inherits from ldm.invoke.ckpt_generator -''' - -import torch -import numpy as np -from ldm.invoke.ckpt_generator.base import CkptGenerator -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -import gc - - -class CkptTxt2Img(CkptGenerator): - def __init__(self, model, precision): - super().__init__(model, precision) - - @torch.no_grad() - def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0, - attention_maps_callback=None, - **kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it - kwargs are 'width' and 'height' - """ - self.perlin = perlin - uc, c, extra_conditioning_info = conditioning - - @torch.no_grad() - def make_image(x_T): - shape = [ - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor, - ] - - if self.free_gpu_mem and self.model.model.device != self.model.device: - self.model.model.to(self.model.device) - - sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) - - samples, _ = sampler.sample( - batch_size = 1, - S = steps, - x_T = x_T, - conditioning = c, - shape = shape, - verbose = False, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc, - extra_conditioning_info = extra_conditioning_info, - eta = ddim_eta, - img_callback = step_callback, - threshold = threshold, - attention_maps_callback = attention_maps_callback, - ) - - if self.free_gpu_mem: - self.model.model.to('cpu') - self.model.cond_stage_model.device = 'cpu' - self.model.cond_stage_model.to('cpu') - gc.collect() - torch.cuda.empty_cache() - - return self.sample_to_image(samples) - - return make_image - - - # returns a tensor filled with random numbers from a normal distribution - def get_noise(self,width,height): - device = self.model.device - if self.use_mps_noise or device.type == 'mps': - x = torch.randn([1, - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - dtype=self.torch_dtype(), - device='cpu').to(device) - else: - x = torch.randn([1, - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - dtype=self.torch_dtype(), - device=device) - if self.perlin > 0.0: - x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) - return x - diff --git a/ldm/invoke/ckpt_generator/txt2img2img.py b/ldm/invoke/ckpt_generator/txt2img2img.py deleted file mode 100644 index 167debb98e..0000000000 --- a/ldm/invoke/ckpt_generator/txt2img2img.py +++ /dev/null @@ -1,182 +0,0 @@ -''' -ldm.invoke.ckpt_generator.txt2img inherits from ldm.invoke.ckpt_generator -''' - -import torch -import numpy as np -import math -import gc -from ldm.invoke.ckpt_generator.base import CkptGenerator -from ldm.invoke.ckpt_generator.omnibus import CkptOmnibus -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -from PIL import Image - -class CkptTxt2Img2Img(CkptGenerator): - def __init__(self, model, precision): - super().__init__(model, precision) - self.init_latent = None # for get_noise() - - @torch.no_grad() - def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,width,height,strength,step_callback=None,**kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it - kwargs are 'width' and 'height' - """ - uc, c, extra_conditioning_info = conditioning - scale_dim = min(width, height) - scale = 512 / scale_dim - - init_width = math.ceil(scale * width / 64) * 64 - init_height = math.ceil(scale * height / 64) * 64 - - @torch.no_grad() - def make_image(x_T): - - shape = [ - self.latent_channels, - init_height // self.downsampling_factor, - init_width // self.downsampling_factor, - ] - - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - #x = self.get_noise(init_width, init_height) - x = x_T - - if self.free_gpu_mem and self.model.model.device != self.model.device: - self.model.model.to(self.model.device) - - samples, _ = sampler.sample( - batch_size = 1, - S = steps, - x_T = x, - conditioning = c, - shape = shape, - verbose = False, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc, - eta = ddim_eta, - img_callback = step_callback, - extra_conditioning_info = extra_conditioning_info - ) - - print( - f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling" - ) - - # resizing - samples = torch.nn.functional.interpolate( - samples, - size=(height // self.downsampling_factor, width // self.downsampling_factor), - mode="bilinear" - ) - - t_enc = int(strength * steps) - ddim_sampler = DDIMSampler(self.model, device=self.model.device) - ddim_sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - z_enc = ddim_sampler.stochastic_encode( - samples, - torch.tensor([t_enc-1]).to(self.model.device), - noise=self.get_noise(width,height,False) - ) - - # decode it - samples = ddim_sampler.decode( - z_enc, - c, - t_enc, - img_callback = step_callback, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - extra_conditioning_info=extra_conditioning_info, - all_timesteps_count=steps - ) - - if self.free_gpu_mem: - self.model.model.to('cpu') - self.model.cond_stage_model.device = 'cpu' - self.model.cond_stage_model.to('cpu') - gc.collect() - torch.cuda.empty_cache() - - return self.sample_to_image(samples) - - # in the case of the inpainting model being loaded, the trick of - # providing an interpolated latent doesn't work, so we transiently - # create a 512x512 PIL image, upscale it, and run the inpainting - # over it in img2img mode. Because the inpaing model is so conservative - # it doesn't change the image (much) - def inpaint_make_image(x_T): - omnibus = CkptOmnibus(self.model,self.precision) - result = omnibus.generate( - prompt, - sampler=sampler, - width=init_width, - height=init_height, - step_callback=step_callback, - steps = steps, - cfg_scale = cfg_scale, - ddim_eta = ddim_eta, - conditioning = conditioning, - **kwargs - ) - assert result is not None and len(result)>0,'** txt2img failed **' - image = result[0][0] - interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS) - print(kwargs.pop('init_image',None)) - result = omnibus.generate( - prompt, - sampler=sampler, - init_image=interpolated_image, - width=width, - height=height, - seed=result[0][1], - step_callback=step_callback, - steps = steps, - cfg_scale = cfg_scale, - ddim_eta = ddim_eta, - conditioning = conditioning, - **kwargs - ) - return result[0][0] - - if sampler.uses_inpainting_model(): - return inpaint_make_image - else: - return make_image - - # returns a tensor filled with random numbers from a normal distribution - def get_noise(self,width,height,scale = True): - # print(f"Get noise: {width}x{height}") - if scale: - trained_square = 512 * 512 - actual_square = width * height - scale = math.sqrt(trained_square / actual_square) - scaled_width = math.ceil(scale * width / 64) * 64 - scaled_height = math.ceil(scale * height / 64) * 64 - else: - scaled_width = width - scaled_height = height - - device = self.model.device - if self.use_mps_noise or device.type == 'mps': - return torch.randn([1, - self.latent_channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - device='cpu').to(device) - else: - return torch.randn([1, - self.latent_channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - device=device) - diff --git a/ldm/invoke/ckpt_to_diffuser.py b/ldm/invoke/ckpt_to_diffuser.py index 82ba73b0a4..f6cac0b814 100644 --- a/ldm/invoke/ckpt_to_diffuser.py +++ b/ldm/invoke/ckpt_to_diffuser.py @@ -25,7 +25,7 @@ from ldm.invoke.globals import ( global_cache_dir, global_config_dir, ) -from ldm.invoke.model_manager import ModelManager, SDLegacyType +from invokeai.models import ModelManager, SDLegacyType from safetensors.torch import load_file from typing import Union @@ -56,7 +56,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS from diffusers.utils import is_safetensors_available from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig -from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline +from invokeai.generator import StableDiffusionGeneratorPipeline def shave_segments(path, n_shave_prefix_segments=1): """ diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 7c654caf69..7ff99c252e 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -14,7 +14,7 @@ from transformers import CLIPTokenizer, CLIPTextModel from compel import Compel from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser from .devices import torch_dtype -from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from invokeai.models import InvokeAIDiffuserComponent from ldm.invoke.globals import Globals def get_tokenizer(model) -> CLIPTokenizer: diff --git a/ldm/invoke/config/model_install_backend.py b/ldm/invoke/config/model_install_backend.py index 60abce8c8b..186af2aaae 100644 --- a/ldm/invoke/config/model_install_backend.py +++ b/ldm/invoke/config/model_install_backend.py @@ -20,7 +20,7 @@ from typing import List import invokeai.configs as configs from ..generator.diffusers_pipeline import StableDiffusionGeneratorPipeline from ..globals import Globals, global_cache_dir, global_config_dir -from ..model_manager import ModelManager +from invokeai.models import ModelManager warnings.filterwarnings("ignore") diff --git a/ldm/invoke/generator/__init__.py b/ldm/invoke/generator/__init__.py deleted file mode 100644 index 2fa5573c84..0000000000 --- a/ldm/invoke/generator/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -''' -Initialization file for the ldm.invoke.generator package -''' -from .base import Generator diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py deleted file mode 100644 index 21d6f271ca..0000000000 --- a/ldm/invoke/generator/base.py +++ /dev/null @@ -1,374 +0,0 @@ -''' -Base class for ldm.invoke.generator.* -including img2img, txt2img, and inpaint -''' -from __future__ import annotations - -import os -import os.path as osp -import random -import traceback -from contextlib import nullcontext - -import cv2 -import numpy as np -import torch - -from PIL import Image, ImageFilter, ImageChops -from diffusers import DiffusionPipeline -from einops import rearrange -from pathlib import Path -from pytorch_lightning import seed_everything -from tqdm import trange - -import invokeai.assets.web as web_assets -from ldm.models.diffusion.ddpm import DiffusionWrapper -from ldm.util import rand_perlin_2d - -downsampling = 8 -CAUTION_IMG = 'caution.png' - -class Generator: - downsampling_factor: int - latent_channels: int - precision: str - model: DiffusionWrapper | DiffusionPipeline - - def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str): - self.model = model - self.precision = precision - self.seed = None - self.latent_channels = model.channels - self.downsampling_factor = downsampling # BUG: should come from model or config - self.safety_checker = None - self.perlin = 0.0 - self.threshold = 0 - self.variation_amount = 0 - self.with_variations = [] - self.use_mps_noise = False - self.free_gpu_mem = None - self.caution_img = None - - # this is going to be overridden in img2img.py, txt2img.py and inpaint.py - def get_make_image(self,prompt,**kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it - """ - raise NotImplementedError("image_iterator() must be implemented in a descendent class") - - def set_variation(self, seed, variation_amount, with_variations): - self.seed = seed - self.variation_amount = variation_amount - self.with_variations = with_variations - - def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, - image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, - h_symmetry_time_pct=None, v_symmetry_time_pct=None, - safety_checker:dict=None, - free_gpu_mem: bool=False, - **kwargs): - scope = nullcontext - self.safety_checker = safety_checker - self.free_gpu_mem = free_gpu_mem - attention_maps_images = [] - attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) - make_image = self.get_make_image( - prompt, - sampler = sampler, - init_image = init_image, - width = width, - height = height, - step_callback = step_callback, - threshold = threshold, - perlin = perlin, - h_symmetry_time_pct = h_symmetry_time_pct, - v_symmetry_time_pct = v_symmetry_time_pct, - attention_maps_callback = attention_maps_callback, - **kwargs - ) - results = [] - seed = seed if seed is not None and seed >= 0 else self.new_seed() - first_seed = seed - seed, initial_noise = self.generate_initial_noise(seed, width, height) - - # There used to be an additional self.model.ema_scope() here, but it breaks - # the inpaint-1.5 model. Not sure what it did.... ? - with scope(self.model.device.type): - for n in trange(iterations, desc='Generating'): - x_T = None - if self.variation_amount > 0: - seed_everything(seed) - target_noise = self.get_noise(width,height) - x_T = self.slerp(self.variation_amount, initial_noise, target_noise) - elif initial_noise is not None: - # i.e. we specified particular variations - x_T = initial_noise - else: - seed_everything(seed) - try: - x_T = self.get_noise(width,height) - except: - print('** An error occurred while getting initial noise **') - print(traceback.format_exc()) - - image = make_image(x_T) - - if self.safety_checker is not None: - image = self.safety_check(image) - - results.append([image, seed]) - - if image_callback is not None: - attention_maps_image = None if len(attention_maps_images)==0 else attention_maps_images[-1] - image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_image) - - seed = self.new_seed() - - # Free up memory from the last generation. - clear_cuda_cache = kwargs['clear_cuda_cache'] if 'clear_cuda_cache' in kwargs else None - if clear_cuda_cache is not None: - clear_cuda_cache() - - return results - - def sample_to_image(self,samples)->Image.Image: - """ - Given samples returned from a sampler, converts - it into a PIL Image - """ - with torch.inference_mode(): - image = self.model.decode_latents(samples) - return self.model.numpy_to_pil(image)[0] - - def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image: - if init_image is None or init_mask is None: - return result - - # Get the original alpha channel of the mask if there is one. - # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') - pil_init_mask = init_mask.getchannel('A') if init_mask.mode == 'RGBA' else init_mask.convert('L') - pil_init_image = init_image.convert('RGBA') # Add an alpha channel if one doesn't exist - - # Build an image with only visible pixels from source to use as reference for color-matching. - init_rgb_pixels = np.asarray(init_image.convert('RGB'), dtype=np.uint8) - init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) - init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) - - # Get numpy version of result - np_image = np.asarray(result, dtype=np.uint8) - - # Mask and calculate mean and standard deviation - mask_pixels = init_a_pixels * init_mask_pixels > 0 - np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] - np_image_masked = np_image[mask_pixels, :] - - if np_init_rgb_pixels_masked.size > 0: - init_means = np_init_rgb_pixels_masked.mean(axis=0) - init_std = np_init_rgb_pixels_masked.std(axis=0) - gen_means = np_image_masked.mean(axis=0) - gen_std = np_image_masked.std(axis=0) - - # Color correct - np_matched_result = np_image.copy() - np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) - matched_result = Image.fromarray(np_matched_result, mode='RGB') - else: - matched_result = Image.fromarray(np_image, mode='RGB') - - # Blur the mask out (into init image) by specified amount - if mask_blur_radius > 0: - nm = np.asarray(pil_init_mask, dtype=np.uint8) - nmd = cv2.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2)) - pmd = Image.fromarray(nmd, mode='L') - blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius)) - else: - blurred_init_mask = pil_init_mask - - multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1]) - - # Paste original on color-corrected generation (using blurred mask) - matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask) - return matched_result - - def sample_to_lowres_estimated_image(self,samples): - # origingally adapted from code by @erucipe and @keturn here: - # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 - - # these updated numbers for v1.5 are from @torridgristle - v1_5_latent_rgb_factors = torch.tensor([ - # R G B - [ 0.3444, 0.1385, 0.0670], # L1 - [ 0.1247, 0.4027, 0.1494], # L2 - [-0.3192, 0.2513, 0.2103], # L3 - [-0.1307, -0.1874, -0.7445] # L4 - ], dtype=samples.dtype, device=samples.device) - - latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors - latents_ubyte = (((latent_image + 1) / 2) - .clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - .byte()).cpu() - - return Image.fromarray(latents_ubyte.numpy()) - - def generate_initial_noise(self, seed, width, height): - initial_noise = None - if self.variation_amount > 0 or len(self.with_variations) > 0: - # use fixed initial noise plus random noise per iteration - seed_everything(seed) - initial_noise = self.get_noise(width,height) - for v_seed, v_weight in self.with_variations: - seed = v_seed - seed_everything(seed) - next_noise = self.get_noise(width,height) - initial_noise = self.slerp(v_weight, initial_noise, next_noise) - if self.variation_amount > 0: - random.seed() # reset RNG to an actually random state, so we can get a random seed for variations - seed = random.randrange(0,np.iinfo(np.uint32).max) - return (seed, initial_noise) - else: - return (seed, None) - - # returns a tensor filled with random numbers from a normal distribution - def get_noise(self,width,height): - """ - Returns a tensor filled with random numbers, either form a normal distribution - (txt2img) or from the latent image (img2img, inpaint) - """ - raise NotImplementedError("get_noise() must be implemented in a descendent class") - - def get_perlin_noise(self,width,height): - fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device - # limit noise to only the diffusion image channels, not the mask channels - input_channels = min(self.latent_channels, 4) - # round up to the nearest block of 8 - temp_width = int((width + 7) / 8) * 8 - temp_height = int((height + 7) / 8) * 8 - noise = torch.stack([ - rand_perlin_2d((temp_height, temp_width), - (8, 8), - device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device) - return noise[0:4, 0:height, 0:width] - - def new_seed(self): - self.seed = random.randrange(0, np.iinfo(np.uint32).max) - return self.seed - - def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): - ''' - Spherical linear interpolation - Args: - t (float/np.ndarray): Float value between 0.0 and 1.0 - v0 (np.ndarray): Starting vector - v1 (np.ndarray): Final vector - DOT_THRESHOLD (float): Threshold for considering the two vectors as - colineal. Not recommended to alter this. - Returns: - v2 (np.ndarray): Interpolation vector between v0 and v1 - ''' - inputs_are_torch = False - if not isinstance(v0, np.ndarray): - inputs_are_torch = True - v0 = v0.detach().cpu().numpy() - if not isinstance(v1, np.ndarray): - inputs_are_torch = True - v1 = v1.detach().cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - if inputs_are_torch: - v2 = torch.from_numpy(v2).to(self.model.device) - - return v2 - - def safety_check(self,image:Image.Image): - ''' - If the CompViz safety checker flags an NSFW image, we - blur it out. - ''' - import diffusers - - checker = self.safety_checker['checker'] - extractor = self.safety_checker['extractor'] - features = extractor([image], return_tensors="pt") - features.to(self.model.device) - - # unfortunately checker requires the numpy version, so we have to convert back - x_image = np.array(image).astype(np.float32) / 255.0 - x_image = x_image[None].transpose(0, 3, 1, 2) - - diffusers.logging.set_verbosity_error() - checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values) - if has_nsfw_concept[0]: - print('** An image with potential non-safe content has been detected. A blurred image will be returned. **') - return self.blur(image) - else: - return image - - def blur(self,input): - blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) - try: - caution = self.get_caution_img() - if caution: - blurry.paste(caution,(0,0),caution) - except FileNotFoundError: - pass - return blurry - - def get_caution_img(self): - path = None - if self.caution_img: - return self.caution_img - path = Path(web_assets.__path__[0]) / CAUTION_IMG - caution = Image.open(path) - self.caution_img = caution.resize((caution.width // 2, caution.height //2)) - return self.caution_img - - # this is a handy routine for debugging use. Given a generated sample, - # convert it into a PNG image and store it at the indicated path - def save_sample(self, sample, filepath): - image = self.sample_to_image(sample) - dirname = os.path.dirname(filepath) or '.' - if not os.path.exists(dirname): - print(f'** creating directory {dirname}') - os.makedirs(dirname, exist_ok=True) - image.save(filepath,'PNG') - - - def torch_dtype(self)->torch.dtype: - return torch.float16 if self.precision == 'float16' else torch.float32 - - # returns a tensor filled with random numbers from a normal distribution - def get_noise(self,width,height): - device = self.model.device - # limit noise to only the diffusion image channels, not the mask channels - input_channels = min(self.latent_channels, 4) - if self.use_mps_noise or device.type == 'mps': - x = torch.randn([1, - input_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - dtype=self.torch_dtype(), - device='cpu').to(device) - else: - x = torch.randn([1, - input_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - dtype=self.torch_dtype(), - device=device) - if self.perlin > 0.0: - perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) - x = (1-self.perlin)*x + self.perlin*perlin_noise - return x diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py deleted file mode 100644 index 5e65cb5d13..0000000000 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ /dev/null @@ -1,765 +0,0 @@ -from __future__ import annotations - -import dataclasses -import inspect -import psutil -import secrets -from collections.abc import Sequence -from dataclasses import dataclass, field -from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any - -import PIL.Image -import einops -import psutil -import torch -import torchvision.transforms as T -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput -from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.outputs import BaseOutput -from torchvision.transforms.functional import resize as tv_resize -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from typing_extensions import ParamSpec - -from ldm.invoke.globals import Globals -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings -from ldm.modules.textual_inversion_manager import TextualInversionManager -from ..devices import normalize_device, CPU_DEVICE -from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup -from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver -from compel import EmbeddingsProvider - - -@dataclass -class PipelineIntermediateState: - run_id: str - step: int - timestep: int - latents: torch.Tensor - predicted_original: Optional[torch.Tensor] = None - attention_map_saver: Optional[AttentionMapSaver] = None - - -# copied from configs/stable-diffusion/v1-inference.yaml -_default_personalization_config_params = dict( - placeholder_strings=["*"], - initializer_wods=["sculpture"], - per_image_tokens=False, - num_vectors_per_token=1, - progressive_words=False -) - - -@dataclass -class AddsMaskLatents: - """Add the channels required for inpainting model input. - - The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask - and the latent encoding of the base image. - - This class assumes the same mask and base image should apply to all items in the batch. - """ - forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] - mask: torch.Tensor - initial_image_latents: torch.Tensor - - def __call__(self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor: - model_input = self.add_mask_channels(latents) - return self.forward(model_input, t, text_embeddings) - - def add_mask_channels(self, latents): - batch_size = latents.size(0) - # duplicate mask and latents for each batch - mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) - image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) - # add mask and image as additional channels - model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w') - return model_input - - -def are_like_tensors(a: torch.Tensor, b: object) -> bool: - return ( - isinstance(b, torch.Tensor) - and (a.size() == b.size()) - ) - -@dataclass -class AddsMaskGuidance: - mask: torch.FloatTensor - mask_latents: torch.FloatTensor - scheduler: SchedulerMixin - noise: torch.Tensor - _debug: Optional[Callable] = None - - def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput: - output_class = step_output.__class__ # We'll create a new one with masked data. - - # The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it. - # It's reasonable to assume the first thing is prev_sample, but then does it have other things - # like pred_original_sample? Should we apply the mask to them too? - # But what if there's just some other random field? - prev_sample = step_output[0] - # Mask anything that has the same shape as prev_sample, return others as-is. - return output_class( - {k: (self.apply_mask(v, self._t_for_field(k, t)) - if are_like_tensors(prev_sample, v) else v) - for k, v in step_output.items()} - ) - - def _t_for_field(self, field_name:str, t): - if field_name == "pred_original_sample": - return torch.zeros_like(t, dtype=t.dtype) # it represents t=0 - return t - - def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: - batch_size = latents.size(0) - mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) - if t.dim() == 0: - # some schedulers expect t to be one-dimensional. - # TODO: file diffusers bug about inconsistency? - t = einops.repeat(t, '-> batch', batch=batch_size) - # Noise shouldn't be re-randomized between steps here. The multistep schedulers - # get very confused about what is happening from step to step when we do that. - mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t) - # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? - # mask_latents = self.scheduler.scale_model_input(mask_latents, t) - mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) - masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) - if self._debug: - self._debug(masked_input, f"t={t} lerped") - return masked_input - - -def trim_to_multiple_of(*args, multiple_of=8): - return tuple((x - x % multiple_of) for x in args) - - -def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor: - """ - - :param image: input image - :param normalize: scale the range to [-1, 1] instead of [0, 1] - :param multiple_of: resize the input so both dimensions are a multiple of this - """ - w, h = trim_to_multiple_of(*image.size) - transformation = T.Compose([ - T.Resize((h, w), T.InterpolationMode.LANCZOS), - T.ToTensor(), - ]) - tensor = transformation(image) - if normalize: - tensor = tensor * 2.0 - 1.0 - return tensor - - -def is_inpainting_model(unet: UNet2DConditionModel): - return unet.conv_in.in_channels == 9 - -CallbackType = TypeVar('CallbackType') -ReturnType = TypeVar('ReturnType') -ParamType = ParamSpec('ParamType') - -@dataclass(frozen=True) -class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): - """Convert a generator to a function with a callback and a return value.""" - - generator_method: Callable[ParamType, ReturnType] - callback_arg_type: Type[CallbackType] - - def __call__(self, *args: ParamType.args, - callback:Callable[[CallbackType], Any]=None, - **kwargs: ParamType.kwargs) -> ReturnType: - result = None - for result in self.generator_method(*args, **kwargs): - if callback is not None and isinstance(result, self.callback_arg_type): - callback(result) - if result is None: - raise AssertionError("why was that an empty generator?") - return result - - -@dataclass(frozen=True) -class ConditioningData: - unconditioned_embeddings: torch.Tensor - text_embeddings: torch.Tensor - guidance_scale: float - """ - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). - Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate - images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - """ - extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None - scheduler_args: dict[str, Any] = field(default_factory=dict) - """ - Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing(). - """ - postprocessing_settings: Optional[PostprocessingSettings] = None - - @property - def dtype(self): - return self.text_embeddings.dtype - - def add_scheduler_args_if_applicable(self, scheduler, **kwargs): - scheduler_args = dict(self.scheduler_args) - step_method = inspect.signature(scheduler.step) - for name, value in kwargs.items(): - try: - step_method.bind_partial(**{name: value}) - except TypeError: - # FIXME: don't silently discard arguments - pass # debug("%s does not accept argument named %r", scheduler, name) - else: - scheduler_args[name] = value - return dataclasses.replace(self, scheduler_args=scheduler_args) - -@dataclass -class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): - r""" - Output class for InvokeAI's Stable Diffusion pipeline. - - Args: - attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user - after generation completes. Optional. - """ - attention_map_saver: Optional[AttentionMapSaver] - - -class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline. - Hopefully future versions of diffusers provide access to more of these functions so that we don't - need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384 - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _model_group: ModelGroup - - ID_LENGTH = 8 - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: Optional[StableDiffusionSafetyChecker], - feature_extractor: Optional[CLIPFeatureExtractor], - requires_safety_checker: bool = False, - precision: str = 'float32', - ): - super().__init__(vae, text_encoder, tokenizer, unet, scheduler, - safety_checker, feature_extractor, requires_safety_checker) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True) - use_full_precision = (precision == 'float32' or precision == 'autocast') - self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer, - text_encoder=self.text_encoder, - full_precision=use_full_precision) - # InvokeAI's interface for text embeddings and whatnot - self.embeddings_provider = EmbeddingsProvider( - tokenizer=self.tokenizer, - text_encoder=self.text_encoder, - textual_inversion_manager=self.textual_inversion_manager - ) - - self._model_group = FullyLoadedModelGroup(self.unet.device) - self._model_group.install(*self._submodels) - - - def _adjust_memory_efficient_attention(self, latents: torch.Tensor): - """ - if xformers is available, use it, otherwise use sliced attention. - """ - if torch.cuda.is_available() and is_xformers_available() and not Globals.disable_xformers: - self.enable_xformers_memory_efficient_attention() - else: - if torch.backends.mps.is_available(): - # until pytorch #91617 is fixed, slicing is borked on MPS - # https://github.com/pytorch/pytorch/issues/91617 - # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. - pass - else: - if self.device.type == 'cpu' or self.device.type == 'mps': - mem_free = psutil.virtual_memory().free - elif self.device.type == 'cuda': - mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device)) - else: - raise ValueError(f"unrecognized device {self.device}") - # input tensor of [1, 4, h/8, w/8] - # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 - max_size_required_for_baddbmm = \ - 16 * \ - latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \ - bytes_per_element_needed_for_baddbmm_duplication - if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code - self.enable_attention_slicing(slice_size='max') - else: - self.disable_attention_slicing() - - - def enable_offload_submodels(self, device: torch.device): - """ - Offload each submodel when it's not in use. - - Useful for low-vRAM situations where the size of the model in memory is a big chunk of - the total available resource, and you want to free up as much for inference as possible. - - This requires more moving parts and may add some delay as the U-Net is swapped out for the - VAE and vice-versa. - """ - models = self._submodels - if self._model_group is not None: - self._model_group.uninstall(*models) - group = LazilyLoadedModelGroup(device) - group.install(*models) - self._model_group = group - - def disable_offload_submodels(self): - """ - Leave all submodels loaded. - - Appropriate for cases where the size of the model in memory is small compared to the memory - required for inference. Avoids the delay and complexity of shuffling the submodels to and - from the GPU. - """ - models = self._submodels - if self._model_group is not None: - self._model_group.uninstall(*models) - group = FullyLoadedModelGroup(self._model_group.execution_device) - group.install(*models) - self._model_group = group - - def offload_all(self): - """Offload all this pipeline's models to CPU.""" - self._model_group.offload_current() - - def ready(self): - """ - Ready this pipeline's models. - - i.e. pre-load them to the GPU if appropriate. - """ - self._model_group.ready() - - def to(self, torch_device: Optional[Union[str, torch.device]] = None): - # overridden method; types match the superclass. - if torch_device is None: - return self - self._model_group.set_device(torch.device(torch_device)) - self._model_group.ready() - - @property - def device(self) -> torch.device: - return self._model_group.execution_device - - @property - def _submodels(self) -> Sequence[torch.nn.Module]: - module_names, _, _ = self.extract_init_dict(dict(self.config)) - values = [getattr(self, name) for name in module_names.keys()] - return [m for m in values if isinstance(m, torch.nn.Module)] - - def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, - conditioning_data: ConditioningData, - *, - noise: torch.Tensor, - callback: Callable[[PipelineIntermediateState], None]=None, - run_id=None) -> InvokeAIStableDiffusionPipelineOutput: - r""" - Function invoked when calling the pipeline for generation. - - :param conditioning_data: - :param latents: Pre-generated un-noised latents, to be used as inputs for - image generation. Can be used to tweak the same generation with different prompts. - :param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - :param noise: Noise to add to the latents, sampled from a Gaussian distribution. - :param callback: - :param run_id: - """ - result_latents, result_attention_map_saver = self.latents_from_embeddings( - latents, num_inference_steps, - conditioning_data, - noise=noise, - run_id=run_id, - callback=callback) - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - with torch.inference_mode(): - image = self.decode_latents(result_latents) - output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver) - return self.check_for_safety(output, dtype=conditioning_data.dtype) - - def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, - conditioning_data: ConditioningData, - *, - noise: torch.Tensor, - timesteps=None, - additional_guidance: List[Callable] = None, run_id=None, - callback: Callable[[PipelineIntermediateState], None] = None - ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: - if timesteps is None: - self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet)) - timesteps = self.scheduler.timesteps - infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) - result: PipelineIntermediateState = infer_latents_from_embeddings( - latents, timesteps, conditioning_data, - noise=noise, - additional_guidance=additional_guidance, - run_id=run_id, - callback=callback) - return result.latents, result.attention_map_saver - - def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, - conditioning_data: ConditioningData, - *, - noise: torch.Tensor, - run_id: str = None, - additional_guidance: List[Callable] = None): - self._adjust_memory_efficient_attention(latents) - if run_id is None: - run_id = secrets.token_urlsafe(self.ID_LENGTH) - if additional_guidance is None: - additional_guidance = [] - extra_conditioning_info = conditioning_data.extra - with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps) - ): - - yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, - latents=latents) - - batch_size = latents.shape[0] - batched_t = torch.full((batch_size,), timesteps[0], - dtype=timesteps.dtype, device=self._model_group.device_for(self.unet)) - latents = self.scheduler.add_noise(latents, noise, batched_t) - - attention_map_saver: Optional[AttentionMapSaver] = None - - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t.fill_(t) - step_output = self.step(batched_t, latents, conditioning_data, - step_index=i, - total_step_count=len(timesteps), - additional_guidance=additional_guidance) - latents = step_output.prev_sample - - latents = self.invokeai_diffuser.do_latent_postprocessing( - postprocessing_settings=conditioning_data.postprocessing_settings, - latents=latents, - sigma=batched_t, - step_index=i, - total_step_count=len(timesteps) - ) - - predicted_original = getattr(step_output, 'pred_original_sample', None) - - # TODO resuscitate attention map saving - #if i == len(timesteps)-1 and extra_conditioning_info is not None: - # eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 - # attention_map_token_ids = range(1, eos_token_index) - # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) - # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) - - yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, - predicted_original=predicted_original, attention_map_saver=attention_map_saver) - - return latents, attention_map_saver - - @torch.inference_mode() - def step(self, t: torch.Tensor, latents: torch.Tensor, - conditioning_data: ConditioningData, - step_index:int, total_step_count:int, - additional_guidance: List[Callable] = None): - # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value - timestep = t[0] - - if additional_guidance is None: - additional_guidance = [] - - # TODO: should this scaling happen here or inside self._unet_forward? - # i.e. before or after passing it to InvokeAIDiffuserComponent - latent_model_input = self.scheduler.scale_model_input(latents, timestep) - - # predict the noise residual - noise_pred = self.invokeai_diffuser.do_diffusion_step( - latent_model_input, t, - conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings, - conditioning_data.guidance_scale, - step_index=step_index, - total_step_count=total_step_count, - ) - - # compute the previous noisy sample x_t -> x_t-1 - step_output = self.scheduler.step(noise_pred, timestep, latents, - **conditioning_data.scheduler_args) - - # TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent. - # But the way things are now, scheduler runs _after_ that, so there was - # no way to use it to apply an operation that happens after the last scheduler.step. - for guidance in additional_guidance: - step_output = guidance(step_output, timestep, conditioning_data) - - return step_output - - def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None): - """predict the noise residual""" - if is_inpainting_model(self.unet) and latents.size(1) == 4: - # Pad out normal non-inpainting inputs for an inpainting model. - # FIXME: There are too many layers of functions and we have too many different ways of - # overriding things! This should get handled in a way more consistent with the other - # use of AddsMaskLatents. - latents = AddsMaskLatents( - self._unet_forward, - mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype), - initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype) - ).add_mask_channels(latents) - - # First three args should be positional, not keywords, so torch hooks can see them. - return self.unet(latents, t, text_embeddings, - cross_attention_kwargs=cross_attention_kwargs).sample - - def img2img_from_embeddings(self, - init_image: Union[torch.FloatTensor, PIL.Image.Image], - strength: float, - num_inference_steps: int, - conditioning_data: ConditioningData, - *, callback: Callable[[PipelineIntermediateState], None] = None, - run_id=None, - noise_func=None - ) -> InvokeAIStableDiffusionPipelineOutput: - if isinstance(init_image, PIL.Image.Image): - init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) - - if init_image.dim() == 3: - init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') - - # 6. Prepare latent variables - initial_latents = self.non_noised_latents_from_image( - init_image, device=self._model_group.device_for(self.unet), - dtype=self.unet.dtype) - noise = noise_func(initial_latents) - - return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, - conditioning_data, - strength, - noise, run_id, callback) - - def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, - conditioning_data: ConditioningData, - strength, - noise: torch.Tensor, run_id=None, callback=None - ) -> InvokeAIStableDiffusionPipelineOutput: - timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, - device=self._model_group.device_for(self.unet)) - result_latents, result_attention_maps = self.latents_from_embeddings( - initial_latents, num_inference_steps, conditioning_data, - timesteps=timesteps, - noise=noise, - run_id=run_id, - callback=callback) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - with torch.inference_mode(): - image = self.decode_latents(result_latents) - output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps) - return self.check_for_safety(output, dtype=conditioning_data.dtype) - - def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device) -> (torch.Tensor, int): - img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) - assert img2img_pipeline.scheduler is self.scheduler - img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, adjusted_steps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) - # Workaround for low strength resulting in zero timesteps. - # TODO: submit upstream fix for zero-step img2img - if timesteps.numel() == 0: - timesteps = self.scheduler.timesteps[-1:] - adjusted_steps = timesteps.numel() - return timesteps, adjusted_steps - - def inpaint_from_embeddings( - self, - init_image: torch.FloatTensor, - mask: torch.FloatTensor, - strength: float, - num_inference_steps: int, - conditioning_data: ConditioningData, - *, callback: Callable[[PipelineIntermediateState], None] = None, - run_id=None, - noise_func=None, - ) -> InvokeAIStableDiffusionPipelineOutput: - device = self._model_group.device_for(self.unet) - latents_dtype = self.unet.dtype - - if isinstance(init_image, PIL.Image.Image): - init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) - - init_image = init_image.to(device=device, dtype=latents_dtype) - mask = mask.to(device=device, dtype=latents_dtype) - - if init_image.dim() == 3: - init_image = init_image.unsqueeze(0) - - timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, device=device) - - # 6. Prepare latent variables - # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents - # because we have our own noise function - init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) - noise = noise_func(init_image_latents) - - if mask.dim() == 3: - mask = mask.unsqueeze(0) - latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \ - .to(device=device, dtype=latents_dtype) - - guidance: List[Callable] = [] - - if is_inpainting_model(self.unet): - # You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint - # (that's why there's a mask!) but it seems to really want that blanked out. - masked_init_image = init_image * torch.where(mask < 0.5, 1, 0) - masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype) - - # TODO: we should probably pass this in so we don't have to try/finally around setting it. - self.invokeai_diffuser.model_forward_callback = \ - AddsMaskLatents(self._unet_forward, latent_mask, masked_latents) - else: - guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)) - - try: - result_latents, result_attention_maps = self.latents_from_embeddings( - init_image_latents, num_inference_steps, - conditioning_data, noise=noise, timesteps=timesteps, - additional_guidance=guidance, - run_id=run_id, callback=callback) - finally: - self.invokeai_diffuser.model_forward_callback = self._unet_forward - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - with torch.inference_mode(): - image = self.decode_latents(result_latents) - output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps) - return self.check_for_safety(output, dtype=conditioning_data.dtype) - - def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): - init_image = init_image.to(device=device, dtype=dtype) - with torch.inference_mode(): - if device.type == 'mps': - # workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222 - # TODO remove this workaround once kulinseth#222 is merged to pytorch mainline - self.vae.to(CPU_DEVICE) - init_image = init_image.to(CPU_DEVICE) - else: - self._model_group.load(self.vae) - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! - if device.type == 'mps': - self.vae.to(device) - init_latents = init_latents.to(device) - - init_latents = 0.18215 * init_latents - return init_latents - - def check_for_safety(self, output, dtype): - with torch.inference_mode(): - screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype) - screened_attention_map_saver = None - if has_nsfw_concept is None or not has_nsfw_concept: - screened_attention_map_saver = output.attention_map_saver - return InvokeAIStableDiffusionPipelineOutput(screened_images, - has_nsfw_concept, - # block the attention maps if NSFW content is detected - attention_map_saver=screened_attention_map_saver) - - def run_safety_checker(self, image, device=None, dtype=None): - # overriding to use the model group for device info instead of requiring the caller to know. - if self.safety_checker is not None: - device = self._model_group.device_for(self.safety_checker) - return super().run_safety_checker(image, device, dtype) - - @torch.inference_mode() - def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): - """ - Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion. - """ - return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments( - text_batch=c, - fragment_weights_batch=fragment_weights, - should_return_tokens=return_tokens, - device=self._model_group.device_for(self.unet)) - - @property - def cond_stage_model(self): - return self.embeddings_provider - - @torch.inference_mode() - def _tokenize(self, prompt: Union[str, List[str]]): - return self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - @property - def channels(self) -> int: - """Compatible with DiffusionWrapper""" - return self.unet.in_channels - - def decode_latents(self, latents): - # Explicit call to get the vae loaded, since `decode` isn't the forward method. - self._model_group.load(self.vae) - return super().decode_latents(latents) - - def debug_latents(self, latents, msg): - with torch.inference_mode(): - from ldm.util import debug_image - decoded = self.numpy_to_pil(self.decode_latents(latents)) - for i, img in enumerate(decoded): - debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True) diff --git a/ldm/invoke/generator/embiggen.py b/ldm/invoke/generator/embiggen.py deleted file mode 100644 index 0a06f90b03..0000000000 --- a/ldm/invoke/generator/embiggen.py +++ /dev/null @@ -1,501 +0,0 @@ -''' -ldm.invoke.generator.embiggen descends from ldm.invoke.generator -and generates with ldm.invoke.generator.img2img -''' - -import numpy as np -import torch -from PIL import Image -from tqdm import trange - -from ldm.invoke.generator.base import Generator -from ldm.invoke.generator.img2img import Img2Img - - -class Embiggen(Generator): - def __init__(self, model, precision): - super().__init__(model, precision) - self.init_latent = None - - # Replace generate because Embiggen doesn't need/use most of what it does normallly - def generate(self,prompt,iterations=1,seed=None, - image_callback=None, step_callback=None, - **kwargs): - - make_image = self.get_make_image( - prompt, - step_callback = step_callback, - **kwargs - ) - results = [] - seed = seed if seed else self.new_seed() - - # Noise will be generated by the Img2Img generator when called - for _ in trange(iterations, desc='Generating'): - # make_image will call Img2Img which will do the equivalent of get_noise itself - image = make_image() - results.append([image, seed]) - if image_callback is not None: - image_callback(image, seed, prompt_in=prompt) - seed = self.new_seed() - return results - - @torch.no_grad() - def get_make_image( - self, - prompt, - sampler, - steps, - cfg_scale, - ddim_eta, - conditioning, - init_img, - strength, - width, - height, - embiggen, - embiggen_tiles, - step_callback=None, - **kwargs - ): - """ - Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image - Return value depends on the seed at the time you call it - """ - assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models" - - # Construct embiggen arg array, and sanity check arguments - if embiggen == None: # embiggen can also be called with just embiggen_tiles - embiggen = [1.0] # If not specified, assume no scaling - elif embiggen[0] < 0: - embiggen[0] = 1.0 - print( - '>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !') - if len(embiggen) < 2: - embiggen.append(0.75) - elif embiggen[1] > 1.0 or embiggen[1] < 0: - embiggen[1] = 0.75 - print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !') - if len(embiggen) < 3: - embiggen.append(0.25) - elif embiggen[2] < 0: - embiggen[2] = 0.25 - print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !') - - # Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math - # and then sort them, because... people. - if embiggen_tiles: - embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles)) - embiggen_tiles.sort() - - if strength >= 0.5: - print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.') - - # Prep img2img generator, since we wrap over it - gen_img2img = Img2Img(self.model,self.precision) - - # Open original init image (not a tensor) to manipulate - initsuperimage = Image.open(init_img) - - with Image.open(init_img) as img: - initsuperimage = img.convert('RGB') - - # Size of the target super init image in pixels - initsuperwidth, initsuperheight = initsuperimage.size - - # Increase by scaling factor if not already resized, using ESRGAN as able - if embiggen[0] != 1.0: - initsuperwidth = round(initsuperwidth*embiggen[0]) - initsuperheight = round(initsuperheight*embiggen[0]) - if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero - from ldm.invoke.restoration.realesrgan import ESRGAN - esrgan = ESRGAN() - print( - f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}') - if embiggen[0] > 2: - initsuperimage = esrgan.process( - initsuperimage, - embiggen[1], # upscale strength - self.seed, - 4, # upscale scale - ) - else: - initsuperimage = esrgan.process( - initsuperimage, - embiggen[1], # upscale strength - self.seed, - 2, # upscale scale - ) - # We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x - # but from personal experiance it doesn't greatly improve anything after 4x - # Resize to target scaling factor resolution - initsuperimage = initsuperimage.resize( - (initsuperwidth, initsuperheight), Image.Resampling.LANCZOS) - - # Use width and height as tile widths and height - # Determine buffer size in pixels - if embiggen[2] < 1: - if embiggen[2] < 0: - embiggen[2] = 0 - overlap_size_x = round(embiggen[2] * width) - overlap_size_y = round(embiggen[2] * height) - else: - overlap_size_x = round(embiggen[2]) - overlap_size_y = round(embiggen[2]) - - # With overall image width and height known, determine how many tiles we need - def ceildiv(a, b): - return -1 * (-a // b) - - # X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count) - # (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill - # (width - overlap_size_x) is how much new we can fill with a single tile - emb_tiles_x = 1 - emb_tiles_y = 1 - if (initsuperwidth - width) > 0: - emb_tiles_x = ceildiv(initsuperwidth - width, - width - overlap_size_x) + 1 - if (initsuperheight - height) > 0: - emb_tiles_y = ceildiv(initsuperheight - height, - height - overlap_size_y) + 1 - # Sanity - assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.' - - # Prep alpha layers -------------- - # https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil - # agradientL is Left-side transparent - agradientL = Image.linear_gradient('L').rotate( - 90).resize((overlap_size_x, height)) - # agradientT is Top-side transparent - agradientT = Image.linear_gradient('L').resize((width, overlap_size_y)) - # radial corner is the left-top corner, made full circle then cut to just the left-top quadrant - agradientC = Image.new('L', (256, 256)) - for y in range(256): - for x in range(256): - # Find distance to lower right corner (numpy takes arrays) - distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0] - # Clamp values to max 255 - if distanceToLR > 255: - distanceToLR = 255 - #Place the pixel as invert of distance - agradientC.putpixel((x, y), round(255 - distanceToLR)) - - # Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges - # Fits for a left-fading gradient on the bottom side and full opacity on the right side. - agradientAsymC = Image.new('L', (256, 256)) - for y in range(256): - for x in range(256): - value = round(max(0, x-(255-y)) * (255 / max(1,y))) - #Clamp values - value = max(0, value) - value = min(255, value) - agradientAsymC.putpixel((x, y), value) - - # Create alpha layers default fully white - alphaLayerL = Image.new("L", (width, height), 255) - alphaLayerT = Image.new("L", (width, height), 255) - alphaLayerLTC = Image.new("L", (width, height), 255) - # Paste gradients into alpha layers - alphaLayerL.paste(agradientL, (0, 0)) - alphaLayerT.paste(agradientT, (0, 0)) - alphaLayerLTC.paste(agradientL, (0, 0)) - alphaLayerLTC.paste(agradientT, (0, 0)) - alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) - # make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile - # to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space - alphaLayerTaC = alphaLayerT.copy() - alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) - alphaLayerLTaC = alphaLayerLTC.copy() - alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) - - if embiggen_tiles: - # Individual unconnected sides - alphaLayerR = Image.new("L", (width, height), 255) - alphaLayerR.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - alphaLayerB = Image.new("L", (width, height), 255) - alphaLayerB.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerTB = Image.new("L", (width, height), 255) - alphaLayerTB.paste(agradientT, (0, 0)) - alphaLayerTB.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerLR = Image.new("L", (width, height), 255) - alphaLayerLR.paste(agradientL, (0, 0)) - alphaLayerLR.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - - # Sides and corner Layers - alphaLayerRBC = Image.new("L", (width, height), 255) - alphaLayerRBC.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - alphaLayerRBC.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerRBC.paste(agradientC.rotate(180).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) - alphaLayerLBC = Image.new("L", (width, height), 255) - alphaLayerLBC.paste(agradientL, (0, 0)) - alphaLayerLBC.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerLBC.paste(agradientC.rotate(90).resize( - (overlap_size_x, overlap_size_y)), (0, height - overlap_size_y)) - alphaLayerRTC = Image.new("L", (width, height), 255) - alphaLayerRTC.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - alphaLayerRTC.paste(agradientT, (0, 0)) - alphaLayerRTC.paste(agradientC.rotate(270).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) - - # All but X layers - alphaLayerABT = Image.new("L", (width, height), 255) - alphaLayerABT.paste(alphaLayerLBC, (0, 0)) - alphaLayerABT.paste(agradientL.rotate( - 180), (width - overlap_size_x, 0)) - alphaLayerABT.paste(agradientC.rotate(180).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) - alphaLayerABL = Image.new("L", (width, height), 255) - alphaLayerABL.paste(alphaLayerRTC, (0, 0)) - alphaLayerABL.paste(agradientT.rotate( - 180), (0, height - overlap_size_y)) - alphaLayerABL.paste(agradientC.rotate(180).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) - alphaLayerABR = Image.new("L", (width, height), 255) - alphaLayerABR.paste(alphaLayerLBC, (0, 0)) - alphaLayerABR.paste(agradientT, (0, 0)) - alphaLayerABR.paste(agradientC.resize( - (overlap_size_x, overlap_size_y)), (0, 0)) - alphaLayerABB = Image.new("L", (width, height), 255) - alphaLayerABB.paste(alphaLayerRTC, (0, 0)) - alphaLayerABB.paste(agradientL, (0, 0)) - alphaLayerABB.paste(agradientC.resize( - (overlap_size_x, overlap_size_y)), (0, 0)) - - # All-around layer - alphaLayerAA = Image.new("L", (width, height), 255) - alphaLayerAA.paste(alphaLayerABT, (0, 0)) - alphaLayerAA.paste(agradientT, (0, 0)) - alphaLayerAA.paste(agradientC.resize( - (overlap_size_x, overlap_size_y)), (0, 0)) - alphaLayerAA.paste(agradientC.rotate(270).resize( - (overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) - - # Clean up temporary gradients - del agradientL - del agradientT - del agradientC - - def make_image(): - # Make main tiles ------------------------------------------------- - if embiggen_tiles: - print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...') - else: - print( - f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...') - - emb_tile_store = [] - # Although we could use the same seed for every tile for determinism, at higher strengths this may - # produce duplicated structures for each tile and make the tiling effect more obvious - # instead track and iterate a local seed we pass to Img2Img - seed = self.seed - seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy - - for tile in range(emb_tiles_x * emb_tiles_y): - # Don't iterate on first tile - if tile != 0: - if seed < seedintlimit: - seed += 1 - else: - seed = 0 - - # Determine if this is a re-run and replace - if embiggen_tiles and not tile in embiggen_tiles: - continue - # Get row and column entries - emb_row_i = tile // emb_tiles_x - emb_column_i = tile % emb_tiles_x - # Determine bounds to cut up the init image - # Determine upper-left point - if emb_column_i + 1 == emb_tiles_x: - left = initsuperwidth - width - else: - left = round(emb_column_i * (width - overlap_size_x)) - if emb_row_i + 1 == emb_tiles_y: - top = initsuperheight - height - else: - top = round(emb_row_i * (height - overlap_size_y)) - right = left + width - bottom = top + height - - # Cropped image of above dimension (does not modify the original) - newinitimage = initsuperimage.crop((left, top, right, bottom)) - # DEBUG: - # newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png' - # newinitimage.save(newinitimagepath) - - if embiggen_tiles: - print( - f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)') - else: - print( - f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles') - - # create a torch tensor from an Image - newinitimage = np.array( - newinitimage).astype(np.float32) / 255.0 - newinitimage = newinitimage[None].transpose(0, 3, 1, 2) - newinitimage = torch.from_numpy(newinitimage) - newinitimage = 2.0 * newinitimage - 1.0 - newinitimage = newinitimage.to(self.model.device) - clear_cuda_cache = kwargs['clear_cuda_cache'] if 'clear_cuda_cache' in kwargs else None - - tile_results = gen_img2img.generate( - prompt, - iterations = 1, - seed = seed, - sampler = sampler, - steps = steps, - cfg_scale = cfg_scale, - conditioning = conditioning, - ddim_eta = ddim_eta, - image_callback = None, # called only after the final image is generated - step_callback = step_callback, # called after each intermediate image is generated - width = width, - height = height, - init_image = newinitimage, # notice that init_image is different from init_img - mask_image = None, - strength = strength, - clear_cuda_cache = clear_cuda_cache - ) - - emb_tile_store.append(tile_results[0][0]) - # DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite - # emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png') - del newinitimage - - # Sanity check we have them all - if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)): - outputsuperimage = Image.new( - "RGBA", (initsuperwidth, initsuperheight)) - if embiggen_tiles: - outputsuperimage.alpha_composite( - initsuperimage.convert('RGBA'), (0, 0)) - for tile in range(emb_tiles_x * emb_tiles_y): - if embiggen_tiles: - if tile in embiggen_tiles: - intileimage = emb_tile_store.pop(0) - else: - continue - else: - intileimage = emb_tile_store[tile] - intileimage = intileimage.convert('RGBA') - # Get row and column entries - emb_row_i = tile // emb_tiles_x - emb_column_i = tile % emb_tiles_x - if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles: - left = 0 - top = 0 - else: - # Determine upper-left point - if emb_column_i + 1 == emb_tiles_x: - left = initsuperwidth - width - else: - left = round(emb_column_i * - (width - overlap_size_x)) - if emb_row_i + 1 == emb_tiles_y: - top = initsuperheight - height - else: - top = round(emb_row_i * (height - overlap_size_y)) - # Handle gradients for various conditions - # Handle emb_rerun case - if embiggen_tiles: - # top of image - if emb_row_i == 0: - if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerB) - # Otherwise do nothing on this tile - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only - intileimage.putalpha(alphaLayerR) - else: - intileimage.putalpha(alphaLayerRBC) - elif emb_column_i == emb_tiles_x - 1: - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerL) - else: - intileimage.putalpha(alphaLayerLBC) - else: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerL) - else: - intileimage.putalpha(alphaLayerLBC) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only - intileimage.putalpha(alphaLayerLR) - else: - intileimage.putalpha(alphaLayerABT) - # bottom of image - elif emb_row_i == emb_tiles_y - 1: - if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right - intileimage.putalpha(alphaLayerTaC) - else: - intileimage.putalpha(alphaLayerRTC) - elif emb_column_i == emb_tiles_x - 1: - # No tiles to look ahead to - intileimage.putalpha(alphaLayerLTC) - else: - if (tile+1) in embiggen_tiles: # Look-ahead right - intileimage.putalpha(alphaLayerLTaC) - else: - intileimage.putalpha(alphaLayerABB) - # vertical middle of image - else: - if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerTaC) - else: - intileimage.putalpha(alphaLayerTB) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only - intileimage.putalpha(alphaLayerRTC) - else: - intileimage.putalpha(alphaLayerABL) - elif emb_column_i == emb_tiles_x - 1: - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerLTC) - else: - intileimage.putalpha(alphaLayerABR) - else: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down - intileimage.putalpha(alphaLayerLTaC) - else: - intileimage.putalpha(alphaLayerABR) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only - intileimage.putalpha(alphaLayerABB) - else: - intileimage.putalpha(alphaLayerAA) - # Handle normal tiling case (much simpler - since we tile left to right, top to bottom) - else: - if emb_row_i == 0 and emb_column_i >= 1: - intileimage.putalpha(alphaLayerL) - elif emb_row_i >= 1 and emb_column_i == 0: - if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right - intileimage.putalpha(alphaLayerT) - else: - intileimage.putalpha(alphaLayerTaC) - else: - if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right - intileimage.putalpha(alphaLayerLTC) - else: - intileimage.putalpha(alphaLayerLTaC) - # Layer tile onto final image - outputsuperimage.alpha_composite(intileimage, (left, top)) - else: - print('Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.') - - # after internal loops and patching up return Embiggen image - return outputsuperimage - # end of function declaration - return make_image diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py deleted file mode 100644 index 67a588234b..0000000000 --- a/ldm/invoke/generator/img2img.py +++ /dev/null @@ -1,70 +0,0 @@ -''' -ldm.invoke.generator.img2img descends from ldm.invoke.generator -''' - -import torch -from diffusers import logging - -from ldm.invoke.generator.base import Generator -from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData -from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings - - -class Img2Img(Generator): - def __init__(self, model, precision): - super().__init__(model, precision) - self.init_latent = None # by get_noise() - - def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,init_image,strength,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0, - h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None, - **kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it. - """ - self.perlin = perlin - - # noinspection PyTypeChecker - pipeline: StableDiffusionGeneratorPipeline = self.model - pipeline.scheduler = sampler - - uc, c, extra_conditioning_info = conditioning - conditioning_data = ( - ConditioningData( - uc, c, cfg_scale, extra_conditioning_info, - postprocessing_settings=PostprocessingSettings( - threshold=threshold, - warmup=warmup, - h_symmetry_time_pct=h_symmetry_time_pct, - v_symmetry_time_pct=v_symmetry_time_pct - ) - ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) - - - def make_image(x_T): - # FIXME: use x_T for initial seeded noise - # We're not at the moment because the pipeline automatically resizes init_image if - # necessary, which the x_T input might not match. - logging.set_verbosity_error() # quench safety check warnings - pipeline_output = pipeline.img2img_from_embeddings( - init_image, strength, steps, conditioning_data, - noise_func=self.get_noise_like, - callback=step_callback - ) - if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: - attention_maps_callback(pipeline_output.attention_map_saver) - return pipeline.numpy_to_pil(pipeline_output.images)[0] - - return make_image - - def get_noise_like(self, like: torch.Tensor): - device = like.device - if device.type == 'mps': - x = torch.randn_like(like, device='cpu').to(device) - else: - x = torch.randn_like(like, device=device) - if self.perlin > 0.0: - shape = like.shape - x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) - return x diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py deleted file mode 100644 index 61c5b56582..0000000000 --- a/ldm/invoke/generator/inpaint.py +++ /dev/null @@ -1,324 +0,0 @@ -''' -ldm.invoke.generator.inpaint descends from ldm.invoke.generator -''' -from __future__ import annotations - -import math - -import PIL -import cv2 -import numpy as np -import torch -from PIL import Image, ImageFilter, ImageOps, ImageChops - -from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline, \ - ConditioningData -from ldm.invoke.generator.img2img import Img2Img -from ldm.invoke.patchmatch import PatchMatch -from ldm.util import debug_image - - -def infill_methods()->list[str]: - methods = [ - "tile", - "solid", - ] - if PatchMatch.patchmatch_available(): - methods.insert(0, 'patchmatch') - return methods - -class Inpaint(Img2Img): - def __init__(self, model, precision): - self.inpaint_height = 0 - self.inpaint_width = 0 - self.enable_image_debugging = False - self.init_latent = None - self.pil_image = None - self.pil_mask = None - self.mask_blur_radius = 0 - self.infill_method = None - super().__init__(model, precision) - - # Outpaint support code - def get_tile_images(self, image: np.ndarray, width=8, height=8): - _nrows, _ncols, depth = image.shape - _strides = image.strides - - nrows, _m = divmod(_nrows, height) - ncols, _n = divmod(_ncols, width) - if _m != 0 or _n != 0: - return None - - return np.lib.stride_tricks.as_strided( - np.ravel(image), - shape=(nrows, ncols, height, width, depth), - strides=(height * _strides[0], width * _strides[1], *_strides), - writeable=False - ) - - def infill_patchmatch(self, im: Image.Image) -> Image: - if im.mode != 'RGBA': - return im - - # Skip patchmatch if patchmatch isn't available - if not PatchMatch.patchmatch_available(): - return im - - # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) - im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3) - im_patched = Image.fromarray(im_patched_np, mode = 'RGB') - return im_patched - - def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: - # Only fill if there's an alpha layer - if im.mode != 'RGBA': - return im - - a = np.asarray(im, dtype=np.uint8) - - tile_size = (tile_size, tile_size) - - # Get the image as tiles of a specified size - tiles = self.get_tile_images(a,*tile_size).copy() - - # Get the mask as tiles - tiles_mask = tiles[:,:,:,:,3] - - # Find any mask tiles with any fully transparent pixels (we will be replacing these later) - tmask_shape = tiles_mask.shape - tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape)) - n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:]) - tiles_mask = (tiles_mask > 0) - tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1) - - # Get RGB tiles in single array and filter by the mask - tshape = tiles.shape - tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:])) - filtered_tiles = tiles_all[tiles_mask] - - if len(filtered_tiles) == 0: - return im - - # Find all invalid tiles and replace with a random valid tile - replace_count = (tiles_mask == False).sum() - rng = np.random.default_rng(seed = seed) - tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:] - - # Convert back to an image - tiles_all = tiles_all.reshape(tshape) - tiles_all = tiles_all.swapaxes(1,2) - st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4])) - si = Image.fromarray(st, mode='RGBA') - - return si - - - def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: - npimg = np.asarray(mask, dtype=np.uint8) - - # Detect any partially transparent regions - npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) - - # Detect hard edges - npedge = cv2.Canny(npimg, threshold1=100, threshold2=200) - - # Combine - npmask = npgradient + npedge - - # Expand - npmask = cv2.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) - - new_mask = Image.fromarray(npmask) - - if edge_blur > 0: - new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur)) - - return ImageOps.invert(new_mask) - - - def seam_paint(self, im: Image.Image, seam_size: int, seam_blur: int, prompt, sampler, steps, cfg_scale, ddim_eta, - conditioning, strength, noise, infill_method, step_callback) -> Image.Image: - hard_mask = self.pil_image.split()[-1].copy() - mask = self.mask_edge(hard_mask, seam_size, seam_blur) - - make_image = self.get_make_image( - prompt, - sampler, - steps, - cfg_scale, - ddim_eta, - conditioning, - init_image = im.copy().convert('RGBA'), - mask_image = mask, - strength = strength, - mask_blur_radius = 0, - seam_size = 0, - step_callback = step_callback, - inpaint_width = im.width, - inpaint_height = im.height, - infill_method = infill_method - ) - - seam_noise = self.get_noise(im.width, im.height) - - result = make_image(seam_noise) - - return result - - - @torch.no_grad() - def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning, - init_image: PIL.Image.Image | torch.FloatTensor, - mask_image: PIL.Image.Image | torch.FloatTensor, - strength: float, - mask_blur_radius: int = 8, - # Seam settings - when 0, doesn't fill seam - seam_size: int = 0, - seam_blur: int = 0, - seam_strength: float = 0.7, - seam_steps: int = 10, - tile_size: int = 32, - step_callback=None, - inpaint_replace=False, enable_image_debugging=False, - infill_method = None, - inpaint_width=None, - inpaint_height=None, - inpaint_fill:tuple(int)=(0x7F, 0x7F, 0x7F, 0xFF), - attention_maps_callback=None, - **kwargs): - """ - Returns a function returning an image derived from the prompt and - the initial image + mask. Return value depends on the seed at - the time you call it. kwargs are 'init_latent' and 'strength' - """ - - self.enable_image_debugging = enable_image_debugging - infill_method = infill_method or infill_methods()[0] - self.infill_method = infill_method - - self.inpaint_width = inpaint_width - self.inpaint_height = inpaint_height - - if isinstance(init_image, PIL.Image.Image): - self.pil_image = init_image.copy() - - # Do infill - if infill_method == 'patchmatch' and PatchMatch.patchmatch_available(): - init_filled = self.infill_patchmatch(self.pil_image.copy()) - elif infill_method == 'tile': - init_filled = self.tile_fill_missing( - self.pil_image.copy(), - seed = self.seed, - tile_size = tile_size - ) - elif infill_method == 'solid': - solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill) - init_filled = PIL.Image.alpha_composite(solid_bg, init_image) - else: - raise ValueError(f"Non-supported infill type {infill_method}", infill_method) - init_filled.paste(init_image, (0,0), init_image.split()[-1]) - - # Resize if requested for inpainting - if inpaint_width and inpaint_height: - init_filled = init_filled.resize((inpaint_width, inpaint_height)) - - debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging) - - # Create init tensor - init_image = image_resized_to_grid_as_tensor(init_filled.convert('RGB')) - - if isinstance(mask_image, PIL.Image.Image): - self.pil_mask = mask_image.copy() - debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging) - - init_alpha = self.pil_image.getchannel("A") - if mask_image.mode != "L": - # FIXME: why do we get passed an RGB image here? We can only use single-channel. - mask_image = mask_image.convert("L") - mask_image = ImageChops.multiply(mask_image, init_alpha) - self.pil_mask = mask_image - - # Resize if requested for inpainting - if inpaint_width and inpaint_height: - mask_image = mask_image.resize((inpaint_width, inpaint_height)) - - debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging) - mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) - else: - mask: torch.FloatTensor = mask_image - - self.mask_blur_radius = mask_blur_radius - - # noinspection PyTypeChecker - pipeline: StableDiffusionGeneratorPipeline = self.model - pipeline.scheduler = sampler - - # todo: support cross-attention control - uc, c, _ = conditioning - conditioning_data = (ConditioningData(uc, c, cfg_scale) - .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) - - - def make_image(x_T): - pipeline_output = pipeline.inpaint_from_embeddings( - init_image=init_image, - mask=1 - mask, # expects white means "paint here." - strength=strength, - num_inference_steps=steps, - conditioning_data=conditioning_data, - noise_func=self.get_noise_like, - callback=step_callback, - ) - - if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: - attention_maps_callback(pipeline_output.attention_map_saver) - - result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0]) - - # Seam paint if this is our first pass (seam_size set to 0 during seam painting) - if seam_size > 0: - old_image = self.pil_image or init_image - old_mask = self.pil_mask or mask_image - - result = self.seam_paint(result, seam_size, seam_blur, prompt, sampler, seam_steps, cfg_scale, ddim_eta, - conditioning, seam_strength, x_T, infill_method, step_callback) - - # Restore original settings - self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning, - old_image, - old_mask, - strength, - mask_blur_radius, seam_size, seam_blur, seam_strength, - seam_steps, tile_size, step_callback, - inpaint_replace, enable_image_debugging, - inpaint_width = inpaint_width, - inpaint_height = inpaint_height, - infill_method = infill_method, - **kwargs) - - return result - - return make_image - - - def sample_to_image(self, samples)->Image.Image: - gen_result = super().sample_to_image(samples).convert('RGB') - return self.postprocess_size_and_mask(gen_result) - - - def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image: - debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging) - - # Resize if necessary - if self.inpaint_width and self.inpaint_height: - gen_result = gen_result.resize(self.pil_image.size) - - if self.pil_image is None or self.pil_mask is None: - return gen_result - - corrected_result = self.repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) - debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging) - - return corrected_result diff --git a/ldm/invoke/generator/omnibus.py b/ldm/invoke/generator/omnibus.py deleted file mode 100644 index a6fae3e567..0000000000 --- a/ldm/invoke/generator/omnibus.py +++ /dev/null @@ -1,173 +0,0 @@ -"""omnibus module to be used with the runwayml 9-channel custom inpainting model""" - -import torch -from PIL import Image, ImageOps -from einops import repeat - -from ldm.invoke.devices import choose_autocast -from ldm.invoke.generator.img2img import Img2Img -from ldm.invoke.generator.txt2img import Txt2Img - - -class Omnibus(Img2Img,Txt2Img): - def __init__(self, model, precision): - super().__init__(model, precision) - self.pil_mask = None - self.pil_image = None - - def get_make_image( - self, - prompt, - sampler, - steps, - cfg_scale, - ddim_eta, - conditioning, - width, - height, - init_image = None, - mask_image = None, - strength = None, - step_callback=None, - threshold=0.0, - perlin=0.0, - mask_blur_radius: int = 8, - **kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it. - """ - self.perlin = perlin - num_samples = 1 - - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - if isinstance(init_image, Image.Image): - self.pil_image = init_image - if init_image.mode != 'RGB': - init_image = init_image.convert('RGB') - init_image = self._image_to_tensor(init_image) - - if isinstance(mask_image, Image.Image): - self.pil_mask = mask_image - - mask_image = ImageChops.multiply(mask_image.convert('L'), self.pil_image.split()[-1]) - mask_image = self._image_to_tensor(ImageOps.invert(mask_image), normalize=False) - - self.mask_blur_radius = mask_blur_radius - - if init_image is not None and mask_image is not None: # inpainting - masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero - - elif init_image is not None: # img2img - scope = choose_autocast(self.precision) - - with scope(self.model.device.type): - self.init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - - # create a completely black mask (1s) - mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device) - # and the masked image is just a copy of the original - masked_image = init_image - - else: # txt2img - init_image = torch.zeros(1, 3, height, width, device=self.model.device) - mask_image = torch.ones(1, 1, height, width, device=self.model.device) - masked_image = init_image - - self.init_latent = init_image - height = init_image.shape[2] - width = init_image.shape[3] - model = self.model - - def make_image(x_T): - with torch.no_grad(): - scope = choose_autocast(self.precision) - with scope(self.model.device.type): - - batch = self.make_batch_sd( - init_image, - mask_image, - masked_image, - prompt=prompt, - device=model.device, - num_samples=num_samples, - ) - - c = model.cond_stage_model.encode(batch["txt"]) - c_cat = list() - for ck in model.concat_keys: - cc = batch[ck].float() - if ck != model.masked_image_key: - bchw = [num_samples, 4, height//8, width//8] - cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) - else: - cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) - c_cat.append(cc) - c_cat = torch.cat(c_cat, dim=1) - - # cond - cond={"c_concat": [c_cat], "c_crossattn": [c]} - - # uncond cond - uc_cross = model.get_unconditional_conditioning(num_samples, "") - uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} - shape = [model.channels, height//8, width//8] - - samples, _ = sampler.sample( - batch_size = 1, - S = steps, - x_T = x_T, - conditioning = cond, - shape = shape, - verbose = False, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc_full, - eta = 1.0, - img_callback = step_callback, - threshold = threshold, - ) - if self.free_gpu_mem: - self.model.model.to("cpu") - return self.sample_to_image(samples) - - return make_image - - def make_batch_sd( - self, - image, - mask, - masked_image, - prompt, - device, - num_samples=1): - batch = { - "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples), - "txt": num_samples * [prompt], - "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), - "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), - } - return batch - - def get_noise(self, width:int, height:int): - if self.init_latent is not None: - height = self.init_latent.shape[2] - width = self.init_latent.shape[3] - return Txt2Img.get_noise(self,width,height) - - - def sample_to_image(self, samples)->Image.Image: - gen_result = super().sample_to_image(samples).convert('RGB') - - if self.pil_image is None or self.pil_mask is None: - return gen_result - if self.pil_image.size != self.pil_mask.size: - return gen_result - - corrected_result = super(Img2Img, self).repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) - - return corrected_result diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py deleted file mode 100644 index 9903de1309..0000000000 --- a/ldm/invoke/generator/txt2img.py +++ /dev/null @@ -1,61 +0,0 @@ -''' -ldm.invoke.generator.txt2img inherits from ldm.invoke.generator -''' -import PIL.Image -import torch - -from .base import Generator -from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData -from ...models.diffusion.shared_invokeai_diffusion import PostprocessingSettings - - -class Txt2Img(Generator): - def __init__(self, model, precision): - super().__init__(model, precision) - - @torch.no_grad() - def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,width,height,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0, - h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None, - **kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it - kwargs are 'width' and 'height' - """ - self.perlin = perlin - - # noinspection PyTypeChecker - pipeline: StableDiffusionGeneratorPipeline = self.model - pipeline.scheduler = sampler - - uc, c, extra_conditioning_info = conditioning - conditioning_data = ( - ConditioningData( - uc, c, cfg_scale, extra_conditioning_info, - postprocessing_settings=PostprocessingSettings( - threshold=threshold, - warmup=warmup, - h_symmetry_time_pct=h_symmetry_time_pct, - v_symmetry_time_pct=v_symmetry_time_pct - ) - ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) - - def make_image(x_T) -> PIL.Image.Image: - pipeline_output = pipeline.image_from_embeddings( - latents=torch.zeros_like(x_T,dtype=self.torch_dtype()), - noise=x_T, - num_inference_steps=steps, - conditioning_data=conditioning_data, - callback=step_callback, - ) - - if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: - attention_maps_callback(pipeline_output.attention_map_saver) - - return pipeline.numpy_to_pil(pipeline_output.images)[0] - - return make_image - - - diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py deleted file mode 100644 index a39dfccc3a..0000000000 --- a/ldm/invoke/generator/txt2img2img.py +++ /dev/null @@ -1,163 +0,0 @@ -''' -ldm.invoke.generator.txt2img inherits from ldm.invoke.generator -''' - -import math -from typing import Callable, Optional - -import torch -from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error - -from ldm.invoke.generator.base import Generator -from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \ - ConditioningData -from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings - - -class Txt2Img2Img(Generator): - def __init__(self, model, precision): - super().__init__(model, precision) - self.init_latent = None # for get_noise() - - def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta, - conditioning, width:int, height:int, strength:float, - step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0, - h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it - kwargs are 'width' and 'height' - """ - self.perlin = perlin - - # noinspection PyTypeChecker - pipeline: StableDiffusionGeneratorPipeline = self.model - pipeline.scheduler = sampler - - uc, c, extra_conditioning_info = conditioning - conditioning_data = ( - ConditioningData( - uc, c, cfg_scale, extra_conditioning_info, - postprocessing_settings = PostprocessingSettings( - threshold=threshold, - warmup=0.2, - h_symmetry_time_pct=h_symmetry_time_pct, - v_symmetry_time_pct=v_symmetry_time_pct - ) - ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) - - def make_image(x_T): - - first_pass_latent_output, _ = pipeline.latents_from_embeddings( - latents=torch.zeros_like(x_T), - num_inference_steps=steps, - conditioning_data=conditioning_data, - noise=x_T, - callback=step_callback, - ) - - # Get our initial generation width and height directly from the latent output so - # the message below is accurate. - init_width = first_pass_latent_output.size()[3] * self.downsampling_factor - init_height = first_pass_latent_output.size()[2] * self.downsampling_factor - print( - f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling" - ) - - # resizing - resized_latents = torch.nn.functional.interpolate( - first_pass_latent_output, - size=(height // self.downsampling_factor, width // self.downsampling_factor), - mode="bilinear" - ) - - # Free up memory from the last generation. - clear_cuda_cache = kwargs['clear_cuda_cache'] or None - if clear_cuda_cache is not None: - clear_cuda_cache() - - second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True) - - # Clear symmetry for the second pass - from dataclasses import replace - new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None) - new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None) - new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings) - - verbosity = get_verbosity() - set_verbosity_error() - pipeline_output = pipeline.img2img_from_latents_and_embeddings( - resized_latents, - num_inference_steps=steps, - conditioning_data=new_conditioning_data, - strength=strength, - noise=second_pass_noise, - callback=step_callback) - set_verbosity(verbosity) - - if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: - attention_maps_callback(pipeline_output.attention_map_saver) - - return pipeline.numpy_to_pil(pipeline_output.images)[0] - - - # FIXME: do we really need something entirely different for the inpainting model? - - # in the case of the inpainting model being loaded, the trick of - # providing an interpolated latent doesn't work, so we transiently - # create a 512x512 PIL image, upscale it, and run the inpainting - # over it in img2img mode. Because the inpaing model is so conservative - # it doesn't change the image (much) - - return make_image - - def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False): - device = like.device - if device.type == 'mps': - x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device) - else: - x = torch.randn_like(like, device=device, dtype=self.torch_dtype()) - if self.perlin > 0.0 and override_perlin == False: - shape = like.shape - x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) - return x - - # returns a tensor filled with random numbers from a normal distribution - def get_noise(self,width,height,scale = True): - # print(f"Get noise: {width}x{height}") - if scale: - # Scale the input width and height for the initial generation - # Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144), - # while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256) - - aspect = width / height - dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor - min_dimension = math.floor(dimension * 0.5) - model_area = dimension * dimension # hardcoded for now since all models are trained on square images - - if aspect > 1.0: - init_height = max(min_dimension, math.sqrt(model_area / aspect)) - init_width = init_height * aspect - else: - init_width = max(min_dimension, math.sqrt(model_area * aspect)) - init_height = init_width / aspect - - scaled_width, scaled_height = trim_to_multiple_of(math.floor(init_width), math.floor(init_height)) - - else: - scaled_width = width - scaled_height = height - - device = self.model.device - channels = self.latent_channels - if channels == 9: - channels = 4 # we don't really want noise for all the mask channels - shape = (1, channels, - scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor) - if self.use_mps_noise or device.type == 'mps': - tensor = torch.empty(size=shape, device='cpu') - tensor = self.get_noise_like(like=tensor).to(device) - else: - tensor = torch.empty(size=shape, device=device) - tensor = self.get_noise_like(like=tensor) - return tensor diff --git a/ldm/invoke/globals.py b/ldm/invoke/globals.py index e47b5c059e..c6ee0bbc54 100644 --- a/ldm/invoke/globals.py +++ b/ldm/invoke/globals.py @@ -61,7 +61,7 @@ Globals.sequential_guidance = False Globals.full_precision = False # whether we should convert ckpt files into diffusers models on the fly -Globals.ckpt_convert = False +Globals.ckpt_convert = True # logging tokenization everywhere Globals.log_tokenization = False diff --git a/ldm/invoke/merge_diffusers.py b/ldm/invoke/merge_diffusers.py index 3cb3613ee3..5c100fcf8b 100644 --- a/ldm/invoke/merge_diffusers.py +++ b/ldm/invoke/merge_diffusers.py @@ -23,7 +23,7 @@ from omegaconf import OmegaConf from ldm.invoke.config.widgets import FloatTitleSlider from ldm.invoke.globals import (Globals, global_cache_dir, global_config_file, global_models_dir, global_set_root) -from ldm.invoke.model_manager import ModelManager +from invokeai.models import ModelManager DEST_MERGED_MODEL_DIR = "merged_models" diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py deleted file mode 100644 index 694d65c1a7..0000000000 --- a/ldm/invoke/model_manager.py +++ /dev/null @@ -1,1372 +0,0 @@ -""" -Manage a cache of Stable Diffusion model files for fast switching. -They are moved between GPU and CPU as necessary. If CPU memory falls -below a preset minimum, the least recently used model will be -cleared and loaded from disk when next needed. -""" -from __future__ import annotations - -import contextlib -import gc -import hashlib -import io -import os -import re -import sys -import textwrap -import time -import warnings -from enum import Enum -from pathlib import Path -from shutil import move, rmtree -from typing import Any, Optional, Union - -import safetensors -import safetensors.torch -import torch -import transformers -from diffusers import AutoencoderKL -from diffusers import logging as dlogging -from huggingface_hub import scan_cache_dir -from omegaconf import OmegaConf -from omegaconf.dictconfig import DictConfig -from picklescan.scanner import scan_file_path - -from ldm.invoke.devices import CPU_DEVICE -from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline -from ldm.invoke.globals import Globals, global_cache_dir -from ldm.util import ( - ask_user, - download_with_resume, - instantiate_from_config, - url_attachment_name, -) - - -class SDLegacyType(Enum): - V1 = 1 - V1_INPAINT = 2 - V2 = 3 - UNKNOWN = 99 - - -DEFAULT_MAX_MODELS = 2 -VAE_TO_REPO_ID = { # hack, see note in convert_and_import() - "vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse", -} - - -class ModelManager(object): - def __init__( - self, - config: OmegaConf, - device_type: torch.device = CPU_DEVICE, - precision: str = "float16", - max_loaded_models=DEFAULT_MAX_MODELS, - sequential_offload=False, - ): - """ - Initialize with the path to the models.yaml config file, - the torch device type, and precision. The optional - min_avail_mem argument specifies how much unused system - (CPU) memory to preserve. The cache of models in RAM will - grow until this value is approached. Default is 2G. - """ - # prevent nasty-looking CLIP log message - transformers.logging.set_verbosity_error() - self.config = config - self.precision = precision - self.device = torch.device(device_type) - self.max_loaded_models = max_loaded_models - self.models = {} - self.stack = [] # this is an LRU FIFO - self.current_model = None - self.sequential_offload = sequential_offload - - def valid_model(self, model_name: str) -> bool: - """ - Given a model name, returns True if it is a valid - identifier. - """ - return model_name in self.config - - def get_model(self, model_name: str): - """ - Given a model named identified in models.yaml, return - the model object. If in RAM will load into GPU VRAM. - If on disk, will load from there. - """ - if not self.valid_model(model_name): - print( - f'** "{model_name}" is not a known model name. Please check your models.yaml file' - ) - return self.current_model - - if self.current_model != model_name: - if model_name not in self.models: # make room for a new one - self._make_cache_room() - self.offload_model(self.current_model) - - if model_name in self.models: - requested_model = self.models[model_name]["model"] - print(f">> Retrieving model {model_name} from system RAM cache") - self.models[model_name]["model"] = self._model_from_cpu(requested_model) - width = self.models[model_name]["width"] - height = self.models[model_name]["height"] - hash = self.models[model_name]["hash"] - - else: # we're about to load a new model, so potentially offload the least recently used one - requested_model, width, height, hash = self._load_model(model_name) - self.models[model_name] = { - "model": requested_model, - "width": width, - "height": height, - "hash": hash, - } - - self.current_model = model_name - self._push_newest_model(model_name) - return { - "model": requested_model, - "width": width, - "height": height, - "hash": hash, - } - - def default_model(self) -> str | None: - """ - Returns the name of the default model, or None - if none is defined. - """ - for model_name in self.config: - if self.config[model_name].get("default"): - return model_name - return list(self.config.keys())[0] # first one - - def set_default_model(self, model_name: str) -> None: - """ - Set the default model. The change will not take - effect until you call model_manager.commit() - """ - assert model_name in self.model_names(), f"unknown model '{model_name}'" - - config = self.config - for model in config: - config[model].pop("default", None) - config[model_name]["default"] = True - - def model_info(self, model_name: str) -> dict: - """ - Given a model name returns the OmegaConf (dict-like) object describing it. - """ - if model_name not in self.config: - return None - return self.config[model_name] - - def model_names(self) -> list[str]: - """ - Return a list consisting of all the names of models defined in models.yaml - """ - return list(self.config.keys()) - - def is_legacy(self, model_name: str) -> bool: - """ - Return true if this is a legacy (.ckpt) model - """ - # if we are converting legacy files automatically, then - # there are no legacy ckpts! - if Globals.ckpt_convert: - return False - info = self.model_info(model_name) - if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")): - return True - return False - - def list_models(self) -> dict: - """ - Return a dict of models in the format: - { model_name1: {'status': ('active'|'cached'|'not loaded'), - 'description': description, - 'format': ('ckpt'|'diffusers'|'vae'), - }, - model_name2: { etc } - Please use model_manager.models() to get all the model names, - model_manager.model_info('model-name') to get the stanza for the model - named 'model-name', and model_manager.config to get the full OmegaConf - object derived from models.yaml - """ - models = {} - for name in sorted(self.config, key=str.casefold): - stanza = self.config[name] - - # don't include VAEs in listing (legacy style) - if "config" in stanza and "/VAE/" in stanza["config"]: - continue - - models[name] = dict() - format = stanza.get("format", "ckpt") # Determine Format - - # Common Attribs - description = stanza.get("description", None) - if self.current_model == name: - status = "active" - elif name in self.models: - status = "cached" - else: - status = "not loaded" - models[name].update( - description=description, - format=format, - status=status, - ) - - # Checkpoint Config Parse - if format == "ckpt": - models[name].update( - config=str(stanza.get("config", None)), - weights=str(stanza.get("weights", None)), - vae=str(stanza.get("vae", None)), - width=str(stanza.get("width", 512)), - height=str(stanza.get("height", 512)), - ) - - # Diffusers Config Parse - if vae := stanza.get("vae", None): - if isinstance(vae, DictConfig): - vae = dict( - repo_id=str(vae.get("repo_id", None)), - path=str(vae.get("path", None)), - subfolder=str(vae.get("subfolder", None)), - ) - - if format == "diffusers": - models[name].update( - vae=vae, - repo_id=str(stanza.get("repo_id", None)), - path=str(stanza.get("path", None)), - ) - - return models - - def print_models(self) -> None: - """ - Print a table of models, their descriptions, and load status - """ - models = self.list_models() - for name in models: - if models[name]["format"] == "vae": - continue - line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["format"]:10s} {models[name]["description"]}' - if models[name]["status"] == "active": - line = f"\033[1m{line}\033[0m" - print(line) - - def del_model(self, model_name: str, delete_files: bool = False) -> None: - """ - Delete the named model. - """ - omega = self.config - if model_name not in omega: - print(f"** Unknown model {model_name}") - return - # save these for use in deletion later - conf = omega[model_name] - repo_id = conf.get("repo_id", None) - path = self._abs_path(conf.get("path", None)) - weights = self._abs_path(conf.get("weights", None)) - - del omega[model_name] - if model_name in self.stack: - self.stack.remove(model_name) - if delete_files: - if weights: - print(f"** deleting file {weights}") - Path(weights).unlink(missing_ok=True) - elif path: - print(f"** deleting directory {path}") - rmtree(path, ignore_errors=True) - elif repo_id: - print(f"** deleting the cached model directory for {repo_id}") - self._delete_model_from_cache(repo_id) - - def add_model( - self, model_name: str, model_attributes: dict, clobber: bool = False - ) -> None: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory and the - method will return True. Will fail with an assertion error if provided - attributes are incorrect or the model name is missing. - """ - omega = self.config - assert "format" in model_attributes, 'missing required field "format"' - if model_attributes["format"] == "diffusers": - assert ( - "description" in model_attributes - ), 'required field "description" is missing' - assert ( - "path" in model_attributes or "repo_id" in model_attributes - ), 'model must have either the "path" or "repo_id" fields defined' - else: - for field in ("description", "weights", "height", "width", "config"): - assert field in model_attributes, f"required field {field} is missing" - - assert ( - clobber or model_name not in omega - ), f'attempt to overwrite existing model definition "{model_name}"' - - omega[model_name] = model_attributes - - if "weights" in omega[model_name]: - omega[model_name]["weights"].replace("\\", "/") - - if clobber: - self._invalidate_cached_model(model_name) - - def _load_model(self, model_name: str): - """Load and initialize the model from configuration variables passed at object creation time""" - if model_name not in self.config: - print( - f'"{model_name}" is not a known model name. Please check your models.yaml file' - ) - return - - mconfig = self.config[model_name] - - # for usage statistics - if self._has_cuda(): - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - - tic = time.time() - - # this does the work - model_format = mconfig.get("format", "ckpt") - if model_format == "ckpt": - weights = mconfig.weights - print(f">> Loading {model_name} from {weights}") - model, width, height, model_hash = self._load_ckpt_model( - model_name, mconfig - ) - elif model_format == "diffusers": - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - model, width, height, model_hash = self._load_diffusers_model(mconfig) - else: - raise NotImplementedError( - f"Unknown model format {model_name}: {model_format}" - ) - - # usage statistics - toc = time.time() - print(">> Model loaded in", "%4.2fs" % (toc - tic)) - if self._has_cuda(): - print( - ">> Max VRAM used to load the model:", - "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9), - "\n>> Current VRAM usage:" - "%4.2fG" % (torch.cuda.memory_allocated() / 1e9), - ) - return model, width, height, model_hash - - def _load_ckpt_model(self, model_name, mconfig): - config = mconfig.config - weights = mconfig.weights - vae = mconfig.get("vae") - width = mconfig.width - height = mconfig.height - - if not os.path.isabs(config): - config = os.path.join(Globals.root, config) - if not os.path.isabs(weights): - weights = os.path.normpath(os.path.join(Globals.root, weights)) - - # if converting automatically to diffusers, then we do the conversion and return - # a diffusers pipeline - if Globals.ckpt_convert: - print( - f">> Converting legacy checkpoint {model_name} into a diffusers model..." - ) - from ldm.invoke.ckpt_to_diffuser import ( - load_pipeline_from_original_stable_diffusion_ckpt, - ) - - self.offload_model(self.current_model) - if vae_config := self._choose_diffusers_vae(model_name): - vae = self._load_vae(vae_config) - if self._has_cuda(): - torch.cuda.empty_cache() - pipeline = load_pipeline_from_original_stable_diffusion_ckpt( - checkpoint_path=weights, - original_config_file=config, - vae=vae, - return_generator_pipeline=True, - precision=torch.float16 - if self.precision == "float16" - else torch.float32, - ) - if self.sequential_offload: - pipeline.enable_offload_submodels(self.device) - else: - pipeline.to(self.device) - - return ( - pipeline, - width, - height, - "NOHASH", - ) - - # scan model - self.scan_model(model_name, weights) - - print(f">> Loading {model_name} from {weights}") - - # for usage statistics - if self._has_cuda(): - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - - # this does the work - if not os.path.isabs(config): - config = os.path.join(Globals.root, config) - omega_config = OmegaConf.load(config) - with open(weights, "rb") as f: - weight_bytes = f.read() - model_hash = self._cached_sha256(weights, weight_bytes) - sd = None - if weights.endswith(".safetensors"): - sd = safetensors.torch.load(weight_bytes) - else: - sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu") - del weight_bytes - # merged models from auto11 merge board are flat for some reason - if "state_dict" in sd: - sd = sd["state_dict"] - - print(" | Forcing garbage collection prior to loading new model") - gc.collect() - model = instantiate_from_config(omega_config.model) - model.load_state_dict(sd, strict=False) - - if self.precision == "float16": - print(" | Using faster float16 precision") - model = model.to(torch.float16) - else: - print(" | Using more accurate float32 precision") - - # look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py - if vae: - if not os.path.isabs(vae): - vae = os.path.normpath(os.path.join(Globals.root, vae)) - if os.path.exists(vae): - print(f" | Loading VAE weights from: {vae}") - vae_ckpt = None - vae_dict = None - if vae.endswith(".safetensors"): - vae_ckpt = safetensors.torch.load_file(vae) - vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"} - else: - vae_ckpt = torch.load(vae, map_location="cpu") - vae_dict = { - k: v - for k, v in vae_ckpt["state_dict"].items() - if k[0:4] != "loss" - } - model.first_stage_model.load_state_dict(vae_dict, strict=False) - else: - print(f" | VAE file {vae} not found. Skipping.") - - model.to(self.device) - # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here - model.cond_stage_model.device = self.device - - model.eval() - - for module in model.modules(): - if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): - module._orig_padding_mode = module.padding_mode - return model, width, height, model_hash - - def _load_diffusers_model(self, mconfig): - name_or_path = self.model_name_or_path(mconfig) - using_fp16 = self.precision == "float16" - - print(f">> Loading diffusers model from {name_or_path}") - if using_fp16: - print(" | Using faster float16 precision") - else: - print(" | Using more accurate float32 precision") - - # TODO: scan weights maybe? - pipeline_args: dict[str, Any] = dict( - safety_checker=None, local_files_only=not Globals.internet_available - ) - if "vae" in mconfig and mconfig["vae"] is not None: - if vae := self._load_vae(mconfig["vae"]): - pipeline_args.update(vae=vae) - if not isinstance(name_or_path, Path): - pipeline_args.update(cache_dir=global_cache_dir("diffusers")) - if using_fp16: - pipeline_args.update(torch_dtype=torch.float16) - fp_args_list = [{"revision": "fp16"}, {}] - else: - fp_args_list = [{}] - - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - - pipeline = None - for fp_args in fp_args_list: - try: - pipeline = StableDiffusionGeneratorPipeline.from_pretrained( - name_or_path, - **pipeline_args, - **fp_args, - ) - except OSError as e: - if str(e).startswith("fp16 is not a valid"): - pass - else: - print( - f"** An unexpected error occurred while downloading the model: {e})" - ) - if pipeline: - break - - dlogging.set_verbosity(verbosity) - assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded') - - if self.sequential_offload: - pipeline.enable_offload_submodels(self.device) - else: - pipeline.to(self.device) - - model_hash = self._diffuser_sha256(name_or_path) - - # square images??? - width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor - height = width - - print(f" | Default image dimensions = {width} x {height}") - - return pipeline, width, height, model_hash - - def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path: - if isinstance(model_name, DictConfig) or isinstance(model_name, dict): - mconfig = model_name - elif model_name in self.config: - mconfig = self.config[model_name] - else: - raise ValueError( - f'"{model_name}" is not a known model name. Please check your models.yaml file' - ) - - if "path" in mconfig and mconfig["path"] is not None: - path = Path(mconfig["path"]) - if not path.is_absolute(): - path = Path(Globals.root, path).resolve() - return path - elif "repo_id" in mconfig: - return mconfig["repo_id"] - else: - raise ValueError("Model config must specify either repo_id or path.") - - def offload_model(self, model_name: str) -> None: - """ - Offload the indicated model to CPU. Will call - _make_cache_room() to free space if needed. - """ - if model_name not in self.models: - return - - print(f">> Offloading {model_name} to CPU") - model = self.models[model_name]["model"] - self.models[model_name]["model"] = self._model_to_cpu(model) - - gc.collect() - if self._has_cuda(): - torch.cuda.empty_cache() - - def scan_model(self, model_name, checkpoint): - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - print(f">> Scanning Model: {model_name}") - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - if scan_result.infected_files == 1: - print(f"\n### Issues Found In Model: {scan_result.issues_count}") - print( - "### WARNING: The model you are trying to load seems to be infected." - ) - print("### For your safety, InvokeAI will not load this model.") - print("### Please use checkpoints from trusted sources.") - print("### Exiting InvokeAI") - sys.exit() - else: - print( - "\n### WARNING: InvokeAI was unable to scan the model you are using." - ) - model_safe_check_fail = ask_user( - "Do you want to to continue loading the model?", ["y", "n"] - ) - if model_safe_check_fail.lower() != "y": - print("### Exiting InvokeAI") - sys.exit() - else: - print(">> Model scanned ok") - - def import_diffuser_model( - self, - repo_or_path: Union[str, Path], - model_name: str = None, - model_description: str = None, - vae: dict = None, - commit_to_conf: Path = None, - ) -> bool: - """ - Attempts to install the indicated diffuser model and returns True if successful. - - "repo_or_path" can be either a repo-id or a path-like object corresponding to the - top of a downloaded diffusers directory. - - You can optionally provide a model name and/or description. If not provided, - then these will be derived from the repo name. If you provide a commit_to_conf - path to the configuration file, then the new entry will be committed to the - models.yaml file. - """ - model_name = model_name or Path(repo_or_path).stem - model_description = model_description or f"Imported diffusers model {model_name}" - new_config = dict( - description=model_description, - vae=vae, - format="diffusers", - ) - if isinstance(repo_or_path, Path) and repo_or_path.exists(): - new_config.update(path=str(repo_or_path)) - else: - new_config.update(repo_id=repo_or_path) - - self.add_model(model_name, new_config, True) - if commit_to_conf: - self.commit(commit_to_conf) - return model_name - - def import_ckpt_model( - self, - weights: Union[str, Path], - config: Union[str, Path] = "configs/stable-diffusion/v1-inference.yaml", - vae: Union[str, Path] = None, - model_name: str = None, - model_description: str = None, - commit_to_conf: Path = None, - ) -> str: - """ - Attempts to install the indicated ckpt file and returns True if successful. - - "weights" can be either a path-like object corresponding to a local .ckpt file - or a http/https URL pointing to a remote model. - - "vae" is a Path or str object pointing to a ckpt or safetensors file to be used - as the VAE for this model. - - "config" is the model config file to use with this ckpt file. It defaults to - v1-inference.yaml. If a URL is provided, the config will be downloaded. - - You can optionally provide a model name and/or description. If not provided, - then these will be derived from the weight file name. If you provide a commit_to_conf - path to the configuration file, then the new entry will be committed to the - models.yaml file. - - Return value is the name of the imported file, or None if an error occurred. - """ - if str(weights).startswith(("http:", "https:")): - model_name = model_name or url_attachment_name(weights) - - weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1") - config_path = self._resolve_path(config, "configs/stable-diffusion") - - if weights_path is None or not weights_path.exists(): - return - if config_path is None or not config_path.exists(): - return - - model_name = ( - model_name or Path(weights).stem - ) # note this gives ugly pathnames if used on a URL without a Content-Disposition header - model_description = ( - model_description or f"Imported stable diffusion weights file {model_name}" - ) - new_config = dict( - weights=str(weights_path), - config=str(config_path), - description=model_description, - format="ckpt", - width=512, - height=512, - ) - if vae: - new_config["vae"] = vae - self.add_model(model_name, new_config, True) - if commit_to_conf: - self.commit(commit_to_conf) - return model_name - - @classmethod - def probe_model_type(self, checkpoint: dict) -> SDLegacyType: - """ - Given a pickle or safetensors model object, probes contents - of the object and returns an SDLegacyType indicating its - format. Valid return values include: - SDLegacyType.V1 - SDLegacyType.V1_INPAINT - SDLegacyType.V2 - SDLegacyType.UNKNOWN - """ - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: - return SDLegacyType.V2 - - try: - state_dict = checkpoint.get("state_dict") or checkpoint - in_channels = state_dict[ - "model.diffusion_model.input_blocks.0.0.weight" - ].shape[1] - if in_channels == 9: - return SDLegacyType.V1_INPAINT - elif in_channels == 4: - return SDLegacyType.V1 - else: - return SDLegacyType.UNKNOWN - except KeyError: - return SDLegacyType.UNKNOWN - - def heuristic_import( - self, - path_url_or_repo: str, - convert: bool = False, - model_name: str = None, - description: str = None, - commit_to_conf: Path = None, - ) -> str: - """ - Accept a string which could be: - - a HF diffusers repo_id - - a URL pointing to a legacy .ckpt or .safetensors file - - a local path pointing to a legacy .ckpt or .safetensors file - - a local directory containing .ckpt and .safetensors files - - a local directory containing a diffusers model - - After determining the nature of the model and downloading it - (if necessary), the file is probed to determine the correct - configuration file (if needed) and it is imported. - - The model_name and/or description can be provided. If not, they will - be generated automatically. - - If convert is true, legacy models will be converted to diffusers - before importing. - - If commit_to_conf is provided, the newly loaded model will be written - to the `models.yaml` file at the indicated path. Otherwise, the changes - will only remain in memory. - - The (potentially derived) name of the model is returned on success, or None - on failure. When multiple models are added from a directory, only the last - imported one is returned. - """ - model_path: Path = None - thing = path_url_or_repo # to save typing - - print(f">> Probing {thing} for import") - - if thing.startswith(("http:", "https:", "ftp:")): - print(f" | {thing} appears to be a URL") - model_path = self._resolve_path( - thing, "models/ldm/stable-diffusion-v1" - ) # _resolve_path does a download if needed - - elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")): - if Path(thing).stem in ["model", "diffusion_pytorch_model"]: - print( - f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import" - ) - return - else: - print(f" | {thing} appears to be a checkpoint file on disk") - model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1") - - elif Path(thing).is_dir() and Path(thing, "model_index.json").exists(): - print(f" | {thing} appears to be a diffusers file on disk") - model_name = self.import_diffuser_model( - thing, - vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), - model_name=model_name, - description=description, - commit_to_conf=commit_to_conf, - ) - - elif Path(thing).is_dir(): - if (Path(thing) / "model_index.json").exists(): - print(f" | {thing} appears to be a diffusers model.") - model_name = self.import_diffuser_model( - thing, commit_to_conf=commit_to_conf - ) - else: - print( - f" |{thing} appears to be a directory. Will scan for models to import" - ) - for m in list(Path(thing).rglob("*.ckpt")) + list( - Path(thing).rglob("*.safetensors") - ): - if model_name := self.heuristic_import( - str(m), convert, commit_to_conf=commit_to_conf - ): - print(f" >> {model_name} successfully imported") - return model_name - - elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing): - print(f" | {thing} appears to be a HuggingFace diffusers repo_id") - model_name = self.import_diffuser_model( - thing, commit_to_conf=commit_to_conf - ) - pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name]) - return model_name - else: - print( - f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id" - ) - - # Model_path is set in the event of a legacy checkpoint file. - # If not set, we're all done - if not model_path: - return - - if model_path.stem in self.config: # already imported - print(" | Already imported. Skipping") - return - - # another round of heuristics to guess the correct config file. - checkpoint = ( - safetensors.torch.load_file(model_path) - if model_path.suffix == ".safetensors" - else torch.load(model_path) - ) - model_type = self.probe_model_type(checkpoint) - - model_config_file = None - if model_type == SDLegacyType.V1: - print(" | SD-v1 model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v1-inference.yaml" - ) - elif model_type == SDLegacyType.V1_INPAINT: - print(" | SD-v1 inpainting model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml" - ) - elif model_type == SDLegacyType.V2: - print( - " | SD-v2 model detected; model will be converted to diffusers format" - ) - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" - ) - convert = True - else: - print( - f"** {thing} is a legacy checkpoint file but not in a known Stable Diffusion model. Skipping import" - ) - return - - if convert: - diffuser_path = Path( - Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem - ) - model_name = self.convert_and_import( - model_path, - diffusers_path=diffuser_path, - vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), - model_name=model_name, - model_description=description, - original_config_file=model_config_file, - commit_to_conf=commit_to_conf, - ) - else: - model_name = self.import_ckpt_model( - model_path, - config=model_config_file, - model_name=model_name, - model_description=description, - vae=str( - Path( - Globals.root, - "models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt", - ) - ), - commit_to_conf=commit_to_conf, - ) - if commit_to_conf: - self.commit(commit_to_conf) - return model_name - - def convert_and_import( - self, - ckpt_path: Path, - diffusers_path: Path, - model_name=None, - model_description=None, - vae=None, - original_config_file: Path = None, - commit_to_conf: Path = None, - ) -> str: - """ - Convert a legacy ckpt weights file to diffuser model and import - into models.yaml. - """ - ckpt_path = self._resolve_path(ckpt_path, "models/ldm/stable-diffusion-v1") - if original_config_file: - original_config_file = self._resolve_path( - original_config_file, "configs/stable-diffusion" - ) - - new_config = None - - from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser - - if diffusers_path.exists(): - print( - f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again." - ) - return - - model_name = model_name or diffusers_path.name - model_description = model_description or f"Optimized version of {model_name}" - print(f">> Optimizing {model_name} (30-60s)") - try: - # By passing the specified VAE to the conversion function, the autoencoder - # will be built into the model rather than tacked on afterward via the config file - vae_model = self._load_vae(vae) if vae else None - convert_ckpt_to_diffuser( - ckpt_path, - diffusers_path, - extract_ema=True, - original_config_file=original_config_file, - vae=vae_model, - ) - print( - f" | Success. Optimized model is now located at {str(diffusers_path)}" - ) - print(f" | Writing new config file entry for {model_name}") - new_config = dict( - path=str(diffusers_path), - description=model_description, - format="diffusers", - ) - if model_name in self.config: - self.del_model(model_name) - self.add_model(model_name, new_config, True) - if commit_to_conf: - self.commit(commit_to_conf) - print(">> Conversion succeeded") - except Exception as e: - print(f"** Conversion failed: {str(e)}") - print( - "** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)" - ) - - return model_name - - def search_models(self, search_folder): - print(f">> Finding Models In: {search_folder}") - models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") - models_folder_safetensors = Path(search_folder).glob("**/*.safetensors") - - ckpt_files = [x for x in models_folder_ckpt if x.is_file()] - safetensor_files = [x for x in models_folder_safetensors if x.is_file()] - - files = ckpt_files + safetensor_files - - found_models = [] - for file in files: - location = str(file.resolve()).replace("\\", "/") - if ( - "model.safetensors" not in location - and "diffusion_pytorch_model.safetensors" not in location - ): - found_models.append({"name": file.stem, "location": location}) - - return search_folder, found_models - - def _choose_diffusers_vae( - self, model_name: str, vae: str = None - ) -> Union[dict, str]: - # In the event that the original entry is using a custom ckpt VAE, we try to - # map that VAE onto a diffuser VAE using a hard-coded dictionary. - # I would prefer to do this differently: We load the ckpt model into memory, swap the - # VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped - # VAE is built into the model. However, when I tried this I got obscure key errors. - if vae: - return vae - if model_name in self.config and ( - vae_ckpt_path := self.model_info(model_name).get("vae", None) - ): - vae_basename = Path(vae_ckpt_path).stem - diffusers_vae = None - if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None): - print( - f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version" - ) - vae = {"repo_id": diffusers_vae} - else: - print( - f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown' - ) - print( - '** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config' - ) - vae = {"repo_id": "stabilityai/sd-vae-ft-mse"} - return vae - - def _make_cache_room(self) -> None: - num_loaded_models = len(self.models) - if num_loaded_models >= self.max_loaded_models: - least_recent_model = self._pop_oldest_model() - print( - f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}" - ) - if least_recent_model is not None: - del self.models[least_recent_model] - gc.collect() - - def print_vram_usage(self) -> None: - if self._has_cuda: - print( - ">> Current VRAM usage: ", - "%4.2fG" % (torch.cuda.memory_allocated() / 1e9), - ) - - def commit(self, config_file_path: str) -> None: - """ - Write current configuration out to the indicated file. - """ - yaml_str = OmegaConf.to_yaml(self.config) - if not os.path.isabs(config_file_path): - config_file_path = os.path.normpath( - os.path.join(Globals.root, config_file_path) - ) - tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp") - with open(tmpfile, "w", encoding="utf-8") as outfile: - outfile.write(self.preamble()) - outfile.write(yaml_str) - os.replace(tmpfile, config_file_path) - - def preamble(self) -> str: - """ - Returns the preamble for the config file. - """ - return textwrap.dedent( - """\ - # This file describes the alternative machine learning models - # available to InvokeAI script. - # - # To add a new model, follow the examples below. Each - # model requires a model config file, a weights file, - # and the width and height of the images it - # was trained on. - """ - ) - - @classmethod - def migrate_models(cls): - """ - Migrate the ~/invokeai/models directory from the legacy format used through 2.2.5 - to the 2.3.0 "diffusers" version. This should be a one-time operation, called at - script startup time. - """ - # Three transformer models to check: bert, clip and safety checker - legacy_locations = [ - Path( - "CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker" - ), - Path("bert-base-uncased/models--bert-base-uncased"), - Path( - "openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14" - ), - ] - models_dir = Path(Globals.root, "models") - legacy_layout = False - for model in legacy_locations: - legacy_layout = legacy_layout or Path(models_dir, model).exists() - if not legacy_layout: - return - - print( - "** Legacy version <= 2.2.5 model directory layout detected. Reorganizing." - ) - print("** This is a quick one-time operation.") - - # transformer files get moved into the hub directory - if cls._is_huggingface_hub_directory_present(): - hub = global_cache_dir("hub") - else: - hub = models_dir / "hub" - - os.makedirs(hub, exist_ok=True) - for model in legacy_locations: - source = models_dir / model - dest = hub / model.stem - print(f"** {source} => {dest}") - if source.exists(): - if dest.exists(): - rmtree(source) - else: - move(source, dest) - - # anything else gets moved into the diffusers directory - if cls._is_huggingface_hub_directory_present(): - diffusers = global_cache_dir("diffusers") - else: - diffusers = models_dir / "diffusers" - - os.makedirs(diffusers, exist_ok=True) - for root, dirs, _ in os.walk(models_dir, topdown=False): - for dir in dirs: - full_path = Path(root, dir) - if full_path.is_relative_to(hub) or full_path.is_relative_to(diffusers): - continue - if Path(dir).match("models--*--*"): - dest = diffusers / dir - print(f"** {full_path} => {dest}") - if dest.exists(): - rmtree(full_path) - else: - move(full_path, dest) - - # now clean up by removing any empty directories - empty = [ - root - for root, dirs, files, in os.walk(models_dir) - if not len(dirs) and not len(files) - ] - for d in empty: - os.rmdir(d) - print("** Migration is done. Continuing...") - - def _resolve_path( - self, source: Union[str, Path], dest_directory: str - ) -> Optional[Path]: - resolved_path = None - if str(source).startswith(("http:", "https:", "ftp:")): - dest_directory = Path(dest_directory) - if not dest_directory.is_absolute(): - dest_directory = Globals.root / dest_directory - dest_directory.mkdir(parents=True, exist_ok=True) - resolved_path = download_with_resume(str(source), dest_directory) - else: - if not os.path.isabs(source): - source = os.path.join(Globals.root, source) - resolved_path = Path(source) - return resolved_path - - def _invalidate_cached_model(self, model_name: str) -> None: - self.offload_model(model_name) - if model_name in self.stack: - self.stack.remove(model_name) - self.models.pop(model_name, None) - - def _model_to_cpu(self, model): - if self.device == CPU_DEVICE: - return model - - if isinstance(model, StableDiffusionGeneratorPipeline): - model.offload_all() - return model - - model.cond_stage_model.device = CPU_DEVICE - model.to(CPU_DEVICE) - - for submodel in ("first_stage_model", "cond_stage_model", "model"): - try: - getattr(model, submodel).to(CPU_DEVICE) - except AttributeError: - pass - return model - - def _model_from_cpu(self, model): - if self.device == CPU_DEVICE: - return model - - if isinstance(model, StableDiffusionGeneratorPipeline): - model.ready() - return model - - model.to(self.device) - model.cond_stage_model.device = self.device - - for submodel in ("first_stage_model", "cond_stage_model", "model"): - try: - getattr(model, submodel).to(self.device) - except AttributeError: - pass - - return model - - def _pop_oldest_model(self): - """ - Remove the first element of the FIFO, which ought - to be the least recently accessed model. Do not - pop the last one, because it is in active use! - """ - return self.stack.pop(0) - - def _push_newest_model(self, model_name: str) -> None: - """ - Maintain a simple FIFO. First element is always the - least recent, and last element is always the most recent. - """ - with contextlib.suppress(ValueError): - self.stack.remove(model_name) - self.stack.append(model_name) - - def _has_cuda(self) -> bool: - return self.device.type == "cuda" - - def _diffuser_sha256( - self, name_or_path: Union[str, Path], chunksize=4096 - ) -> Union[str, bytes]: - path = None - if isinstance(name_or_path, Path): - path = name_or_path - else: - owner, repo = name_or_path.split("/") - path = Path(global_cache_dir("diffusers") / f"models--{owner}--{repo}") - if not path.exists(): - return None - hashpath = path / "checksum.sha256" - if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime: - with open(hashpath) as f: - hash = f.read() - return hash - print(" | Calculating sha256 hash of model files") - tic = time.time() - sha = hashlib.sha256() - count = 0 - for root, dirs, files in os.walk(path, followlinks=False): - for name in files: - count += 1 - with open(os.path.join(root, name), "rb") as f: - while chunk := f.read(chunksize): - sha.update(chunk) - hash = sha.hexdigest() - toc = time.time() - print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) - with open(hashpath, "w") as f: - f.write(hash) - return hash - - def _cached_sha256(self, path, data) -> Union[str, bytes]: - dirname = os.path.dirname(path) - basename = os.path.basename(path) - base, _ = os.path.splitext(basename) - hashpath = os.path.join(dirname, base + ".sha256") - - if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime( - hashpath - ): - with open(hashpath) as f: - hash = f.read() - return hash - - print(" | Calculating sha256 hash of weights file") - tic = time.time() - sha = hashlib.sha256() - sha.update(data) - hash = sha.hexdigest() - toc = time.time() - print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic)) - - with open(hashpath, "w") as f: - f.write(hash) - return hash - - def _load_vae(self, vae_config) -> AutoencoderKL: - vae_args = {} - try: - name_or_path = self.model_name_or_path(vae_config) - except Exception: - return None - if name_or_path is None: - return None - using_fp16 = self.precision == "float16" - - vae_args.update( - cache_dir=global_cache_dir("diffusers"), - local_files_only=not Globals.internet_available, - ) - - print(f" | Loading diffusers VAE from {name_or_path}") - if using_fp16: - vae_args.update(torch_dtype=torch.float16) - fp_args_list = [{"revision": "fp16"}, {}] - else: - print(" | Using more accurate float32 precision") - fp_args_list = [{}] - - vae = None - deferred_error = None - - # A VAE may be in a subfolder of a model's repository. - if "subfolder" in vae_config: - vae_args["subfolder"] = vae_config["subfolder"] - - for fp_args in fp_args_list: - # At some point we might need to be able to use different classes here? But for now I think - # all Stable Diffusion VAE are AutoencoderKL. - try: - vae = AutoencoderKL.from_pretrained(name_or_path, **vae_args, **fp_args) - except OSError as e: - if str(e).startswith("fp16 is not a valid"): - pass - else: - deferred_error = e - if vae: - break - - if not vae and deferred_error: - print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}") - - return vae - - @staticmethod - def _delete_model_from_cache(repo_id): - cache_info = scan_cache_dir(global_cache_dir("diffusers")) - - # I'm sure there is a way to do this with comprehensions - # but the code quickly became incomprehensible! - hashes_to_delete = set() - for repo in cache_info.repos: - if repo.repo_id == repo_id: - for revision in repo.revisions: - hashes_to_delete.add(revision.commit_hash) - strategy = cache_info.delete_revisions(*hashes_to_delete) - print( - f"** deletion of this model is expected to free {strategy.expected_freed_size_str}" - ) - strategy.execute() - - @staticmethod - def _abs_path(path: str | Path) -> Path: - if path is None or Path(path).is_absolute(): - return path - return Path(Globals.root, path).resolve() - - @staticmethod - def _is_huggingface_hub_directory_present() -> bool: - return ( - os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None - ) diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py deleted file mode 100644 index 3db7b6fd73..0000000000 --- a/ldm/models/autoencoder.py +++ /dev/null @@ -1,596 +0,0 @@ -import torch -import pytorch_lightning as pl -import torch.nn.functional as F -from contextlib import contextmanager - -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer - -from ldm.modules.diffusionmodules.model import Encoder, Decoder -from ldm.modules.distributions.distributions import ( - DiagonalGaussianDistribution, -) - -from ldm.util import instantiate_from_config - - -class VQModel(pl.LightningModule): - def __init__( - self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key='image', - colorize_nlabels=None, - monitor=None, - batch_resize_range=None, - scheduler_config=None, - lr_g_factor=1.0, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - use_ema=False, - ): - super().__init__() - self.embed_dim = embed_dim - self.n_embed = n_embed - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer( - n_embed, - embed_dim, - beta=0.25, - remap=remap, - sane_index_shape=sane_index_shape, - ) - self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d( - embed_dim, ddconfig['z_channels'], 1 - ) - if colorize_nlabels is not None: - assert type(colorize_nlabels) == int - self.register_buffer( - 'colorize', torch.randn(3, colorize_nlabels, 1, 1) - ) - if monitor is not None: - self.monitor = monitor - self.batch_resize_range = batch_resize_range - if self.batch_resize_range is not None: - print( - f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.' - ) - - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self) - print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.') - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - self.scheduler_config = scheduler_config - self.lr_g_factor = lr_g_factor - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - print(f'{context}: Switched to EMA weights') - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - print(f'{context}: Restored training weights') - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location='cpu')['state_dict'] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print('Deleting key {} from state_dict.'.format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) - print( - f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' - ) - if len(missing) > 0: - print(f'Missing Keys: {missing}') - print(f'Unexpected Keys: {unexpected}') - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - quant, emb_loss, info = self.quantize(h) - return quant, emb_loss, info - - def encode_to_prequant(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, quant): - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - def decode_code(self, code_b): - quant_b = self.quantize.embed_code(code_b) - dec = self.decode(quant_b) - return dec - - def forward(self, input, return_pred_indices=False): - quant, diff, (_, _, ind) = self.encode(input) - dec = self.decode(quant) - if return_pred_indices: - return dec, diff, ind - return dec, diff - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = ( - x.permute(0, 3, 1, 2) - .to(memory_format=torch.contiguous_format) - .float() - ) - if self.batch_resize_range is not None: - lower_size = self.batch_resize_range[0] - upper_size = self.batch_resize_range[1] - if self.global_step <= 4: - # do the first few batches with max size to avoid later oom - new_resize = upper_size - else: - new_resize = np.random.choice( - np.arange(lower_size, upper_size + 16, 16) - ) - if new_resize != x.shape[2]: - x = F.interpolate(x, size=new_resize, mode='bicubic') - x = x.detach() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - # https://github.com/pytorch/pytorch/issues/37142 - # try not to fool the heuristics - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - - if optimizer_idx == 0: - # autoencode - aeloss, log_dict_ae = self.loss( - qloss, - x, - xrec, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split='train', - predicted_indices=ind, - ) - - self.log_dict( - log_dict_ae, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=True, - ) - return aeloss - - if optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss( - qloss, - x, - xrec, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split='train', - ) - self.log_dict( - log_dict_disc, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=True, - ) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step( - batch, batch_idx, suffix='_ema' - ) - return log_dict - - def _validation_step(self, batch, batch_idx, suffix=''): - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss( - qloss, - x, - xrec, - 0, - self.global_step, - last_layer=self.get_last_layer(), - split='val' + suffix, - predicted_indices=ind, - ) - - discloss, log_dict_disc = self.loss( - qloss, - x, - xrec, - 1, - self.global_step, - last_layer=self.get_last_layer(), - split='val' + suffix, - predicted_indices=ind, - ) - rec_loss = log_dict_ae[f'val{suffix}/rec_loss'] - self.log( - f'val{suffix}/rec_loss', - rec_loss, - prog_bar=True, - logger=True, - on_step=False, - on_epoch=True, - sync_dist=True, - ) - self.log( - f'val{suffix}/aeloss', - aeloss, - prog_bar=True, - logger=True, - on_step=False, - on_epoch=True, - sync_dist=True, - ) - if version.parse(pl.__version__) >= version.parse('1.4.0'): - del log_dict_ae[f'val{suffix}/rec_loss'] - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr_d = self.learning_rate - lr_g = self.lr_g_factor * self.learning_rate - print('lr_d', lr_d) - print('lr_g', lr_g) - opt_ae = torch.optim.Adam( - list(self.encoder.parameters()) - + list(self.decoder.parameters()) - + list(self.quantize.parameters()) - + list(self.quant_conv.parameters()) - + list(self.post_quant_conv.parameters()), - lr=lr_g, - betas=(0.5, 0.9), - ) - opt_disc = torch.optim.Adam( - self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9) - ) - - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - - print('Setting up LambdaLR scheduler...') - scheduler = [ - { - 'scheduler': LambdaLR( - opt_ae, lr_lambda=scheduler.schedule - ), - 'interval': 'step', - 'frequency': 1, - }, - { - 'scheduler': LambdaLR( - opt_disc, lr_lambda=scheduler.schedule - ), - 'interval': 'step', - 'frequency': 1, - }, - ] - return [opt_ae, opt_disc], scheduler - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if only_inputs: - log['inputs'] = x - return log - xrec, _ = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log['inputs'] = x - log['reconstructions'] = xrec - if plot_ema: - with self.ema_scope(): - xrec_ema, _ = self(x) - if x.shape[1] > 3: - xrec_ema = self.to_rgb(xrec_ema) - log['reconstructions_ema'] = xrec_ema - return log - - def to_rgb(self, x): - assert self.image_key == 'segmentation' - if not hasattr(self, 'colorize'): - self.register_buffer( - 'colorize', torch.randn(3, x.shape[1], 1, 1).to(x) - ) - x = F.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - -class VQModelInterface(VQModel): - def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) - self.embed_dim = embed_dim - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - -class AutoencoderKL(pl.LightningModule): - def __init__( - self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key='image', - colorize_nlabels=None, - monitor=None, - ): - super().__init__() - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig['double_z'] - self.quant_conv = torch.nn.Conv2d( - 2 * ddconfig['z_channels'], 2 * embed_dim, 1 - ) - self.post_quant_conv = torch.nn.Conv2d( - embed_dim, ddconfig['z_channels'], 1 - ) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels) == int - self.register_buffer( - 'colorize', torch.randn(3, colorize_nlabels, 1, 1) - ) - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location='cpu')['state_dict'] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print('Deleting key {} from state_dict.'.format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f'Restored from {path}') - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = ( - x.permute(0, 3, 1, 2) - .to(memory_format=torch.contiguous_format) - .float() - ) - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss( - inputs, - reconstructions, - posterior, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split='train', - ) - self.log( - 'aeloss', - aeloss, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) - self.log_dict( - log_dict_ae, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=False, - ) - return aeloss - - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss( - inputs, - reconstructions, - posterior, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split='train', - ) - - self.log( - 'discloss', - discloss, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) - self.log_dict( - log_dict_disc, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=False, - ) - return discloss - - def validation_step(self, batch, batch_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss( - inputs, - reconstructions, - posterior, - 0, - self.global_step, - last_layer=self.get_last_layer(), - split='val', - ) - - discloss, log_dict_disc = self.loss( - inputs, - reconstructions, - posterior, - 1, - self.global_step, - last_layer=self.get_last_layer(), - split='val', - ) - - self.log('val/rec_loss', log_dict_ae['val/rec_loss']) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - opt_ae = torch.optim.Adam( - list(self.encoder.parameters()) - + list(self.decoder.parameters()) - + list(self.quant_conv.parameters()) - + list(self.post_quant_conv.parameters()), - lr=lr, - betas=(0.5, 0.9), - ) - opt_disc = torch.optim.Adam( - self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) - ) - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - @torch.no_grad() - def log_images(self, batch, only_inputs=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log['samples'] = self.decode(torch.randn_like(posterior.sample())) - log['reconstructions'] = xrec - log['inputs'] = x - return log - - def to_rgb(self, x): - assert self.image_key == 'segmentation' - if not hasattr(self, 'colorize'): - self.register_buffer( - 'colorize', torch.randn(3, x.shape[1], 1, 1).to(x) - ) - x = F.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff - super().__init__() - - def encode(self, x, *args, **kwargs): - return x - - def decode(self, x, *args, **kwargs): - return x - - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x - - def forward(self, x, *args, **kwargs): - return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py deleted file mode 100644 index be0d8c1919..0000000000 --- a/ldm/models/diffusion/classifier.py +++ /dev/null @@ -1,355 +0,0 @@ -import os -import torch -import pytorch_lightning as pl -from omegaconf import OmegaConf -from torch.nn import functional as F -from torch.optim import AdamW -from torch.optim.lr_scheduler import LambdaLR -from copy import deepcopy -from einops import rearrange -from glob import glob -from natsort import natsorted - -from ldm.modules.diffusionmodules.openaimodel import ( - EncoderUNetModel, - UNetModel, -) -from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config - -__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class NoisyLatentImageClassifier(pl.LightningModule): - def __init__( - self, - diffusion_path, - num_classes, - ckpt_path=None, - pool='attention', - label_key=None, - diffusion_ckpt_path=None, - scheduler_config=None, - weight_decay=1.0e-2, - log_steps=10, - monitor='val/loss', - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.num_classes = num_classes - # get latest config of diffusion model - diffusion_config = natsorted( - glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')) - )[-1] - self.diffusion_config = OmegaConf.load(diffusion_config).model - self.diffusion_config.params.ckpt_path = diffusion_ckpt_path - self.load_diffusion() - - self.monitor = monitor - self.numd = ( - self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 - ) - self.log_time_interval = ( - self.diffusion_model.num_timesteps // log_steps - ) - self.log_steps = log_steps - - self.label_key = ( - label_key - if not hasattr(self.diffusion_model, 'cond_stage_key') - else self.diffusion_model.cond_stage_key - ) - - assert ( - self.label_key is not None - ), 'label_key neither in diffusion model nor in model.params' - - if self.label_key not in __models__: - raise NotImplementedError() - - self.load_classifier(ckpt_path, pool) - - self.scheduler_config = scheduler_config - self.use_scheduler = self.scheduler_config is not None - self.weight_decay = weight_decay - - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location='cpu') - if 'state_dict' in list(sd.keys()): - sd = sd['state_dict'] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print('Deleting key {} from state_dict.'.format(k)) - del sd[k] - missing, unexpected = ( - self.load_state_dict(sd, strict=False) - if not only_model - else self.model.load_state_dict(sd, strict=False) - ) - print( - f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' - ) - if len(missing) > 0: - print(f'Missing Keys: {missing}') - if len(unexpected) > 0: - print(f'Unexpected Keys: {unexpected}') - - def load_diffusion(self): - model = instantiate_from_config(self.diffusion_config) - self.diffusion_model = model.eval() - self.diffusion_model.train = disabled_train - for param in self.diffusion_model.parameters(): - param.requires_grad = False - - def load_classifier(self, ckpt_path, pool): - model_config = deepcopy( - self.diffusion_config.params.unet_config.params - ) - model_config.in_channels = ( - self.diffusion_config.params.unet_config.params.out_channels - ) - model_config.out_channels = self.num_classes - if self.label_key == 'class_label': - model_config.pool = pool - - self.model = __models__[self.label_key](**model_config) - if ckpt_path is not None: - print( - '#####################################################################' - ) - print(f'load from ckpt "{ckpt_path}"') - print( - '#####################################################################' - ) - self.init_from_ckpt(ckpt_path) - - @torch.no_grad() - def get_x_noisy(self, x, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x)) - continuous_sqrt_alpha_cumprod = None - if self.diffusion_model.use_continuous_noise: - continuous_sqrt_alpha_cumprod = ( - self.diffusion_model.sample_continuous_noise_level( - x.shape[0], t + 1 - ) - ) - # todo: make sure t+1 is correct here - - return self.diffusion_model.q_sample( - x_start=x, - t=t, - noise=noise, - continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod, - ) - - def forward(self, x_noisy, t, *args, **kwargs): - return self.model(x_noisy, t) - - @torch.no_grad() - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') - x = x.to(memory_format=torch.contiguous_format).float() - return x - - @torch.no_grad() - def get_conditioning(self, batch, k=None): - if k is None: - k = self.label_key - assert k is not None, 'Needs to provide label key' - - targets = batch[k].to(self.device) - - if self.label_key == 'segmentation': - targets = rearrange(targets, 'b h w c -> b c h w') - for down in range(self.numd): - h, w = targets.shape[-2:] - targets = F.interpolate( - targets, size=(h // 2, w // 2), mode='nearest' - ) - - # targets = rearrange(targets,'b c h w -> b h w c') - - return targets - - def compute_top_k(self, logits, labels, k, reduction='mean'): - _, top_ks = torch.topk(logits, k, dim=1) - if reduction == 'mean': - return ( - (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() - ) - elif reduction == 'none': - return (top_ks == labels[:, None]).float().sum(dim=-1) - - def on_train_epoch_start(self): - # save some memory - self.diffusion_model.model.to('cpu') - - @torch.no_grad() - def write_logs(self, loss, logits, targets): - log_prefix = 'train' if self.training else 'val' - log = {} - log[f'{log_prefix}/loss'] = loss.mean() - log[f'{log_prefix}/acc@1'] = self.compute_top_k( - logits, targets, k=1, reduction='mean' - ) - log[f'{log_prefix}/acc@5'] = self.compute_top_k( - logits, targets, k=5, reduction='mean' - ) - - self.log_dict( - log, - prog_bar=False, - logger=True, - on_step=self.training, - on_epoch=True, - ) - self.log( - 'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False - ) - self.log( - 'global_step', - self.global_step, - logger=False, - on_epoch=False, - prog_bar=True, - ) - lr = self.optimizers().param_groups[0]['lr'] - self.log( - 'lr_abs', - lr, - on_step=True, - logger=True, - on_epoch=False, - prog_bar=True, - ) - - def shared_step(self, batch, t=None): - x, *_ = self.diffusion_model.get_input( - batch, k=self.diffusion_model.first_stage_key - ) - targets = self.get_conditioning(batch) - if targets.dim() == 4: - targets = targets.argmax(dim=1) - if t is None: - t = torch.randint( - 0, - self.diffusion_model.num_timesteps, - (x.shape[0],), - device=self.device, - ).long() - else: - t = torch.full( - size=(x.shape[0],), fill_value=t, device=self.device - ).long() - x_noisy = self.get_x_noisy(x, t) - logits = self(x_noisy, t) - - loss = F.cross_entropy(logits, targets, reduction='none') - - self.write_logs(loss.detach(), logits.detach(), targets.detach()) - - loss = loss.mean() - return loss, logits, x_noisy, targets - - def training_step(self, batch, batch_idx): - loss, *_ = self.shared_step(batch) - return loss - - def reset_noise_accs(self): - self.noisy_acc = { - t: {'acc@1': [], 'acc@5': []} - for t in range( - 0, - self.diffusion_model.num_timesteps, - self.diffusion_model.log_every_t, - ) - } - - def on_validation_start(self): - self.reset_noise_accs() - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - loss, *_ = self.shared_step(batch) - - for t in self.noisy_acc: - _, logits, _, targets = self.shared_step(batch, t) - self.noisy_acc[t]['acc@1'].append( - self.compute_top_k(logits, targets, k=1, reduction='mean') - ) - self.noisy_acc[t]['acc@5'].append( - self.compute_top_k(logits, targets, k=5, reduction='mean') - ) - - return loss - - def configure_optimizers(self): - optimizer = AdamW( - self.model.parameters(), - lr=self.learning_rate, - weight_decay=self.weight_decay, - ) - - if self.use_scheduler: - scheduler = instantiate_from_config(self.scheduler_config) - - print('Setting up LambdaLR scheduler...') - scheduler = [ - { - 'scheduler': LambdaLR( - optimizer, lr_lambda=scheduler.schedule - ), - 'interval': 'step', - 'frequency': 1, - } - ] - return [optimizer], scheduler - - return optimizer - - @torch.no_grad() - def log_images(self, batch, N=8, *args, **kwargs): - log = dict() - x = self.get_input(batch, self.diffusion_model.first_stage_key) - log['inputs'] = x - - y = self.get_conditioning(batch) - - if self.label_key == 'class_label': - y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label']) - log['labels'] = y - - if ismap(y): - log['labels'] = self.diffusion_model.to_rgb(y) - - for step in range(self.log_steps): - current_time = step * self.log_time_interval - - _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) - - log[f'inputs@t{current_time}'] = x_noisy - - pred = F.one_hot( - logits.argmax(dim=1), num_classes=self.num_classes - ) - pred = rearrange(pred, 'b h w c -> b c h w') - - log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb( - pred - ) - - for key in log: - log[key] = log[key][:N] - - return log diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py deleted file mode 100644 index a34f22e683..0000000000 --- a/ldm/models/diffusion/cross_attention_control.py +++ /dev/null @@ -1,642 +0,0 @@ - -# adapted from bloc97's CrossAttentionControl colab -# https://github.com/bloc97/CrossAttentionControl - - -import enum -import math -from typing import Optional, Callable - -import psutil -import torch -import diffusers -from torch import nn - -from compel.cross_attention_control import Arguments -from diffusers.models.unet_2d_condition import UNet2DConditionModel -from diffusers.models.cross_attention import AttnProcessor -from ldm.invoke.devices import torch_dtype - - -class CrossAttentionType(enum.Enum): - SELF = 1 - TOKENS = 2 - - -class Context: - - cross_attention_mask: Optional[torch.Tensor] - cross_attention_index_map: Optional[torch.Tensor] - - class Action(enum.Enum): - NONE = 0 - SAVE = 1, - APPLY = 2 - - def __init__(self, arguments: Arguments, step_count: int): - """ - :param arguments: Arguments for the cross-attention control process - :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) - """ - self.cross_attention_mask = None - self.cross_attention_index_map = None - self.self_cross_attention_action = Context.Action.NONE - self.tokens_cross_attention_action = Context.Action.NONE - self.arguments = arguments - self.step_count = step_count - - self.self_cross_attention_module_identifiers = [] - self.tokens_cross_attention_module_identifiers = [] - - self.saved_cross_attention_maps = {} - - self.clear_requests(cleanup=True) - - def register_cross_attention_modules(self, model): - for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF): - if name in self.self_cross_attention_module_identifiers: - assert False, f"name {name} cannot appear more than once" - self.self_cross_attention_module_identifiers.append(name) - for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): - if name in self.tokens_cross_attention_module_identifiers: - assert False, f"name {name} cannot appear more than once" - self.tokens_cross_attention_module_identifiers.append(name) - - def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): - if cross_attention_type == CrossAttentionType.SELF: - self.self_cross_attention_action = Context.Action.SAVE - else: - self.tokens_cross_attention_action = Context.Action.SAVE - - def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType): - if cross_attention_type == CrossAttentionType.SELF: - self.self_cross_attention_action = Context.Action.APPLY - else: - self.tokens_cross_attention_action = Context.Action.APPLY - - def is_tokens_cross_attention(self, module_identifier) -> bool: - return module_identifier in self.tokens_cross_attention_module_identifiers - - def get_should_save_maps(self, module_identifier: str) -> bool: - if module_identifier in self.self_cross_attention_module_identifiers: - return self.self_cross_attention_action == Context.Action.SAVE - elif module_identifier in self.tokens_cross_attention_module_identifiers: - return self.tokens_cross_attention_action == Context.Action.SAVE - return False - - def get_should_apply_saved_maps(self, module_identifier: str) -> bool: - if module_identifier in self.self_cross_attention_module_identifiers: - return self.self_cross_attention_action == Context.Action.APPLY - elif module_identifier in self.tokens_cross_attention_module_identifiers: - return self.tokens_cross_attention_action == Context.Action.APPLY - return False - - def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ - -> list[CrossAttentionType]: - """ - Should cross-attention control be applied on the given step? - :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. - :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. - """ - if percent_through is None: - return [CrossAttentionType.SELF, CrossAttentionType.TOKENS] - - opts = self.arguments.edit_options - to_control = [] - if opts['s_start'] <= percent_through < opts['s_end']: - to_control.append(CrossAttentionType.SELF) - if opts['t_start'] <= percent_through < opts['t_end']: - to_control.append(CrossAttentionType.TOKENS) - return to_control - - def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, - slice_size: Optional[int]): - if identifier not in self.saved_cross_attention_maps: - self.saved_cross_attention_maps[identifier] = { - 'dim': dim, - 'slice_size': slice_size, - 'slices': {offset or 0: slice} - } - else: - self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice - - def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int): - saved_attention_dict = self.saved_cross_attention_maps[identifier] - if requested_dim is None: - if saved_attention_dict['dim'] is not None: - raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") - return saved_attention_dict['slices'][0] - - if saved_attention_dict['dim'] == requested_dim: - if slice_size != saved_attention_dict['slice_size']: - raise RuntimeError( - f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") - return saved_attention_dict['slices'][requested_offset] - - if saved_attention_dict['dim'] is None: - whole_saved_attention = saved_attention_dict['slices'][0] - if requested_dim == 0: - return whole_saved_attention[requested_offset:requested_offset + slice_size] - elif requested_dim == 1: - return whole_saved_attention[:, requested_offset:requested_offset + slice_size] - - raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") - - def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]: - saved_attention = self.saved_cross_attention_maps.get(identifier, None) - if saved_attention is None: - return None, None - return saved_attention['dim'], saved_attention['slice_size'] - - def clear_requests(self, cleanup=True): - self.tokens_cross_attention_action = Context.Action.NONE - self.self_cross_attention_action = Context.Action.NONE - if cleanup: - self.saved_cross_attention_maps = {} - - def offload_saved_attention_slices_to_cpu(self): - for key, map_dict in self.saved_cross_attention_maps.items(): - for offset, slice in map_dict['slices'].items(): - map_dict[offset] = slice.to('cpu') - - - -class InvokeAICrossAttentionMixin: - """ - Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls - through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling - and dymamic slicing strategy selection. - """ - def __init__(self): - self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - self.attention_slice_wrangler = None - self.slicing_strategy_getter = None - self.attention_slice_calculated_callback = None - - def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]): - ''' - Set custom attention calculator to be called when attention is calculated - :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), - which returns either the suggested_attention_slice or an adjusted equivalent. - `module` is the current CrossAttention module for which the callback is being invoked. - `suggested_attention_slice` is the default-calculated attention slice - `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. - If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. - - Pass None to use the default attention calculation. - :return: - ''' - self.attention_slice_wrangler = wrangler - - def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]): - self.slicing_strategy_getter = getter - - def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]): - self.attention_slice_calculated_callback = callback - - def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): - # calculate attention scores - #attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - # calculate attention slice by taking the best scores for each latent pixel - default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) - attention_slice_wrangler = self.attention_slice_wrangler - if attention_slice_wrangler is not None: - attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) - else: - attention_slice = default_attention_slice - - if self.attention_slice_calculated_callback is not None: - self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size) - - hidden_states = torch.bmm(attention_slice, value) - return hidden_states - - def einsum_op_slice_dim0(self, q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[0], slice_size): - end = i + slice_size - r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) - return r - - def einsum_op_slice_dim1(self, q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) - return r - - def einsum_op_mps_v1(self, q, k, v): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return self.einsum_lowest_level(q, k, v, None, None, None) - else: - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_dim1(q, k, v, slice_size) - - def einsum_op_mps_v2(self, q, k, v): - if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return self.einsum_lowest_level(q, k, v, None, None, None) - else: - return self.einsum_op_slice_dim0(q, k, v, 1) - - def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): - size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) - if size_mb <= max_tensor_mb: - return self.einsum_lowest_level(q, k, v, None, None, None) - div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() - if div <= q.shape[0]: - return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) - return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) - - def einsum_op_cuda(self, q, k, v): - # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) - slicing_strategy_getter = self.slicing_strategy_getter - if slicing_strategy_getter is not None: - (dim, slice_size) = slicing_strategy_getter(self) - if dim is not None: - # print("using saved slicing strategy with dim", dim, "slice size", slice_size) - if dim == 0: - return self.einsum_op_slice_dim0(q, k, v, slice_size) - elif dim == 1: - return self.einsum_op_slice_dim1(q, k, v, slice_size) - - # fallback for when there is no saved strategy, or saved strategy does not slice - mem_free_total = get_mem_free_total(q.device) - # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - - - def get_invokeai_attention_mem_efficient(self, q, k, v): - if q.device.type == 'cuda': - #print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) - return self.einsum_op_cuda(q, k, v) - - if q.device.type == 'mps' or q.device.type == 'cpu': - if self.mem_total_gb >= 32: - return self.einsum_op_mps_v1(q, k, v) - return self.einsum_op_mps_v2(q, k, v) - - # Smaller slices are faster due to L2/L3/SLC caches. - # Tested on i7 with 8MB L3 cache. - return self.einsum_op_tensor_mem(q, k, v, 32) - - - -def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): - if is_running_diffusers: - unet = model - unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) - else: - remove_attention_function(model) - - -def override_cross_attention(model, context: Context, is_running_diffusers = False): - """ - Inject attention parameters and functions into the passed in model to enable cross attention editing. - - :param model: The unet model to inject into. - :return: None - """ - - # adapted from init_attention_edit - device = context.arguments.edited_conditioning.device - - # urgh. should this be hardcoded? - max_length = 77 - # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length, dtype=torch_dtype(device)) - indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.arange(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: - if b0 < max_length: - if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): - # these tokens have not been edited - indices[b0:b1] = indices_target[a0:a1] - mask[b0:b1] = 1 - - context.cross_attention_mask = mask.to(device) - context.cross_attention_index_map = indices.to(device) - if is_running_diffusers: - unet = model - old_attn_processors = unet.attn_processors - if torch.backends.mps.is_available(): - # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS - unet.set_attn_processor(SwapCrossAttnProcessor()) - else: - # try to re-use an existing slice size - default_slice_size = 4 - slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) - unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) - return old_attn_processors - else: - context.register_cross_attention_modules(model) - inject_attention_function(model, context) - return None - - - - -def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: - from ldm.modules.attention import CrossAttention # avoid circular import - cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention - which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" - attention_module_tuples = [(name,module) for name, module in model.named_modules() if - isinstance(module, cross_attention_class) and which_attn in name] - cross_attention_modules_in_model_count = len(attention_module_tuples) - expected_count = 16 - if cross_attention_modules_in_model_count != expected_count: - # non-fatal error but .swap() won't work. - print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " + - f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + - f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + - f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " + - f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " + - f"work properly until it is fixed.") - return attention_module_tuples - - -def inject_attention_function(unet, context: Context): - # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - - def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): - - #memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() - - attention_slice = suggested_attention_slice - - if context.get_should_save_maps(module.identifier): - #print(module.identifier, "saving suggested_attention_slice of shape", - # suggested_attention_slice.shape, "dim", dim, "offset", offset) - slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice - context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size) - elif context.get_should_apply_saved_maps(module.identifier): - #print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) - saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) - - # slice may have been offloaded to CPU - saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) - - if context.is_tokens_cross_attention(module.identifier): - index_map = context.cross_attention_index_map - remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) - this_attention_slice = suggested_attention_slice - - mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device)) - saved_mask = mask - this_mask = 1 - mask - attention_slice = remapped_saved_attention_slice * saved_mask + \ - this_attention_slice * this_mask - else: - # just use everything - attention_slice = saved_attention_slice - - return attention_slice - - cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF) - for identifier, module in cross_attention_modules: - module.identifier = identifier - try: - module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter( - lambda module: context.get_slicing_strategy(identifier) - ) - except AttributeError as e: - if is_attribute_error_about(e, 'set_attention_slice_wrangler'): - print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO - else: - raise - - -def remove_attention_function(unet): - cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF) - for identifier, module in cross_attention_modules: - try: - # clear wrangler callback - module.set_attention_slice_wrangler(None) - module.set_slicing_strategy_getter(None) - except AttributeError as e: - if is_attribute_error_about(e, 'set_attention_slice_wrangler'): - print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") - else: - raise - - -def is_attribute_error_about(error: AttributeError, attribute: str): - if hasattr(error, 'name'): # Python 3.10 - return error.name == attribute - else: # Python 3.9 - return attribute in str(error) - - - -def get_mem_free_total(device): - #only on cuda - if not torch.cuda.is_available(): - return None - stats = torch.cuda.memory_stats(device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(device) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - return mem_free_total - - - -class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - InvokeAICrossAttentionMixin.__init__(self) - - def _attention(self, query, key, value, attention_mask=None): - #default_result = super()._attention(query, key, value) - if attention_mask is not None: - print(f"{type(self).__name__} ignoring passed-in attention_mask") - attention_result = self.get_invokeai_attention_mem_efficient(query, key, value) - - hidden_states = self.reshape_batch_dim_to_heads(attention_result) - return hidden_states - - - - - -## 🧨diffusers implementation follows - - -""" -# base implementation - -class CrossAttnProcessor: - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - -""" -from dataclasses import field, dataclass - -import torch - -from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor - - -@dataclass -class SwapCrossAttnContext: - modified_text_embeddings: torch.Tensor - index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt - mask: torch.Tensor # in the target space of the index_map - cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list) - - def __int__(self, - cac_types_to_do: [CrossAttentionType], - modified_text_embeddings: torch.Tensor, - index_map: torch.Tensor, - mask: torch.Tensor): - self.cross_attention_types_to_do = cac_types_to_do - self.modified_text_embeddings = modified_text_embeddings - self.index_map = index_map - self.mask = mask - - def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool: - return attn_type in self.cross_attention_types_to_do - - @classmethod - def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \ - -> tuple[torch.Tensor, torch.Tensor]: - - # mask=1 means use original prompt attention, mask=0 means use modified prompt attention - mask = torch.zeros(max_length) - indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.arange(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in edit_opcodes: - if b0 < max_length: - if name == "equal": - # these tokens remain the same as in the original prompt - indices[b0:b1] = indices_target[a0:a1] - mask[b0:b1] = 1 - - return mask, indices - - -class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): - - # TODO: dynamically pick slice size based on memory conditions - - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, - # kwargs - swap_cross_attn_context: SwapCrossAttnContext=None): - - attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS - - # if cross-attention control is not in play, just call through to the base implementation. - if attention_type is CrossAttentionType.SELF or \ - swap_cross_attn_context is None or \ - not swap_cross_attn_context.wants_cross_attention_control(attention_type): - #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") - return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) - #else: - # print(f"SwapCrossAttnContext for {attention_type} active") - - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask( - attention_mask=attention_mask, target_length=sequence_length, - batch_size=batch_size) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - original_text_embeddings = encoder_hidden_states - modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings - original_text_key = attn.to_k(original_text_embeddings) - modified_text_key = attn.to_k(modified_text_embeddings) - original_value = attn.to_v(original_text_embeddings) - modified_value = attn.to_v(modified_text_embeddings) - - original_text_key = attn.head_to_batch_dim(original_text_key) - modified_text_key = attn.head_to_batch_dim(modified_text_key) - original_value = attn.head_to_batch_dim(original_value) - modified_value = attn.head_to_batch_dim(modified_value) - - # compute slices and prepare output tensor - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - # do slices - for i in range(max(1,hidden_states.shape[0] // self.slice_size)): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - original_key_slice = original_text_key[start_idx:end_idx] - modified_key_slice = modified_text_key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice) - modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice) - - # because the prompt modifications may result in token sequences shifted forwards or backwards, - # the original attention probabilities must be remapped to account for token index changes in the - # modified prompt - remapped_original_attn_slice = torch.index_select(original_attn_slice, -1, - swap_cross_attn_context.index_map) - - # only some tokens taken from the original attention probabilities. this is controlled by the mask. - mask = swap_cross_attn_context.mask - inverse_mask = 1 - mask - attn_slice = \ - remapped_original_attn_slice * mask + \ - modified_attn_slice * inverse_mask - - del remapped_original_attn_slice, modified_attn_slice - - attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) - hidden_states[start_idx:end_idx] = attn_slice - - - # done - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): - - def __init__(self): - super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice - diff --git a/ldm/models/diffusion/cross_attention_map_saving.py b/ldm/models/diffusion/cross_attention_map_saving.py deleted file mode 100644 index 82983573d3..0000000000 --- a/ldm/models/diffusion/cross_attention_map_saving.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import PIL -import torch -from torchvision.transforms.functional import resize as tv_resize, InterpolationMode - -from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType - - -class AttentionMapSaver(): - - def __init__(self, token_ids: range, latents_shape: torch.Size): - self.token_ids = token_ids - self.latents_shape = latents_shape - #self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]]) - self.collated_maps = {} - - def clear_maps(self): - self.collated_maps = {} - - def add_attention_maps(self, maps: torch.Tensor, key: str): - """ - Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any). - :param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77). - :param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match. - :return: None - """ - key_and_size = f'{key}_{maps.shape[1]}' - - # extract desired tokens - maps = maps[:, :, self.token_ids] - - # merge attention heads to a single map per token - maps = torch.sum(maps, 0) - - # store - if key_and_size not in self.collated_maps: - self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu') - self.collated_maps[key_and_size] += maps.cpu() - - def write_maps_to_disk(self, path: str): - pil_image = self.get_stacked_maps_image() - pil_image.save(path, 'PNG') - - def get_stacked_maps_image(self) -> PIL.Image: - """ - Scale all collected attention maps to the same size, blend them together and return as an image. - :return: An image containing a vertical stack of blended attention maps, one for each requested token. - """ - num_tokens = len(self.token_ids) - if num_tokens == 0: - return None - - latents_height = self.latents_shape[0] - latents_width = self.latents_shape[1] - - merged = None - - for key, maps in self.collated_maps.items(): - - # maps has shape [(H*W), N] for N tokens - # but we want [N, H, W] - this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height)) - this_maps_height = int(float(latents_height) * this_scale_factor) - this_maps_width = int(float(latents_width) * this_scale_factor) - # and we need to do some dimension juggling - maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width]) - - # scale to output size if necessary - if this_scale_factor != 1: - maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC) - - # normalize - maps_min = torch.min(maps) - maps_range = torch.max(maps) - maps_min - #print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}") - maps_normalized = (maps - maps_min) / maps_range - # expand to (-0.1, 1.1) and clamp - maps_normalized_expanded = maps_normalized * 1.1 - 0.05 - maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1) - - # merge together, producing a vertical stack - maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width]) - - if merged is None: - merged = maps_stacked - else: - # screen blend - merged = 1 - (1 - maps_stacked)*(1 - merged) - - if merged is None: - return None - - merged_bytes = merged.mul(0xff).byte() - return PIL.Image.fromarray(merged_bytes.numpy(), mode='L') diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py deleted file mode 100644 index 304009c1d3..0000000000 --- a/ldm/models/diffusion/ddim.py +++ /dev/null @@ -1,111 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -from ldm.models.diffusion.sampler import Sampler -from ldm.modules.diffusionmodules.util import noise_like - -class DDIMSampler(Sampler): - def __init__(self, model, schedule='linear', device=None, **kwargs): - super().__init__(model,schedule,model.num_timesteps,device) - - self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, - model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) - - def prepare_to_sample(self, t_enc, **kwargs): - super().prepare_to_sample(t_enc, **kwargs) - - extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) - - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) - else: - self.invokeai_diffuser.restore_default_cross_attention() - - - # This is the central routine - @torch.no_grad() - def p_sample( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - step_count:int=1000, # total number of steps - **kwargs, - ): - b, *_, device = *x.shape, x.device - - if ( - unconditional_conditioning is None - or unconditional_guidance_scale == 1.0 - ): - # damian0815 would like to know when/if this code path is used - e_t = self.model.apply_model(x, t, c) - else: - # step_index counts in the opposite direction to index - step_index = step_count-(index+1) - e_t = self.invokeai_diffuser.do_diffusion_step( - x, t, - unconditional_conditioning, c, - unconditional_guidance_scale, - step_index=step_index - ) - if score_corrector is not None: - assert self.model.parameterization == 'eps' - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) - - alphas = ( - self.model.alphas_cumprod - if use_original_steps - else self.ddim_alphas - ) - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = ( - sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - ) - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0, None - diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py deleted file mode 100644 index 7c7ba9f5fd..0000000000 --- a/ldm/models/diffusion/ddpm.py +++ /dev/null @@ -1,2271 +0,0 @@ -""" -wild mixture of -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py -https://github.com/CompVis/taming-transformers --- merci -""" - -import torch - -import torch.nn as nn -import os -import numpy as np -import pytorch_lightning as pl -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange, repeat -from contextlib import contextmanager -from functools import partial -from tqdm import tqdm -from torchvision.utils import make_grid -from pytorch_lightning.utilities.distributed import rank_zero_only -from omegaconf import ListConfig -import urllib - -from ldm.modules.textual_inversion_manager import TextualInversionManager -from ldm.util import ( - log_txt_as_img, - exists, - default, - ismap, - isimage, - mean_flat, - count_params, - instantiate_from_config, -) -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import ( - normal_kl, - DiagonalGaussianDistribution, -) -from ldm.models.autoencoder import ( - VQModelInterface, - IdentityFirstStage, - AutoencoderKL, -) -from ldm.modules.diffusionmodules.util import ( - make_beta_schedule, - extract_into_tensor, - noise_like, -) -from ldm.models.diffusion.ddim import DDIMSampler - - -__conditioning_keys__ = { - 'concat': 'c_concat', - 'crossattn': 'c_crossattn', - 'adm': 'y', -} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def uniform_on_device(r1, r2, shape, device): - return (r1 - r2) * torch.rand(*shape, device=device) + r2 - - -class DDPM(pl.LightningModule): - # classic DDPM with Gaussian diffusion, in image space - def __init__( - self, - unet_config, - timesteps=1000, - beta_schedule='linear', - loss_type='l2', - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor='val/loss', - use_ema=True, - first_stage_key='image', - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0.0, - embedding_reg_weight=0.0, - v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1.0, - conditioning_key=None, - parameterization='eps', # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0.0, - ): - super().__init__() - assert parameterization in [ - 'eps', - 'x0', - ], 'currently only supporting "eps" and "x0"' - self.parameterization = parameterization - print( - f' | {self.__class__.__name__}: Running in {self.parameterization}-prediction mode' - ) - self.cond_stage_model = None - self.clip_denoised = clip_denoised - self.log_every_t = log_every_t - self.first_stage_key = first_stage_key - self.image_size = image_size # try conv? - self.channels = channels - self.use_positional_encodings = use_positional_encodings - self.model = DiffusionWrapper(unet_config, conditioning_key) - count_params(self.model, verbose=True) - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model) - print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.') - - self.use_scheduler = scheduler_config is not None - if self.use_scheduler: - self.scheduler_config = scheduler_config - - self.v_posterior = v_posterior - self.original_elbo_weight = original_elbo_weight - self.l_simple_weight = l_simple_weight - self.embedding_reg_weight = embedding_reg_weight - - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt( - ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet - ) - - self.register_schedule( - given_betas=given_betas, - beta_schedule=beta_schedule, - timesteps=timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - - self.loss_type = loss_type - - self.learn_logvar = learn_logvar - self.logvar = torch.full( - fill_value=logvar_init, size=(self.num_timesteps,) - ) - if self.learn_logvar: - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - - def register_schedule( - self, - given_betas=None, - beta_schedule='linear', - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - if exists(given_betas): - betas = given_betas - else: - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - alphas = 1.0 - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert ( - alphas_cumprod.shape[0] == self.num_timesteps - ), 'alphas have to be defined for each timestep' - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer( - 'alphas_cumprod_prev', to_torch(alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)) - ) - self.register_buffer( - 'sqrt_one_minus_alphas_cumprod', - to_torch(np.sqrt(1.0 - alphas_cumprod)), - ) - self.register_buffer( - 'log_one_minus_alphas_cumprod', - to_torch(np.log(1.0 - alphas_cumprod)), - ) - self.register_buffer( - 'sqrt_recip_alphas_cumprod', - to_torch(np.sqrt(1.0 / alphas_cumprod)), - ) - self.register_buffer( - 'sqrt_recipm1_alphas_cumprod', - to_torch(np.sqrt(1.0 / alphas_cumprod - 1)), - ) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * ( - 1.0 - alphas_cumprod_prev - ) / (1.0 - alphas_cumprod) + self.v_posterior * betas - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer( - 'posterior_variance', to_torch(posterior_variance) - ) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer( - 'posterior_log_variance_clipped', - to_torch(np.log(np.maximum(posterior_variance, 1e-20))), - ) - self.register_buffer( - 'posterior_mean_coef1', - to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) - ), - ) - self.register_buffer( - 'posterior_mean_coef2', - to_torch( - (1.0 - alphas_cumprod_prev) - * np.sqrt(alphas) - / (1.0 - alphas_cumprod) - ), - ) - - if self.parameterization == 'eps': - lvlb_weights = self.betas**2 / ( - 2 - * self.posterior_variance - * to_torch(alphas) - * (1 - self.alphas_cumprod) - ) - elif self.parameterization == 'x0': - lvlb_weights = ( - 0.5 - * np.sqrt(torch.Tensor(alphas_cumprod)) - / (2.0 * 1 - torch.Tensor(alphas_cumprod)) - ) - else: - raise NotImplementedError('mu not supported') - # TODO how to choose this term - lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) - assert not torch.isnan(self.lvlb_weights).all() - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f'{context}: Switched to EMA weights') - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f'{context}: Restored training weights') - - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location='cpu') - if 'state_dict' in list(sd.keys()): - sd = sd['state_dict'] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print('Deleting key {} from state_dict.'.format(k)) - del sd[k] - missing, unexpected = ( - self.load_state_dict(sd, strict=False) - if not only_model - else self.model.load_state_dict(sd, strict=False) - ) - print( - f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' - ) - if len(missing) > 0: - print(f'Missing Keys: {missing}') - if len(unexpected) > 0: - print(f'Unexpected Keys: {unexpected}') - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) - * x_start - ) - variance = extract_into_tensor( - 1.0 - self.alphas_cumprod, t, x_start.shape - ) - log_variance = extract_into_tensor( - self.log_one_minus_alphas_cumprod, t, x_start.shape - ) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) - * x_t - - extract_into_tensor( - self.sqrt_recipm1_alphas_cumprod, t, x_t.shape - ) - * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) - * x_start - + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) - * x_t - ) - posterior_variance = extract_into_tensor( - self.posterior_variance, t, x_t.shape - ) - posterior_log_variance_clipped = extract_into_tensor( - self.posterior_log_variance_clipped, t, x_t.shape - ) - return ( - posterior_mean, - posterior_variance, - posterior_log_variance_clipped, - ) - - def p_mean_variance(self, x, t, clip_denoised: bool): - model_out = self.model(x, t) - if self.parameterization == 'eps': - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == 'x0': - x_recon = model_out - if clip_denoised: - x_recon.clamp_(-1.0, 1.0) - - ( - model_mean, - posterior_variance, - posterior_log_variance, - ) = self.q_posterior(x_start=x_recon, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance( - x=x, t=t, clip_denoised=clip_denoised - ) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape( - b, *((1,) * (len(x.shape) - 1)) - ) - return ( - model_mean - + nonzero_mask * (0.5 * model_log_variance).exp() * noise - ) - - @torch.no_grad() - def p_sample_loop(self, shape, return_intermediates=False): - device = self.betas.device - b = shape[0] - img = torch.randn(shape, device=device) - intermediates = [img] - for i in tqdm( - reversed(range(0, self.num_timesteps)), - desc='Sampling t', - total=self.num_timesteps, - dynamic_ncols=True, - ): - img = self.p_sample( - img, - torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised, - ) - if i % self.log_every_t == 0 or i == self.num_timesteps - 1: - intermediates.append(img) - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, batch_size=16, return_intermediates=False): - image_size = self.image_size - channels = self.channels - return self.p_sample_loop( - (batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates, - ) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) - * x_start - + extract_into_tensor( - self.sqrt_one_minus_alphas_cumprod, t, x_start.shape - ) - * noise - ) - - def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': - loss = (target - pred).abs() - if mean: - loss = loss.mean() - elif self.loss_type == 'l2': - if mean: - loss = torch.nn.functional.mse_loss(target, pred) - else: - loss = torch.nn.functional.mse_loss( - target, pred, reduction='none' - ) - else: - raise NotImplementedError("unknown loss type '{loss_type}'") - - return loss - - def p_losses(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_out = self.model(x_noisy, t) - - loss_dict = {} - if self.parameterization == 'eps': - target = noise - elif self.parameterization == 'x0': - target = x_start - else: - raise NotImplementedError( - f'Paramterization {self.parameterization} not yet supported' - ) - - loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - - log_prefix = 'train' if self.training else 'val' - - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) - loss_simple = loss.mean() * self.l_simple_weight - - loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) - - loss = loss_simple + self.original_elbo_weight * loss_vlb - - loss_dict.update({f'{log_prefix}/loss': loss}) - - return loss, loss_dict - - def forward(self, x, *args, **kwargs): - # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size - # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - t = torch.randint( - 0, self.num_timesteps, (x.shape[0],), device=self.device - ).long() - return self.p_losses(x, t, *args, **kwargs) - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') - x = x.to(memory_format=torch.contiguous_format).float() - return x - - def shared_step(self, batch): - x = self.get_input(batch, self.first_stage_key) - loss, loss_dict = self(x) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) - - self.log_dict( - loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True - ) - - self.log( - 'global_step', - self.global_step, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log( - 'lr_abs', - lr, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - return loss - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - _, loss_dict_no_ema = self.shared_step(batch) - with self.ema_scope(): - _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = { - key + '_ema': loss_dict_ema[key] for key in loss_dict_ema - } - self.log_dict( - loss_dict_no_ema, - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - ) - self.log_dict( - loss_dict_ema, - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - ) - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - def _get_rows_from_list(self, samples): - n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - @torch.no_grad() - def log_images( - self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs - ): - log = dict() - x = self.get_input(batch, self.first_stage_key) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - x = x.to(self.device)[:N] - log['inputs'] = x - - # get diffusion row - diffusion_row = list() - x_start = x[:n_row] - - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(x_start) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - diffusion_row.append(x_noisy) - - log['diffusion_row'] = self._get_rows_from_list(diffusion_row) - - if sample: - # get denoise row - with self.ema_scope('Plotting'): - samples, denoise_row = self.sample( - batch_size=N, return_intermediates=True - ) - - log['samples'] = samples - log['denoise_row'] = self._get_rows_from_list(denoise_row) - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.learn_logvar: - params = params + [self.logvar] - opt = torch.optim.AdamW(params, lr=lr) - return opt - - -class LatentDiffusion(DDPM): - """main class""" - - def __init__( - self, - first_stage_config, - cond_stage_config, - personalization_config, - num_timesteps_cond=None, - cond_stage_key='image', - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - *args, - **kwargs, - ): - - self.num_timesteps_cond = default(num_timesteps_cond, 1) - self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs['timesteps'] - # for backwards compatibility after implementation of DiffusionWrapper - if conditioning_key is None: - conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__': - conditioning_key = None - ckpt_path = kwargs.pop('ckpt_path', None) - ignore_keys = kwargs.pop('ignore_keys', []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) - self.concat_mode = concat_mode - self.cond_stage_trainable = cond_stage_trainable - self.cond_stage_key = cond_stage_key - - try: - self.num_downs = ( - len(first_stage_config.params.ddconfig.ch_mult) - 1 - ) - except: - self.num_downs = 0 - if not scale_by_std: - self.scale_factor = scale_factor - else: - self.register_buffer('scale_factor', torch.tensor(scale_factor)) - self.instantiate_first_stage(first_stage_config) - self.instantiate_cond_stage(cond_stage_config) - - self.cond_stage_forward = cond_stage_forward - self.clip_denoised = False - self.bbox_tokenizer = None - - self.restarted_from_ckpt = False - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys) - self.restarted_from_ckpt = True - - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - - self.model.eval() - self.model.train = disabled_train - for param in self.model.parameters(): - param.requires_grad = False - - self.embedding_manager = self.instantiate_embedding_manager( - personalization_config, self.cond_stage_model - ) - self.textual_inversion_manager = TextualInversionManager( - tokenizer = self.cond_stage_model.tokenizer, - text_encoder = self.cond_stage_model.transformer, - full_precision = True - ) - # this circular component dependency is gross and bad, needs to be rethought - self.cond_stage_model.set_textual_inversion_manager(self.textual_inversion_manager) - - self.emb_ckpt_counter = 0 - - # if self.embedding_manager.is_clip: - # self.cond_stage_model.update_embedding_func(self.embedding_manager) - - for param in self.embedding_manager.embedding_parameters(): - param.requires_grad = True - - def make_cond_schedule( - self, - ): - self.cond_ids = torch.full( - size=(self.num_timesteps,), - fill_value=self.num_timesteps - 1, - dtype=torch.long, - ) - ids = torch.round( - torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) - ).long() - self.cond_ids[: self.num_timesteps_cond] = ids - - @rank_zero_only - @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None): - # only for very first batch - if ( - self.scale_by_std - and self.current_epoch == 0 - and self.global_step == 0 - and batch_idx == 0 - and not self.restarted_from_ckpt - ): - assert ( - self.scale_factor == 1.0 - ), 'rather not use custom rescaling and std-rescaling simultaneously' - # set rescale weight to 1./std of encodings - print('### USING STD-RESCALING ###') - x = super().get_input(batch, self.first_stage_key) - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - del self.scale_factor - self.register_buffer('scale_factor', 1.0 / z.flatten().std()) - print(f'setting self.scale_factor to {self.scale_factor}') - print('### USING STD-RESCALING ###') - - def register_schedule( - self, - given_betas=None, - beta_schedule='linear', - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - super().register_schedule( - given_betas, - beta_schedule, - timesteps, - linear_start, - linear_end, - cosine_s, - ) - - self.shorten_cond_schedule = self.num_timesteps_cond > 1 - if self.shorten_cond_schedule: - self.make_cond_schedule() - - def instantiate_first_stage(self, config): - model = instantiate_from_config(config) - self.first_stage_model = model.eval() - self.first_stage_model.train = disabled_train - for param in self.first_stage_model.parameters(): - param.requires_grad = False - - def instantiate_cond_stage(self, config): - if not self.cond_stage_trainable: - if config == '__is_first_stage__': - print('Using first stage also as cond stage.') - self.cond_stage_model = self.first_stage_model - elif config == '__is_unconditional__': - print( - f'Training {self.__class__.__name__} as an unconditional model.' - ) - self.cond_stage_model = None - # self.be_unconditional = True - else: - model = instantiate_from_config(config) - self.cond_stage_model = model.eval() - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - else: - assert config != '__is_first_stage__' - assert config != '__is_unconditional__' - try: - model = instantiate_from_config(config) - except urllib.error.URLError: - raise SystemExit( - "* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine." - ) - self.cond_stage_model = model - - def instantiate_embedding_manager(self, config, embedder): - model = instantiate_from_config(config, embedder=embedder) - - if config.params.get( - 'embedding_manager_ckpt', None - ): # do not load if missing OR empty string - model.load(config.params.embedding_manager_ckpt) - - return model - - def _get_denoise_row_from_list( - self, samples, desc='', force_no_decoder_quantization=False - ): - denoise_row = [] - for zd in tqdm(samples, desc=desc): - denoise_row.append( - self.decode_first_stage( - zd.to(self.device), - force_not_quantize=force_no_decoder_quantization, - ) - ) - n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - def get_first_stage_encoding(self, encoder_posterior): - if isinstance(encoder_posterior, DiagonalGaussianDistribution): - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError( - f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" - ) - return self.scale_factor * z - - def get_learned_conditioning(self, c, **kwargs): - if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable( - self.cond_stage_model.encode - ): - c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager,**kwargs - ) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - else: - c = self.cond_stage_model(c, **kwargs) - else: - assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs) - return c - - def meshgrid(self, h, w): - y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) - x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) - - arr = torch.cat([y, x], dim=-1) - return arr - - def delta_border(self, h, w): - """ - :param h: height - :param w: width - :return: normalized distance to image border, - wtith min distance = 0 at border and max dist = 0.5 at image center - """ - lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) - arr = self.meshgrid(h, w) / lower_right_corner - dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] - dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] - edge_dist = torch.min( - torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1 - )[0] - return edge_dist - - def get_weighting(self, h, w, Ly, Lx, device): - weighting = self.delta_border(h, w) - weighting = torch.clip( - weighting, - self.split_input_params['clip_min_weight'], - self.split_input_params['clip_max_weight'], - ) - weighting = ( - weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) - ) - - if self.split_input_params['tie_braker']: - L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip( - L_weighting, - self.split_input_params['clip_min_tie_weight'], - self.split_input_params['clip_max_tie_weight'], - ) - - L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) - weighting = weighting * L_weighting - return weighting - - def get_fold_unfold( - self, x, kernel_size, stride, uf=1, df=1 - ): # todo load once not every time, shorten code - """ - :param x: img of size (bs, c, h, w) - :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) - """ - bs, nc, h, w = x.shape - - # number of crops in image - Ly = (h - kernel_size[0]) // stride[0] + 1 - Lx = (w - kernel_size[1]) // stride[1] + 1 - - if uf == 1 and df == 1: - fold_params = dict( - kernel_size=kernel_size, dilation=1, padding=0, stride=stride - ) - unfold = torch.nn.Unfold(**fold_params) - - fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) - - weighting = self.get_weighting( - kernel_size[0], kernel_size[1], Ly, Lx, x.device - ).to(x.dtype) - normalization = fold(weighting).view( - 1, 1, h, w - ) # normalizes the overlap - weighting = weighting.view( - (1, 1, kernel_size[0], kernel_size[1], Ly * Lx) - ) - - elif uf > 1 and df == 1: - fold_params = dict( - kernel_size=kernel_size, dilation=1, padding=0, stride=stride - ) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict( - kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, - padding=0, - stride=(stride[0] * uf, stride[1] * uf), - ) - fold = torch.nn.Fold( - output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 - ) - - weighting = self.get_weighting( - kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device - ).to(x.dtype) - normalization = fold(weighting).view( - 1, 1, h * uf, w * uf - ) # normalizes the overlap - weighting = weighting.view( - (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) - ) - - elif df > 1 and uf == 1: - fold_params = dict( - kernel_size=kernel_size, dilation=1, padding=0, stride=stride - ) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict( - kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, - padding=0, - stride=(stride[0] // df, stride[1] // df), - ) - fold = torch.nn.Fold( - output_size=(x.shape[2] // df, x.shape[3] // df), - **fold_params2, - ) - - weighting = self.get_weighting( - kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device - ).to(x.dtype) - normalization = fold(weighting).view( - 1, 1, h // df, w // df - ) # normalizes the overlap - weighting = weighting.view( - (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) - ) - - else: - raise NotImplementedError - - return fold, unfold, normalization, weighting - - @torch.no_grad() - def get_input( - self, - batch, - k, - return_first_stage_outputs=False, - force_c_encode=False, - cond_key=None, - return_original_cond=False, - bs=None, - ): - x = super().get_input(batch, k) - if bs is not None: - x = x[:bs] - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - - if self.model.conditioning_key is not None: - if cond_key is None: - cond_key = self.cond_stage_key - if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox']: - xc = batch[cond_key] - elif cond_key == 'class_label': - xc = batch - else: - xc = super().get_input(batch, cond_key).to(self.device) - else: - xc = x - if not self.cond_stage_trainable or force_c_encode: - if isinstance(xc, dict) or isinstance(xc, list): - # import pudb; pudb.set_trace() - c = self.get_learned_conditioning(xc) - else: - c = self.get_learned_conditioning(xc.to(self.device)) - else: - c = xc - if bs is not None: - c = c[:bs] - - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} - - else: - c = None - xc = None - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - c = {'pos_x': pos_x, 'pos_y': pos_y} - out = [z, c] - if return_first_stage_outputs: - xrec = self.decode_first_stage(z) - out.extend([x, xrec]) - if return_original_cond: - out.append(xc) - return out - - @torch.no_grad() - def decode_first_stage( - self, z, predict_cids=False, force_not_quantize=False - ): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry( - z, shape=None - ) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1.0 / self.scale_factor * z - - if hasattr(self, 'split_input_params'): - if self.split_input_params['patch_distributed_vq']: - ks = self.split_input_params['ks'] # eg. (128, 128) - stride = self.split_input_params['stride'] # eg. (64, 64) - uf = self.split_input_params['vqf'] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print('reducing Kernel') - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print('reducing stride') - - fold, unfold, normalization, weighting = self.get_fold_unfold( - z, ks, stride, uf=uf - ) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [ - self.first_stage_model.decode( - z[:, :, :, :, i], - force_not_quantize=predict_cids - or force_not_quantize, - ) - for i in range(z.shape[-1]) - ] - else: - - output_list = [ - self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1]) - ] - - o = torch.stack( - output_list, axis=-1 - ) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view( - (o.shape[0], -1, o.shape[-1]) - ) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, - force_not_quantize=predict_cids or force_not_quantize, - ) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, force_not_quantize=predict_cids or force_not_quantize - ) - else: - return self.first_stage_model.decode(z) - - # same as above but without decorator - def differentiable_decode_first_stage( - self, z, predict_cids=False, force_not_quantize=False - ): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry( - z, shape=None - ) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1.0 / self.scale_factor * z - - if hasattr(self, 'split_input_params'): - if self.split_input_params['patch_distributed_vq']: - ks = self.split_input_params['ks'] # eg. (128, 128) - stride = self.split_input_params['stride'] # eg. (64, 64) - uf = self.split_input_params['vqf'] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print('reducing Kernel') - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print('reducing stride') - - fold, unfold, normalization, weighting = self.get_fold_unfold( - z, ks, stride, uf=uf - ) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [ - self.first_stage_model.decode( - z[:, :, :, :, i], - force_not_quantize=predict_cids - or force_not_quantize, - ) - for i in range(z.shape[-1]) - ] - else: - - output_list = [ - self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1]) - ] - - o = torch.stack( - output_list, axis=-1 - ) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view( - (o.shape[0], -1, o.shape[-1]) - ) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, - force_not_quantize=predict_cids or force_not_quantize, - ) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, force_not_quantize=predict_cids or force_not_quantize - ) - else: - return self.first_stage_model.decode(z) - - @torch.no_grad() - def encode_first_stage(self, x): - if hasattr(self, 'split_input_params'): - if self.split_input_params['patch_distributed_vq']: - ks = self.split_input_params['ks'] # eg. (128, 128) - stride = self.split_input_params['stride'] # eg. (64, 64) - df = self.split_input_params['vqf'] - self.split_input_params['original_image_size'] = x.shape[-2:] - bs, nc, h, w = x.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print('reducing Kernel') - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print('reducing stride') - - fold, unfold, normalization, weighting = self.get_fold_unfold( - x, ks, stride, df=df - ) - z = unfold(x) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - - output_list = [ - self.first_stage_model.encode(z[:, :, :, :, i]) - for i in range(z.shape[-1]) - ] - - o = torch.stack(output_list, axis=-1) - o = o * weighting - - # Reverse reshape to img shape - o = o.view( - (o.shape[0], -1, o.shape[-1]) - ) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization - return decoded - - else: - return self.first_stage_model.encode(x) - else: - return self.first_stage_model.encode(x) - - def shared_step(self, batch, **kwargs): - x, c = self.get_input(batch, self.first_stage_key) - loss = self(x, c) - return loss - - def forward(self, x, c, *args, **kwargs): - t = torch.randint( - 0, self.num_timesteps, (x.shape[0],), device=self.device - ).long() - if self.model.conditioning_key is not None: - assert c is not None - if self.cond_stage_trainable: - c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option - tc = self.cond_ids[t].to(self.device) - c = self.q_sample( - x_start=c, t=tc, noise=torch.randn_like(c.float()) - ) - - return self.p_losses(x, c, t, *args, **kwargs) - - def _rescale_annotations( - self, bboxes, crop_coordinates - ): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - - def apply_model(self, x_noisy, t, cond, return_ids=False): - - if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict - pass - else: - if not isinstance(cond, list): - cond = [cond] - key = ( - 'c_concat' - if self.model.conditioning_key == 'concat' - else 'c_crossattn' - ) - cond = {key: cond} - - if hasattr(self, 'split_input_params'): - assert ( - len(cond) == 1 - ) # todo can only deal with one conditioning atm - assert not return_ids - ks = self.split_input_params['ks'] # eg. (128, 128) - stride = self.split_input_params['stride'] # eg. (64, 64) - - h, w = x_noisy.shape[-2:] - - fold, unfold, normalization, weighting = self.get_fold_unfold( - x_noisy, ks, stride - ) - - z = unfold(x_noisy) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] - - if ( - self.cond_stage_key - in ['image', 'LR_image', 'segmentation', 'bbox_img'] - and self.model.conditioning_key - ): # todo check for completeness - c_key = next(iter(cond.keys())) # get key - c = next(iter(cond.values())) # get value - assert ( - len(c) == 1 - ) # todo extend to list with more than one elem - c = c[0] # get element - - c = unfold(c) - c = c.view( - (c.shape[0], -1, ks[0], ks[1], c.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - - cond_list = [ - {c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1]) - ] - - elif self.cond_stage_key == 'coordinates_bbox': - assert ( - 'original_image_size' in self.split_input_params - ), 'BoudingBoxRescaling is missing original_image_size' - - # assuming padding of unfold is always 0 and its dilation is always 1 - n_patches_per_row = int((w - ks[0]) / stride[0] + 1) - full_img_h, full_img_w = self.split_input_params[ - 'original_image_size' - ] - # as we are operating on latents, we need the factor from the original image size to the - # spatial latent size to properly rescale the crops for regenerating the bbox annotations - num_downs = self.first_stage_model.encoder.num_resolutions - 1 - rescale_latent = 2 ** (num_downs) - - # get top left positions of patches as conforming for the bbbox tokenizer, therefore we - # need to rescale the tl patch coordinates to be in between (0,1) - tl_patch_coordinates = [ - ( - rescale_latent - * stride[0] - * (patch_nr % n_patches_per_row) - / full_img_w, - rescale_latent - * stride[1] - * (patch_nr // n_patches_per_row) - / full_img_h, - ) - for patch_nr in range(z.shape[-1]) - ] - - # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) - patch_limits = [ - ( - x_tl, - y_tl, - rescale_latent * ks[0] / full_img_w, - rescale_latent * ks[1] / full_img_h, - ) - for x_tl, y_tl in tl_patch_coordinates - ] - # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] - - # tokenize crop coordinates for the bounding boxes of the respective patches - patch_limits_tknzd = [ - torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[ - None - ].to(self.device) - for bbox in patch_limits - ] # list of length l with tensors of shape (1, 2) - print(patch_limits_tknzd[0].shape) - # cut tknzd crop position from conditioning - assert isinstance( - cond, dict - ), 'cond must be dict to be fed into model' - cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) - print(cut_cond.shape) - - adapted_cond = torch.stack( - [ - torch.cat([cut_cond, p], dim=1) - for p in patch_limits_tknzd - ] - ) - adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') - print(adapted_cond.shape) - adapted_cond = self.get_learned_conditioning(adapted_cond) - print(adapted_cond.shape) - adapted_cond = rearrange( - adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1] - ) - print(adapted_cond.shape) - - cond_list = [{'c_crossattn': [e]} for e in adapted_cond] - - else: - cond_list = [ - cond for i in range(z.shape[-1]) - ] # Todo make this more efficient - - # apply model by loop over crops - output_list = [ - self.model(z_list[i], t, **cond_list[i]) - for i in range(z.shape[-1]) - ] - assert not isinstance( - output_list[0], tuple - ) # todo cant deal with multiple model outputs check this never happens - - o = torch.stack(output_list, axis=-1) - o = o * weighting - # Reverse reshape to img shape - o = o.view( - (o.shape[0], -1, o.shape[-1]) - ) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - x_recon = fold(o) / normalization - - else: - x_recon = self.model(x_noisy, t, **cond) - - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon - - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) - * x_t - - pred_xstart - ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - - def _prior_bpd(self, x_start): - """ - Get the prior KL term for the variational lower-bound, measured in - bits-per-dim. - This term can't be optimized, as it only depends on the encoder. - :param x_start: the [N x C x ...] tensor of inputs. - :return: a batch of [N] KL values (in bits), one per batch element. - """ - batch_size = x_start.shape[0] - t = torch.tensor( - [self.num_timesteps - 1] * batch_size, device=x_start.device - ) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl( - mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 - ) - return mean_flat(kl_prior) / np.log(2.0) - - def p_losses(self, x_start, cond, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_output = self.apply_model(x_noisy, t, cond) - - loss_dict = {} - prefix = 'train' if self.training else 'val' - - if self.parameterization == 'x0': - target = x_start - elif self.parameterization == 'eps': - target = noise - else: - raise NotImplementedError() - - loss_simple = self.get_loss(model_output, target, mean=False).mean( - [1, 2, 3] - ) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) - - logvar_t = self.logvar[t.item()].to(self.device) - loss = loss_simple / torch.exp(logvar_t) + logvar_t - # loss = loss_simple / torch.exp(self.logvar) + self.logvar - if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) - - loss = self.l_simple_weight * loss.mean() - - loss_vlb = self.get_loss(model_output, target, mean=False).mean( - dim=(1, 2, 3) - ) - loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += self.original_elbo_weight * loss_vlb - loss_dict.update({f'{prefix}/loss': loss}) - - if self.embedding_reg_weight > 0: - loss_embedding_reg = ( - self.embedding_manager.embedding_to_coarse_loss().mean() - ) - - loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg}) - - loss += self.embedding_reg_weight * loss_embedding_reg - loss_dict.update({f'{prefix}/loss': loss}) - - return loss, loss_dict - - def p_mean_variance( - self, - x, - c, - t, - clip_denoised: bool, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - score_corrector=None, - corrector_kwargs=None, - ): - t_in = t - model_out = self.apply_model( - x, t_in, c, return_ids=return_codebook_ids - ) - - if score_corrector is not None: - assert self.parameterization == 'eps' - model_out = score_corrector.modify_score( - self, model_out, x, t, c, **corrector_kwargs - ) - - if return_codebook_ids: - model_out, logits = model_out - - if self.parameterization == 'eps': - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == 'x0': - x_recon = model_out - else: - raise NotImplementedError() - - if clip_denoised: - x_recon.clamp_(-1.0, 1.0) - if quantize_denoised: - x_recon, _, [_, _, indices] = self.first_stage_model.quantize( - x_recon - ) - ( - model_mean, - posterior_variance, - posterior_log_variance, - ) = self.q_posterior(x_start=x_recon, x_t=x, t=t) - if return_codebook_ids: - return ( - model_mean, - posterior_variance, - posterior_log_variance, - logits, - ) - elif return_x0: - return ( - model_mean, - posterior_variance, - posterior_log_variance, - x_recon, - ) - else: - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample( - self, - x, - c, - t, - clip_denoised=False, - repeat_noise=False, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - ): - b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance( - x=x, - c=c, - t=t, - clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - ) - if return_codebook_ids: - raise DeprecationWarning('Support dropped.') - model_mean, _, model_log_variance, logits = outputs - elif return_x0: - model_mean, _, model_log_variance, x0 = outputs - else: - model_mean, _, model_log_variance = outputs - - noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape( - b, *((1,) * (len(x.shape) - 1)) - ) - - if return_codebook_ids: - return model_mean + nonzero_mask * ( - 0.5 * model_log_variance - ).exp() * noise, logits.argmax(dim=1) - if return_x0: - return ( - model_mean - + nonzero_mask * (0.5 * model_log_variance).exp() * noise, - x0, - ) - else: - return ( - model_mean - + nonzero_mask * (0.5 * model_log_variance).exp() * noise - ) - - @torch.no_grad() - def progressive_denoising( - self, - cond, - shape, - verbose=True, - callback=None, - quantize_denoised=False, - img_callback=None, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - batch_size=None, - x_T=None, - start_T=None, - log_every_t=None, - ): - if not log_every_t: - log_every_t = self.log_every_t - timesteps = self.num_timesteps - if batch_size is not None: - b = batch_size if batch_size is not None else shape[0] - shape = [batch_size] + list(shape) - else: - b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T - intermediates = [] - if cond is not None: - if isinstance(cond, dict): - cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) - for key in cond - } - else: - cond = ( - [c[:batch_size] for c in cond] - if isinstance(cond, list) - else cond[:batch_size] - ) - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = ( - tqdm( - reversed(range(0, timesteps)), - desc='Progressive Generation', - total=timesteps, - ) - if verbose - else reversed(range(0, timesteps)) - ) - if type(temperature) == float: - temperature = [temperature] * timesteps - - for i in iterator: - ts = torch.full((b,), i, device=self.device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample( - x_start=cond, t=tc, noise=torch.randn_like(cond) - ) - - img, x0_partial = self.p_sample( - img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - return_x0=True, - temperature=temperature[i], - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - ) - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1.0 - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(x0_partial) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - return img, intermediates - - @torch.no_grad() - def p_sample_loop( - self, - cond, - shape, - return_intermediates=False, - x_T=None, - verbose=True, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - start_T=None, - log_every_t=None, - ): - - if not log_every_t: - log_every_t = self.log_every_t - device = self.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - intermediates = [img] - if timesteps is None: - timesteps = self.num_timesteps - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = ( - tqdm( - reversed(range(0, timesteps)), - desc='Sampling t', - total=timesteps, - ) - if verbose - else reversed(range(0, timesteps)) - ) - - if mask is not None: - assert x0 is not None - assert ( - x0.shape[2:3] == mask.shape[2:3] - ) # spatial size has to match - - for i in iterator: - ts = torch.full((b,), i, device=device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample( - x_start=cond, t=tc, noise=torch.randn_like(cond) - ) - - img = self.p_sample( - img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - ) - if mask is not None: - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1.0 - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(img) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample( - self, - cond, - batch_size=16, - return_intermediates=False, - x_T=None, - verbose=True, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - shape=None, - **kwargs, - ): - if shape is None: - shape = ( - batch_size, - self.channels, - self.image_size, - self.image_size, - ) - if cond is not None: - if isinstance(cond, dict): - cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) - for key in cond - } - else: - cond = ( - [c[:batch_size] for c in cond] - if isinstance(cond, list) - else cond[:batch_size] - ) - return self.p_sample_loop( - cond, - shape, - return_intermediates=return_intermediates, - x_T=x_T, - verbose=verbose, - timesteps=timesteps, - quantize_denoised=quantize_denoised, - mask=mask, - x0=x0, - ) - - @torch.no_grad() - def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): - - if ddim: - ddim_sampler = DDIMSampler(self) - shape = (self.channels, self.image_size, self.image_size) - samples, intermediates = ddim_sampler.sample( - ddim_steps, batch_size, shape, cond, verbose=False, **kwargs - ) - - else: - samples, intermediates = self.sample( - cond=cond, - batch_size=batch_size, - return_intermediates=True, - **kwargs, - ) - - return samples, intermediates - - @torch.no_grad() - def get_unconditional_conditioning(self, batch_size, null_label=None): - if null_label is not None: - xc = null_label - if isinstance(xc, ListConfig): - xc = list(xc) - if isinstance(xc, dict) or isinstance(xc, list): - c = self.get_learned_conditioning(xc) - else: - if hasattr(xc, "to"): - xc = xc.to(self.device) - c = self.get_learned_conditioning(xc) - else: - # todo: get null label from cond_stage_model - raise NotImplementedError() - c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) - return c - - @torch.no_grad() - def log_images( - self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=50, - ddim_eta=1.0, - return_keys=None, - quantize_denoised=True, - inpaint=False, - plot_denoise_rows=False, - plot_progressive_rows=False, - plot_diffusion_rows=False, - **kwargs, - ): - - use_ddim = ddim_steps is not None - - log = dict() - z, c, x, xrec, xc = self.get_input( - batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N, - ) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - log['inputs'] = x - log['reconstruction'] = xrec - if self.model.conditioning_key is not None: - if hasattr(self.cond_stage_model, 'decode'): - xc = self.cond_stage_model.decode(c) - log['conditioning'] = xc - elif self.cond_stage_key in ['caption']: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch['caption']) - log['conditioning'] = xc - elif self.cond_stage_key == 'class_label': - xc = log_txt_as_img( - (x.shape[2], x.shape[3]), batch['human_label'] - ) - log['conditioning'] = xc - elif isimage(xc): - log['conditioning'] = xc - if ismap(xc): - log['original_conditioning'] = self.to_rgb(xc) - - if plot_diffusion_rows: - # get diffusion row - diffusion_row = list() - z_start = z[:n_row] - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(z_start) - z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - diffusion_row.append(self.decode_first_stage(z_noisy)) - - diffusion_row = torch.stack( - diffusion_row - ) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange( - diffusion_grid, 'b n c h w -> (b n) c h w' - ) - diffusion_grid = make_grid( - diffusion_grid, nrow=diffusion_row.shape[0] - ) - log['diffusion_row'] = diffusion_grid - - if sample: - # get denoise row - with self.ema_scope('Plotting'): - samples, z_denoise_row = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - ) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) - x_samples = self.decode_first_stage(samples) - log['samples'] = x_samples - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log['denoise_row'] = denoise_grid - - uc = self.get_learned_conditioning(len(c) * ['']) - sample_scaled, _ = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - unconditional_guidance_scale=5.0, - unconditional_conditioning=uc, - ) - log['samples_scaled'] = self.decode_first_stage(sample_scaled) - - if ( - quantize_denoised - and not isinstance(self.first_stage_model, AutoencoderKL) - and not isinstance(self.first_stage_model, IdentityFirstStage) - ): - # also display when quantizing x0 while sampling - with self.ema_scope('Plotting Quantized Denoised'): - samples, z_denoise_row = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - quantize_denoised=True, - ) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, - # quantize_denoised=True) - x_samples = self.decode_first_stage(samples.to(self.device)) - log['samples_x0_quantized'] = x_samples - - if inpaint: - # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] - mask = torch.ones(N, h, w).to(self.device) - # zeros will be filled in - mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 - mask = mask[:, None, ...] - with self.ema_scope('Plotting Inpaint'): - - samples, _ = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask, - ) - x_samples = self.decode_first_stage(samples.to(self.device)) - log['samples_inpainting'] = x_samples - log['mask'] = mask - - # outpaint - with self.ema_scope('Plotting Outpaint'): - samples, _ = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask, - ) - x_samples = self.decode_first_stage(samples.to(self.device)) - log['samples_outpainting'] = x_samples - - if plot_progressive_rows: - with self.ema_scope('Plotting Progressives'): - img, progressives = self.progressive_denoising( - c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N, - ) - prog_row = self._get_denoise_row_from_list( - progressives, desc='Progressive Generation' - ) - log['progressive_row'] = prog_row - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - - if self.embedding_manager is not None: - params = list(self.embedding_manager.embedding_parameters()) - # params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters()) - else: - params = list(self.model.parameters()) - if self.cond_stage_trainable: - print( - f'{self.__class__.__name__}: Also optimizing conditioner params!' - ) - params = params + list(self.cond_stage_model.parameters()) - if self.learn_logvar: - print('Diffusion model optimizing logvar') - params.append(self.logvar) - opt = torch.optim.AdamW(params, lr=lr) - if self.use_scheduler: - assert 'target' in self.scheduler_config - scheduler = instantiate_from_config(self.scheduler_config) - - print('Setting up LambdaLR scheduler...') - scheduler = [ - { - 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1, - } - ] - return [opt], scheduler - return opt - - @torch.no_grad() - def to_rgb(self, x): - x = x.float() - if not hasattr(self, 'colorize'): - self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) - x = nn.functional.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - @rank_zero_only - def on_save_checkpoint(self, checkpoint): - checkpoint.clear() - - if os.path.isdir(self.trainer.checkpoint_callback.dirpath): - self.embedding_manager.save( - os.path.join( - self.trainer.checkpoint_callback.dirpath, 'embeddings.pt' - ) - ) - - if (self.global_step - self.emb_ckpt_counter) > 500: - self.embedding_manager.save( - os.path.join( - self.trainer.checkpoint_callback.dirpath, - f'embeddings_gs-{self.global_step}.pt', - ) - ) - - self.emb_ckpt_counter += 500 - - -class DiffusionWrapper(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): - super().__init__() - self.diffusion_model = instantiate_from_config(diff_model_config) - self.conditioning_key = conditioning_key - assert self.conditioning_key in [ - None, - 'concat', - 'crossattn', - 'hybrid', - 'adm', - ] - - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): - if self.conditioning_key is None: - out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': - cc = torch.cat(c_crossattn, 1) - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'adm': - cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) - else: - raise NotImplementedError() - - return out - - -class Layout2ImgDiffusion(LatentDiffusion): - # TODO: move all layout-specific hacks to this class - def __init__(self, cond_stage_key, *args, **kwargs): - assert ( - cond_stage_key == 'coordinates_bbox' - ), 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) - - def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) - - key = 'train' if self.training else 'validation' - dset = self.trainer.datamodule.datasets[key] - mapper = dset.conditional_builders[self.cond_stage_key] - - bbox_imgs = [] - map_fn = lambda catno: dset.get_textual_label( - dset.get_category_id(catno) - ) - for tknzd_bbox in batch[self.cond_stage_key][:N]: - bboximg = mapper.plot( - tknzd_bbox.detach().cpu(), map_fn, (256, 256) - ) - bbox_imgs.append(bboximg) - - cond_img = torch.stack(bbox_imgs, dim=0) - logs['bbox_image'] = cond_img - return logs - -class LatentInpaintDiffusion(LatentDiffusion): - def __init__( - self, - concat_keys=("mask", "masked_image"), - masked_image_key="masked_image", - finetune_keys=None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.masked_image_key = masked_image_key - assert self.masked_image_key in concat_keys - self.concat_keys = concat_keys - - - @torch.no_grad() - def get_input( - self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False - ): - # note: restricted to non-trainable encoders currently - assert ( - not self.cond_stage_trainable - ), "trainable cond stages not yet supported for inpainting" - z, c, x, xrec, xc = super().get_input( - batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs, - ) - - assert exists(self.concat_keys) - c_cat = list() - for ck in self.concat_keys: - cc = ( - rearrange(batch[ck], "b h w c -> b c h w") - .to(memory_format=torch.contiguous_format) - .float() - ) - if bs is not None: - cc = cc[:bs] - cc = cc.to(self.device) - bchw = z.shape - if ck != self.masked_image_key: - cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) - else: - cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) - c_cat.append(cc) - c_cat = torch.cat(c_cat, dim=1) - all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} - if return_first_stage_outputs: - return z, all_conds, x, xrec, xc - return z, all_conds diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py deleted file mode 100644 index f98ca8de21..0000000000 --- a/ldm/models/diffusion/ksampler.py +++ /dev/null @@ -1,312 +0,0 @@ -"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" - -import k_diffusion as K -import torch -from torch import nn - -from .cross_attention_map_saving import AttentionMapSaver -from .sampler import Sampler -from .shared_invokeai_diffusion import InvokeAIDiffuserComponent - - -# at this threshold, the scheduler will stop using the Karras -# noise schedule and start using the model's schedule -STEP_THRESHOLD = 30 - -def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): - if threshold <= 0.0: - return result - maxval = 0.0 + torch.max(result).cpu().numpy() - minval = 0.0 + torch.min(result).cpu().numpy() - if maxval < threshold and minval > -threshold: - return result - if maxval > threshold: - maxval = min(max(1, scale*maxval), threshold) - if minval < -threshold: - minval = max(min(-1, scale*minval), -threshold) - return torch.clamp(result, min=minval, max=maxval) - - -class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0): - super().__init__() - self.inner_model = model - self.threshold = threshold - self.warmup_max = warmup - self.warmup = max(warmup / 10, 1) - self.invokeai_diffuser = InvokeAIDiffuserComponent(model, - model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) - - - def prepare_to_sample(self, t_enc, **kwargs): - - extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc) - else: - self.invokeai_diffuser.restore_default_cross_attention() - - - def forward(self, x, sigma, uncond, cond, cond_scale): - next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) - if self.warmup < self.warmup_max: - thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) - self.warmup += 1 - else: - thresh = self.threshold - if thresh > self.threshold: - thresh = self.threshold - return cfg_apply_threshold(next_x, thresh) - -class KSampler(Sampler): - def __init__(self, model, schedule='lms', device=None, **kwargs): - denoiser = K.external.CompVisDenoiser(model) - super().__init__( - denoiser, - schedule, - steps=model.num_timesteps, - ) - self.sigmas = None - self.ds = None - self.s_in = None - self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD) - if self.karras_max is None: - self.karras_max = STEP_THRESHOLD - - def make_schedule( - self, - ddim_num_steps, - ddim_discretize='uniform', - ddim_eta=0.0, - verbose=False, - ): - outer_model = self.model - self.model = outer_model.inner_model - super().make_schedule( - ddim_num_steps, - ddim_discretize='uniform', - ddim_eta=0.0, - verbose=False, - ) - self.model = outer_model - self.ddim_num_steps = ddim_num_steps - # we don't need both of these sigmas, but storing them here to make - # comparison easier later on - self.model_sigmas = self.model.get_sigmas(ddim_num_steps) - self.karras_sigmas = K.sampling.get_sigmas_karras( - n=ddim_num_steps, - sigma_min=self.model.sigmas[0].item(), - sigma_max=self.model.sigmas[-1].item(), - rho=7., - device=self.device, - ) - - if ddim_num_steps >= self.karras_max: - print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})') - self.sigmas = self.model_sigmas - else: - print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})') - self.sigmas = self.karras_sigmas - - # ALERT: We are completely overriding the sample() method in the base class, which - # means that inpainting will not work. To get this to work we need to be able to - # modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way - # in the lstein/k-diffusion branch. - - @torch.no_grad() - def decode( - self, - z_enc, - cond, - t_enc, - img_callback=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - init_latent = None, - mask = None, - **kwargs - ): - samples,_ = self.sample( - batch_size = 1, - S = t_enc, - x_T = z_enc, - shape = z_enc.shape[1:], - conditioning = cond, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning = unconditional_conditioning, - img_callback = img_callback, - x0 = init_latent, - mask = mask, - **kwargs - ) - return samples - - # this is a no-op, provided here for compatibility with ddim and plms samplers - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - return x0 - - # Most of these arguments are ignored and are only present for compatibility with - # other samples - @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - attention_maps_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None, - threshold = 0, - perlin = 0, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - def route_callback(k_callback_values): - if img_callback is not None: - img_callback(k_callback_values['x'],k_callback_values['i']) - - # if make_schedule() hasn't been called, we do it now - if self.sigmas is None: - self.make_schedule( - ddim_num_steps=S, - ddim_eta = eta, - verbose = False, - ) - - # sigmas are set up in make_schedule - we take the last steps items - sigmas = self.sigmas[-S-1:] - - # x_T is variation noise. When an init image is provided (in x0) we need to add - # more randomness to the starting image. - if x_T is not None: - if x0 is not None: - x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0] - else: - x = x_T * sigmas[0] - else: - x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) - model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) - - # setup attention maps saving. checks for None are because there are multiple code paths to get here. - attention_map_saver = None - if attention_maps_callback is not None and extra_conditioning_info is not None: - eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 - attention_map_token_ids = range(1, eos_token_index) - attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:]) - model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) - - extra_args = { - 'cond': conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': unconditional_guidance_scale, - } - print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)') - sampling_result = ( - K.sampling.__dict__[f'sample_{self.schedule}']( - model_wrap_cfg, x, sigmas, extra_args=extra_args, - callback=route_callback - ), - None, - ) - if attention_map_saver is not None: - attention_maps_callback(attention_map_saver) - return sampling_result - - # this code will support inpainting if and when ksampler API modified or - # a workaround is found. - @torch.no_grad() - def p_sample( - self, - img, - cond, - ts, - index, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - extra_conditioning_info=None, - **kwargs, - ): - if self.model_wrap is None: - self.model_wrap = CFGDenoiser(self.model) - extra_args = { - 'cond': cond, - 'uncond': unconditional_conditioning, - 'cond_scale': unconditional_guidance_scale, - } - if self.s_in is None: - self.s_in = img.new_ones([img.shape[0]]) - if self.ds is None: - self.ds = [] - - # terrible, confusing names here - steps = self.ddim_num_steps - t_enc = self.t_enc - - # sigmas is a full steps in length, but t_enc might - # be less. We start in the middle of the sigma array - # and work our way to the end after t_enc steps. - # index starts at t_enc and works its way to zero, - # so the actual formula for indexing into sigmas: - # sigma_index = (steps-index) - s_index = t_enc - index - 1 - self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info) - img = K.sampling.__dict__[f'_{self.schedule}']( - self.model_wrap, - img, - self.sigmas, - s_index, - s_in = self.s_in, - ds = self.ds, - extra_args=extra_args, - ) - - return img, None, None - - # REVIEW THIS METHOD: it has never been tested. In particular, - # we should not be multiplying by self.sigmas[0] if we - # are at an intermediate step in img2img. See similar in - # sample() which does work. - def get_initial_image(self,x_T,shape,steps): - print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing') - x = (torch.randn(shape, device=self.device) * self.sigmas[0]) - if x_T is not None: - return x_T + x - else: - return x - - def prepare_to_sample(self,t_enc,**kwargs): - self.t_enc = t_enc - self.model_wrap = None - self.ds = None - self.s_in = None - - def q_sample(self,x0,ts): - ''' - Overrides parent method to return the q_sample of the inner model. - ''' - return self.model.inner_model.q_sample(x0,ts) - - def conditioning_key(self)->str: - return self.model.inner_model.model.conditioning_key - diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py deleted file mode 100644 index 9edd333780..0000000000 --- a/ldm/models/diffusion/plms.py +++ /dev/null @@ -1,146 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm -from functools import partial -from ldm.invoke.devices import choose_torch_device -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -from ldm.models.diffusion.sampler import Sampler -from ldm.modules.diffusionmodules.util import noise_like - - -class PLMSSampler(Sampler): - def __init__(self, model, schedule='linear', device=None, **kwargs): - super().__init__(model,schedule,model.num_timesteps, device) - - def prepare_to_sample(self, t_enc, **kwargs): - super().prepare_to_sample(t_enc, **kwargs) - - extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) - - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) - else: - self.invokeai_diffuser.restore_default_cross_attention() - - - # this is the essential routine - @torch.no_grad() - def p_sample( - self, - x, # image, called 'img' elsewhere - c, # conditioning, called 'cond' elsewhere - t, # timesteps, called 'ts' elsewhere - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - old_eps=[], - t_next=None, - step_count:int=1000, # total number of steps - **kwargs, - ): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if ( - unconditional_conditioning is None - or unconditional_guidance_scale == 1.0 - ): - # damian0815 would like to know when/if this code path is used - e_t = self.model.apply_model(x, t, c) - else: - # step_index counts in the opposite direction to index - step_index = step_count-(index+1) - e_t = self.invokeai_diffuser.do_diffusion_step(x, t, - unconditional_conditioning, c, - unconditional_guidance_scale, - step_index=step_index) - if score_corrector is not None: - assert self.model.parameterization == 'eps' - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) - - return e_t - - alphas = ( - self.model.alphas_cumprod - if use_original_steps - else self.ddim_alphas - ) - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full( - (b, 1, 1, 1), alphas_prev[index], device=device - ) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = ( - sigma_t - * noise_like(x.shape, device, repeat_noise) - * temperature - ) - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = ( - 55 * e_t - - 59 * old_eps[-1] - + 37 * old_eps[-2] - - 9 * old_eps[-3] - ) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py deleted file mode 100644 index d7ec5bf1f4..0000000000 --- a/ldm/models/diffusion/sampler.py +++ /dev/null @@ -1,450 +0,0 @@ -''' -ldm.models.diffusion.sampler - -Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc -''' -import torch -import numpy as np -from tqdm import tqdm -from functools import partial -from ldm.invoke.devices import choose_torch_device -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent - -from ldm.modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, - extract_into_tensor, -) - -class Sampler(object): - def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs): - self.model = model - self.ddim_timesteps = None - self.ddpm_num_timesteps = steps - self.schedule = schedule - self.device = device or choose_torch_device() - self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, - model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device(self.device): - attr = attr.to(torch.float32).to(torch.device(self.device)) - setattr(self, name, attr) - - # This method was copied over from ddim.py and probably does stuff that is - # ddim-specific. Disentangle at some point. - def make_schedule( - self, - ddim_num_steps, - ddim_discretize='uniform', - ddim_eta=0.0, - verbose=False, - ): - self.total_steps = ddim_num_steps - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), 'alphas have to be defined for each timestep' - to_torch = ( - lambda x: x.clone() - .detach() - .to(torch.float32) - .to(self.model.device) - ) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer( - 'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - 'sqrt_one_minus_alphas_cumprod', - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - 'log_one_minus_alphas_cumprod', - to_torch(np.log(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - 'sqrt_recip_alphas_cumprod', - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), - ) - self.register_buffer( - 'sqrt_recipm1_alphas_cumprod', - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ( - ddim_sigmas, - ddim_alphas, - ddim_alphas_prev, - ) = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer( - 'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas) - ) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - 'ddim_sigmas_for_original_num_steps', - sigmas_for_original_sampling_steps, - ) - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return ( - extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 - + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) - * noise - ) - - @torch.no_grad() - def sample( - self, - S, # S is steps - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change. - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=False, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): - ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - # check to see if make_schedule() has run, and if not, run it - if self.ddim_timesteps is None: - self.make_schedule( - ddim_num_steps=S, - ddim_eta = eta, - verbose = False, - ) - - ts = self.get_timesteps(S) - - # sampling - C, H, W = shape - shape = (batch_size, C, H, W) - samples, intermediates = self.do_sampling( - conditioning, - shape, - timesteps=ts, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - steps=S, - **kwargs - ) - return samples, intermediates - - @torch.no_grad() - def do_sampling( - self, - cond, - shape, - timesteps=None, - x_T=None, - ddim_use_original_steps=False, - callback=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - steps=None, - **kwargs - ): - b = shape[0] - time_range = ( - list(reversed(range(0, timesteps))) - if ddim_use_original_steps - else np.flip(timesteps) - ) - - total_steps=steps - - iterator = tqdm( - time_range, - desc=f'{self.__class__.__name__}', - total=total_steps, - dynamic_ncols=True, - ) - old_eps = [] - self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs) - img = self.get_initial_image(x_T,shape,total_steps) - - # probably don't need this at all - intermediates = {'x_inter': [img], 'pred_x0': [img]} - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full( - (b,), - step, - device=self.device, - dtype=torch.long - ) - ts_next = torch.full( - (b,), - time_range[min(i + 1, len(time_range) - 1)], - device=self.device, - dtype=torch.long, - ) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - outs = self.p_sample( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, - t_next=ts_next, - step_count=steps - ) - img, pred_x0, e_t = outs - - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: - callback(i) - if img_callback: - img_callback(img,i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - # NOTE that decode() and sample() are almost the same code, and do the same thing. - # The variable names are changed in order to be confusing. - @torch.no_grad() - def decode( - self, - x_latent, - cond, - t_start, - img_callback=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - init_latent = None, - mask = None, - all_timesteps_count = None, - **kwargs - ): - timesteps = ( - np.arange(self.ddpm_num_timesteps) - if use_original_steps - else self.ddim_timesteps - ) - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)') - - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) - x_dec = x_latent - x0 = init_latent - self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full( - (x_latent.shape[0],), - step, - device=x_latent.device, - dtype=torch.long, - ) - - ts_next = torch.full( - (x_latent.shape[0],), - time_range[min(i + 1, len(time_range) - 1)], - device=self.device, - dtype=torch.long, - ) - - if mask is not None: - assert x0 is not None - xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? - x_dec = xdec_orig * mask + (1.0 - mask) * x_dec - - outs = self.p_sample( - x_dec, - cond, - ts, - index=index, - use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - t_next = ts_next, - step_count=len(self.ddim_timesteps) - ) - - x_dec, pred_x0, e_t = outs - if img_callback: - img_callback(x_dec,i) - - return x_dec - - def get_initial_image(self,x_T,shape,timesteps=None): - if x_T is None: - return torch.randn(shape, device=self.device) - else: - return x_T - - def p_sample( - self, - img, - cond, - ts, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - old_eps=None, - t_next=None, - steps=None, - ): - raise NotImplementedError("p_sample() must be implemented in a descendent class") - - def prepare_to_sample(self,t_enc,**kwargs): - ''' - Hook that will be called right before the very first invocation of p_sample() - to allow subclass to do additional initialization. t_enc corresponds to the actual - number of steps that will be run, and may be less than total steps if img2img is - active. - ''' - pass - - def get_timesteps(self,ddim_steps): - ''' - The ddim and plms samplers work on timesteps. This method is called after - ddim_timesteps are created in make_schedule(), and selects the portion of - timesteps that will be used for sampling, depending on the t_enc in img2img. - ''' - return self.ddim_timesteps[:ddim_steps] - - def q_sample(self,x0,ts): - ''' - Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to - return self.model.inner_model.q_sample(x0,ts) - ''' - return self.model.q_sample(x0,ts) - - def conditioning_key(self)->str: - return self.model.model.conditioning_key - - def uses_inpainting_model(self)->bool: - return self.conditioning_key() in ('hybrid','concat') - - def adjust_settings(self,**kwargs): - ''' - This is a catch-all method for adjusting any instance variables - after the sampler is instantiated. No type-checking performed - here, so use with care! - ''' - for k in kwargs.keys(): - try: - setattr(self,k,kwargs[k]) - except AttributeError: - print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}') diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py deleted file mode 100644 index cddddd3e86..0000000000 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ /dev/null @@ -1,491 +0,0 @@ -from contextlib import contextmanager -from dataclasses import dataclass -from math import ceil -from typing import Callable, Optional, Union, Any, Dict - -import numpy as np -import torch -from diffusers.models.cross_attention import AttnProcessor -from typing_extensions import TypeAlias - -from ldm.invoke.globals import Globals -from ldm.models.diffusion.cross_attention_control import Arguments, \ - restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ - CrossAttentionType, SwapCrossAttnContext -from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver - -ModelForwardCallback: TypeAlias = Union[ - # x, t, conditioning, Optional[cross-attention kwargs] - Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor], - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] -] - -@dataclass(frozen=True) -class PostprocessingSettings: - threshold: float - warmup: float - h_symmetry_time_pct: Optional[float] - v_symmetry_time_pct: Optional[float] - - -class InvokeAIDiffuserComponent: - ''' - The aim of this component is to provide a single place for code that can be applied identically to - all InvokeAI diffusion procedures. - - At the moment it includes the following features: - * Cross attention control ("prompt2prompt") - * Hybrid conditioning (used for inpainting) - ''' - debug_thresholding = False - sequential_guidance = False - - @dataclass - class ExtraConditioningInfo: - - tokens_count_including_eos_bos: int - cross_attention_control_args: Optional[Arguments] = None - - @property - def wants_cross_attention_control(self): - return self.cross_attention_control_args is not None - - - def __init__(self, model, model_forward_callback: ModelForwardCallback, - is_running_diffusers: bool=False, - ): - """ - :param model: the unet model to pass through to cross attention control - :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) - """ - self.conditioning = None - self.model = model - self.is_running_diffusers = is_running_diffusers - self.model_forward_callback = model_forward_callback - self.cross_attention_control_context = None - self.sequential_guidance = Globals.sequential_guidance - - @contextmanager - def custom_attention_context(self, - extra_conditioning_info: Optional[ExtraConditioningInfo], - step_count: int): - do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control - old_attn_processor = None - if do_swap: - old_attn_processor = self.override_cross_attention(extra_conditioning_info, - step_count=step_count) - try: - yield None - finally: - if old_attn_processor is not None: - self.restore_default_cross_attention(old_attn_processor) - # TODO resuscitate attention map saving - #self.remove_attention_map_saving() - - def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: - """ - setup cross attention .swap control. for diffusers this replaces the attention processor, so - the previous attention processor is returned so that the caller can restore it later. - """ - self.conditioning = conditioning - self.cross_attention_control_context = Context( - arguments=self.conditioning.cross_attention_control_args, - step_count=step_count - ) - return override_cross_attention(self.model, - self.cross_attention_control_context, - is_running_diffusers=self.is_running_diffusers) - - def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None): - self.conditioning = None - self.cross_attention_control_context = None - restore_default_cross_attention(self.model, - is_running_diffusers=self.is_running_diffusers, - restore_attention_processor=restore_attention_processor) - - def setup_attention_map_saving(self, saver: AttentionMapSaver): - def callback(slice, dim, offset, slice_size, key): - if dim is not None: - # sliced tokens attention map saving is not implemented - return - saver.add_attention_maps(slice, key) - - tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) - for identifier, module in tokens_cross_attention_modules: - key = ('down' if identifier.startswith('down') else - 'up' if identifier.startswith('up') else - 'mid') - module.set_attention_slice_calculated_callback( - lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)) - - def remove_attention_map_saving(self): - tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) - for _, module in tokens_cross_attention_modules: - module.set_attention_slice_calculated_callback(None) - - def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, - unconditioning: Union[torch.Tensor,dict], - conditioning: Union[torch.Tensor,dict], - unconditional_guidance_scale: float, - step_index: Optional[int]=None, - total_step_count: Optional[int]=None, - ): - """ - :param x: current latents - :param sigma: aka t, passed to the internal model to control how much denoising will occur - :param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768] - :param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768] - :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has - :param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas . - :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. - """ - - - cross_attention_control_types_to_do = [] - context: Context = self.cross_attention_control_context - if self.cross_attention_control_context is not None: - percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) - cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) - - wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) - wants_hybrid_conditioning = isinstance(conditioning, dict) - - if wants_hybrid_conditioning: - unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning, - conditioning) - elif wants_cross_attention_control: - unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma, - unconditioning, - conditioning, - cross_attention_control_types_to_do) - elif self.sequential_guidance: - unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially( - x, sigma, unconditioning, conditioning) - - else: - unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning( - x, sigma, unconditioning, conditioning) - - combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) - - return combined_next_x - - def do_latent_postprocessing( - self, - postprocessing_settings: PostprocessingSettings, - latents: torch.Tensor, - sigma, - step_index, - total_step_count - ) -> torch.Tensor: - if postprocessing_settings is not None: - percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) - latents = self.apply_threshold(postprocessing_settings, latents, percent_through) - latents = self.apply_symmetry(postprocessing_settings, latents, percent_through) - return latents - - def calculate_percent_through(self, sigma, step_index, total_step_count): - if step_index is not None and total_step_count is not None: - # 🧨diffusers codepath - percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate - else: - # legacy compvis codepath - # TODO remove when compvis codepath support is dropped - if step_index is None and sigma is None: - raise ValueError( - f"Either step_index or sigma is required when doing cross attention control, but both are None.") - percent_through = self.estimate_percent_through(step_index, sigma) - return percent_through - - # methods below are called from do_diffusion_step and should be considered private to this class. - - def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): - # fast batched path - x_twice = torch.cat([x] * 2) - sigma_twice = torch.cat([sigma] * 2) - both_conditionings = torch.cat([unconditioning, conditioning]) - both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings) - unconditioned_next_x, conditioned_next_x = both_results.chunk(2) - if conditioned_next_x.device.type == 'mps': - # prevent a result filled with zeros. seems to be a torch bug. - conditioned_next_x = conditioned_next_x.clone() - return unconditioned_next_x, conditioned_next_x - - - def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor): - # low-memory sequential path - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) - conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) - if conditioned_next_x.device.type == 'mps': - # prevent a result filled with zeros. seems to be a torch bug. - conditioned_next_x = conditioned_next_x.clone() - return unconditioned_next_x, conditioned_next_x - - - def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): - assert isinstance(conditioning, dict) - assert isinstance(unconditioning, dict) - x_twice = torch.cat([x] * 2) - sigma_twice = torch.cat([sigma] * 2) - both_conditionings = dict() - for k in conditioning: - if isinstance(conditioning[k], list): - both_conditionings[k] = [ - torch.cat([unconditioning[k][i], conditioning[k][i]]) - for i in range(len(conditioning[k])) - ] - else: - both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) - unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) - return unconditioned_next_x, conditioned_next_x - - - def _apply_cross_attention_controlled_conditioning(self, - x: torch.Tensor, - sigma, - unconditioning, - conditioning, - cross_attention_control_types_to_do): - if self.is_running_diffusers: - return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, - conditioning, - cross_attention_control_types_to_do) - else: - return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, - cross_attention_control_types_to_do) - - def _apply_cross_attention_controlled_conditioning__diffusers(self, - x: torch.Tensor, - sigma, - unconditioning, - conditioning, - cross_attention_control_types_to_do): - context: Context = self.cross_attention_control_context - - cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning, - index_map=context.cross_attention_index_map, - mask=context.cross_attention_mask, - cross_attention_types_to_do=[]) - # no cross attention for unconditioning (negative prompt) - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, - {"swap_cross_attn_context": cross_attn_processor_context}) - - # do requested cross attention types for conditioning (positive prompt) - cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do - conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, - {"swap_cross_attn_context": cross_attn_processor_context}) - return unconditioned_next_x, conditioned_next_x - - - def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): - # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) - # slower non-batched path (20% slower on mac MPS) - # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of - # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. - # This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8) - # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, - # representing batched uncond + cond, but then when it comes to applying the saved attention, the - # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) - # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. - context:Context = self.cross_attention_control_context - - try: - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) - - # process x using the original prompt, saving the attention maps - #print("saving attention maps for", cross_attention_control_types_to_do) - for ca_type in cross_attention_control_types_to_do: - context.request_save_attention_maps(ca_type) - _ = self.model_forward_callback(x, sigma, conditioning) - context.clear_requests(cleanup=False) - - # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - #print("applying saved attention maps for", cross_attention_control_types_to_do) - for ca_type in cross_attention_control_types_to_do: - context.request_apply_saved_attention_maps(ca_type) - edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning - conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) - context.clear_requests(cleanup=True) - - except: - context.clear_requests(cleanup=True) - raise - - return unconditioned_next_x, conditioned_next_x - - def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale): - # to scale how much effect conditioning has, calculate the changes it does and then scale that - scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale - combined_next_x = unconditioned_next_x + scaled_delta - return combined_next_x - - def apply_threshold( - self, - postprocessing_settings: PostprocessingSettings, - latents: torch.Tensor, - percent_through: float - ) -> torch.Tensor: - - if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0: - return latents - - threshold = postprocessing_settings.threshold - warmup = postprocessing_settings.warmup - - if percent_through < warmup: - current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup)) - else: - current_threshold = threshold - - if current_threshold <= 0: - return latents - - maxval = latents.max().item() - minval = latents.min().item() - - scale = 0.7 # default value from #395 - - if self.debug_thresholding: - std, mean = [i.item() for i in torch.std_mean(latents)] - outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold)) - print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n" - f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n" - f" | {outside / latents.numel() * 100:.2f}% values outside threshold") - - if maxval < current_threshold and minval > -current_threshold: - return latents - - num_altered = 0 - - # MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py! - - if maxval > current_threshold: - latents = torch.clone(latents) - maxval = np.clip(maxval * scale, 1, current_threshold) - num_altered += torch.count_nonzero(latents > maxval) - latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval - - if minval < -current_threshold: - latents = torch.clone(latents) - minval = np.clip(minval * scale, -current_threshold, -1) - num_altered += torch.count_nonzero(latents < minval) - latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval - - if self.debug_thresholding: - print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n" - f" | {num_altered / latents.numel() * 100:.2f}% values altered") - - return latents - - def apply_symmetry( - self, - postprocessing_settings: PostprocessingSettings, - latents: torch.Tensor, - percent_through: float - ) -> torch.Tensor: - - # Reset our last percent through if this is our first step. - if percent_through == 0.0: - self.last_percent_through = 0.0 - - if postprocessing_settings is None: - return latents - - # Check for out of bounds - h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct - if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)): - h_symmetry_time_pct = None - - v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct - if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)): - v_symmetry_time_pct = None - - dev = latents.device.type - - latents.to(device='cpu') - - if ( - h_symmetry_time_pct != None and - self.last_percent_through < h_symmetry_time_pct and - percent_through >= h_symmetry_time_pct - ): - # Horizontal symmetry occurs on the 3rd dimension of the latent - width = latents.shape[3] - x_flipped = torch.flip(latents, dims=[3]) - latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3) - - if ( - v_symmetry_time_pct != None and - self.last_percent_through < v_symmetry_time_pct and - percent_through >= v_symmetry_time_pct - ): - # Vertical symmetry occurs on the 2nd dimension of the latent - height = latents.shape[2] - y_flipped = torch.flip(latents, dims=[2]) - latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2) - - self.last_percent_through = percent_through - return latents.to(device=dev) - - def estimate_percent_through(self, step_index, sigma): - if step_index is not None and self.cross_attention_control_context is not None: - # percent_through will never reach 1.0 (but this is intended) - return float(step_index) / float(self.cross_attention_control_context.step_count) - # find the best possible index of the current sigma in the sigma sequence - smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma) - sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0 - # flip because sigmas[0] is for the fully denoised image - # percent_through must be <1 - return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0]) - # print('estimated percent_through', percent_through, 'from sigma', sigma.item()) - - - # todo: make this work - @classmethod - def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) # aka sigmas - - deltas = None - uncond_latents = None - weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] - - # below is fugly omg - num_actual_conditionings = len(c_or_weighted_c_list) - conditionings = [uc] + [c for c,weight in weighted_cond_list] - weights = [1] + [weight for c,weight in weighted_cond_list] - chunk_count = ceil(len(conditionings)/2) - deltas = None - for chunk_index in range(chunk_count): - offset = chunk_index*2 - chunk_size = min(2, len(conditionings)-offset) - - if chunk_size == 1: - c_in = conditionings[offset] - latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) - latents_b = None - else: - c_in = torch.cat(conditionings[offset:offset+2]) - latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) - - # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining - if chunk_index == 0: - uncond_latents = latents_a - deltas = latents_b - uncond_latents - else: - deltas = torch.cat((deltas, latents_a - uncond_latents)) - if latents_b is not None: - deltas = torch.cat((deltas, latents_b - uncond_latents)) - - # merge the weighted deltas together into a single merged delta - per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) - normalize = False - if normalize: - per_delta_weights /= torch.sum(per_delta_weights) - reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) - deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) - - # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) - # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) - - return uncond_latents + deltas_merged * global_guidance_scale diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 6737ed2060..0cd69366ce 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat -from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin +from invokeai.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin from ldm.modules.diffusionmodules.util import checkpoint def exists(val): diff --git a/pyproject.toml b/pyproject.toml index 3d50bd124d..4b5a5d5fda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,12 @@ version = { attr = "ldm.invoke.__version__" } [tool.setuptools.packages.find] "where" = ["."] -"include" = ["invokeai.assets.web*", "invokeai.backend*", "invokeai.frontend.dist*", "invokeai.configs*", "ldm*"] +"include" = [ + "invokeai.assets.web*", "invokeai.models*", + "invokeai.generator*","invokeai.backend*", + "invokeai.frontend.dist*", "invokeai.configs*", + "ldm*" +] [tool.setuptools.package-data] "invokeai.assets.web" = ["**.png"]