From e6179af46a75afb5118ba1981da0e6594d910803 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 14 Sep 2022 07:02:31 -0400 Subject: [PATCH] Refactor generate.py and dream.py (#534) * revert inadvertent change of conda env name (#528) * Refactor generate.py and dream.py * config file path (models.yaml) is parsed inside Generate() to simplify API * Better handling of keyboard interrupts in file loading mode vs interactive * Removed oodles of unused variables. * move nonfunctional inpainting out of the scripts directory * fix ugly ddim tqdm formatting --- ldm/generate.py | 204 +++++++++++++------------- ldm/models/diffusion/ddim.py | 2 +- scripts/dream.py | 97 +++++------- scripts/{ => orig_scripts}/inpaint.py | 0 4 files changed, 141 insertions(+), 162 deletions(-) rename scripts/{ => orig_scripts}/inpaint.py (100%) diff --git a/ldm/generate.py b/ldm/generate.py index 52c8846d80..1bb8e33eb9 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -17,7 +17,7 @@ import transformers from omegaconf import OmegaConf from PIL import Image, ImageOps from torch import nn -from pytorch_lightning import seed_everything +from pytorch_lightning import seed_everything, logging from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler @@ -35,7 +35,7 @@ Example Usage: from ldm.generate import Generate # Create an object with default values -gr = Generate() +gr = Generate('stable-diffusion-1.4') # do the slow model initialization gr.load_model() @@ -79,16 +79,17 @@ still work. The full list of arguments to Generate() are: gr = Generate( + # these values are set once and shouldn't be changed + conf = path to configuration file ('configs/models.yaml') + model = symbolic name of the model in the configuration file + full_precision = False + + # this value is sticky and maintained between generation calls + sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms + + # these are deprecated - use conf and model instead weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt') - config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml') - iterations = // how many times to run the sampling (1) - steps = // 50 - seed = // current system time - sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms - grid = // false - width = // image width, multiple of 64 (512) - height = // image height, multiple of 64 (512) - cfg_scale = // condition-free guidance scale (7.5) + config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml') ) """ @@ -101,66 +102,62 @@ class Generate: def __init__( self, - iterations = 1, - steps = 50, - cfg_scale = 7.5, - weights = 'models/ldm/stable-diffusion-v1/model.ckpt', - config = 'configs/stable-diffusion/v1-inference.yaml', - grid = False, - width = 512, - height = 512, + model = 'stable-diffusion-1.4', + conf = 'configs/models.yaml', + embedding_path = None, sampler_name = 'k_lms', ddim_eta = 0.0, # deterministic full_precision = False, - strength = 0.75, # default in scripts/img2img.py - seamless = False, - embedding_path = None, - device_type = 'cuda', - ignore_ctrl_c = False, + # these are deprecated; if present they override values in the conf file + weights = None, + config = None, ): - self.iterations = iterations - self.width = width - self.height = height - self.steps = steps - self.cfg_scale = cfg_scale - self.weights = weights - self.config = config - self.sampler_name = sampler_name - self.grid = grid - self.ddim_eta = ddim_eta - self.full_precision = True if choose_torch_device() == 'mps' else full_precision - self.strength = strength - self.seamless = seamless - self.embedding_path = embedding_path - self.device_type = device_type - self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here... - self.model = None # empty for now - self.sampler = None - self.device = None - self.generators = {} - self.base_generator = None - self.seed = None + models = OmegaConf.load(conf) + mconfig = models[model] + self.weights = mconfig.weights if weights is None else weights + self.config = mconfig.config if config is None else config + self.height = mconfig.height + self.width = mconfig.width + self.iterations = 1 + self.steps = 50 + self.cfg_scale = 7.5 + self.sampler_name = sampler_name + self.ddim_eta = 0.0 # same seed always produces same image + self.full_precision = True if choose_torch_device() == 'mps' else full_precision + self.strength = 0.75 + self.seamless = False + self.embedding_path = embedding_path + self.model = None # empty for now + self.sampler = None + self.device = None + self.session_peakmem = None + self.generators = {} + self.base_generator = None + self.seed = None - if device_type == 'cuda' and not torch.cuda.is_available(): - device_type = choose_torch_device() - print(">> cuda not available, using device", device_type) + # Note that in previous versions, there was an option to pass the + # device to Generate(). However the device was then ignored, so + # it wasn't actually doing anything. This logic could be reinstated. + device_type = choose_torch_device() self.device = torch.device(device_type) # for VRAM usage statistics - device_type = choose_torch_device() - self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None + self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None transformers.logging.set_verbosity_error() + # gets rid of annoying messages about random seed + logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) + def prompt2png(self, prompt, outdir, **kwargs): """ Takes a prompt and an output directory, writes out the requested number of PNG files, and returns an array of [[filename,seed],[filename,seed]...] Optional named arguments are the same as those passed to Generate and prompt2image() """ - results = self.prompt2image(prompt, **kwargs) + results = self.prompt2image(prompt, **kwargs) pngwriter = PngWriter(outdir) - prefix = pngwriter.unique_prefix() - outputs = [] + prefix = pngwriter.unique_prefix() + outputs = [] for image, seed in results: name = f'{prefix}.{seed}.png' path = pngwriter.save_image_and_prompt_to_png( @@ -183,33 +180,35 @@ class Generate: self, # these are common prompt, - iterations = None, - steps = None, - seed = None, - cfg_scale = None, - ddim_eta = None, - skip_normalize = False, - image_callback = None, - step_callback = None, - width = None, - height = None, - sampler_name = None, - seamless = False, - log_tokenization= False, - with_variations = None, - variation_amount = 0.0, + iterations = None, + steps = None, + seed = None, + cfg_scale = None, + ddim_eta = None, + skip_normalize = False, + image_callback = None, + step_callback = None, + width = None, + height = None, + sampler_name = None, + seamless = False, + log_tokenization = False, + with_variations = None, + variation_amount = 0.0, # these are specific to img2img and inpaint - init_img = None, - init_mask = None, - fit = False, - strength = None, + init_img = None, + init_mask = None, + fit = False, + strength = None, # these are specific to embiggen (which also relies on img2img args) embiggen = None, embiggen_tiles = None, # these are specific to GFPGAN/ESRGAN - gfpgan_strength= 0, - save_original = False, - upscale = None, + gfpgan_strength = 0, + save_original = False, + upscale = None, + # Set this True to handle KeyboardInterrupt internally + catch_interrupts = False, **args, ): # eat up additional cruft """ @@ -262,10 +261,9 @@ class Generate: self.log_tokenization = log_tokenization with_variations = [] if with_variations is None else with_variations - model = ( - self.load_model() - ) # will instantiate the model or return it from cache - + # will instantiate the model or return it from cache + model = self.load_model() + for m in model.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): m.padding_mode = 'circular' if seamless else m._orig_padding_mode @@ -281,7 +279,6 @@ class Generate: (embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None) ), 'Embiggen requires an init/input image to be specified' - # check this logic - doesn't look right if len(with_variations) > 0 or variation_amount > 1.0: assert seed is not None,\ 'seed must be specified when using with_variations' @@ -298,7 +295,7 @@ class Generate: self._set_sampler() tic = time.time() - if torch.cuda.is_available(): + if self._has_cuda(): torch.cuda.reset_peak_memory_stats() results = list() @@ -307,9 +304,9 @@ class Generate: try: uc, c = get_uc_and_c( - prompt, model=self.model, + prompt, model =self.model, skip_normalize=skip_normalize, - log_tokens=self.log_tokenization + log_tokens =self.log_tokenization ) (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit) @@ -352,27 +349,25 @@ class Generate: save_original = save_original, image_callback = image_callback) - except KeyboardInterrupt: - print('*interrupted*') - if not self.ignore_ctrl_c: - raise KeyboardInterrupt - print( - '>> Partial results will be returned; if --grid was requested, nothing will be returned.' - ) except RuntimeError as e: print(traceback.format_exc(), file=sys.stderr) print('>> Could not generate image.') + except KeyboardInterrupt: + if catch_interrupts: + print('**Interrupted** Partial results will be returned.') + else: + raise KeyboardInterrupt toc = time.time() print('>> Usage stats:') print( f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic) ) - if torch.cuda.is_available() and self.device.type == 'cuda': + if self._has_cuda(): print( f'>> Max VRAM used for this generation:', '%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9), - 'Current VRAM utilization:' + 'Current VRAM utilization:', '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), ) @@ -439,8 +434,7 @@ class Generate: if self.model is None: seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) try: - config = OmegaConf.load(self.config) - model = self._load_model_from_config(config, self.weights) + model = self._load_model_from_config(self.config, self.weights) if self.embedding_path is not None: model.embedding_manager.load( self.embedding_path, self.full_precision @@ -541,8 +535,11 @@ class Generate: print(msg) - def _load_model_from_config(self, config, ckpt): - print(f'>> Loading model from {ckpt}') + # Be warned: config is the path to the model config file, not the dream conf file! + # Also note that we can get config and weights from self, so why do we need to + # pass them as args? + def _load_model_from_config(self, config, weights): + print(f'>> Loading model from {weights}') # for usage statistics device_type = choose_torch_device() @@ -551,10 +548,11 @@ class Generate: tic = time.time() # this does the work - pl_sd = torch.load(ckpt, map_location='cpu') - sd = pl_sd['state_dict'] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) + c = OmegaConf.load(config) + pl_sd = torch.load(weights, map_location='cpu') + sd = pl_sd['state_dict'] + model = instantiate_from_config(c.model) + m, u = model.load_state_dict(sd, strict=False) if self.full_precision: print( @@ -573,7 +571,7 @@ class Generate: print( f'>> Model loaded in', '%4.2fs' % (toc - tic) ) - if device_type == 'cuda': + if self._has_cuda(): print( '>> Max VRAM used to load the model:', '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), @@ -710,3 +708,5 @@ class Generate: return width, height, resize_needed + def _has_cuda(self): + return self.device.type == 'cuda' diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 3868540526..b875aac331 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -225,7 +225,7 @@ class DDIMSampler(object): total_steps = ( timesteps if ddim_use_original_steps else timesteps.shape[0] ) - print(f'Running DDIM Sampling with {total_steps} timesteps') + print(f'\nRunning DDIM Sampling with {total_steps} timesteps') iterator = tqdm( time_range, diff --git a/scripts/dream.py b/scripts/dream.py index 8559c1b083..aec27506c3 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -33,53 +33,35 @@ def main(): print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') sys.exit(-1) - try: - models = OmegaConf.load(opt.config) - width = models[opt.model].width - height = models[opt.model].height - config = models[opt.model].config - weights = models[opt.model].weights - except (FileNotFoundError, IOError, KeyError) as e: - print(f'{e}. Aborting.') - sys.exit(-1) - print('* Initializing, be patient...\n') sys.path.append('.') - from pytorch_lightning import logging from ldm.generate import Generate # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported import transformers - transformers.logging.set_verbosity_error() - # creating a simple text2image object with a handful of + # creating a simple Generate object with a handful of # defaults passed on the command line. # additional parameters will be added (or overriden) during # the user input loop - t2i = Generate( - width=width, - height=height, - sampler_name=opt.sampler_name, - weights=weights, - full_precision=opt.full_precision, - config=config, - grid=opt.grid, - # this is solely for recreating the prompt - seamless=opt.seamless, - embedding_path=opt.embedding_path, - device_type=opt.device, - ignore_ctrl_c=opt.infile is None, - ) + try: + gen = Generate( + conf = opt.config, + model = opt.model, + sampler_name = opt.sampler_name, + embedding_path = opt.embedding_path, + full_precision = opt.full_precision, + ) + except (FileNotFoundError, IOError, KeyError) as e: + print(f'{e}. Aborting.') + sys.exit(-1) # make sure the output directory exists if not os.path.exists(opt.outdir): os.makedirs(opt.outdir) - # gets rid of annoying messages about random seed - logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) - # load the infile as a list of lines infile = None if opt.infile: @@ -98,21 +80,23 @@ def main(): print(">> changed to seamless tiling mode") # preload the model - t2i.load_model() + gen.load_model() if not infile: print( "\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)" ) - cmd_parser = create_cmd_parser() + # web server loops forever if opt.web: - dream_server_loop(t2i, opt.host, opt.port, opt.outdir) - else: - main_loop(t2i, opt.outdir, opt.prompt_as_dir, cmd_parser, infile) + dream_server_loop(gen, opt.host, opt.port, opt.outdir) + sys.exit(0) + cmd_parser = create_cmd_parser() + main_loop(gen, opt.outdir, opt.prompt_as_dir, cmd_parser, infile) -def main_loop(t2i, outdir, prompt_as_dir, parser, infile): +# TODO: main_loop() has gotten busy. Needs to be refactored. +def main_loop(gen, outdir, prompt_as_dir, parser, infile): """prompt/read/execute loop""" done = False path_filter = re.compile(r'[<>:"/\\|?*]') @@ -132,9 +116,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): except EOFError: done = True continue - except KeyboardInterrupt: - done = True - continue # skip empty lines if not command.strip(): @@ -184,6 +165,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): if len(opt.prompt) == 0: print('Try again with a prompt!') continue + # retrieve previous value! if opt.init_img is not None and re.match('^-\\d+$', opt.init_img): try: @@ -204,8 +186,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): opt.seed = None continue - do_grid = opt.grid or t2i.grid - if opt.with_variations is not None: # shotgun parsing, woo parts = [] @@ -258,11 +238,11 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): file_writer = PngWriter(current_outdir) prefix = file_writer.unique_prefix() results = [] # list of filename, prompt pairs - grid_images = dict() # seed -> Image, only used if `do_grid` + grid_images = dict() # seed -> Image, only used if `opt.grid` def image_writer(image, seed, upscaled=False): path = None - if do_grid: + if opt.grid: grid_images[seed] = image else: if upscaled and opt.save_original: @@ -278,16 +258,16 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): iter_opt.with_variations = opt.with_variations + this_variation iter_opt.variation_amount = 0 normalized_prompt = PromptFormatter( - t2i, iter_opt).normalize_prompt() + gen, iter_opt).normalize_prompt() metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}' elif opt.with_variations is not None: normalized_prompt = PromptFormatter( - t2i, opt).normalize_prompt() + gen, opt).normalize_prompt() # use the original seed - the per-iteration value is the last variation-seed metadata_prompt = f'{normalized_prompt} -S{opt.seed}' else: normalized_prompt = PromptFormatter( - t2i, opt).normalize_prompt() + gen, opt).normalize_prompt() metadata_prompt = f'{normalized_prompt} -S{seed}' path = file_writer.save_image_and_prompt_to_png( image, metadata_prompt, filename) @@ -296,16 +276,21 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): results.append([path, metadata_prompt]) last_results.append([path, seed]) - t2i.prompt2image(image_callback=image_writer, **vars(opt)) + catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts + gen.prompt2image( + image_callback=image_writer, + catch_interrupts=catch_ctrl_c, + **vars(opt) + ) - if do_grid and len(grid_images) > 0: + if opt.grid and len(grid_images) > 0: grid_img = make_grid(list(grid_images.values())) grid_seeds = list(grid_images.keys()) first_seed = last_results[0][1] filename = f'{prefix}.{first_seed}.png' # TODO better metadata for grid images normalized_prompt = PromptFormatter( - t2i, opt).normalize_prompt() + gen, opt).normalize_prompt() metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -n{len(grid_images)} # {grid_seeds}' path = file_writer.save_image_and_prompt_to_png( grid_img, metadata_prompt, filename @@ -337,11 +322,12 @@ def get_next_command(infile=None) -> str: # command string raise EOFError else: command = command.strip() - print(f'#{command}') + if len(command)>0: + print(f'#{command}') return command -def dream_server_loop(t2i, host, port, outdir): +def dream_server_loop(gen, host, port, outdir): print('\n* --web was specified, starting web server...') # Change working directory to the stable-diffusion directory os.chdir( @@ -349,7 +335,7 @@ def dream_server_loop(t2i, host, port, outdir): ) # Start server - DreamServer.model = t2i + DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for DreamServer.outdir = outdir dream_server = ThreadingDreamServer((host, port)) print(">> Started Stable Diffusion dream server!") @@ -519,13 +505,6 @@ def create_argv_parser(): default='model', help='Indicates the Stable Diffusion model to use.', ) - parser.add_argument( - '--device', - '-d', - type=str, - default='cuda', - help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available" - ) parser.add_argument( '--model', default='stable-diffusion-1.4', diff --git a/scripts/inpaint.py b/scripts/orig_scripts/inpaint.py similarity index 100% rename from scripts/inpaint.py rename to scripts/orig_scripts/inpaint.py