code is reorganized and mostly functional. Grid needs to be brought back online, as well as naming of img2img variants (currently the variants get written but not logged)

This commit is contained in:
Lincoln Stein
2022-08-24 19:47:59 -04:00
parent b12955c963
commit b978536385
3 changed files with 101 additions and 128 deletions

View File

@@ -6,7 +6,7 @@ import shlex
import os
import sys
import copy
from ldm.dream_util import Completer,PngWriter
from ldm.dream_util import Completer,PngWriter,PromptFormatter
debugging = False
@@ -27,10 +27,6 @@ def main():
config = "configs/stable-diffusion/v1-inference.yaml"
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
# command line history will be stored in a file called "~/.dream_history"
if readline_available:
setup_readline()
print("* Initializing, be patient...\n")
sys.path.append('.')
from pytorch_lightning import logging
@@ -46,8 +42,6 @@ def main():
# the user input loop
t2i = T2I(width=width,
height=height,
batch_size=opt.batch_size,
outdir=opt.outdir,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
@@ -79,13 +73,13 @@ def main():
log_path = os.path.join(opt.outdir,'dream_log.txt')
with open(log_path,'a') as log:
cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log,infile)
main_loop(t2i,opt.outdir,cmd_parser,log,infile)
log.close()
if infile:
infile.close()
def main_loop(t2i,parser,log,infile):
def main_loop(t2i,outdir,parser,log,infile):
''' prompt/read/execute loop '''
done = False
@@ -123,13 +117,13 @@ def main_loop(t2i,parser,log,infile):
if elements[0]=='cd' and len(elements)>1:
if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}")
opt.outdir=elements[1]
outdir=elements[1]
else:
print(f"directory {elements[1]} does not exist")
continue
if elements[0]=='pwd':
print(f"current output directory is {opt.outdir}")
print(f"current output directory is {outdir}")
continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
@@ -158,88 +152,41 @@ def main_loop(t2i,parser,log,infile):
print("Try again with a prompt!")
continue
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
try:
file_writer = PngWriter(opt)
opt.callback = file_writer(write_image)
run_generator(**vars(opt))
file_writer = PngWriter(outdir,opt,normalized_prompt)
callback = file_writer.write_image
t2i.prompt2image(image_callback=callback,
**vars(opt))
results = file_writer.files_written
except AssertionError as e:
print(e)
continue
print("Outputs:")
write_log_message(t2i,opt,results,log)
write_log_message(t2i,normalized_prompt,results,log)
print("goodbye!")
def write_log_message(t2i,opt,results,logfile):
def write_log_message(t2i,prompt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches(t2i,opt)
prompt_str = ' '.join(switches)
# when multiple images are produced in batch, then we keep track of where each starts
last_seed = None
img_num = 1
batch_size = opt.batch_size or t2i.batch_size
seenit = {}
seeds = [a[1] for a in results]
if batch_size > 1:
seeds = f"(seeds for each batch row: {seeds})"
else:
seeds = f"(seeds for individual images: {seeds})"
seeds = f"(seeds for individual images: {seeds})"
for r in results:
seed = r[1]
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
log_message = (f'{r[0]}: {prompt} -S{seed}')
if batch_size > 1:
if seed != last_seed:
img_num = 1
log_message += f' # (batch image {img_num} of {batch_size})'
else:
img_num += 1
log_message += f' # (batch image {img_num} of {batch_size})'
last_seed = seed
print(log_message)
logfile.write(log_message+"\n")
logfile.flush()
if r[0] not in seenit:
seenit[r[0]] = True
try:
if opt.grid:
_write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}')
else:
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
except FileNotFoundError:
print(f"Could not open file '{r[0]}' for reading")
def _reconstruct_switches(t2i,opt):
'''Normalize the prompt and switches'''
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.variants:
switches.append(f'-v{opt.variants}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return switches
def _write_prompt_to_png(path,prompt):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
im = Image.open(path)
im.save(path,"PNG",pnginfo=info)
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
parser.add_argument("--laion400m",
@@ -260,10 +207,6 @@ def create_argv_parser():
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],