From bc18a94d8cb118985115394a9228ea2e772a478f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 13 Feb 2023 14:11:36 -0500 Subject: [PATCH] add ability to retrieve current list of embedding trigger strings This PR adds a new attributer to ldm.generate, `embedding_trigger_strings`: ``` gen = Generate(...) strings = gen.embedding_trigger_strings strings = gen.embedding_trigger_strings() ``` The trigger strings will change when the model is updated to show only those strings which are compatible with the current model. Dynamically-downloaded triggers from the HF Concepts Library will only show up after they are used for the first time. However, the full list of concepts available for download can be retrieved programatically like this: ``` from ldm.invoke.concepts_lib import HuggingFAceConceptsLibrary concepts = HuggingFaceConceptsLibrary() trigger_strings = concepts.list_concepts() ``` --- ldm/generate.py | 978 ++++++++++++----------- ldm/invoke/CLI.py | 947 ++++++++++++---------- ldm/invoke/args.py | 3 + ldm/invoke/readline.py | 2 +- ldm/modules/textual_inversion_manager.py | 275 ++++--- 5 files changed, 1253 insertions(+), 952 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 8cb3058694..aa04b24df3 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -11,6 +11,7 @@ import re import sys import time import traceback +from typing import List import cv2 import diffusers @@ -18,19 +19,19 @@ import numpy as np import skimage import torch import transformers -from PIL import Image, ImageOps from diffusers.pipeline_utils import DiffusionPipeline from diffusers.utils.import_utils import is_xformers_available from omegaconf import OmegaConf -from pytorch_lightning import seed_everything, logging +from PIL import Image, ImageOps +from pytorch_lightning import logging, seed_everything import ldm.invoke.conditioning from ldm.invoke.args import metadata_from_png from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from ldm.invoke.conditioning import get_uc_and_c_and_ec -from ldm.invoke.devices import choose_torch_device, choose_precision +from ldm.invoke.devices import choose_precision, choose_torch_device from ldm.invoke.generator.inpaint import infill_methods -from ldm.invoke.globals import global_cache_dir, Globals +from ldm.invoke.globals import Globals, global_cache_dir from ldm.invoke.image_util import InitImageResizer from ldm.invoke.model_manager import ModelManager from ldm.invoke.pngwriter import PngWriter @@ -42,14 +43,17 @@ from ldm.models.diffusion.plms import PLMSSampler def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + def new_func(*args, **kw): device = kw.get("device", "mps") - kw["device"]="cpu" + kw["device"] = "cpu" return orig(*args, **kw).to(device) + return new_func return orig + torch.rand = fix_func(torch.rand) torch.rand_like = fix_func(torch.rand_like) torch.randn = fix_func(torch.randn) @@ -60,7 +64,7 @@ torch.bernoulli = fix_func(torch.bernoulli) torch.multinomial = fix_func(torch.multinomial) # this is fallback model in case no default is defined -FALLBACK_MODEL_NAME='stable-diffusion-1.5' +FALLBACK_MODEL_NAME = "stable-diffusion-1.5" """Simplified text to image API for stable diffusion/latent diffusion @@ -129,59 +133,60 @@ gr = Generate( """ + class Generate: """Generate class Stores default values for multiple configuration items """ def __init__( - self, - model = None, - conf = 'configs/models.yaml', - embedding_path = None, - sampler_name = 'k_lms', - ddim_eta = 0.0, # deterministic - full_precision = False, - precision = 'auto', - outdir = 'outputs/img-samples', - gfpgan=None, - codeformer=None, - esrgan=None, - free_gpu_mem: bool=False, - safety_checker:bool=False, - max_loaded_models:int=2, - # these are deprecated; if present they override values in the conf file - weights = None, - config = None, + self, + model=None, + conf="configs/models.yaml", + embedding_path=None, + sampler_name="k_lms", + ddim_eta=0.0, # deterministic + full_precision=False, + precision="auto", + outdir="outputs/img-samples", + gfpgan=None, + codeformer=None, + esrgan=None, + free_gpu_mem: bool = False, + safety_checker: bool = False, + max_loaded_models: int = 2, + # these are deprecated; if present they override values in the conf file + weights=None, + config=None, ): - mconfig = OmegaConf.load(conf) - self.height = None - self.width = None - self.model_manager = None - self.iterations = 1 - self.steps = 50 - self.cfg_scale = 7.5 - self.sampler_name = sampler_name - self.ddim_eta = ddim_eta # same seed always produces same image - self.precision = precision - self.strength = 0.75 - self.seamless = False - self.seamless_axes = {'x','y'} - self.hires_fix = False + mconfig = OmegaConf.load(conf) + self.height = None + self.width = None + self.model_manager = None + self.iterations = 1 + self.steps = 50 + self.cfg_scale = 7.5 + self.sampler_name = sampler_name + self.ddim_eta = ddim_eta # same seed always produces same image + self.precision = precision + self.strength = 0.75 + self.seamless = False + self.seamless_axes = {"x", "y"} + self.hires_fix = False self.embedding_path = embedding_path - self.model = None # empty for now - self.model_hash = None - self.sampler = None - self.device = None + self.model = None # empty for now + self.model_hash = None + self.sampler = None + self.device = None self.session_peakmem = None self.base_generator = None - self.seed = None + self.seed = None self.outdir = outdir self.gfpgan = gfpgan self.codeformer = codeformer self.esrgan = esrgan self.free_gpu_mem = free_gpu_mem - self.max_loaded_models = max_loaded_models, + self.max_loaded_models = (max_loaded_models,) self.size_matters = True # used to warn once about large image sizes and VRAM self.txt2mask = None self.safety_checker = None @@ -192,62 +197,77 @@ class Generate: # device to Generate(). However the device was then ignored, so # it wasn't actually doing anything. This logic could be reinstated. device_type = choose_torch_device() - print(f'>> Using device_type {device_type}') + print(f">> Using device_type {device_type}") self.device = torch.device(device_type) if full_precision: - if self.precision != 'auto': - raise ValueError('Remove --full_precision / -F if using --precision') - print('Please remove deprecated --full_precision / -F') - print('If auto config does not work you can use --precision=float32') - self.precision = 'float32' - if self.precision == 'auto': + if self.precision != "auto": + raise ValueError("Remove --full_precision / -F if using --precision") + print("Please remove deprecated --full_precision / -F") + print("If auto config does not work you can use --precision=float32") + self.precision = "float32" + if self.precision == "auto": self.precision = choose_precision(self.device) - Globals.full_precision = self.precision=='float32' + Globals.full_precision = self.precision == "float32" if is_xformers_available(): if not Globals.disable_xformers: - print('>> xformers memory-efficient attention is available and enabled') + print(">> xformers memory-efficient attention is available and enabled") else: - print('>> xformers memory-efficient attention is available but disabled') + print( + ">> xformers memory-efficient attention is available but disabled" + ) else: - print('>> xformers not installed') + print(">> xformers not installed") # model caching system for fast switching - self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models) + self.model_manager = ModelManager( + mconfig, self.device, self.precision, max_loaded_models=max_loaded_models + ) # don't accept invalid models fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME model = model or fallback if not self.model_manager.valid_model(model): - print(f'** "{model}" is not a known model name; falling back to {fallback}.') + print( + f'** "{model}" is not a known model name; falling back to {fallback}.' + ) model = None - self.model_name = model or fallback + self.model_name = model or fallback # for VRAM usage statistics - self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None + self.session_peakmem = ( + torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None + ) transformers.logging.set_verbosity_error() # gets rid of annoying messages about random seed - logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) + logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) # load safety checker if requested if safety_checker: try: - print('>> Initializing safety checker') - from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + print(">> Initializing safety checker") + from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, + ) from transformers import AutoFeatureExtractor + safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_path = global_cache_dir("hub") - self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, - local_files_only=True, - cache_dir=safety_model_path, + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( + safety_model_id, + local_files_only=True, + cache_dir=safety_model_path, ) - self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, - local_files_only=True, - cache_dir=safety_model_path, + self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained( + safety_model_id, + local_files_only=True, + cache_dir=safety_model_path, ) self.safety_checker.to(self.device) except Exception: - print('** An error was encountered while installing the safety checker:') + print( + "** An error was encountered while installing the safety checker:" + ) print(traceback.format_exc()) def prompt2png(self, prompt, outdir, **kwargs): @@ -256,95 +276,95 @@ class Generate: of PNG files, and returns an array of [[filename,seed],[filename,seed]...] Optional named arguments are the same as those passed to Generate and prompt2image() """ - results = self.prompt2image(prompt, **kwargs) + results = self.prompt2image(prompt, **kwargs) pngwriter = PngWriter(outdir) - prefix = pngwriter.unique_prefix() - outputs = [] + prefix = pngwriter.unique_prefix() + outputs = [] for image, seed in results: - name = f'{prefix}.{seed}.png' + name = f"{prefix}.{seed}.png" path = pngwriter.save_image_and_prompt_to_png( - image, dream_prompt=f'{prompt} -S{seed}', name=name) + image, dream_prompt=f"{prompt} -S{seed}", name=name + ) outputs.append([path, seed]) return outputs def txt2img(self, prompt, **kwargs): - outdir = kwargs.pop('outdir', self.outdir) + outdir = kwargs.pop("outdir", self.outdir) return self.prompt2png(prompt, outdir, **kwargs) def img2img(self, prompt, **kwargs): - outdir = kwargs.pop('outdir', self.outdir) + outdir = kwargs.pop("outdir", self.outdir) assert ( - 'init_img' in kwargs - ), 'call to img2img() must include the init_img argument' + "init_img" in kwargs + ), "call to img2img() must include the init_img argument" return self.prompt2png(prompt, outdir, **kwargs) def prompt2image( - self, - # these are common - prompt, - iterations = None, - steps = None, - seed = None, - cfg_scale = None, - ddim_eta = None, - skip_normalize = False, - image_callback = None, - step_callback = None, - width = None, - height = None, - sampler_name = None, - seamless = False, - seamless_axes = {'x','y'}, - log_tokenization = False, - with_variations = None, - variation_amount = 0.0, - threshold = 0.0, - perlin = 0.0, - karras_max = None, - outdir = None, - # these are specific to img2img and inpaint - init_img = None, - init_mask = None, - text_mask = None, - invert_mask = False, - fit = False, - strength = None, - init_color = None, - # these are specific to embiggen (which also relies on img2img args) - embiggen = None, - embiggen_tiles = None, - embiggen_strength = None, - # these are specific to GFPGAN/ESRGAN - gfpgan_strength= 0, - facetool = None, - facetool_strength = 0, - codeformer_fidelity = None, - save_original = False, - upscale = None, - upscale_denoise_str = 0.75, - # this is specific to inpainting and causes more extreme inpainting - inpaint_replace = 0.0, - # This controls the size at which inpaint occurs (scaled up for inpaint, then back down for the result) - inpaint_width = None, - inpaint_height = None, - # This will help match inpainted areas to the original image more smoothly - mask_blur_radius: int = 8, - # Set this True to handle KeyboardInterrupt internally - catch_interrupts = False, - hires_fix = False, - use_mps_noise = False, - # Seam settings for outpainting - seam_size: int = 0, - seam_blur: int = 0, - seam_strength: float = 0.7, - seam_steps: int = 10, - tile_size: int = 32, - infill_method = None, - force_outpaint: bool = False, - enable_image_debugging = False, - - **args, - ): # eat up additional cruft + self, + # these are common + prompt, + iterations=None, + steps=None, + seed=None, + cfg_scale=None, + ddim_eta=None, + skip_normalize=False, + image_callback=None, + step_callback=None, + width=None, + height=None, + sampler_name=None, + seamless=False, + seamless_axes={"x", "y"}, + log_tokenization=False, + with_variations=None, + variation_amount=0.0, + threshold=0.0, + perlin=0.0, + karras_max=None, + outdir=None, + # these are specific to img2img and inpaint + init_img=None, + init_mask=None, + text_mask=None, + invert_mask=False, + fit=False, + strength=None, + init_color=None, + # these are specific to embiggen (which also relies on img2img args) + embiggen=None, + embiggen_tiles=None, + embiggen_strength=None, + # these are specific to GFPGAN/ESRGAN + gfpgan_strength=0, + facetool=None, + facetool_strength=0, + codeformer_fidelity=None, + save_original=False, + upscale=None, + upscale_denoise_str=0.75, + # this is specific to inpainting and causes more extreme inpainting + inpaint_replace=0.0, + # This controls the size at which inpaint occurs (scaled up for inpaint, then back down for the result) + inpaint_width=None, + inpaint_height=None, + # This will help match inpainted areas to the original image more smoothly + mask_blur_radius: int = 8, + # Set this True to handle KeyboardInterrupt internally + catch_interrupts=False, + hires_fix=False, + use_mps_noise=False, + # Seam settings for outpainting + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + infill_method=None, + force_outpaint: bool = False, + enable_image_debugging=False, + **args, + ): # eat up additional cruft self.clear_cuda_stats() """ ldm.generate.prompt2image() is the common entry point for txt2img() and img2img() @@ -401,12 +421,14 @@ class Generate: ddim_eta = ddim_eta or self.ddim_eta iterations = iterations or self.iterations strength = strength or self.strength - outdir = outdir or self.outdir + outdir = outdir or self.outdir self.seed = seed self.log_tokenization = log_tokenization self.step_callback = step_callback self.karras_max = karras_max - self.infill_method = infill_method or infill_methods()[0], # The infill method to use + self.infill_method = ( + infill_method or infill_methods()[0], + ) # The infill method to use with_variations = [] if with_variations is None else with_variations # will instantiate the model or return it from cache @@ -423,33 +445,33 @@ class Generate: else: configure_model_padding(model, seamless, seamless_axes) - assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' - assert threshold >= 0.0, '--threshold must be >=0.0' + assert cfg_scale > 1.0, "CFG_Scale (-C) must be >1.0" + assert threshold >= 0.0, "--threshold must be >=0.0" assert ( 0.0 < strength <= 1.0 - ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0' + ), "img2img and inpaint strength can only work with 0.0 < strength < 1.0" assert ( 0.0 <= variation_amount <= 1.0 - ), '-v --variation_amount must be in [0.0, 1.0]' - assert ( - 0.0 <= perlin <= 1.0 - ), '--perlin must be in [0.0, 1.0]' - assert ( - (embiggen == None and embiggen_tiles == None) or ( - (embiggen != None or embiggen_tiles != None) and init_img != None) - ), 'Embiggen requires an init/input image to be specified' + ), "-v --variation_amount must be in [0.0, 1.0]" + assert 0.0 <= perlin <= 1.0, "--perlin must be in [0.0, 1.0]" + assert (embiggen == None and embiggen_tiles == None) or ( + (embiggen != None or embiggen_tiles != None) and init_img != None + ), "Embiggen requires an init/input image to be specified" if len(with_variations) > 0 or variation_amount > 1.0: - assert seed is not None,\ - 'seed must be specified when using with_variations' + assert seed is not None, "seed must be specified when using with_variations" if variation_amount == 0.0: - assert iterations == 1,\ - 'when using --with_variations, multiple iterations are only possible when using --variation_amount' - assert all(0 <= weight <= 1 for _, weight in with_variations),\ - f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}' + assert ( + iterations == 1 + ), "when using --with_variations, multiple iterations are only possible when using --variation_amount" + assert all( + 0 <= weight <= 1 for _, weight in with_variations + ), f"variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}" width, height, _ = self._resolution_check(width, height, log=True) - assert inpaint_replace >=0.0 and inpaint_replace <= 1.0,'inpaint_replace must be between 0.0 and 1.0' + assert ( + inpaint_replace >= 0.0 and inpaint_replace <= 1.0 + ), "inpaint_replace must be between 0.0 and 1.0" if sampler_name and (sampler_name != self.sampler_name): self.sampler_name = sampler_name @@ -459,12 +481,12 @@ class Generate: prompt = self.huggingface_concepts_library.replace_concepts_with_triggers( prompt, lambda concepts: self.load_huggingface_concepts(concepts), - self.model.textual_inversion_manager.get_all_trigger_strings() + self.model.textual_inversion_manager.get_all_trigger_strings(), ) # bit of a hack to change the cached sampler's karras threshold to # whatever the user asked for - if karras_max is not None and isinstance(self.sampler,KSampler): + if karras_max is not None and isinstance(self.sampler, KSampler): self.sampler.adjust_settings(karras_max=karras_max) tic = time.time() @@ -476,18 +498,24 @@ class Generate: mask_image = None try: - if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device: + if ( + self.free_gpu_mem + and self.model.cond_stage_model.device != self.model.device + ): self.model.cond_stage_model.device = self.model.device self.model.cond_stage_model.to(self.model.device) except AttributeError: - print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.") + print( + ">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser." + ) pass try: uc, c, extra_conditioning_info = get_uc_and_c_and_ec( - prompt, model =self.model, + prompt, + model=self.model, skip_normalize_legacy_blend=skip_normalize, - log_tokens =self.log_tokenization + log_tokens=self.log_tokenization, ) init_image, mask_image = self._make_images( @@ -502,17 +530,21 @@ class Generate: ) # TODO: Hacky selection of operation to perform. Needs to be refactored. - generator = self.select_generator(init_image, mask_image, embiggen, hires_fix, force_outpaint) - - generator.set_variation( - self.seed, variation_amount, with_variations + generator = self.select_generator( + init_image, mask_image, embiggen, hires_fix, force_outpaint ) + + generator.set_variation(self.seed, variation_amount, with_variations) generator.use_mps_noise = use_mps_noise - checker = { - 'checker':self.safety_checker, - 'extractor':self.safety_feature_extractor - } if self.safety_checker else None + checker = ( + { + "checker": self.safety_checker, + "extractor": self.safety_feature_extractor, + } + if self.safety_checker + else None + ) results = generator.generate( prompt, @@ -524,11 +556,11 @@ class Generate: conditioning=(uc, c, extra_conditioning_info), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated - step_callback=step_callback, # called after each intermediate image is generated + step_callback=step_callback, # called after each intermediate image is generated width=width, height=height, - init_img=init_img, # embiggen needs to manipulate from the unmodified init_img - init_image=init_image, # notice that init_image is different from init_img + init_img=init_img, # embiggen needs to manipulate from the unmodified init_img + init_image=init_image, # notice that init_image is different from init_img mask_image=mask_image, strength=strength, threshold=threshold, @@ -539,41 +571,45 @@ class Generate: inpaint_replace=inpaint_replace, mask_blur_radius=mask_blur_radius, safety_checker=checker, - seam_size = seam_size, - seam_blur = seam_blur, - seam_strength = seam_strength, - seam_steps = seam_steps, - tile_size = tile_size, - infill_method = infill_method, - force_outpaint = force_outpaint, - inpaint_height = inpaint_height, - inpaint_width = inpaint_width, - enable_image_debugging = enable_image_debugging, + seam_size=seam_size, + seam_blur=seam_blur, + seam_strength=seam_strength, + seam_steps=seam_steps, + tile_size=tile_size, + infill_method=infill_method, + force_outpaint=force_outpaint, + inpaint_height=inpaint_height, + inpaint_width=inpaint_width, + enable_image_debugging=enable_image_debugging, free_gpu_mem=self.free_gpu_mem, - clear_cuda_cache=self.clear_cuda_cache + clear_cuda_cache=self.clear_cuda_cache, ) if init_color: - self.correct_colors(image_list = results, - reference_image_path = init_color, - image_callback = image_callback) + self.correct_colors( + image_list=results, + reference_image_path=init_color, + image_callback=image_callback, + ) if upscale is not None or facetool_strength > 0: - self.upscale_and_reconstruct(results, - upscale = upscale, - upscale_denoise_str = upscale_denoise_str, - facetool = facetool, - strength = facetool_strength, - codeformer_fidelity = codeformer_fidelity, - save_original = save_original, - image_callback = image_callback) + self.upscale_and_reconstruct( + results, + upscale=upscale, + upscale_denoise_str=upscale_denoise_str, + facetool=facetool, + strength=facetool_strength, + codeformer_fidelity=codeformer_fidelity, + save_original=save_original, + image_callback=image_callback, + ) except KeyboardInterrupt: # Clear the CUDA cache on an exception self.clear_cuda_cache() if catch_interrupts: - print('**Interrupted** Partial results will be returned.') + print("**Interrupted** Partial results will be returned.") else: raise KeyboardInterrupt except RuntimeError: @@ -581,30 +617,24 @@ class Generate: self.clear_cuda_cache() print(traceback.format_exc(), file=sys.stderr) - print('>> Could not generate image.') + print(">> Could not generate image.") toc = time.time() - print('\n>> Usage stats:') - print( - f'>> {len(results)} image(s) generated in', '%4.2fs' % ( - toc - tic) - ) + print("\n>> Usage stats:") + print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic)) self.print_cuda_stats() return results def gather_cuda_stats(self): if self._has_cuda(): self.max_memory_allocated = max( - self.max_memory_allocated, - torch.cuda.max_memory_allocated(self.device) + self.max_memory_allocated, torch.cuda.max_memory_allocated(self.device) ) self.memory_allocated = max( - self.memory_allocated, - torch.cuda.memory_allocated(self.device) + self.memory_allocated, torch.cuda.memory_allocated(self.device) ) self.session_peakmem = max( - self.session_peakmem, - torch.cuda.max_memory_allocated(self.device) + self.session_peakmem, torch.cuda.max_memory_allocated(self.device) ) def clear_cuda_cache(self): @@ -620,35 +650,35 @@ class Generate: if self._has_cuda(): self.gather_cuda_stats() print( - '>> Max VRAM used for this generation:', - '%4.2fG.' % (self.max_memory_allocated / 1e9), - 'Current VRAM utilization:', - '%4.2fG' % (self.memory_allocated / 1e9), + ">> Max VRAM used for this generation:", + "%4.2fG." % (self.max_memory_allocated / 1e9), + "Current VRAM utilization:", + "%4.2fG" % (self.memory_allocated / 1e9), ) print( - '>> Max VRAM used since script start: ', - '%4.2fG' % (self.session_peakmem / 1e9), + ">> Max VRAM used since script start: ", + "%4.2fG" % (self.session_peakmem / 1e9), ) # this needs to be generalized to all sorts of postprocessors, which should be wrapped # in a nice harmonized call signature. For now we have a bunch of if/elses! def apply_postprocessor( - self, - image_path, - tool = 'gfpgan', # one of 'upscale', 'gfpgan', 'codeformer', 'outpaint', or 'embiggen' - facetool_strength = 0.0, - codeformer_fidelity = 0.75, - upscale = None, - upscale_denoise_str = 0.75, - out_direction = None, - outcrop = [], - save_original = True, # to get new name - callback = None, - opt = None, - ): + self, + image_path, + tool="gfpgan", # one of 'upscale', 'gfpgan', 'codeformer', 'outpaint', or 'embiggen' + facetool_strength=0.0, + codeformer_fidelity=0.75, + upscale=None, + upscale_denoise_str=0.75, + out_direction=None, + outcrop=[], + save_original=True, # to get new name + callback=None, + opt=None, + ): # retrieve the seed from the image; - seed = None + seed = None prompt = None args = metadata_from_png(image_path) @@ -656,13 +686,13 @@ class Generate: if seed is None or seed < 0: seed = random.randrange(0, np.iinfo(np.uint32).max) - prompt = opt.prompt or args.prompt or '' + prompt = opt.prompt or args.prompt or "" print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}') # try to reuse the same filename prefix as the original file. # we take everything up to the first period prefix = None - m = re.match(r'^([^.]+)\.',os.path.basename(image_path)) + m = re.match(r"^([^.]+)\.", os.path.basename(image_path)) if m: prefix = m.groups()[0] @@ -672,99 +702,106 @@ class Generate: # used by multiple postfixers # todo: cross-attention control uc, c, extra_conditioning_info = get_uc_and_c_and_ec( - prompt, model=self.model, + prompt, + model=self.model, skip_normalize_legacy_blend=opt.skip_normalize, - log_tokens=ldm.invoke.conditioning.log_tokenization + log_tokens=ldm.invoke.conditioning.log_tokenization, ) - if tool in ('gfpgan','codeformer','upscale'): - if tool == 'gfpgan': - facetool = 'gfpgan' - elif tool == 'codeformer': - facetool = 'codeformer' - elif tool == 'upscale': - facetool = 'gfpgan' # but won't be run + if tool in ("gfpgan", "codeformer", "upscale"): + if tool == "gfpgan": + facetool = "gfpgan" + elif tool == "codeformer": + facetool = "codeformer" + elif tool == "upscale": + facetool = "gfpgan" # but won't be run facetool_strength = 0 return self.upscale_and_reconstruct( - [[image,seed]], - facetool = facetool, - strength = facetool_strength, - codeformer_fidelity = codeformer_fidelity, - save_original = save_original, - upscale = upscale, - upscale_denoise_str = upscale_denoise_str, - image_callback = callback, - prefix = prefix, + [[image, seed]], + facetool=facetool, + strength=facetool_strength, + codeformer_fidelity=codeformer_fidelity, + save_original=save_original, + upscale=upscale, + upscale_denoise_str=upscale_denoise_str, + image_callback=callback, + prefix=prefix, ) - elif tool == 'outcrop': + elif tool == "outcrop": from ldm.invoke.restoration.outcrop import Outcrop + extend_instructions = {} - for direction,pixels in _pairwise(opt.outcrop): + for direction, pixels in _pairwise(opt.outcrop): try: - extend_instructions[direction]=int(pixels) + extend_instructions[direction] = int(pixels) except ValueError: - print('** invalid extension instruction. Use ..., as in "top 64 left 128 right 64 bottom 64"') + print( + '** invalid extension instruction. Use ..., as in "top 64 left 128 right 64 bottom 64"' + ) opt.seed = seed opt.prompt = prompt if len(extend_instructions) > 0: - restorer = Outcrop(image,self,) - return restorer.process ( + restorer = Outcrop( + image, + self, + ) + return restorer.process( extend_instructions, - opt = opt, - orig_opt = args, - image_callback = callback, - prefix = prefix, + opt=opt, + orig_opt=args, + image_callback=callback, + prefix=prefix, ) - elif tool == 'embiggen': + elif tool == "embiggen": # fetch the metadata from the image generator = self.select_generator(embiggen=True) opt.strength = opt.embiggen_strength or 0.40 - print(f'>> Setting img2img strength to {opt.strength} for happy embiggening') + print( + f">> Setting img2img strength to {opt.strength} for happy embiggening" + ) generator.generate( prompt, - sampler = self.sampler, - steps = opt.steps, - cfg_scale = opt.cfg_scale, - ddim_eta = self.ddim_eta, - conditioning= (uc, c, extra_conditioning_info), - init_img = image_path, # not the Image! (sigh) - init_image = image, # embiggen wants both! (sigh) - strength = opt.strength, - width = opt.width, - height = opt.height, - embiggen = opt.embiggen, - embiggen_tiles = opt.embiggen_tiles, - embiggen_strength = opt.embiggen_strength, - image_callback = callback, + sampler=self.sampler, + steps=opt.steps, + cfg_scale=opt.cfg_scale, + ddim_eta=self.ddim_eta, + conditioning=(uc, c, extra_conditioning_info), + init_img=image_path, # not the Image! (sigh) + init_image=image, # embiggen wants both! (sigh) + strength=opt.strength, + width=opt.width, + height=opt.height, + embiggen=opt.embiggen, + embiggen_tiles=opt.embiggen_tiles, + embiggen_strength=opt.embiggen_strength, + image_callback=callback, ) - elif tool == 'outpaint': + elif tool == "outpaint": from ldm.invoke.restoration.outpaint import Outpaint - restorer = Outpaint(image,self) - return restorer.process( - opt, - args, - image_callback = callback, - prefix = prefix - ) + + restorer = Outpaint(image, self) + return restorer.process(opt, args, image_callback=callback, prefix=prefix) elif tool is None: - print('* please provide at least one postprocessing option, such as -G or -U') + print( + "* please provide at least one postprocessing option, such as -G or -U" + ) return None else: - print(f'* postprocessing tool {tool} is not yet supported') + print(f"* postprocessing tool {tool} is not yet supported") return None def select_generator( - self, - init_image:Image.Image=None, - mask_image:Image.Image=None, - embiggen:bool=False, - hires_fix:bool=False, - force_outpaint:bool=False, + self, + init_image: Image.Image = None, + mask_image: Image.Image = None, + embiggen: bool = False, + hires_fix: bool = False, + force_outpaint: bool = False, ): inpainting_model_in_use = self.sampler.uses_inpainting_model() @@ -786,40 +823,46 @@ class Generate: return self._make_txt2img() def _make_images( - self, - img, - mask, - width, - height, - fit=False, - text_mask=None, - invert_mask=False, - force_outpaint=False, + self, + img, + mask, + width, + height, + fit=False, + text_mask=None, + invert_mask=False, + force_outpaint=False, ): - init_image = None - init_mask = None + init_image = None + init_mask = None if not img: return None, None image = self._load_img(img) if image.width < self.width and image.height < self.height: - print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions') + print( + f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions" + ) # if image has a transparent area and no mask was provided, then try to generate mask if self._has_transparency(image): self._transparency_check_and_warning(image, mask, force_outpaint) init_mask = self._create_init_mask(image, width, height, fit=fit) - if (image.width * image.height) > (self.width * self.height) and self.size_matters: - print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") + if (image.width * image.height) > ( + self.width * self.height + ) and self.size_matters: + print( + ">> This input is larger than your defaults. If you run out of memory, please use a smaller image." + ) self.size_matters = False - init_image = self._create_init_image(image,width,height,fit=fit) + init_image = self._create_init_image(image, width, height, fit=fit) if mask: mask_image = self._load_img(mask) - init_mask = self._create_init_mask(mask_image,width,height,fit=fit) + init_mask = self._create_init_mask(mask_image, width, height, fit=fit) elif text_mask: init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) @@ -827,47 +870,47 @@ class Generate: if init_mask and invert_mask: init_mask = ImageOps.invert(init_mask) - return init_image,init_mask + return init_image, init_mask def _make_base(self): - return self._load_generator('','Generator') + return self._load_generator("", "Generator") def _make_txt2img(self): - return self._load_generator('.txt2img','Txt2Img') + return self._load_generator(".txt2img", "Txt2Img") def _make_img2img(self): - return self._load_generator('.img2img','Img2Img') + return self._load_generator(".img2img", "Img2Img") def _make_embiggen(self): - return self._load_generator('.embiggen','Embiggen') + return self._load_generator(".embiggen", "Embiggen") def _make_txt2img2img(self): - return self._load_generator('.txt2img2img','Txt2Img2Img') + return self._load_generator(".txt2img2img", "Txt2Img2Img") def _make_inpaint(self): - return self._load_generator('.inpaint','Inpaint') + return self._load_generator(".inpaint", "Inpaint") def _make_omnibus(self): - return self._load_generator('.omnibus','Omnibus') + return self._load_generator(".omnibus", "Omnibus") def _load_generator(self, module, class_name): if self.is_legacy_model(self.model_name): - mn = f'ldm.invoke.ckpt_generator{module}' - cn = f'Ckpt{class_name}' + mn = f"ldm.invoke.ckpt_generator{module}" + cn = f"Ckpt{class_name}" else: - mn = f'ldm.invoke.generator{module}' + mn = f"ldm.invoke.generator{module}" cn = class_name module = importlib.import_module(mn) - constructor = getattr(module,cn) + constructor = getattr(module, cn) return constructor(self.model, self.precision) def load_model(self): - ''' + """ preload model identified in self.model_name - ''' + """ return self.set_model(self.model_name) - def set_model(self,model_name): + def set_model(self, model_name): """ Given the name of a model defined in models.yaml, will load and initialize it and return the model object. Previously-used models will be cached. @@ -884,7 +927,9 @@ class Generate: # the model cache does the loading and offloading cache = self.model_manager if not cache.valid_model(model_name): - raise KeyError(f'** "{model_name}" is not a known model name. Cannot change.') + raise KeyError( + f'** "{model_name}" is not a known model name. Cannot change.' + ) cache.print_vram_usage() @@ -897,20 +942,20 @@ class Generate: try: model_data = cache.get_model(model_name) except Exception as e: - print(f'** model {model_name} could not be loaded: {str(e)}') + print(f"** model {model_name} could not be loaded: {str(e)}") print(traceback.format_exc(), file=sys.stderr) if previous_model_name is None: raise e - print(f'** trying to reload previous model') - model_data = cache.get_model(previous_model_name) # load previous + print("** trying to reload previous model") + model_data = cache.get_model(previous_model_name) # load previous if model_data is None: raise e model_name = previous_model_name - self.model = model_data['model'] - self.width = model_data['width'] - self.height= model_data['height'] - self.model_hash = model_data['hash'] + self.model = model_data["model"] + self.width = model_data["width"] + self.height = model_data["height"] + self.model_hash = model_data["hash"] # uncache generators so they pick up new models self.generators = {} @@ -920,35 +965,37 @@ class Generate: for root, _, files in os.walk(self.embedding_path): for name in files: ti_path = os.path.join(root, name) - self.model.textual_inversion_manager.load_textual_inversion(ti_path, - defer_injecting_tokens=True) - print(f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}') + self.model.textual_inversion_manager.load_textual_inversion( + ti_path, defer_injecting_tokens=True + ) + print( + f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}' + ) self.model_name = model_name self._set_sampler() # requires self.model_name to be set first return self.model - def load_huggingface_concepts(self, concepts:list[str]): + def load_huggingface_concepts(self, concepts: list[str]): self.model.textual_inversion_manager.load_huggingface_concepts(concepts) @property def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary: return self.model.textual_inversion_manager.hf_concepts_library - def correct_colors(self, - image_list, - reference_image_path, - image_callback = None): + @property + def embedding_trigger_strings(self) -> List[str]: + return self.model.textual_inversion_manager.get_all_trigger_strings() + + def correct_colors(self, image_list, reference_image_path, image_callback=None): reference_image = Image.open(reference_image_path) - correction_target = cv2.cvtColor(np.asarray(reference_image), - cv2.COLOR_RGB2LAB) + correction_target = cv2.cvtColor(np.asarray(reference_image), cv2.COLOR_RGB2LAB) for r in image_list: image, seed = r - image = cv2.cvtColor(np.asarray(image), - cv2.COLOR_RGB2LAB) - image = skimage.exposure.match_histograms(image, - correction_target, - channel_axis=2) + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2LAB) + image = skimage.exposure.match_histograms( + image, correction_target, channel_axis=2 + ) image = Image.fromarray( cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8") ) @@ -957,34 +1004,46 @@ class Generate: else: r[0] = image - def upscale_and_reconstruct(self, - image_list, - facetool = 'gfpgan', - upscale = None, - upscale_denoise_str = 0.75, - strength = 0.0, - codeformer_fidelity = 0.75, - save_original = False, - image_callback = None, - prefix = None, + def upscale_and_reconstruct( + self, + image_list, + facetool="gfpgan", + upscale=None, + upscale_denoise_str=0.75, + strength=0.0, + codeformer_fidelity=0.75, + save_original=False, + image_callback=None, + prefix=None, ): - for r in image_list: image, seed = r try: if strength > 0: if self.gfpgan is not None or self.codeformer is not None: - if facetool == 'gfpgan': + if facetool == "gfpgan": if self.gfpgan is None: - print('>> GFPGAN not found. Face restoration is disabled.') + print( + ">> GFPGAN not found. Face restoration is disabled." + ) else: - image = self.gfpgan.process(image, strength, seed) - if facetool == 'codeformer': + image = self.gfpgan.process(image, strength, seed) + if facetool == "codeformer": if self.codeformer is None: - print('>> CodeFormer not found. Face restoration is disabled.') + print( + ">> CodeFormer not found. Face restoration is disabled." + ) else: - cf_device = 'cpu' if str(self.device) == 'mps' else self.device - image = self.codeformer.process(image=image, strength=strength, device=cf_device, seed=seed, fidelity=codeformer_fidelity) + cf_device = ( + "cpu" if str(self.device) == "mps" else self.device + ) + image = self.codeformer.process( + image=image, + strength=strength, + device=cf_device, + seed=seed, + fidelity=codeformer_fidelity, + ) else: print(">> Face Restoration is disabled.") if upscale is not None: @@ -992,12 +1051,17 @@ class Generate: if len(upscale) < 2: upscale.append(0.75) image = self.esrgan.process( - image, upscale[1], seed, int(upscale[0]), denoise_str=upscale_denoise_str) + image, + upscale[1], + seed, + int(upscale[0]), + denoise_str=upscale_denoise_str, + ) else: print(">> ESRGAN is disabled. Image not upscaled.") except Exception as e: print( - f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}' + f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}" ) if image_callback is not None: @@ -1005,22 +1069,26 @@ class Generate: else: r[0] = image - def apply_textmask(self, image_path:str, prompt:str, callback, threshold:float=0.5): - assert os.path.exists(image_path), f'** "{image_path}" not found. Please enter the name of an existing image file to mask **' - basename,_ = os.path.splitext(os.path.basename(image_path)) + def apply_textmask( + self, image_path: str, prompt: str, callback, threshold: float = 0.5 + ): + assert os.path.exists( + image_path + ), f'** "{image_path}" not found. Please enter the name of an existing image file to mask **' + basename, _ = os.path.splitext(os.path.basename(image_path)) if self.txt2mask is None: - self.txt2mask = Txt2Mask(device = self.device, refined=True) - segmented = self.txt2mask.segment(image_path,prompt) + self.txt2mask = Txt2Mask(device=self.device, refined=True) + segmented = self.txt2mask.segment(image_path, prompt) trans = segmented.to_transparent() inverse = segmented.to_transparent(invert=True) mask = segmented.to_mask(threshold) path_filter = re.compile(r'[<>:"/\\|?*]') - safe_prompt = path_filter.sub('_', prompt)[:50].rstrip(' .') + safe_prompt = path_filter.sub("_", prompt)[:50].rstrip(" .") - callback(trans,f'{safe_prompt}.deselected',use_prefix=basename) - callback(inverse,f'{safe_prompt}.selected',use_prefix=basename) - callback(mask,f'{safe_prompt}.masked',use_prefix=basename) + callback(trans, f"{safe_prompt}.deselected", use_prefix=basename) + callback(inverse, f"{safe_prompt}.selected", use_prefix=basename) + callback(mask, f"{safe_prompt}.masked", use_prefix=basename) # to help WebGUI - front end to generator util function def sample_to_image(self, samples): @@ -1029,7 +1097,7 @@ class Generate: def sample_to_lowres_estimated_image(self, samples): return self._make_base().sample_to_lowres_estimated_image(samples) - def is_legacy_model(self,model_name)->bool: + def is_legacy_model(self, model_name) -> bool: return self.model_manager.is_legacy(model_name) def _set_sampler(self): @@ -1041,29 +1109,31 @@ class Generate: # very repetitive code - can this be simplified? The KSampler names are # consistent, at least def _set_sampler_legacy(self): - msg = f'>> Setting Sampler to {self.sampler_name}' - if self.sampler_name == 'plms': + msg = f">> Setting Sampler to {self.sampler_name}" + if self.sampler_name == "plms": self.sampler = PLMSSampler(self.model, device=self.device) - elif self.sampler_name == 'ddim': + elif self.sampler_name == "ddim": self.sampler = DDIMSampler(self.model, device=self.device) - elif self.sampler_name == 'k_dpm_2_a': - self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device) - elif self.sampler_name == 'k_dpm_2': - self.sampler = KSampler(self.model, 'dpm_2', device=self.device) - elif self.sampler_name == 'k_dpmpp_2_a': - self.sampler = KSampler(self.model, 'dpmpp_2s_ancestral', device=self.device) - elif self.sampler_name == 'k_dpmpp_2': - self.sampler = KSampler(self.model, 'dpmpp_2m', device=self.device) - elif self.sampler_name == 'k_euler_a': - self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device) - elif self.sampler_name == 'k_euler': - self.sampler = KSampler(self.model, 'euler', device=self.device) - elif self.sampler_name == 'k_heun': - self.sampler = KSampler(self.model, 'heun', device=self.device) - elif self.sampler_name == 'k_lms': - self.sampler = KSampler(self.model, 'lms', device=self.device) + elif self.sampler_name == "k_dpm_2_a": + self.sampler = KSampler(self.model, "dpm_2_ancestral", device=self.device) + elif self.sampler_name == "k_dpm_2": + self.sampler = KSampler(self.model, "dpm_2", device=self.device) + elif self.sampler_name == "k_dpmpp_2_a": + self.sampler = KSampler( + self.model, "dpmpp_2s_ancestral", device=self.device + ) + elif self.sampler_name == "k_dpmpp_2": + self.sampler = KSampler(self.model, "dpmpp_2m", device=self.device) + elif self.sampler_name == "k_euler_a": + self.sampler = KSampler(self.model, "euler_ancestral", device=self.device) + elif self.sampler_name == "k_euler": + self.sampler = KSampler(self.model, "euler", device=self.device) + elif self.sampler_name == "k_heun": + self.sampler = KSampler(self.model, "heun", device=self.device) + elif self.sampler_name == "k_lms": + self.sampler = KSampler(self.model, "lms", device=self.device) else: - msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' + msg = f">> Unsupported Sampler: {self.sampler_name}, Defaulting to plms" self.sampler = PLMSSampler(self.model, device=self.device) print(msg) @@ -1090,51 +1160,59 @@ class Generate: if self.sampler_name in scheduler_map: sampler_class = scheduler_map[self.sampler_name] - msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})' + msg = ( + f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})" + ) self.sampler = sampler_class.from_config(self.model.scheduler.config) else: - msg = (f'>> Unsupported Sampler: {self.sampler_name} ' - f'Defaulting to {default}') + msg = ( + f">> Unsupported Sampler: {self.sampler_name} " + f"Defaulting to {default}" + ) self.sampler = default print(msg) - if not hasattr(self.sampler, 'uses_inpainting_model'): + if not hasattr(self.sampler, "uses_inpainting_model"): # FIXME: terrible kludge! self.sampler.uses_inpainting_model = lambda: False - def _load_img(self, img)->Image: + def _load_img(self, img) -> Image: if isinstance(img, Image.Image): image = img - print( - f'>> using provided input image of size {image.width}x{image.height}' - ) + print(f">> using provided input image of size {image.width}x{image.height}") elif isinstance(img, str): - assert os.path.exists(img), f'>> {img}: File not found' + assert os.path.exists(img), f">> {img}: File not found" image = Image.open(img) print( - f'>> loaded input image of size {image.width}x{image.height} from {img}' + f">> loaded input image of size {image.width}x{image.height} from {img}" ) else: image = Image.open(img) - print( - f'>> loaded input image of size {image.width}x{image.height}' - ) + print(f">> loaded input image of size {image.width}x{image.height}") image = ImageOps.exif_transpose(image) return image def _create_init_image(self, image: Image.Image, width, height, fit=True): - if image.mode != 'RGBA': - image = image.convert('RGBA') - image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) + if image.mode != "RGBA": + image = image.convert("RGBA") + image = ( + self._fit_image(image, (width, height)) + if fit + else self._squeeze_image(image) + ) return image def _create_init_mask(self, image, width, height, fit=True): # convert into a black/white mask image = self._image_to_mask(image) - image = image.convert('RGB') - image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) + image = image.convert("RGB") + image = ( + self._fit_image(image, (width, height)) + if fit + else self._squeeze_image(image) + ) return image # The mask is expected to have the region to be inpainted @@ -1142,10 +1220,10 @@ class Generate: # image with the transparent part black. def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image: # Obtain the mask from the transparency channel - if mask_image.mode == 'L': + if mask_image.mode == "L": mask = mask_image - elif mask_image.mode in ('RGB', 'P'): - mask = mask_image.convert('L') + elif mask_image.mode in ("RGB", "P"): + mask = mask_image.convert("L") else: # Obtain the mask from the transparency channel mask = Image.new(mode="L", size=mask_image.size, color=255) @@ -1154,16 +1232,20 @@ class Generate: mask = ImageOps.invert(mask) return mask - def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image: + def _txt2mask( + self, image: Image, text_mask: list, width, height, fit=True + ) -> Image: prompt = text_mask[0] - confidence_level = text_mask[1] if len(text_mask)>1 else 0.5 + confidence_level = text_mask[1] if len(text_mask) > 1 else 0.5 if self.txt2mask is None: - self.txt2mask = Txt2Mask(device = self.device) + self.txt2mask = Txt2Mask(device=self.device) segmented = self.txt2mask.segment(image, prompt) mask = segmented.to_mask(float(confidence_level)) - mask = mask.convert('RGB') - mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask) + mask = mask.convert("RGB") + mask = ( + self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask) + ) return mask def _has_transparency(self, image): @@ -1180,8 +1262,8 @@ class Generate: return True return False - def _check_for_erasure(self, image:Image.Image)->bool: - if image.mode not in ('RGBA','RGB'): + def _check_for_erasure(self, image: Image.Image) -> bool: + if image.mode not in ("RGBA", "RGB"): return False width, height = image.size pixdata = image.load() @@ -1190,20 +1272,20 @@ class Generate: for x in range(width): if pixdata[x, y][3] == 0: r, g, b, _ = pixdata[x, y] - if (r, g, b) != (0, 0, 0) and \ - (r, g, b) != (255, 255, 255): + if (r, g, b) != (0, 0, 0) and (r, g, b) != (255, 255, 255): colored += 1 return colored == 0 - def _transparency_check_and_warning(self,image, mask, force_outpaint=False): + def _transparency_check_and_warning(self, image, mask, force_outpaint=False): if not mask: print( - '>> Initial image has transparent areas. Will inpaint in these regions.') + ">> Initial image has transparent areas. Will inpaint in these regions." + ) if (not force_outpaint) and self._check_for_erasure(image): print( - '>> WARNING: Colors underneath the transparent region seem to have been erased.\n', - '>> Inpainting will be suboptimal. Please preserve the colors when making\n', - '>> a transparency mask, or provide mask explicitly using --init_mask (-M).' + ">> WARNING: Colors underneath the transparent region seem to have been erased.\n", + ">> Inpainting will be suboptimal. Please preserve the colors when making\n", + ">> a transparency mask, or provide mask explicitly using --init_mask (-M).", ) def _squeeze_image(self, image): @@ -1214,13 +1296,11 @@ class Generate: def _fit_image(self, image, max_dimensions): w, h = max_dimensions - print( - f'>> image will be resized to fit inside a box {w}x{h} in size.' - ) + print(f">> image will be resized to fit inside a box {w}x{h} in size.") # note that InitImageResizer does the multiple of 64 truncation internally image = InitImageResizer(image).resize(width=w, height=h) print( - f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}' + f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}" ) return image @@ -1232,30 +1312,32 @@ class Generate: if h != height or w != width: if log: print( - f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}' + f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}" ) height = h width = w resize_needed = True return width, height, resize_needed - def _has_cuda(self): - return self.device.type == 'cuda' + return self.device.type == "cuda" - def write_intermediate_images(self,modulus,path): + def write_intermediate_images(self, modulus, path): counter = -1 if not os.path.exists(path): os.makedirs(path) + def callback(img): nonlocal counter counter += 1 if counter % modulus != 0: - return; + return image = self.sample_to_image(img) - image.save(os.path.join(path,f'{counter:03}.png'),'PNG') + image.save(os.path.join(path, f"{counter:03}.png"), "PNG") + return callback + def _pairwise(iterable): "s -> (s0, s1), (s2, s3), (s4, s5), ..." a = iter(iterable) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index 32c6d816be..d56984caf3 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -1,9 +1,8 @@ import os import re -import sys import shlex +import sys import traceback - from argparse import Namespace from pathlib import Path from typing import Optional, Union @@ -11,41 +10,47 @@ from typing import Optional, Union if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -from ldm.invoke.globals import Globals +import click # type: ignore +import pyparsing # type: ignore + +import ldm.invoke from ldm.generate import Generate -from ldm.invoke.prompt_parser import PromptParser -from ldm.invoke.readline import get_completer, Completer -from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png -from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata +from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps, + metadata_from_png) +from ldm.invoke.globals import Globals from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from ldm.invoke.model_manager import ModelManager - -import click # type: ignore -import ldm.invoke -import pyparsing # type: ignore +from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata +from ldm.invoke.prompt_parser import PromptParser +from ldm.invoke.readline import Completer, get_completer # global used in multiple functions (fix) infile = None + def main(): """Initialize command-line parsers and the diffusion model""" global infile - opt = Args() + opt = Args() args = opt.parse_args() if not args: sys.exit(-1) if args.laion400m: - print('--laion400m flag has been deprecated. Please use --model laion400m instead.') + print( + "--laion400m flag has been deprecated. Please use --model laion400m instead." + ) sys.exit(-1) if args.weights: - print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.') + print( + "--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead." + ) sys.exit(-1) if args.max_loaded_models is not None: if args.max_loaded_models <= 0: - print('--max_loaded_models must be >= 1; using 1') + print("--max_loaded_models must be >= 1; using 1") args.max_loaded_models = 1 # alert - setting a few globals here @@ -55,36 +60,42 @@ def main(): Globals.disable_xformers = not args.xformers Globals.ckpt_convert = args.ckpt_convert - print(f'>> Internet connectivity is {Globals.internet_available}') + print(f">> Internet connectivity is {Globals.internet_available}") if not args.conf: - config_file = os.path.join(Globals.root,'configs','models.yaml') + config_file = os.path.join(Globals.root, "configs", "models.yaml") if not os.path.exists(config_file): - report_model_error(opt, FileNotFoundError(f"The file {config_file} could not be found.")) + report_model_error( + opt, FileNotFoundError(f"The file {config_file} could not be found.") + ) - print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}') + print(f">> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}") print(f'>> InvokeAI runtime directory is "{Globals.root}"') # loading here to avoid long delays on startup - from ldm.generate import Generate - # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported import transformers # type: ignore + + from ldm.generate import Generate + transformers.logging.set_verbosity_error() import diffusers + diffusers.logging.set_verbosity_error() # Loading Face Restoration and ESRGAN Modules - gfpgan,codeformer,esrgan = load_face_restoration(opt) + gfpgan, codeformer, esrgan = load_face_restoration(opt) # normalize the config directory relative to root if not os.path.isabs(opt.conf): - opt.conf = os.path.normpath(os.path.join(Globals.root,opt.conf)) + opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf)) if opt.embeddings: if not os.path.isabs(opt.embedding_path): - embedding_path = os.path.normpath(os.path.join(Globals.root,opt.embedding_path)) + embedding_path = os.path.normpath( + os.path.join(Globals.root, opt.embedding_path) + ) else: embedding_path = opt.embedding_path else: @@ -97,35 +108,35 @@ def main(): if opt.infile: try: if os.path.isfile(opt.infile): - infile = open(opt.infile, 'r', encoding='utf-8') - elif opt.infile == '-': # stdin + infile = open(opt.infile, "r", encoding="utf-8") + elif opt.infile == "-": # stdin infile = sys.stdin else: - raise FileNotFoundError(f'{opt.infile} not found.') + raise FileNotFoundError(f"{opt.infile} not found.") except (FileNotFoundError, IOError) as e: - print(f'{e}. Aborting.') + print(f"{e}. Aborting.") sys.exit(-1) # creating a Generate object: try: gen = Generate( - conf = opt.conf, - model = opt.model, - sampler_name = opt.sampler_name, - embedding_path = embedding_path, - full_precision = opt.full_precision, - precision = opt.precision, + conf=opt.conf, + model=opt.model, + sampler_name=opt.sampler_name, + embedding_path=embedding_path, + full_precision=opt.full_precision, + precision=opt.precision, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan, free_gpu_mem=opt.free_gpu_mem, safety_checker=opt.safety_checker, max_loaded_models=opt.max_loaded_models, - ) + ) except (FileNotFoundError, TypeError, AssertionError) as e: - report_model_error(opt,e) + report_model_error(opt, e) except (IOError, KeyError) as e: - print(f'{e}. Aborting.') + print(f"{e}. Aborting.") sys.exit(-1) if opt.seamless: @@ -160,11 +171,14 @@ def main(): try: main_loop(gen, opt) except KeyboardInterrupt: - print(f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}') + print( + f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}' + ) except Exception: print(">> An error occurred:") traceback.print_exc() + # TODO: main_loop() has gotten busy. Needs to be refactored. def main_loop(gen, opt): """prompt/read/execute loop""" @@ -177,23 +191,22 @@ def main_loop(gen, opt): # The readline completer reads history from the .dream_history file located in the # output directory specified at the time of script launch. We do not currently support # changing the history file midstream when the output directory is changed. - completer = get_completer(opt, models=gen.model_manager.list_models()) + completer = get_completer(opt, models=gen.model_manager.list_models()) set_default_output_dir(opt, completer) if gen.model: add_embedding_terms(gen, completer) - output_cntr = completer.get_current_history_length()+1 + output_cntr = completer.get_current_history_length() + 1 # os.pathconf is not available on Windows - if hasattr(os, 'pathconf'): - path_max = os.pathconf(opt.outdir, 'PC_PATH_MAX') - name_max = os.pathconf(opt.outdir, 'PC_NAME_MAX') + if hasattr(os, "pathconf"): + path_max = os.pathconf(opt.outdir, "PC_PATH_MAX") + name_max = os.pathconf(opt.outdir, "PC_NAME_MAX") else: path_max = 260 name_max = 255 while not done: - - operation = 'generate' + operation = "generate" try: command = get_next_command(infile, gen.model_name) @@ -206,17 +219,17 @@ def main_loop(gen, opt): if not command.strip(): continue - if command.startswith(('#', '//')): + if command.startswith(("#", "//")): continue - if len(command.strip()) == 1 and command.startswith('q'): + if len(command.strip()) == 1 and command.startswith("q"): done = True break - if not command.startswith('!history'): + if not command.startswith("!history"): completer.add_history(command) - if command.startswith('!'): + if command.startswith("!"): command, operation = do_command(command, gen, opt, completer) if operation is None: @@ -228,14 +241,14 @@ def main_loop(gen, opt): if opt.init_img: try: if not opt.prompt: - oldargs = metadata_from_png(opt.init_img) + oldargs = metadata_from_png(opt.init_img) opt.prompt = oldargs.prompt print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}') except (OSError, AttributeError, KeyError): pass if len(opt.prompt) == 0: - opt.prompt = '' + opt.prompt = "" # width and height are set by model if not specified if not opt.width: @@ -244,36 +257,35 @@ def main_loop(gen, opt): opt.height = gen.height # retrieve previous value of init image if requested - if opt.init_img is not None and re.match('^-\\d+$', opt.init_img): + if opt.init_img is not None and re.match("^-\\d+$", opt.init_img): try: opt.init_img = last_results[int(opt.init_img)][0] - print(f'>> Reusing previous image {opt.init_img}') + print(f">> Reusing previous image {opt.init_img}") except IndexError: - print( - f'>> No previous initial image at position {opt.init_img} found') + print(f">> No previous initial image at position {opt.init_img} found") opt.init_img = None continue # the outdir can change with each command, so we adjust it here - set_default_output_dir(opt,completer) + set_default_output_dir(opt, completer) # try to relativize pathnames - for attr in ('init_img','init_mask','init_color'): - if getattr(opt,attr) and not os.path.exists(getattr(opt,attr)): - basename = getattr(opt,attr) - path = os.path.join(opt.outdir,basename) - setattr(opt,attr,path) + for attr in ("init_img", "init_mask", "init_color"): + if getattr(opt, attr) and not os.path.exists(getattr(opt, attr)): + basename = getattr(opt, attr) + path = os.path.join(opt.outdir, basename) + setattr(opt, attr, path) # retrieve previous value of seed if requested # Exception: for postprocess operations negative seed values # mean "discard the original seed and generate a new one" # (this is a non-obvious hack and needs to be reworked) - if opt.seed is not None and opt.seed < 0 and operation != 'postprocess': + if opt.seed is not None and opt.seed < 0 and operation != "postprocess": try: opt.seed = last_results[opt.seed][1] - print(f'>> Reusing previous seed {opt.seed}') + print(f">> Reusing previous seed {opt.seed}") except IndexError: - print(f'>> No previous seed at position {opt.seed} found') + print(f">> No previous seed at position {opt.seed} found") opt.seed = None continue @@ -283,13 +295,13 @@ def main_loop(gen, opt): if opt.with_variations is not None: opt.with_variations = split_variations(opt.with_variations) - if opt.prompt_as_dir and operation == 'generate': + if opt.prompt_as_dir and operation == "generate": # sanitize the prompt to a valid folder name - subdir = path_filter.sub('_', opt.prompt)[:name_max].rstrip(' .') + subdir = path_filter.sub("_", opt.prompt)[:name_max].rstrip(" .") # truncate path to maximum allowed length # 39 is the length of '######.##########.##########-##.png', plus two separators and a NUL - subdir = subdir[:(path_max - 39 - len(os.path.abspath(opt.outdir)))] + subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))] current_outdir = os.path.join(opt.outdir, subdir) print('Writing files to directory: "' + current_outdir + '"') @@ -305,14 +317,26 @@ def main_loop(gen, opt): # Here is where the images are actually generated! last_results = [] try: - file_writer = PngWriter(current_outdir) - results = [] # list of filename, prompt pairs - grid_images = dict() # seed -> Image, only used if `opt.grid` + file_writer = PngWriter(current_outdir) + results = [] # list of filename, prompt pairs + grid_images = dict() # seed -> Image, only used if `opt.grid` prior_variations = opt.with_variations or [] prefix = file_writer.unique_prefix() - step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None + step_callback = ( + make_step_callback(gen, opt, prefix) + if opt.save_intermediates > 0 + else None + ) - def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None): + def image_writer( + image, + seed, + upscaled=False, + first_seed=None, + use_prefix=None, + prompt_in=None, + attention_maps_image=None, + ): # note the seed is the seed of the current image # the first_seed is the original seed that noise is added to # when the -v switch is used to generate variations @@ -323,25 +347,31 @@ def main_loop(gen, opt): if opt.grid: grid_images[seed] = image - elif operation == 'mask': - filename = f'{prefix}.{use_prefix}.{seed}.png' + elif operation == "mask": + filename = f"{prefix}.{use_prefix}.{seed}.png" tm = opt.text_mask[0] - th = opt.text_mask[1] if len(opt.text_mask)>1 else 0.5 - formatted_dream_prompt = f'!mask {opt.input_file_path} -tm {tm} {th}' + th = opt.text_mask[1] if len(opt.text_mask) > 1 else 0.5 + formatted_dream_prompt = ( + f"!mask {opt.input_file_path} -tm {tm} {th}" + ) path = file_writer.save_image_and_prompt_to_png( - image = image, - dream_prompt = formatted_dream_prompt, - metadata = {}, - name = filename, - compress_level = opt.png_compression, + image=image, + dream_prompt=formatted_dream_prompt, + metadata={}, + name=filename, + compress_level=opt.png_compression, ) results.append([path, formatted_dream_prompt]) else: if use_prefix is not None: prefix = use_prefix - postprocessed = upscaled if upscaled else operation=='postprocess' - opt.prompt = gen.huggingface_concepts_library.replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers + postprocessed = upscaled if upscaled else operation == "postprocess" + opt.prompt = ( + gen.huggingface_concepts_library.replace_triggers_with_concepts( + opt.prompt or prompt_in + ) + ) # to avoid the problem of non-unique concept triggers filename, formatted_dream_prompt = prepare_image_metadata( opt, prefix, @@ -349,23 +379,30 @@ def main_loop(gen, opt): operation, prior_variations, postprocessed, - first_seed + first_seed, ) path = file_writer.save_image_and_prompt_to_png( - image = image, - dream_prompt = formatted_dream_prompt, - metadata = metadata_dumps( + image=image, + dream_prompt=formatted_dream_prompt, + metadata=metadata_dumps( opt, - seeds = [seed if opt.variation_amount==0 and len(prior_variations)==0 else first_seed], - model_hash = gen.model_hash, + seeds=[ + seed + if opt.variation_amount == 0 + and len(prior_variations) == 0 + else first_seed + ], + model_hash=gen.model_hash, ), - name = filename, - compress_level = opt.png_compression, + name=filename, + compress_level=opt.png_compression, ) # update rfc metadata - if operation == 'postprocess': - tool = re.match('postprocess:(\w+)',opt.last_operation).groups()[0] + if operation == "postprocess": + tool = re.match( + "postprocess:(\w+)", opt.last_operation + ).groups()[0] add_postprocessing_to_metadata( opt, opt.input_file_path, @@ -379,49 +416,51 @@ def main_loop(gen, opt): results.append([path, formatted_dream_prompt]) # so that the seed autocompletes (on linux|mac when -S or --seed specified - if completer and operation == 'generate': + if completer and operation == "generate": completer.add_seed(seed) completer.add_seed(first_seed) last_results.append([path, seed]) - if operation == 'generate': - catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts - opt.last_operation='generate' + if operation == "generate": + catch_ctrl_c = ( + infile is None + ) # if running interactively, we catch keyboard interrupts + opt.last_operation = "generate" try: gen.prompt2image( image_callback=image_writer, step_callback=step_callback, catch_interrupts=catch_ctrl_c, - **vars(opt) + **vars(opt), ) except (PromptParser.ParsingException, pyparsing.ParseException) as e: - print('** An error occurred while processing your prompt **') - print(f'** {str(e)} **') - elif operation == 'postprocess': - print(f'>> fixing {opt.prompt}') - opt.last_operation = do_postprocess(gen,opt,image_writer) + print("** An error occurred while processing your prompt **") + print(f"** {str(e)} **") + elif operation == "postprocess": + print(f">> fixing {opt.prompt}") + opt.last_operation = do_postprocess(gen, opt, image_writer) - elif operation == 'mask': - print(f'>> generating masks from {opt.prompt}') + elif operation == "mask": + print(f">> generating masks from {opt.prompt}") do_textmask(gen, opt, image_writer) if opt.grid and len(grid_images) > 0: - grid_img = make_grid(list(grid_images.values())) + grid_img = make_grid(list(grid_images.values())) grid_seeds = list(grid_images.keys()) first_seed = last_results[0][1] - filename = f'{prefix}.{first_seed}.png' - formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images)) - formatted_dream_prompt += f' # {grid_seeds}' + filename = f"{prefix}.{first_seed}.png" + formatted_dream_prompt = opt.dream_prompt_str( + seed=first_seed, grid=True, iterations=len(grid_images) + ) + formatted_dream_prompt += f" # {grid_seeds}" metadata = metadata_dumps( - opt, - seeds = grid_seeds, - model_hash = gen.model_hash - ) + opt, seeds=grid_seeds, model_hash=gen.model_hash + ) path = file_writer.save_image_and_prompt_to_png( - image = grid_img, - dream_prompt = formatted_dream_prompt, - metadata = metadata, - name = filename + image=grid_img, + dream_prompt=formatted_dream_prompt, + metadata=metadata, + name=filename, ) results = [[path, formatted_dream_prompt]] @@ -433,286 +472,321 @@ def main_loop(gen, opt): print(e) continue - print('Outputs:') - log_path = os.path.join(current_outdir, 'invoke_log') - output_cntr = write_log(results, log_path ,('txt', 'md'), output_cntr) + print("Outputs:") + log_path = os.path.join(current_outdir, "invoke_log") + output_cntr = write_log(results, log_path, ("txt", "md"), output_cntr) print() + print( + f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}' + ) - print(f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}') # TO DO: remove repetitive code and the awkward command.replace() trope # Just do a simple parse of the command! -def do_command(command:str, gen, opt:Args, completer) -> tuple: +def do_command(command: str, gen, opt: Args, completer) -> tuple: global infile - operation = 'generate' # default operation, alternative is 'postprocess' + operation = "generate" # default operation, alternative is 'postprocess' - if command.startswith('!dream'): # in case a stored prompt still contains the !dream command - command = command.replace('!dream ','',1) + if command.startswith( + "!dream" + ): # in case a stored prompt still contains the !dream command + command = command.replace("!dream ", "", 1) - elif command.startswith('!fix'): - command = command.replace('!fix ','',1) - operation = 'postprocess' + elif command.startswith("!fix"): + command = command.replace("!fix ", "", 1) + operation = "postprocess" - elif command.startswith('!mask'): - command = command.replace('!mask ','',1) - operation = 'mask' + elif command.startswith("!mask"): + command = command.replace("!mask ", "", 1) + operation = "mask" - elif command.startswith('!switch'): - model_name = command.replace('!switch ','',1) + elif command.startswith("!switch"): + model_name = command.replace("!switch ", "", 1) try: gen.set_model(model_name) add_embedding_terms(gen, completer) except KeyError as e: print(str(e)) except Exception as e: - report_model_error(opt,e) + report_model_error(opt, e) completer.add_history(command) operation = None - elif command.startswith('!models'): + elif command.startswith("!models"): gen.model_manager.print_models() completer.add_history(command) operation = None - elif command.startswith('!import'): + elif command.startswith("!import"): path = shlex.split(command) if len(path) < 2: - print('** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1') + print( + "** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1" + ) else: import_model(path[1], gen, opt, completer) completer.add_history(command) operation = None - elif command.startswith('!convert'): + elif command.startswith("!convert"): path = shlex.split(command) if len(path) < 2: - print('** please provide the path to a .ckpt or .safetensors model') + print("** please provide the path to a .ckpt or .safetensors model") elif not os.path.exists(path[1]): - print(f'** {path[1]}: model not found') + print(f"** {path[1]}: model not found") else: optimize_model(path[1], gen, opt, completer) completer.add_history(command) operation = None - - elif command.startswith('!optimize'): + elif command.startswith("!optimize"): path = shlex.split(command) if len(path) < 2: - print('** please provide an installed model name') + print("** please provide an installed model name") elif not path[1] in gen.model_manager.list_models(): - print(f'** {path[1]}: model not found') + print(f"** {path[1]}: model not found") else: optimize_model(path[1], gen, opt, completer) completer.add_history(command) operation = None - elif command.startswith('!edit'): + elif command.startswith("!edit"): path = shlex.split(command) if len(path) < 2: - print('** please provide the name of a model') + print("** please provide the name of a model") else: edit_model(path[1], gen, opt, completer) completer.add_history(command) operation = None - elif command.startswith('!del'): + elif command.startswith("!del"): path = shlex.split(command) if len(path) < 2: - print('** please provide the name of a model') + print("** please provide the name of a model") else: del_config(path[1], gen, opt, completer) completer.add_history(command) operation = None - elif command.startswith('!fetch'): - file_path = command.replace('!fetch','',1).strip() - retrieve_dream_command(opt,file_path,completer) + elif command.startswith("!fetch"): + file_path = command.replace("!fetch", "", 1).strip() + retrieve_dream_command(opt, file_path, completer) completer.add_history(command) operation = None - elif command.startswith('!replay'): - file_path = command.replace('!replay','',1).strip() + elif command.startswith("!replay"): + file_path = command.replace("!replay", "", 1).strip() if infile is None and os.path.isfile(file_path): - infile = open(file_path, 'r', encoding='utf-8') + infile = open(file_path, "r", encoding="utf-8") completer.add_history(command) operation = None - elif command.startswith('!history'): + elif command.startswith("!trigger"): + print("Embedding trigger strings: ", ", ".join(gen.embedding_trigger_strings)) + operation = None + + elif command.startswith("!history"): completer.show_history() operation = None - elif command.startswith('!search'): - search_str = command.replace('!search','',1).strip() + elif command.startswith("!search"): + search_str = command.replace("!search", "", 1).strip() completer.show_history(search_str) operation = None - elif command.startswith('!clear'): + elif command.startswith("!clear"): completer.clear_history() operation = None - elif re.match('^!(\d+)',command): - command_no = re.match('^!(\d+)',command).groups()[0] - command = completer.get_line(int(command_no)) + elif re.match("^!(\d+)", command): + command_no = re.match("^!(\d+)", command).groups()[0] + command = completer.get_line(int(command_no)) completer.set_line(command) operation = None else: # not a recognized command, so give the --help text - command = '-h' + command = "-h" return command, operation -def set_default_output_dir(opt:Args, completer:Completer): - ''' + +def set_default_output_dir(opt: Args, completer: Completer): + """ If opt.outdir is relative, we add the root directory to it normalize the outdir relative to root and make sure it exists. - ''' + """ if not os.path.isabs(opt.outdir): - opt.outdir=os.path.normpath(os.path.join(Globals.root,opt.outdir)) + opt.outdir = os.path.normpath(os.path.join(Globals.root, opt.outdir)) if not os.path.exists(opt.outdir): os.makedirs(opt.outdir) completer.set_default_dir(opt.outdir) def import_model(model_path: str, gen, opt, completer): - ''' + """ model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or (3) a huggingface repository id - ''' + """ model_name = None - if model_path.startswith(('http:','https:','ftp:')): + if model_path.startswith(("http:", "https:", "ftp:")): model_name = import_ckpt_model(model_path, gen, opt, completer) - elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path): + elif ( + os.path.exists(model_path) + and model_path.endswith((".ckpt", ".safetensors")) + and os.path.isfile(model_path) + ): model_name = import_ckpt_model(model_path, gen, opt, completer) elif os.path.isdir(model_path): - # Allow for a directory containing multiple models. - models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors')) + models = list(Path(model_path).rglob("*.ckpt")) + list( + Path(model_path).rglob("*.safetensors") + ) if models: # Only the last model name will be used below. for model in sorted(models): - - if click.confirm(f'Import {model.stem} ?', default=True): + if click.confirm(f"Import {model.stem} ?", default=True): model_name = import_ckpt_model(model, gen, opt, completer) print() else: model_name = import_diffuser_model(Path(model_path), gen, opt, completer) - elif re.match(r'^[\w.+-]+/[\w.+-]+$', model_path): + elif re.match(r"^[\w.+-]+/[\w.+-]+$", model_path): model_name = import_diffuser_model(model_path, gen, opt, completer) else: - print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.') + print( + f"** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can't import." + ) if not model_name: return if not _verify_load(model_name, gen): - print('** model failed to load. Discarding configuration entry') + print("** model failed to load. Discarding configuration entry") gen.model_manager.del_model(model_name) return - if input('Make this the default model? [n] ').strip() in ('y','Y'): + if input("Make this the default model? [n] ").strip() in ("y", "Y"): gen.model_manager.set_default_model(model_name) gen.model_manager.commit(opt.conf) completer.update_models(gen.model_manager.list_models()) - print(f'>> {model_name} successfully installed') + print(f">> {model_name} successfully installed") -def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]: + +def import_diffuser_model( + path_or_repo: Union[Path, str], gen, _, completer +) -> Optional[str]: manager = gen.model_manager default_name = Path(path_or_repo).stem - default_description = f'Imported model {default_name}' + default_description = f"Imported model {default_name}" model_name, model_description = _get_model_name_and_desc( manager, completer, model_name=default_name, - model_description=default_description + model_description=default_description, ) vae = None - if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'): - vae = dict(repo_id='stabilityai/sd-vae-ft-mse') + if input( + 'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ' + ).strip() in ("y", "Y"): + vae = dict(repo_id="stabilityai/sd-vae-ft-mse") if not manager.import_diffuser_model( - path_or_repo, - model_name = model_name, - vae = vae, - description = model_description): - print('** model failed to import') + path_or_repo, model_name=model_name, vae=vae, description=model_description + ): + print("** model failed to import") return None return model_name -def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]: + +def import_ckpt_model( + path_or_url: Union[Path, str], gen, opt, completer +) -> Optional[str]: manager = gen.model_manager default_name = Path(path_or_url).stem - default_description = f'Imported model {default_name}' + default_description = f"Imported model {default_name}" model_name, model_description = _get_model_name_and_desc( manager, completer, model_name=default_name, - model_description=default_description + model_description=default_description, ) config_file = None - default = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml') \ - if re.search('inpaint',default_name, flags=re.IGNORECASE) \ - else Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml') + default = ( + Path(Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml") + if re.search("inpaint", default_name, flags=re.IGNORECASE) + else Path(Globals.root, "configs/stable-diffusion/v1-inference.yaml") + ) - completer.complete_extensions(('.yaml','.yml')) + completer.complete_extensions((".yaml", ".yml")) completer.set_line(str(default)) done = False while not done: - config_file = input('Configuration file for this model: ').strip() + config_file = input("Configuration file for this model: ").strip() done = os.path.exists(config_file) - completer.complete_extensions(('.ckpt','.safetensors')) + completer.complete_extensions((".ckpt", ".safetensors")) vae = None - default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt') + default = Path( + Globals.root, "models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt" + ) completer.set_line(str(default)) done = False while not done: - vae = input('VAE file for this model (leave blank for none): ').strip() or None + vae = input("VAE file for this model (leave blank for none): ").strip() or None done = (not vae) or os.path.exists(vae) completer.complete_extensions(None) if not manager.import_ckpt_model( - path_or_url, - config = config_file, - vae = vae, - model_name = model_name, - model_description = model_description, - commit_to_conf = opt.conf, + path_or_url, + config=config_file, + vae=vae, + model_name=model_name, + model_description=model_description, + commit_to_conf=opt.conf, ): - print('** model failed to import') + print("** model failed to import") return None return model_name -def _verify_load(model_name:str, gen)->bool: - print('>> Verifying that new model loads...') + +def _verify_load(model_name: str, gen) -> bool: + print(">> Verifying that new model loads...") current_model = gen.model_name if not gen.model_manager.get_model(model_name): return False - do_switch = input('Keep model loaded? [y] ') - if len(do_switch)==0 or do_switch[0] in ('y','Y'): + do_switch = input("Keep model loaded? [y] ") + if len(do_switch) == 0 or do_switch[0] in ("y", "Y"): gen.set_model(model_name) else: - print('>> Restoring previous model') + print(">> Restoring previous model") gen.set_model(current_model) return True -def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_description:str=''): - model_name = _get_model_name(model_manager.list_models(),completer,model_name) + +def _get_model_name_and_desc( + model_manager, completer, model_name: str = "", model_description: str = "" +): + model_name = _get_model_name(model_manager.list_models(), completer, model_name) completer.set_line(model_description) - model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description + model_description = ( + input(f"Description for this model [{model_description}]: ").strip() + or model_description + ) return model_name, model_description -def _is_inpainting(model_name_or_path: str)->bool: - if re.search('inpaint',model_name_or_path, flags=re.IGNORECASE): - return not input('Is this an inpainting model? [y] ').startswith(('n','N')) + +def _is_inpainting(model_name_or_path: str) -> bool: + if re.search("inpaint", model_name_or_path, flags=re.IGNORECASE): + return not input("Is this an inpainting model? [y] ").startswith(("n", "N")) else: - return not input('Is this an inpainting model? [n] ').startswith(('y','Y')) + return not input("Is this an inpainting model? [n] ").startswith(("y", "Y")) + def optimize_model(model_name_or_path: str, gen, opt, completer): manager = gen.model_manager @@ -722,70 +796,76 @@ def optimize_model(model_name_or_path: str, gen, opt, completer): if model_name_or_path == gen.model_name: print("** Can't convert the active model. !switch to another model first. **") return - elif (model_info := manager.model_info(model_name_or_path)): - if 'weights' in model_info: - ckpt_path = Path(model_info['weights']) - original_config_file = Path(model_info['config']) + elif model_info := manager.model_info(model_name_or_path): + if "weights" in model_info: + ckpt_path = Path(model_info["weights"]) + original_config_file = Path(model_info["config"]) model_name = model_name_or_path - model_description = model_info['description'] + model_description = model_info["description"] else: - print(f'** {model_name_or_path} is not a legacy .ckpt weights file') + print(f"** {model_name_or_path} is not a legacy .ckpt weights file") return elif os.path.exists(model_name_or_path): ckpt_path = Path(model_name_or_path) model_name, model_description = _get_model_name_and_desc( - manager, - completer, - ckpt_path.stem, - f'Converted model {ckpt_path.stem}' + manager, completer, ckpt_path.stem, f"Converted model {ckpt_path.stem}" ) is_inpainting = _is_inpainting(model_name_or_path) original_config_file = Path( - 'configs', - 'stable-diffusion', - 'v1-inpainting-inference.yaml' if is_inpainting else 'v1-inference.yaml' + "configs", + "stable-diffusion", + "v1-inpainting-inference.yaml" if is_inpainting else "v1-inference.yaml", ) else: - print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file') + print( + f"** {model_name_or_path} is neither an existing model nor the path to a .ckpt file" + ) return if not ckpt_path.is_absolute(): - ckpt_path = Path(Globals.root,ckpt_path) + ckpt_path = Path(Globals.root, ckpt_path) if original_config_file and not original_config_file.is_absolute(): - original_config_file = Path(Globals.root,original_config_file) + original_config_file = Path(Globals.root, original_config_file) - diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir,model_name) + diffuser_path = Path( + Globals.root, "models", Globals.converted_ckpts_dir, model_name + ) if diffuser_path.exists(): - print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.') + print( + f"** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again." + ) return vae = None - if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'): - vae = dict(repo_id='stabilityai/sd-vae-ft-mse') + if input( + 'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ' + ).strip() in ("y", "Y"): + vae = dict(repo_id="stabilityai/sd-vae-ft-mse") new_config = gen.model_manager.convert_and_import( ckpt_path, diffuser_path, model_name=model_name, model_description=model_description, - vae = vae, - original_config_file = original_config_file, + vae=vae, + original_config_file=original_config_file, commit_to_conf=opt.conf, ) if not new_config: return completer.update_models(gen.model_manager.list_models()) - if input(f'Load optimized model {model_name}? [y] ').strip() not in ('n','N'): + if input(f"Load optimized model {model_name}? [y] ").strip() not in ("n", "N"): gen.set_model(model_name) - response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ') - if response.startswith(('y','Y')): + response = input(f"Delete the original .ckpt file at ({ckpt_path} ? [n] ") + if response.startswith(("y", "Y")): ckpt_path.unlink(missing_ok=True) - print(f'{ckpt_path} deleted') + print(f"{ckpt_path} deleted") -def del_config(model_name:str, gen, opt, completer): + +def del_config(model_name: str, gen, opt, completer): current_model = gen.model_name if model_name == current_model: print("** Can't delete active model. !switch to another model first. **") @@ -794,31 +874,38 @@ def del_config(model_name:str, gen, opt, completer): print(f"** Unknown model {model_name}") return - if input(f'Remove {model_name} from the list of models known to InvokeAI? [y] ').strip().startswith(('n','N')): + if ( + input(f"Remove {model_name} from the list of models known to InvokeAI? [y] ") + .strip() + .startswith(("n", "N")) + ): return - delete_completely = input('Completely remove the model file or directory from disk? [n] ').startswith(('y','Y')) - gen.model_manager.del_model(model_name,delete_files=delete_completely) + delete_completely = input( + "Completely remove the model file or directory from disk? [n] " + ).startswith(("y", "Y")) + gen.model_manager.del_model(model_name, delete_files=delete_completely) gen.model_manager.commit(opt.conf) - print(f'** {model_name} deleted') + print(f"** {model_name} deleted") completer.update_models(gen.model_manager.list_models()) -def edit_model(model_name:str, gen, opt, completer): + +def edit_model(model_name: str, gen, opt, completer): manager = gen.model_manager if not (info := manager.model_info(model_name)): - print(f'** Unknown model {model_name}') + print(f"** Unknown model {model_name}") return - print(f'\n>> Editing model {model_name} from configuration file {opt.conf}') - new_name = _get_model_name(manager.list_models(),completer,model_name) + print(f"\n>> Editing model {model_name} from configuration file {opt.conf}") + new_name = _get_model_name(manager.list_models(), completer, model_name) for attribute in info.keys(): if type(info[attribute]) != str: continue - if attribute == 'format': + if attribute == "format": continue completer.set_line(info[attribute]) - info[attribute] = input(f'{attribute}: ') or info[attribute] + info[attribute] = input(f"{attribute}: ") or info[attribute] if new_name != model_name: manager.del_model(model_name) @@ -826,23 +913,26 @@ def edit_model(model_name:str, gen, opt, completer): # this does the update manager.add_model(new_name, info, True) - if input('Make this the default model? [n] ').startswith(('y','Y')): + if input("Make this the default model? [n] ").startswith(("y", "Y")): manager.set_default_model(new_name) manager.commit(opt.conf) completer.update_models(manager.list_models()) - print('>> Model successfully updated') + print(">> Model successfully updated") -def _get_model_name(existing_names,completer,default_name:str='')->str: + +def _get_model_name(existing_names, completer, default_name: str = "") -> str: done = False completer.set_line(default_name) while not done: - model_name = input(f'Short name for this model [{default_name}]: ').strip() - if len(model_name)==0: + model_name = input(f"Short name for this model [{default_name}]: ").strip() + if len(model_name) == 0: model_name = default_name - if not re.match('^[\w._+:/-]+$',model_name): - print('** model name must contain only words, digits and the characters "._+:/-" **') + if not re.match("^[\w._+:/-]+$", model_name): + print( + '** model name must contain only words, digits and the characters "._+:/-" **' + ) elif model_name != default_name and model_name in existing_names: - print(f'** the name {model_name} is already in use. Pick another.') + print(f"** the name {model_name} is already in use. Pick another.") else: done = True return model_name @@ -851,197 +941,223 @@ def _get_model_name(existing_names,completer,default_name:str='')->str: def do_textmask(gen, opt, callback): image_path = opt.prompt if not os.path.exists(image_path): - image_path = os.path.join(opt.outdir,image_path) - assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **' - assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **' + image_path = os.path.join(opt.outdir, image_path) + assert os.path.exists( + image_path + ), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **' + assert ( + opt.text_mask is not None and len(opt.text_mask) >= 1 + ), "** Please provide a text mask with -tm **" opt.input_file_path = image_path tm = opt.text_mask[0] - threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5 + threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5 gen.apply_textmask( - image_path = image_path, - prompt = tm, - threshold = threshold, - callback = callback, + image_path=image_path, + prompt=tm, + threshold=threshold, + callback=callback, ) -def do_postprocess (gen, opt, callback): - file_path = opt.prompt # treat the prompt as the file pathname + +def do_postprocess(gen, opt, callback): + file_path = opt.prompt # treat the prompt as the file pathname if opt.new_prompt is not None: opt.prompt = opt.new_prompt else: opt.prompt = None - if os.path.dirname(file_path) == '': #basename given - file_path = os.path.join(opt.outdir,file_path) + if os.path.dirname(file_path) == "": # basename given + file_path = os.path.join(opt.outdir, file_path) opt.input_file_path = file_path - tool=None + tool = None if opt.facetool_strength > 0: tool = opt.facetool elif opt.embiggen: - tool = 'embiggen' + tool = "embiggen" elif opt.upscale: - tool = 'upscale' + tool = "upscale" elif opt.out_direction: - tool = 'outpaint' + tool = "outpaint" elif opt.outcrop: - tool = 'outcrop' - opt.save_original = True # do not overwrite old image! - opt.last_operation = f'postprocess:{tool}' + tool = "outcrop" + opt.save_original = True # do not overwrite old image! + opt.last_operation = f"postprocess:{tool}" try: gen.apply_postprocessor( - image_path = file_path, - tool = tool, - facetool_strength = opt.facetool_strength, - codeformer_fidelity = opt.codeformer_fidelity, - save_original = opt.save_original, - upscale = opt.upscale, - upscale_denoise_str = opt.esrgan_denoise_str, - out_direction = opt.out_direction, - outcrop = opt.outcrop, - callback = callback, - opt = opt, + image_path=file_path, + tool=tool, + facetool_strength=opt.facetool_strength, + codeformer_fidelity=opt.codeformer_fidelity, + save_original=opt.save_original, + upscale=opt.upscale, + upscale_denoise_str=opt.esrgan_denoise_str, + out_direction=opt.out_direction, + outcrop=opt.outcrop, + callback=callback, + opt=opt, ) except OSError: print(traceback.format_exc(), file=sys.stderr) - print(f'** {file_path}: file could not be read') + print(f"** {file_path}: file could not be read") return except (KeyError, AttributeError): print(traceback.format_exc(), file=sys.stderr) return return opt.last_operation -def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command): - original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file) - new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file) + +def add_postprocessing_to_metadata(opt, original_file, new_file, tool, command): + original_file = ( + original_file + if os.path.exists(original_file) + else os.path.join(opt.outdir, original_file) + ) + new_file = ( + new_file if os.path.exists(new_file) else os.path.join(opt.outdir, new_file) + ) try: - meta = retrieve_metadata(original_file)['sd-metadata'] + meta = retrieve_metadata(original_file)["sd-metadata"] except AttributeError: try: - meta = retrieve_metadata(new_file)['sd-metadata'] + meta = retrieve_metadata(new_file)["sd-metadata"] except AttributeError: meta = {} - if 'image' not in meta: - meta = metadata_dumps(opt,seeds=[opt.seed])['image'] - meta['image'] = {} - img_data = meta.get('image') - pp = img_data.get('postprocessing',[]) or [] + if "image" not in meta: + meta = metadata_dumps(opt, seeds=[opt.seed])["image"] + meta["image"] = {} + img_data = meta.get("image") + pp = img_data.get("postprocessing", []) or [] pp.append( { - 'tool':tool, - 'dream_command':command, + "tool": tool, + "dream_command": command, } ) - meta['image']['postprocessing'] = pp - write_metadata(new_file,meta) + meta["image"]["postprocessing"] = pp + write_metadata(new_file, meta) + def prepare_image_metadata( - opt, - prefix, - seed, - operation='generate', - prior_variations=[], - postprocessed=False, - first_seed=None + opt, + prefix, + seed, + operation="generate", + prior_variations=[], + postprocessed=False, + first_seed=None, ): - if postprocessed and opt.save_original: - filename = choose_postprocess_name(opt,prefix,seed) + filename = choose_postprocess_name(opt, prefix, seed) else: wildcards = dict(opt.__dict__) - wildcards['prefix'] = prefix - wildcards['seed'] = seed + wildcards["prefix"] = prefix + wildcards["seed"] = seed try: filename = opt.fnformat.format(**wildcards) except KeyError as e: - print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use {{prefix}}.{{seed}}.png\' instead') - filename = f'{prefix}.{seed}.png' + print( + f"** The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead" + ) + filename = f"{prefix}.{seed}.png" except IndexError: - print("** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead") - filename = f'{prefix}.{seed}.png' + print( + "** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead" + ) + filename = f"{prefix}.{seed}.png" if opt.variation_amount > 0: - first_seed = first_seed or seed - this_variation = [[seed, opt.variation_amount]] - opt.with_variations = prior_variations + this_variation + first_seed = first_seed or seed + this_variation = [[seed, opt.variation_amount]] + opt.with_variations = prior_variations + this_variation formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) elif len(prior_variations) > 0: formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) - elif operation == 'postprocess': - formatted_dream_prompt = '!fix '+opt.dream_prompt_str(seed=seed,prompt=opt.input_file_path) + elif operation == "postprocess": + formatted_dream_prompt = "!fix " + opt.dream_prompt_str( + seed=seed, prompt=opt.input_file_path + ) else: formatted_dream_prompt = opt.dream_prompt_str(seed=seed) - return filename,formatted_dream_prompt + return filename, formatted_dream_prompt -def choose_postprocess_name(opt,prefix,seed) -> str: - match = re.search('postprocess:(\w+)',opt.last_operation) + +def choose_postprocess_name(opt, prefix, seed) -> str: + match = re.search("postprocess:(\w+)", opt.last_operation) if match: - modifier = match.group(1) # will look like "gfpgan", "upscale", "outpaint" or "embiggen" + modifier = match.group( + 1 + ) # will look like "gfpgan", "upscale", "outpaint" or "embiggen" else: - modifier = 'postprocessed' + modifier = "postprocessed" - counter = 0 - filename = None + counter = 0 + filename = None available = False while not available: if counter == 0: - filename = f'{prefix}.{seed}.{modifier}.png' + filename = f"{prefix}.{seed}.{modifier}.png" else: - filename = f'{prefix}.{seed}.{modifier}-{counter:02d}.png' - available = not os.path.exists(os.path.join(opt.outdir,filename)) + filename = f"{prefix}.{seed}.{modifier}-{counter:02d}.png" + available = not os.path.exists(os.path.join(opt.outdir, filename)) counter += 1 return filename -def get_next_command(infile=None, model_name='no model') -> str: # command string + +def get_next_command(infile=None, model_name="no model") -> str: # command string if infile is None: - command = input(f'({model_name}) invoke> ').strip() + command = input(f"({model_name}) invoke> ").strip() else: command = infile.readline() if not command: raise EOFError else: command = command.strip() - if len(command)>0: - print(f'#{command}') + if len(command) > 0: + print(f"#{command}") return command -def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan): - print('\n* --web was specified, starting web server...') - from invokeai.backend import InvokeAIWebServer - # Change working directory to the stable-diffusion directory - os.chdir( - os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) - ) - invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan) +def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan): + print("\n* --web was specified, starting web server...") + from invokeai.backend import InvokeAIWebServer + + # Change working directory to the stable-diffusion directory + os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + invoke_ai_web_server = InvokeAIWebServer( + generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan + ) try: invoke_ai_web_server.run() except KeyboardInterrupt: pass -def add_embedding_terms(gen,completer): - ''' + +def add_embedding_terms(gen, completer): + """ Called after setting the model, updates the autocompleter with any terms loaded by the embedding manager. - ''' + """ trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings() completer.add_embedding_terms(trigger_strings) + def split_variations(variations_string) -> list: # shotgun parsing, woo parts = [] broken = False # python doesn't have labeled loops... - for part in variations_string.split(','): - seed_and_weight = part.split(':') + for part in variations_string.split(","): + seed_and_weight = part.split(":") if len(seed_and_weight) != 2: print(f'** Could not parse with_variation part "{part}"') broken = True break try: - seed = int(seed_and_weight[0]) + seed = int(seed_and_weight[0]) weight = float(seed_and_weight[1]) except ValueError: print(f'** Could not parse with_variation part "{part}"') @@ -1055,40 +1171,48 @@ def split_variations(variations_string) -> list: else: return parts + def load_face_restoration(opt): try: gfpgan, codeformer, esrgan = None, None, None if opt.restore or opt.esrgan: from ldm.invoke.restoration import Restoration + restoration = Restoration() if opt.restore: - gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path) + gfpgan, codeformer = restoration.load_face_restore_models( + opt.gfpgan_model_path + ) else: - print('>> Face restoration disabled') + print(">> Face restoration disabled") if opt.esrgan: esrgan = restoration.load_esrgan(opt.esrgan_bg_tile) else: - print('>> Upscaling disabled') + print(">> Upscaling disabled") else: - print('>> Face restoration and upscaling disabled') + print(">> Face restoration and upscaling disabled") except (ModuleNotFoundError, ImportError): print(traceback.format_exc(), file=sys.stderr) - print('>> You may need to install the ESRGAN and/or GFPGAN modules') - return gfpgan,codeformer,esrgan + print(">> You may need to install the ESRGAN and/or GFPGAN modules") + return gfpgan, codeformer, esrgan + def make_step_callback(gen, opt, prefix): - destination = os.path.join(opt.outdir,'intermediates',prefix) - os.makedirs(destination,exist_ok=True) - print(f'>> Intermediate images will be written into {destination}') + destination = os.path.join(opt.outdir, "intermediates", prefix) + os.makedirs(destination, exist_ok=True) + print(f">> Intermediate images will be written into {destination}") + def callback(img, step): - if step % opt.save_intermediates == 0 or step == opt.steps-1: - filename = os.path.join(destination,f'{step:04}.png') + if step % opt.save_intermediates == 0 or step == opt.steps - 1: + filename = os.path.join(destination, f"{step:04}.png") image = gen.sample_to_image(img) - image.save(filename,'PNG') + image.save(filename, "PNG") + return callback -def retrieve_dream_command(opt,command,completer): - ''' + +def retrieve_dream_command(opt, command, completer): + """ Given a full or partial path to a previously-generated image file, will retrieve and format the dream command used to generate the image, and pop it into the readline buffer (linux, Mac), or print out a comment @@ -1097,34 +1221,35 @@ def retrieve_dream_command(opt,command,completer): Given a wildcard path to a folder with image png files, will retrieve and format the dream command used to generate the images, and save them to a file commands.txt for further processing - ''' + """ if len(command) == 0: return tokens = command.split() - dir,basename = os.path.split(tokens[0]) + dir, basename = os.path.split(tokens[0]) if len(dir) == 0: - path = os.path.join(opt.outdir,basename) + path = os.path.join(opt.outdir, basename) else: path = tokens[0] if len(tokens) > 1: return write_commands(opt, path, tokens[1]) - cmd = '' + cmd = "" try: cmd = dream_cmd_from_png(path) except OSError: - print(f'## {tokens[0]}: file could not be read') + print(f"## {tokens[0]}: file could not be read") except (KeyError, AttributeError, IndexError): - print(f'## {tokens[0]}: file has no metadata') + print(f"## {tokens[0]}: file has no metadata") except: - print(f'## {tokens[0]}: file could not be processed') - if len(cmd)>0: + print(f"## {tokens[0]}: file could not be processed") + if len(cmd) > 0: completer.set_line(cmd) -def write_commands(opt, file_path:str, outfilepath:str): - dir,basename = os.path.split(file_path) + +def write_commands(opt, file_path: str, outfilepath: str): + dir, basename = os.path.split(file_path) try: paths = sorted(list(Path(dir).glob(basename))) except ValueError: @@ -1137,39 +1262,46 @@ def write_commands(opt, file_path:str, outfilepath:str): try: cmd = dream_cmd_from_png(path) except (KeyError, AttributeError, IndexError): - print(f'## {path}: file has no metadata') + print(f"## {path}: file has no metadata") except: - print(f'## {path}: file could not be processed') + print(f"## {path}: file could not be processed") if cmd: - commands.append(f'# {path}') + commands.append(f"# {path}") commands.append(cmd) - if len(commands)>0: - dir,basename = os.path.split(outfilepath) - if len(dir)==0: - outfilepath = os.path.join(opt.outdir,basename) - with open(outfilepath, 'w', encoding='utf-8') as f: - f.write('\n'.join(commands)) - print(f'>> File {outfilepath} with commands created') + if len(commands) > 0: + dir, basename = os.path.split(outfilepath) + if len(dir) == 0: + outfilepath = os.path.join(opt.outdir, basename) + with open(outfilepath, "w", encoding="utf-8") as f: + f.write("\n".join(commands)) + print(f">> File {outfilepath} with commands created") -def report_model_error(opt:Namespace, e:Exception): + +def report_model_error(opt: Namespace, e: Exception): print(f'** An error occurred while attempting to initialize the model: "{str(e)}"') - print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.') - yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE') + print( + "** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models." + ) + yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE") if yes_to_all: - print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE') + print( + "** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE" + ) else: - response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ') - if response.startswith(('n', 'N')): + response = input( + "Do you want to run invokeai-configure script to select and/or reinstall models? [y] " + ) + if response.startswith(("n", "N")): return - print('invokeai-configure is launching....\n') + print("invokeai-configure is launching....\n") # Match arguments that were set on the CLI # only the arguments accepted by the configuration script are parsed root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] config = ["--config", opt.conf] if opt.conf is not None else [] previous_args = sys.argv - sys.argv = [ 'invokeai-configure' ] + sys.argv = ["invokeai-configure"] sys.argv.extend(root_dir) sys.argv.extend(config) if yes_to_all is not None: @@ -1177,21 +1309,24 @@ def report_model_error(opt:Namespace, e:Exception): sys.argv.append(arg) from ldm.invoke.config import invokeai_configure + invokeai_configure.main() - print('** InvokeAI will now restart') + print("** InvokeAI will now restart") sys.argv = previous_args - main() # would rather do a os.exec(), but doesn't exist? + main() # would rather do a os.exec(), but doesn't exist? sys.exit(0) -def check_internet()->bool: - ''' + +def check_internet() -> bool: + """ Return true if the internet is reachable. It does this by pinging huggingface.co. - ''' + """ import urllib.request - host = 'http://huggingface.co' + + host = "http://huggingface.co" try: - urllib.request.urlopen(host,timeout=1) + urllib.request.urlopen(host, timeout=1) return True except: return False diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index d81de4f1ca..1bd1aa46ab 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -751,6 +751,9 @@ class Args(object): !fix applies upscaling/facefixing to a previously-generated image. invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer + *embeddings* + invoke> !triggers -- return all trigger phrases contained in loaded embedding files + *History manipulation* !fetch retrieves the command used to generate an earlier image. Provide a directory wildcard and the name of a file to write and all the commands diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index f14af0714f..1e9b31ea8d 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -60,7 +60,7 @@ COMMANDS = ( '--text_mask','-tm', '!fix','!fetch','!replay','!history','!search','!clear', '!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model', - '!mask', + '!mask','!triggers', ) MODEL_COMMANDS = ( '!switch', diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py index 2e61be6b12..8ca1a0bf5e 100644 --- a/ldm/modules/textual_inversion_manager.py +++ b/ldm/modules/textual_inversion_manager.py @@ -1,11 +1,12 @@ import os import traceback +from dataclasses import dataclass +from pathlib import Path from typing import Optional import torch -from dataclasses import dataclass from picklescan.scanner import scan_file_path -from transformers import CLIPTokenizer, CLIPTextModel +from transformers import CLIPTextModel, CLIPTokenizer from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary @@ -21,11 +22,14 @@ class TextualInversion: def embedding_vector_length(self) -> int: return self.embedding.shape[0] -class TextualInversionManager(): - def __init__(self, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - full_precision: bool=True): + +class TextualInversionManager: + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + full_precision: bool = True, + ): self.tokenizer = tokenizer self.text_encoder = text_encoder self.full_precision = full_precision @@ -38,47 +42,60 @@ class TextualInversionManager(): if concept_name in self.hf_concepts_library.concepts_loaded: continue trigger = self.hf_concepts_library.concept_to_trigger(concept_name) - if self.has_textual_inversion_for_trigger_string(trigger) \ - or self.has_textual_inversion_for_trigger_string(concept_name) \ - or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered - print(f'>> Loaded local embedding for trigger {concept_name}') + if ( + self.has_textual_inversion_for_trigger_string(trigger) + or self.has_textual_inversion_for_trigger_string(concept_name) + or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>") + ): # in case a token with literal angle brackets encountered + print(f">> Loaded local embedding for trigger {concept_name}") continue bin_file = self.hf_concepts_library.get_concept_model_path(concept_name) if not bin_file: continue - print(f'>> Loaded remote embedding for trigger {concept_name}') + print(f">> Loaded remote embedding for trigger {concept_name}") self.load_textual_inversion(bin_file) - self.hf_concepts_library.concepts_loaded[concept_name]=True + self.hf_concepts_library.concepts_loaded[concept_name] = True def get_all_trigger_strings(self) -> list[str]: return [ti.trigger_string for ti in self.textual_inversions] - def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False): - if str(ckpt_path).endswith('.DS_Store'): + def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool = False): + if str(ckpt_path).endswith(".DS_Store"): return try: scan_result = scan_file_path(ckpt_path) if scan_result.infected_files == 1: - print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') - print('### For your safety, InvokeAI will not load this embed.') + print( + f"\n### Security Issues Found in Model: {scan_result.issues_count}" + ) + print("### For your safety, InvokeAI will not load this embed.") return except Exception: - print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}") + ckpt_path = Path(ckpt_path) + print( + f"** Notice: {ckpt_path.parents[0].stem}/{ckpt_path.stem} is incompatible with this model" + ) return embedding_info = self._parse_embedding(ckpt_path) if embedding_info: try: - self._add_textual_inversion(embedding_info['name'], - embedding_info['embedding'], - defer_injecting_tokens=defer_injecting_tokens) + self._add_textual_inversion( + embedding_info["name"], + embedding_info["embedding"], + defer_injecting_tokens=defer_injecting_tokens, + ) except ValueError as e: print(f' | Ignoring incompatible embedding {embedding_info["name"]}') - print(f' | The error was {str(e)}') + print(f" | The error was {str(e)}") else: - print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.') + print( + f">> Failed to load embedding located at {ckpt_path}. Unsupported file." + ) - def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion: + def _add_textual_inversion( + self, trigger_str, embedding, defer_injecting_tokens=False + ) -> TextualInversion: """ Add a textual inversion to be recognised. :param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added. @@ -86,46 +103,59 @@ class TextualInversionManager(): :return: The token id for the added embedding, either existing or newly-added. """ if trigger_str in [ti.trigger_string for ti in self.textual_inversions]: - print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'") + print( + f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'" + ) return if not self.full_precision: embedding = embedding.half() if len(embedding.shape) == 1: embedding = embedding.unsqueeze(0) elif len(embedding.shape) > 2: - raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.") + raise ValueError( + f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2." + ) try: - ti = TextualInversion( - trigger_string=trigger_str, - embedding=embedding - ) + ti = TextualInversion(trigger_string=trigger_str, embedding=embedding) if not defer_injecting_tokens: self._inject_tokens_and_assign_embeddings(ti) self.textual_inversions.append(ti) return ti except ValueError as e: - if str(e).startswith('Warning'): + if str(e).startswith("Warning"): print(f">> {str(e)}") else: traceback.print_exc() - print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.") + print( + f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}." + ) raise def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int: - if ti.trigger_token_id is not None: - raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'") + raise ValueError( + f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'" + ) - trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0]) + trigger_token_id = self._get_or_create_token_id_and_assign_embedding( + ti.trigger_string, ti.embedding[0] + ) if ti.embedding_vector_length > 1: # for embeddings with vector length > 1 - pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)] + pad_token_strings = [ + ti.trigger_string + "-!pad-" + str(pad_index) + for pad_index in range(1, ti.embedding_vector_length) + ] # todo: batched UI for faster loading when vector length >2 - pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \ - for (i, pad_token_str) in enumerate(pad_token_strings)] + pad_token_ids = [ + self._get_or_create_token_id_and_assign_embedding( + pad_token_str, ti.embedding[1 + i] + ) + for (i, pad_token_str) in enumerate(pad_token_strings) + ] else: pad_token_ids = [] @@ -133,7 +163,6 @@ class TextualInversionManager(): ti.pad_token_ids = pad_token_ids return ti.trigger_token_id - def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool: try: ti = self.get_textual_inversion_for_trigger_string(trigger_string) @@ -141,32 +170,43 @@ class TextualInversionManager(): except StopIteration: return False - - def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion: - return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string) - + def get_textual_inversion_for_trigger_string( + self, trigger_string: str + ) -> TextualInversion: + return next( + ti for ti in self.textual_inversions if ti.trigger_string == trigger_string + ) def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion: - return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id) + return next( + ti for ti in self.textual_inversions if ti.trigger_token_id == token_id + ) - def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]: + def create_deferred_token_ids_for_any_trigger_terms( + self, prompt_string: str + ) -> list[int]: injected_token_ids = [] for ti in self.textual_inversions: if ti.trigger_token_id is None and ti.trigger_string in prompt_string: if ti.embedding_vector_length > 1: - print(f">> Preparing tokens for textual inversion {ti.trigger_string}...") + print( + f">> Preparing tokens for textual inversion {ti.trigger_string}..." + ) try: self._inject_tokens_and_assign_embeddings(ti) except ValueError as e: - print(f' | Ignoring incompatible embedding trigger {ti.trigger_string}') - print(f' | The error was {str(e)}') + print( + f" | Ignoring incompatible embedding trigger {ti.trigger_string}" + ) + print(f" | The error was {str(e)}") continue injected_token_ids.append(ti.trigger_token_id) injected_token_ids.extend(ti.pad_token_ids) return injected_token_ids - - def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]: + def expand_textual_inversion_token_ids_if_necessary( + self, prompt_token_ids: list[int] + ) -> list[int]: """ Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes. @@ -181,20 +221,31 @@ class TextualInversionManager(): raise ValueError("prompt_token_ids must not start with bos_token_id") if prompt_token_ids[-1] == self.tokenizer.eos_token_id: raise ValueError("prompt_token_ids must not end with eos_token_id") - textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions] + textual_inversion_trigger_token_ids = [ + ti.trigger_token_id for ti in self.textual_inversions + ] prompt_token_ids = prompt_token_ids.copy() for i, token_id in reversed(list(enumerate(prompt_token_ids))): if token_id in textual_inversion_trigger_token_ids: - textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id) - for pad_idx in range(0, textual_inversion.embedding_vector_length-1): - prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx]) + textual_inversion = next( + ti + for ti in self.textual_inversions + if ti.trigger_token_id == token_id + ) + for pad_idx in range(0, textual_inversion.embedding_vector_length - 1): + prompt_token_ids.insert( + i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx] + ) return prompt_token_ids - - def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int: + def _get_or_create_token_id_and_assign_embedding( + self, token_str: str, embedding: torch.Tensor + ) -> int: if len(embedding.shape) != 1: - raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2") + raise ValueError( + "Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2" + ) existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str) if existing_token_id == self.tokenizer.unk_token_id: num_tokens_added = self.tokenizer.add_tokens(token_str) @@ -207,66 +258,78 @@ class TextualInversionManager(): token_id = self.tokenizer.convert_tokens_to_ids(token_str) if token_id == self.tokenizer.unk_token_id: raise RuntimeError(f"Unable to find token id for token '{token_str}'") - if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape: - raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.") + if ( + self.text_encoder.get_input_embeddings().weight.data[token_id].shape + != embedding.shape + ): + raise ValueError( + f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}." + ) self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding return token_id def _parse_embedding(self, embedding_file: str): - file_type = embedding_file.split('.')[-1] - if file_type == 'pt': + file_type = embedding_file.split(".")[-1] + if file_type == "pt": return self._parse_embedding_pt(embedding_file) - elif file_type == 'bin': + elif file_type == "bin": return self._parse_embedding_bin(embedding_file) else: - print(f'>> Not a recognized embedding file: {embedding_file}') + print(f">> Not a recognized embedding file: {embedding_file}") def _parse_embedding_pt(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location='cpu') + embedding_ckpt = torch.load(embedding_file, map_location="cpu") embedding_info = {} # Check if valid embedding file - if 'string_to_token' and 'string_to_param' in embedding_ckpt: - + if "string_to_token" and "string_to_param" in embedding_ckpt: # Catch variants that do not have the expected keys or values. try: - embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0]) + embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( + os.path.splitext(embedding_file)[0] + ) # Check num of embeddings and warn user only the first will be used - embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"]) - if embedding_info['num_of_embeddings'] > 1: - print('>> More than 1 embedding found. Will use the first one') + embedding_info["num_of_embeddings"] = len( + embedding_ckpt["string_to_token"] + ) + if embedding_info["num_of_embeddings"] > 1: + print(">> More than 1 embedding found. Will use the first one") - embedding = list(embedding_ckpt['string_to_param'].values())[0] - except (AttributeError,KeyError): + embedding = list(embedding_ckpt["string_to_param"].values())[0] + except (AttributeError, KeyError): return self._handle_broken_pt_variants(embedding_ckpt, embedding_file) - embedding_info['embedding'] = embedding - embedding_info['num_vectors_per_token'] = embedding.size()[0] - embedding_info['token_dim'] = embedding.size()[1] + embedding_info["embedding"] = embedding + embedding_info["num_vectors_per_token"] = embedding.size()[0] + embedding_info["token_dim"] = embedding.size()[1] try: - embedding_info['trained_steps'] = embedding_ckpt['step'] - embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name'] - embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint'] + embedding_info["trained_steps"] = embedding_ckpt["step"] + embedding_info["trained_model_name"] = embedding_ckpt[ + "sd_checkpoint_name" + ] + embedding_info["trained_model_checksum"] = embedding_ckpt[ + "sd_checkpoint" + ] except AttributeError: print(">> No Training Details Found. Passing ...") # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ # They are actually .bin files - elif len(embedding_ckpt.keys())==1: - print('>> Detected .bin file masquerading as .pt file') + elif len(embedding_ckpt.keys()) == 1: + print(">> Detected .bin file masquerading as .pt file") embedding_info = self._parse_embedding_bin(embedding_file) else: - print('>> Invalid embedding format') + print(">> Invalid embedding format") embedding_info = None return embedding_info def _parse_embedding_bin(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location='cpu') + embedding_ckpt = torch.load(embedding_file, map_location="cpu") embedding_info = {} if list(embedding_ckpt.keys()) == 0: @@ -274,27 +337,45 @@ class TextualInversionManager(): embedding_info = None else: for token in list(embedding_ckpt.keys()): - embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0]) - embedding_info['embedding'] = embedding_ckpt[token] - embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1 - embedding_info['token_dim'] = embedding_info['embedding'].size()[0] + embedding_info["name"] = token or os.path.basename( + os.path.splitext(embedding_file)[0] + ) + embedding_info["embedding"] = embedding_ckpt[token] + embedding_info[ + "num_vectors_per_token" + ] = 1 # All Concepts seem to default to 1 + embedding_info["token_dim"] = embedding_info["embedding"].size()[0] return embedding_info - def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict: - ''' + def _handle_broken_pt_variants( + self, embedding_ckpt: dict, embedding_file: str + ) -> dict: + """ This handles the broken .pt file variants. We only know of one at present. - ''' + """ embedding_info = {} - if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor): - print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829 - for token in list(embedding_ckpt['string_to_token'].keys()): - embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0]) - embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token] - embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0] - embedding_info['token_dim'] = embedding_info['embedding'].size()[0] + if isinstance( + list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor + ): + print( + ">> Detected .pt file variant 1" + ) # example at https://github.com/invoke-ai/InvokeAI/issues/1829 + for token in list(embedding_ckpt["string_to_token"].keys()): + embedding_info["name"] = ( + token + if token != "*" + else os.path.basename(os.path.splitext(embedding_file)[0]) + ) + embedding_info["embedding"] = embedding_ckpt[ + "string_to_param" + ].state_dict()[token] + embedding_info["num_vectors_per_token"] = embedding_info[ + "embedding" + ].shape[0] + embedding_info["token_dim"] = embedding_info["embedding"].size()[0] else: - print('>> Invalid embedding format') + print(">> Invalid embedding format") embedding_info = None return embedding_info