diff --git a/configs/models.yaml b/configs/models.yaml index 332ee26409..162da38da2 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -13,6 +13,13 @@ stable-diffusion-1.4: width: 512 height: 512 default: true +inpainting-1.5: + description: runwayML tuned inpainting model v1.5 + weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt + config: configs/stable-diffusion/v1-inpainting-inference.yaml +# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt + width: 512 + height: 512 stable-diffusion-1.5: config: configs/stable-diffusion/v1-inference.yaml weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt diff --git a/configs/stable-diffusion/v1-inpainting-inference.yaml b/configs/stable-diffusion/v1-inpainting-inference.yaml new file mode 100644 index 0000000000..5652e04374 --- /dev/null +++ b/configs/stable-diffusion/v1-inpainting-inference.yaml @@ -0,0 +1,79 @@ +model: + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important + monitor: val/loss_simple_ema + scale_factor: 0.18215 + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + personalization_config: + target: ldm.modules.embedding_manager.EmbeddingManager + params: + placeholder_strings: ["*"] + initializer_words: ['face', 'man', 'photo', 'africanmale'] + per_image_tokens: false + num_vectors_per_token: 1 + progressive_words: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 # 4 data + 4 downscaled image + 1 mask + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/ldm/generate.py b/ldm/generate.py index 43ed28eecd..3785be56bb 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -421,7 +421,10 @@ class Generate: ) # TODO: Hacky selection of operation to perform. Needs to be refactored. - if (init_image is not None) and (mask_image is not None): + if self.sampler.conditioning_key() in ('hybrid','concat'): + print(f'** Inpainting model detected. Will try it! **') + generator = self._make_omnibus() + elif (init_image is not None) and (mask_image is not None): generator = self._make_inpaint() elif (embiggen != None or embiggen_tiles != None): generator = self._make_embiggen() @@ -677,6 +680,7 @@ class Generate: return init_image,init_mask + # lots o' repeated code here! Turn into a make_func() def _make_base(self): if not self.generators.get('base'): from ldm.invoke.generator import Generator @@ -687,6 +691,7 @@ class Generate: if not self.generators.get('img2img'): from ldm.invoke.generator.img2img import Img2Img self.generators['img2img'] = Img2Img(self.model, self.precision) + self.generators['img2img'].free_gpu_mem = self.free_gpu_mem return self.generators['img2img'] def _make_embiggen(self): @@ -715,6 +720,15 @@ class Generate: self.generators['inpaint'] = Inpaint(self.model, self.precision) return self.generators['inpaint'] + # "omnibus" supports the runwayML custom inpainting model, which does + # txt2img, img2img and inpainting using slight variations on the same code + def _make_omnibus(self): + if not self.generators.get('omnibus'): + from ldm.invoke.generator.omnibus import Omnibus + self.generators['omnibus'] = Omnibus(self.model, self.precision) + self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem + return self.generators['omnibus'] + def load_model(self): ''' preload model identified in self.model_name diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index abde269acf..143e49150e 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -181,7 +181,9 @@ class Args(object): switches_started = False for element in elements: - if element[0] == '-' and not switches_started: + if len(element) == 0: # empty prompt + pass + elif element[0] == '-' and not switches_started: switches_started = True if switches_started: switches.append(element) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 7c095de7b7..7365bc9a87 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -123,8 +123,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n else: conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) - unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) + conditioning = flatten_hybrid_conditioning(unconditioning, conditioning) return ( unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( cross_attention_control_args=cac_args @@ -166,4 +166,25 @@ def get_tokens_length(model, fragments: list[Fragment]): tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) return sum([len(x) for x in tokens]) +def flatten_hybrid_conditioning(uncond, cond): + ''' + This handles the choice between a conditional conditioning + that is a tensor (used by cross attention) vs one that has additional + dimensions as well, as used by 'hybrid' + ''' + if isinstance(cond, dict): + assert isinstance(uncond, dict) + cond_in = dict() + for k in cond: + if isinstance(cond[k], list): + cond_in[k] = [ + torch.cat([uncond[k][i], cond[k][i]]) + for i in range(len(cond[k])) + ] + else: + cond_in[k] = torch.cat([uncond[k], cond[k]]) + return cond_in + else: + return cond + diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index d42013ea73..2e96c93cbb 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -6,6 +6,7 @@ import torch import numpy as np import random import os +import traceback from tqdm import tqdm, trange from PIL import Image, ImageFilter from einops import rearrange, repeat @@ -43,7 +44,7 @@ class Generator(): self.variation_amount = variation_amount self.with_variations = with_variations - def generate(self,prompt,init_image,width,height,iterations=1,seed=None, + def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, safety_checker:dict=None, **kwargs): @@ -51,6 +52,7 @@ class Generator(): self.safety_checker = safety_checker make_image = self.get_make_image( prompt, + sampler = sampler, init_image = init_image, width = width, height = height, @@ -59,12 +61,14 @@ class Generator(): perlin = perlin, **kwargs ) - results = [] seed = seed if seed is not None else self.new_seed() first_seed = seed seed, initial_noise = self.generate_initial_noise(seed, width, height) - with scope(self.model.device.type), self.model.ema_scope(): + + # There used to be an additional self.model.ema_scope() here, but it breaks + # the inpaint-1.5 model. Not sure what it did.... ? + with scope(self.model.device.type): for n in trange(iterations, desc='Generating'): x_T = None if self.variation_amount > 0: @@ -79,7 +83,8 @@ class Generator(): try: x_T = self.get_noise(width,height) except: - pass + print('** An error occurred while getting initial noise **') + print(traceback.format_exc()) image = make_image(x_T) @@ -95,10 +100,10 @@ class Generator(): return results - def sample_to_image(self,samples): + def sample_to_image(self,samples)->Image.Image: """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it + Given samples returned from a sampler, converts + it into a PIL Image """ x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 79b943024c..c4810de385 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -15,7 +15,7 @@ from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserCompo class Img2Img(Generator): def __init__(self, model, precision): super().__init__(model, precision) - self.init_latent = None # by get_noise() + self.init_latent = None # by get_noise() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): @@ -80,7 +80,10 @@ class Img2Img(Generator): def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) + if len(image.shape) == 2: # 'L' image, as in a mask + image = image[None,None] + else: # 'RGB' image + image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) if normalize: image = 2.0 * image - 1.0 diff --git a/ldm/invoke/generator/omnibus.py b/ldm/invoke/generator/omnibus.py new file mode 100644 index 0000000000..e0705ec397 --- /dev/null +++ b/ldm/invoke/generator/omnibus.py @@ -0,0 +1,151 @@ +"""omnibus module to be used with the runwayml 9-channel custom inpainting model""" + +import torch +import numpy as np +from einops import repeat +from PIL import Image, ImageOps +from ldm.invoke.devices import choose_autocast +from ldm.invoke.generator.base import downsampling +from ldm.invoke.generator.img2img import Img2Img +from ldm.invoke.generator.txt2img import Txt2Img + +class Omnibus(Img2Img,Txt2Img): + def __init__(self, model, precision): + super().__init__(model, precision) + + def get_make_image( + self, + prompt, + sampler, + steps, + cfg_scale, + ddim_eta, + conditioning, + width, + height, + init_image = None, + mask_image = None, + strength = None, + step_callback=None, + threshold=0.0, + perlin=0.0, + **kwargs): + """ + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it. + """ + self.perlin = perlin + num_samples = 1 + + sampler.make_schedule( + ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False + ) + + if isinstance(init_image, Image.Image): + init_image = self._image_to_tensor(init_image) + + if isinstance(mask_image, Image.Image): + mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False) + + t_enc = steps + + if init_image is not None and mask_image is not None: # inpainting + masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero + + elif init_image is not None: # img2img + scope = choose_autocast(self.precision) + + with scope(self.model.device.type): + self.init_latent = self.model.get_first_stage_encoding( + self.model.encode_first_stage(init_image) + ) # move to latent space + + # create a completely black mask (1s) + mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device) + # and the masked image is just a copy of the original + masked_image = init_image + + else: # txt2img + init_image = torch.zeros(1, 3, height, width, device=self.model.device) + mask_image = torch.ones(1, 1, height, width, device=self.model.device) + masked_image = init_image + + self.init_latent = init_image + height = init_image.shape[2] + width = init_image.shape[3] + model = self.model + + def make_image(x_T): + with torch.no_grad(): + scope = choose_autocast(self.precision) + with scope(self.model.device.type): + + batch = self.make_batch_sd( + init_image, + mask_image, + masked_image, + prompt=prompt, + device=model.device, + num_samples=num_samples, + ) + + c = model.cond_stage_model.encode(batch["txt"]) + c_cat = list() + for ck in model.concat_keys: + cc = batch[ck].float() + if ck != model.masked_image_key: + bchw = [num_samples, 4, height//8, width//8] + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + + # cond + cond={"c_concat": [c_cat], "c_crossattn": [c]} + + # uncond cond + uc_cross = model.get_unconditional_conditioning(num_samples, "") + uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} + shape = [model.channels, height//8, width//8] + + samples, _ = sampler.sample( + batch_size = 1, + S = steps, + x_T = x_T, + conditioning = cond, + shape = shape, + verbose = False, + unconditional_guidance_scale = cfg_scale, + unconditional_conditioning = uc_full, + eta = 1.0, + img_callback = step_callback, + threshold = threshold, + ) + if self.free_gpu_mem: + self.model.model.to("cpu") + return self.sample_to_image(samples) + + return make_image + + def make_batch_sd( + self, + image, + mask, + masked_image, + prompt, + device, + num_samples=1): + batch = { + "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples), + "txt": num_samples * [prompt], + "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), + "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), + } + return batch + + def get_noise(self, width:int, height:int): + if self.init_latent is not None: + height = self.init_latent.shape[2] + width = self.init_latent.shape[3] + return Txt2Img.get_noise(self,width,height) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index f580dfba25..f972a9eb16 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -13,6 +13,7 @@ import gc import hashlib import psutil import transformers +import traceback import os from sys import getrefcount from omegaconf import OmegaConf @@ -73,6 +74,7 @@ class ModelCache(object): self.models[model_name]['hash'] = hash except Exception as e: print(f'** model {model_name} could not be loaded: {str(e)}') + print(traceback.format_exc()) print(f'** restoring {self.current_model}') self.get_model(self.current_model) return None diff --git a/ldm/invoke/restoration/outcrop.py b/ldm/invoke/restoration/outcrop.py index 017d9de7e1..0c4831dd84 100644 --- a/ldm/invoke/restoration/outcrop.py +++ b/ldm/invoke/restoration/outcrop.py @@ -89,6 +89,9 @@ class Outcrop(object): def _extend(self,image:Image,pixels:int)-> Image: extended_img = Image.new('RGBA',(image.width,image.height+pixels)) + mask_height = pixels if self.generate.model.model.conditioning_key in ('hybrid','concat') \ + else pixels *2 + # first paste places old image at top of extended image, stretch # it, and applies a gaussian blur to it # take the top half region, stretch and paste it @@ -105,7 +108,9 @@ class Outcrop(object): # now make the top part transparent to use as a mask alpha = extended_img.getchannel('A') - alpha.paste(0,(0,0,extended_img.width,pixels*2)) + alpha.paste(0,(0,0,extended_img.width,mask_height)) extended_img.putalpha(alpha) + extended_img.save('outputs/curly_extended.png') + return extended_img diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py index 359f5688d1..3db7b6fd73 100644 --- a/ldm/models/autoencoder.py +++ b/ldm/models/autoencoder.py @@ -66,7 +66,7 @@ class VQModel(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self) - print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') + print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.') if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index b11e8578e7..e5a502f977 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -53,12 +53,14 @@ class DDIMSampler(Sampler): # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: + # step_index counts in the opposite direction to index step_index = step_count-(index+1) - e_t = self.invokeai_diffuser.do_diffusion_step(x, t, - unconditional_conditioning, c, - unconditional_guidance_scale, - step_index=step_index) - + e_t = self.invokeai_diffuser.do_diffusion_step( + x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index + ) if score_corrector is not None: assert self.model.parameterization == 'eps' e_t = score_corrector.modify_score( diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 57027b224c..827ab4e890 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -19,6 +19,7 @@ from functools import partial from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig import urllib from ldm.util import ( @@ -120,7 +121,7 @@ class DDPM(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) - print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') + print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.') self.use_scheduler = scheduler_config is not None if self.use_scheduler: @@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM): return samples, intermediates + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + @torch.no_grad() def log_images( self, @@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule): cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == 'hybrid': - xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) + xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == 'adm': cc = c_crossattn[0] @@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion): cond_img = torch.stack(bbox_imgs, dim=0) logs['bbox_image'] = cond_img return logs + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + finetune_keys=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = ( + rearrange(batch[ck], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 2f5bf53850..5a63313b32 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -23,9 +23,10 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0): + def __init__(self, sampler, threshold = 0, warmup = 0): super().__init__() - self.inner_model = model + self.inner_model = sampler.model + self.sampler = sampler self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) @@ -43,10 +44,14 @@ class CFGDenoiser(nn.Module): def forward(self, x, sigma, uncond, cond, cond_scale): - - next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) - - # apply threshold + if isinstance(cond,dict): # hybrid model + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = self.sampler.make_cond_in(uncond,cond) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + next_x = uncond + (cond - uncond) * cond_scale + else: # cross attention model + next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 @@ -56,8 +61,6 @@ class CFGDenoiser(nn.Module): thresh = self.threshold return cfg_apply_threshold(next_x, thresh) - - class KSampler(Sampler): def __init__(self, model, schedule='lms', device=None, **kwargs): denoiser = K.external.CompVisDenoiser(model) @@ -286,3 +289,6 @@ class KSampler(Sampler): ''' return self.model.inner_model.q_sample(x0,ts) + def conditioning_key(self)->str: + return self.model.inner_model.model.conditioning_key + diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 6bd519b63b..5124badcd1 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -14,9 +14,6 @@ class PLMSSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps, device) - self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, - model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) - def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) @@ -67,7 +64,6 @@ class PLMSSampler(Sampler): unconditional_conditioning, c, unconditional_guidance_scale, step_index=step_index) - if score_corrector is not None: assert self.model.parameterization == 'eps' e_t = score_corrector.modify_score( diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 853702ef68..f31f5b1758 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -11,6 +11,7 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.modules.diffusionmodules.util import ( make_ddim_sampling_parameters, @@ -26,6 +27,8 @@ class Sampler(object): self.ddpm_num_timesteps = steps self.schedule = schedule self.device = device or choose_torch_device() + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, + model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) def register_buffer(self, name, attr): if type(attr) == torch.Tensor: @@ -160,6 +163,18 @@ class Sampler(object): **kwargs, ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + # check to see if make_schedule() has run, and if not, run it if self.ddim_timesteps is None: self.make_schedule( @@ -196,7 +211,7 @@ class Sampler(object): ) return samples, intermediates - #torch.no_grad() + @torch.no_grad() def do_sampling( self, cond, @@ -257,6 +272,7 @@ class Sampler(object): ) if mask is not None: + print('DEBUG: in masking routine') assert x0 is not None img_orig = self.model.q_sample( x0, ts @@ -313,7 +329,6 @@ class Sampler(object): all_timesteps_count = None, **kwargs ): - timesteps = ( np.arange(self.ddpm_num_timesteps) if use_original_steps @@ -420,3 +435,27 @@ class Sampler(object): ''' return self.model.q_sample(x0,ts) + def conditioning_key(self)->str: + return self.model.model.conditioning_key + + # def make_cond_in(self, uncond, cond): + # ''' + # This handles the choice between a conditional conditioning + # that is a tensor (used by cross attention) vs one that is a dict + # used by 'hybrid' + # ''' + # if isinstance(cond, dict): + # assert isinstance(uncond, dict) + # cond_in = dict() + # for k in cond: + # if isinstance(cond[k], list): + # cond_in[k] = [ + # torch.cat([uncond[k][i], cond[k][i]]) + # for i in range(len(cond[k])) + # ] + # else: + # cond_in[k] = torch.cat([uncond[k], cond[k]]) + # else: + # cond_in = torch.cat([uncond, cond]) + # return cond_in + diff --git a/scripts/invoke.py b/scripts/invoke.py index faa85de80e..466536bc46 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -171,9 +171,9 @@ def main_loop(gen, opt): except (OSError, AttributeError, KeyError): pass - if len(opt.prompt) == 0: - print('\nTry again with a prompt!') - continue +# if len(opt.prompt) == 0: +# print('\nTry again with a prompt!') +# continue # width and height are set by model if not specified if not opt.width: