diff --git a/ldm/dream_util.py b/ldm/dream_util.py index 1526223cd8..ceab2940b1 100644 --- a/ldm/dream_util.py +++ b/ldm/dream_util.py @@ -1,6 +1,7 @@ '''Utilities for dealing with PNG images and their path names''' import os import atexit +import re from PIL import Image,PngImagePlugin # ---------------readline utilities--------------------- @@ -94,40 +95,43 @@ if readline_available: # -------------------image generation utils----- class PngWriter: - def __init__(self,opt): - self.opt = opt - self.filepath = None - self.files_written = [] + def __init__(self,outdir,opt,prompt): + self.outdir = outdir + self.opt = opt + self.prompt = prompt + self.filepath = None + self.files_written = [] def write_image(self,image,seed): - self.filepath = self.unique_filename(self,opt,seed,self.filepath) # will increment name in some sensible way + self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way try: - image.save(self.filename) + prompt = f'{self.prompt} -S{seed}' + self.save_image_and_prompt_to_png(image,prompt,self.filepath) except IOError as e: print(e) self.files_written.append([self.filepath,seed]) - def unique_filename(self,opt,seed,previouspath): + def unique_filename(self,seed,previouspath): revision = 1 if previouspath is None: # sort reverse alphabetically until we find max+1 - dirlist = sorted(os.listdir(outdir),reverse=True) + dirlist = sorted(os.listdir(self.outdir),reverse=True) # find the first filename that matches our pattern or return 000000.0.png filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') basecount = int(filename.split('.',1)[0]) basecount += 1 - if opt.batch_size > 1: + if self.opt.batch_size > 1: filename = f'{basecount:06}.{seed}.01.png' else: filename = f'{basecount:06}.{seed}.png' - return os.path.join(outdir,filename) + return os.path.join(self.outdir,filename) else: basename = os.path.basename(previouspath) x = re.match('^(\d+)\..*\.png',basename) if not x: - return self.unique_filename(opt,seed,previouspath) + return self.unique_filename(seed,previouspath) basecount = int(x.groups()[0]) series = 0 @@ -135,9 +139,41 @@ class PngWriter: while not finished: series += 1 filename = f'{basecount:06}.{seed}.png' - if isbatch or os.path.exists(os.path.join(outdir,filename)): + if self.opt.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)): filename = f'{basecount:06}.{seed}.{series:02}.png' - finished = not os.path.exists(os.path.join(outdir,filename)) - return os.path.join(outdir,filename) + finished = not os.path.exists(os.path.join(self.outdir,filename)) + return os.path.join(self.outdir,filename) + def save_image_and_prompt_to_png(self,image,prompt,path): + info = PngImagePlugin.PngInfo() + info.add_text("Dream",prompt) + image.save(path,"PNG",pnginfo=info) + +class PromptFormatter(): + def __init__(self,t2i,opt): + self.t2i = t2i + self.opt = opt + + def normalize_prompt(self): + '''Normalize the prompt and switches''' + t2i = self.t2i + opt = self.opt + + switches = list() + switches.append(f'"{opt.prompt}"') + switches.append(f'-s{opt.steps or t2i.steps}') + switches.append(f'-b{opt.batch_size or t2i.batch_size}') + switches.append(f'-W{opt.width or t2i.width}') + switches.append(f'-H{opt.height or t2i.height}') + switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') + switches.append(f'-m{t2i.sampler_name}') + if opt.variants: + switches.append(f'-v{opt.variants}') + if opt.init_img: + switches.append(f'-I{opt.init_img}') + if opt.strength and opt.init_img is not None: + switches.append(f'-f{opt.strength or t2i.strength}') + if t2i.full_precision: + switches.append('-F') + return ' '.join(switches) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 8e8b077922..3b5aaeb696 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -99,13 +99,13 @@ The vast majority of these arguments default to reasonable values. def __init__(self, batch_size=1, iterations = 1, - grid=False, - individual=None, # redundant steps=50, seed=None, cfg_scale=7.5, weights="models/ldm/stable-diffusion-v1/model.ckpt", config = "configs/stable-diffusion/v1-inference.yaml", + width=512, + height=512, sampler_name="klms", latent_channels=4, downsampling_factor=8, @@ -121,7 +121,6 @@ The vast majority of these arguments default to reasonable values. self.iterations = iterations self.width = width self.height = height - self.grid = grid self.steps = steps self.cfg_scale = cfg_scale self.weights = weights @@ -143,25 +142,26 @@ The vast majority of these arguments default to reasonable values. else: self.seed = seed - def generate(self, - # these are common - prompt, - batch_size=None, - iterations=None, - steps=None, - seed=None, - cfg_scale=None, - ddim_eta=None, - skip_normalize=False, - image_callback=None, - # these are specific to txt2img - width=None, - height=None, - # these are specific to img2img - init_img=None, - strength=None, - variants=None): - '''ldm.generate() is the common entry point for txt2img() and img2img()''' + def prompt2image(self, + # these are common + prompt, + batch_size=None, + iterations=None, + steps=None, + seed=None, + cfg_scale=None, + ddim_eta=None, + skip_normalize=False, + image_callback=None, + # these are specific to txt2img + width=None, + height=None, + # these are specific to img2img + init_img=None, + strength=None, + variants=None, + **args): # eat up additional cruft + '''ldm.prompt2image() is the common entry point for txt2img() and img2img()''' steps = steps or self.steps seed = seed or self.seed width = width or self.width @@ -178,10 +178,6 @@ The vast majority of these arguments default to reasonable values. data = [batch_size * [prompt]] scope = autocast if self.precision=="autocast" else nullcontext - if grid: - callback = self.image2png - else: - callback = None tic = time.time() if init_img: @@ -212,7 +208,7 @@ The vast majority of these arguments default to reasonable values. steps,seed,cfg_scale,ddim_eta, skip_normalize, width,height, - callback=callback): # the callback is called each time a new Image is generated + callback): # the callback is called each time a new Image is generated """ Generate an image from the prompt, writing iteration images into the outdir The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...] @@ -224,14 +220,14 @@ The vast majority of these arguments default to reasonable values. # Gawd. Too many levels of indent here. Need to refactor into smaller routines! try: - with precision_scope(self.device.type), model.ema_scope(): + with precision_scope(self.device.type), self.model.ema_scope(): all_samples = list() for n in trange(iterations, desc="Sampling"): seed_everything(seed) for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) + uc = self.model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -247,9 +243,9 @@ The vast majority of these arguments default to reasonable values. weight = weights[i] if not skip_normalize: weight = weight / totalWeight - c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) else: # just standard 1 prompt - c = model.get_learned_conditioning(prompts) + c = self.model.get_learned_conditioning(prompts) shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples_ddim, _ = sampler.sample(S=steps, @@ -261,7 +257,7 @@ The vast majority of these arguments default to reasonable values. unconditional_conditioning=uc, eta=ddim_eta) - x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = self.model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') @@ -277,8 +273,6 @@ The vast majority of these arguments default to reasonable values. except RuntimeError as e: print(str(e)) - toc = time.time() - print(f'{image_count} images generated in',"%4.2fs"% (toc-tic)) return images @torch.no_grad() @@ -297,14 +291,14 @@ The vast majority of these arguments default to reasonable values. # PLMS sampler not supported yet, so ignore previous sampler if self.sampler_name!='ddim': print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler") - sampler = DDIMSampler(model, device=self.device) + sampler = DDIMSampler(self.model, device=self.device) else: sampler = self.sampler init_image = self._load_img(init_img).to(self.device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) with precision_scope(self.device.type): - init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) @@ -314,14 +308,14 @@ The vast majority of these arguments default to reasonable values. images = list() try: - with precision_scope(self.device.type), model.ema_scope(): + with precision_scope(self.device.type), self.model.ema_scope(): all_samples = list() for n in trange(iterations, desc="Sampling"): seed_everything(seed) for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) + uc = self.model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -337,9 +331,9 @@ The vast majority of these arguments default to reasonable values. weight = weights[i] if not skip_normalize: weight = weight / totalWeight - c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) else: # just standard 1 prompt - c = model.get_learned_conditioning(prompts) + c = self.model.get_learned_conditioning(prompts) # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) @@ -347,7 +341,7 @@ The vast majority of these arguments default to reasonable values. samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc,) - x_samples = model.decode_first_stage(samples) + x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: diff --git a/scripts/dream.py b/scripts/dream.py index 6ff7802fa2..ab01e8db01 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -6,7 +6,7 @@ import shlex import os import sys import copy -from ldm.dream_util import Completer,PngWriter +from ldm.dream_util import Completer,PngWriter,PromptFormatter debugging = False @@ -27,10 +27,6 @@ def main(): config = "configs/stable-diffusion/v1-inference.yaml" weights = "models/ldm/stable-diffusion-v1/model.ckpt" - # command line history will be stored in a file called "~/.dream_history" - if readline_available: - setup_readline() - print("* Initializing, be patient...\n") sys.path.append('.') from pytorch_lightning import logging @@ -46,8 +42,6 @@ def main(): # the user input loop t2i = T2I(width=width, height=height, - batch_size=opt.batch_size, - outdir=opt.outdir, sampler_name=opt.sampler_name, weights=weights, full_precision=opt.full_precision, @@ -79,13 +73,13 @@ def main(): log_path = os.path.join(opt.outdir,'dream_log.txt') with open(log_path,'a') as log: cmd_parser = create_cmd_parser() - main_loop(t2i,cmd_parser,log,infile) + main_loop(t2i,opt.outdir,cmd_parser,log,infile) log.close() if infile: infile.close() -def main_loop(t2i,parser,log,infile): +def main_loop(t2i,outdir,parser,log,infile): ''' prompt/read/execute loop ''' done = False @@ -123,13 +117,13 @@ def main_loop(t2i,parser,log,infile): if elements[0]=='cd' and len(elements)>1: if os.path.exists(elements[1]): print(f"setting image output directory to {elements[1]}") - opt.outdir=elements[1] + outdir=elements[1] else: print(f"directory {elements[1]} does not exist") continue if elements[0]=='pwd': - print(f"current output directory is {opt.outdir}") + print(f"current output directory is {outdir}") continue if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command @@ -158,88 +152,41 @@ def main_loop(t2i,parser,log,infile): print("Try again with a prompt!") continue + normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt() try: - file_writer = PngWriter(opt) - opt.callback = file_writer(write_image) - run_generator(**vars(opt)) + file_writer = PngWriter(outdir,opt,normalized_prompt) + callback = file_writer.write_image + + t2i.prompt2image(image_callback=callback, + **vars(opt)) results = file_writer.files_written + except AssertionError as e: print(e) continue print("Outputs:") - write_log_message(t2i,opt,results,log) + write_log_message(t2i,normalized_prompt,results,log) print("goodbye!") -def write_log_message(t2i,opt,results,logfile): +def write_log_message(t2i,prompt,results,logfile): ''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata ''' - switches = _reconstruct_switches(t2i,opt) - prompt_str = ' '.join(switches) - - # when multiple images are produced in batch, then we keep track of where each starts last_seed = None img_num = 1 - batch_size = opt.batch_size or t2i.batch_size seenit = {} seeds = [a[1] for a in results] - if batch_size > 1: - seeds = f"(seeds for each batch row: {seeds})" - else: - seeds = f"(seeds for individual images: {seeds})" + seeds = f"(seeds for individual images: {seeds})" for r in results: seed = r[1] - log_message = (f'{r[0]}: {prompt_str} -S{seed}') + log_message = (f'{r[0]}: {prompt} -S{seed}') - if batch_size > 1: - if seed != last_seed: - img_num = 1 - log_message += f' # (batch image {img_num} of {batch_size})' - else: - img_num += 1 - log_message += f' # (batch image {img_num} of {batch_size})' - last_seed = seed print(log_message) logfile.write(log_message+"\n") logfile.flush() - if r[0] not in seenit: - seenit[r[0]] = True - try: - if opt.grid: - _write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}') - else: - _write_prompt_to_png(r[0],f'{prompt_str} -S{seed}') - except FileNotFoundError: - print(f"Could not open file '{r[0]}' for reading") -def _reconstruct_switches(t2i,opt): - '''Normalize the prompt and switches''' - switches = list() - switches.append(f'"{opt.prompt}"') - switches.append(f'-s{opt.steps or t2i.steps}') - switches.append(f'-b{opt.batch_size or t2i.batch_size}') - switches.append(f'-W{opt.width or t2i.width}') - switches.append(f'-H{opt.height or t2i.height}') - switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') - switches.append(f'-m{t2i.sampler_name}') - if opt.variants: - switches.append(f'-v{opt.variants}') - if opt.init_img: - switches.append(f'-I{opt.init_img}') - if opt.strength and opt.init_img is not None: - switches.append(f'-f{opt.strength or t2i.strength}') - if t2i.full_precision: - switches.append('-F') - return switches - -def _write_prompt_to_png(path,prompt): - info = PngImagePlugin.PngInfo() - info.add_text("Dream",prompt) - im = Image.open(path) - im.save(path,"PNG",pnginfo=info) - def create_argv_parser(): parser = argparse.ArgumentParser(description="Parse script's command line args") parser.add_argument("--laion400m", @@ -260,10 +207,6 @@ def create_argv_parser(): dest='full_precision', action='store_true', help="use slower full precision math for calculations") - parser.add_argument('-b','--batch_size', - type=int, - default=1, - help="number of images to produce per iteration (faster, but doesn't generate individual seeds") parser.add_argument('--sampler','-m', dest="sampler_name", choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],