Compare commits

...

5 Commits

Author SHA1 Message Date
Lincoln Stein
87fb4186d4 folded in changes from img2img-dev 2022-08-18 12:45:02 -04:00
Lincoln Stein
750408f793 added command-line completion 2022-08-18 12:43:59 -04:00
Lincoln Stein
bf76c4f283 img2img is now working; small refactoring of grid code in simplet2i.py 2022-08-18 10:47:53 -04:00
Lincoln Stein
831bbd7a54 improved error reporting when a missing online dependency can't be downloaded 2022-08-17 18:06:30 -04:00
Lincoln Stein
c477525036 catch and handle malformed user inputs; documentation fixes 2022-08-17 12:35:49 -04:00
6 changed files with 302 additions and 60 deletions

View File

@@ -12,10 +12,11 @@ lets you create images from a prompt in just three lines of code:
~~~~
from ldm.simplet2i import T2I
model = T2I()
model.text2image("a unicorn in manhattan")
model = T2I()
outputs = model.text2image("a unicorn in manhattan")
~~~~
Outputs is a list of lists in the format [[filename1,seed1],[filename2,seed2]...]
Please see ldm/simplet2i.py for more information.
## Interactive command-line interface similar to the Discord bot
@@ -27,6 +28,9 @@ server. The advantage of this is that the lengthy model
initialization only happens once. After that image generation is
fast.
The script uses the readline library to allow for in-line editing,
command history (up and down arrows) and more.
Note that this has only been tested in the Linux environment!
~~~~
@@ -48,14 +52,18 @@ Outputs:
outputs/txt2img-samples/00010.png: "ashley judd riding a camel" -n2 -S 1362479620
dream> "your prompt here" -n6 -g
...
outputs/txt2img-samples/00041.png: "your prompt here" -n6 -g -S 2685670268
seeds for individual rows: [2685670268, 1216708065, 2335773498, 822223658, 714542046, 3395302430]
~~~~
Command-line arguments (`./scripts/dream.py -h`) allow you to change
Command-line arguments passed to the script allow you to change
various defaults, and select between the mature stable-diffusion
weights (512x512) and the older (256x256) latent diffusion weights
(laion400m). Within the script, the switches are (mostly) identical to
those used in the Discord bot, except you don't need to type "!dream".
(laion400m). From the dream> prompt, the arguments are (mostly)
identical to those used in the Discord bot, except you don't need to
type "!dream". Pass "-h" (or "--help") to list the arguments.
For command-line help, type -h (or --help) at the dream> prompt.
## Workaround for machines with limited internet connectivity

View File

@@ -17,6 +17,7 @@ from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
import urllib
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
@@ -524,7 +525,10 @@ class LatentDiffusion(DDPM):
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
try:
model = instantiate_from_config(config)
except urllib.error.URLError:
raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.")
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):

View File

@@ -60,7 +60,10 @@ class BERTTokenizer(AbstractEncoder):
# by running:
# from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased")
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
try:
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
except OSError:
raise SystemExit("* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length

View File

@@ -7,7 +7,8 @@ from ldm.simplet2i import T2I
t2i = T2I(outdir = <path> // outputs/txt2img-samples
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
batch = <integer> // 1
iterations = <integer> // how many times to run the sampling (1)
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler = ['ddim','plms'] // ddim
@@ -17,27 +18,32 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
cfg_scale = <float> // unconditional guidance scale (7.5)
fixed_code = <boolean> // False
)
# do the slow model initialization
t2i.load_model()
# Do the fast inference & image generation. Any options passed here
# override the default values assigned during class initialization
# Will call load_model() if the model was not previously loaded.
t2i.txt2img(prompt = <string> // required
// the remaining option arguments override constructur value when present
outdir = <path>
iterations = <integer>
batch = <integer>
steps = <integer>
seed = <integer>
sampler = ['ddim','plms']
grid = <boolean>
width = <integer>
height = <integer>
cfg_scale = <float>
) -> boolean
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
results = t2i.txt2img(prompt = "an astronaut riding a horse"
outdir = "./outputs/txt2img-samples)
)
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
# Same thing, but using an initial image.
results = t2i.img2img(prompt = "an astronaut riding a horse"
outdir = "./outputs/img2img-samples"
init_img = "./sketches/horse+rider.png")
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
"""
import torch
import numpy as np
import random
@@ -47,7 +53,7 @@ from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from einops import rearrange, repeat
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
@@ -67,7 +73,7 @@ class T2I:
model
config
iterations
batch
batch_size
steps
seed
sampler
@@ -80,10 +86,11 @@ class T2I:
latent_channels
downsampling_factor
precision
strength
"""
def __init__(self,
outdir="outputs/txt2img-samples",
batch=1,
batch_size=1,
iterations = 1,
width=512,
height=512,
@@ -99,10 +106,11 @@ class T2I:
downsampling_factor=8,
ddim_eta=0.0, # deterministic
fixed_code=False,
precision='autocast'
precision='autocast',
strength=0.75 # default in scripts/img2img.py
):
self.outdir = outdir
self.batch = batch
self.batch_size = batch_size
self.iterations = iterations
self.width = width
self.height = height
@@ -117,16 +125,21 @@ class T2I:
self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta
self.precision = precision
self.strength = strength
self.model = None # empty for now
self.sampler = None
if seed is None:
self.seed = self._new_seed()
else:
self.seed = seed
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None):
""" generate an image from the prompt, writing iteration images into the outdir """
cfg_scale=None,ddim_eta=None,strength=None,init_img=None):
"""
Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
"""
outdir = outdir or self.outdir
steps = steps or self.steps
seed = seed or self.seed
@@ -134,8 +147,9 @@ class T2I:
height = height or self.height
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
batch = batch or self.batch
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength # not actually used here, but preserved for code refactoring
model = self.load_model() # will instantiate the model or return it from cache
@@ -146,7 +160,7 @@ class T2I:
if individual:
grid = False
data = [batch * [prompt]]
data = [batch_size * [prompt]]
# make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True)
@@ -154,7 +168,7 @@ class T2I:
start_code = None
if self.fixed_code:
start_code = torch.randn([batch,
start_code = torch.randn([batch_size,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
@@ -176,14 +190,14 @@ class T2I:
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch * [""])
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
conditioning=c,
batch_size=batch,
batch_size_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
@@ -208,24 +222,146 @@ class T2I:
seed = self._new_seed()
if grid:
n_rows = batch if batch>1 else int(math.sqrt(batch * iterations))
# save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
toc = time.time()
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
return images
# There is lots of shared code between this and txt2img and should be refactored.
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None):
"""
Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
"""
outdir = outdir or self.outdir
steps = steps or self.steps
seed = seed or self.seed
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength
if init_img is None:
print("no init_img provided!")
return []
model = self.load_model() # will instantiate the model or return it from cache
# grid and individual are mutually exclusive, with individual taking priority.
# not necessary, but needed for compatability with dream bot
if (grid is None):
grid = self.grid
if individual:
grid = False
data = [batch_size * [prompt]]
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
sampler = DDIMSampler(model)
else:
sampler = self.sampler
# make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True)
base_count = len(os.listdir(outdir))-1
assert os.path.isfile(init_img)
init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
try:
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
except AssertionError:
print(f"strength must be between 0.0 and 1.0, but received value {strength}")
return []
t_enc = int(strength * steps)
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if self.precision=="autocast" else nullcontext
images = list()
seeds = list()
tic = time.time()
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not grid:
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
base_count += 1
else:
all_samples.append(x_samples)
seeds.append(seed)
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
toc = time.time()
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
return images
def _make_grid(self,samples,seeds,batch_size,iterations,outdir):
images = list()
base_count = len(os.listdir(outdir))-1
n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations))
# save as grid
grid = torch.stack(samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
return images
def _new_seed(self):
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
@@ -267,3 +403,13 @@ class T2I:
model.eval()
return model
def _load_img(self,path):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.

View File

@@ -6,6 +6,8 @@ import shlex
import atexit
import os
debugging = False
def main():
''' Initialize command-line parsers and the diffusion model '''
arg_parser = create_argv_parser()
@@ -24,7 +26,7 @@ def main():
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
# command line history will be stored in a file called "~/.dream_history"
load_history()
setup_readline()
print("* Initializing, be patient...\n")
from pytorch_lightning import logging
@@ -36,7 +38,7 @@ def main():
# the user input loop
t2i = T2I(width=width,
height=height,
batch=opt.batch,
batch_size=opt.batch_size,
outdir=opt.outdir,
sampler=opt.sampler,
weights=weights,
@@ -50,8 +52,9 @@ def main():
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# preload the model
t2i.load_model()
print("\n* Initialization done! Awaiting your command...")
if not debugging:
t2i.load_model()
print("\n* Initialization done! Awaiting your command (-h for help)...")
log_path = os.path.join(opt.outdir,"dream_log.txt")
with open(log_path,'a') as log:
@@ -68,6 +71,7 @@ def main_loop(t2i,parser,log):
print("goodbye!")
break
# rearrange the arguments to mimic how it works in the Dream bot.
elements = shlex.split(command)
switches = ['']
switches_started = False
@@ -81,12 +85,20 @@ def main_loop(t2i,parser,log):
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][:len(switches[0])-1]
try:
opt = parser.parse_args(switches)
except SystemExit:
parser.print_help()
pass
results = t2i.txt2img(**vars(opt))
continue
if len(opt.prompt)==0:
print("Try again with a prompt!")
continue
if opt.init_img is None:
results = t2i.txt2img(**vars(opt))
else:
results = t2i.img2img(**vars(opt))
print("Outputs:")
write_log_message(opt,switches,results,log)
@@ -130,7 +142,7 @@ def create_argv_parser():
type=int,
default=1,
help="number of images to generate")
parser.add_argument('-b','--batch',
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)")
@@ -147,19 +159,29 @@ def create_argv_parser():
def create_cmd_parser():
parser = argparse.ArgumentParser(description="Parse terminal input in a discord 'dreambot' fashion")
parser = argparse.ArgumentParser(description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12')
parser.add_argument('prompt')
parser.add_argument('-s','--steps',type=int,help="number of steps")
parser.add_argument('-S','--seed',type=int,help="image seed")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)")
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)")
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
return parser
def setup_readline():
readline.set_completer(Completer(['--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
load_history()
def load_history():
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
try:
@@ -169,5 +191,64 @@ def load_history():
pass
atexit.register(readline.write_history_file,histfile)
class Completer():
def __init__(self,options):
self.options = sorted(options)
return
def complete(self,text,state):
if text.startswith('-I') or text.startswith('--init_img'):
return self._image_completions(text,state)
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _image_completions(self,text,state):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
continue
full_path = os.path.join(dir,n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
elif n.endswith('.png'):
matches.append(os.path.join(os.path.dirname(text),n))
try:
response = matches[state]
except IndexError:
response = None
return response
if __name__ == "__main__":
main()

View File

@@ -5,13 +5,13 @@
# two machines must share a common .cache directory.
# this will preload the Bert tokenizer fles
print("preloading bert tokenizer...",end='')
print("preloading bert tokenizer...")
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
print("...success")
# this will download requirements for Kornia
print("preloading Kornia requirements...",end='')
print("preloading Kornia requirements...")
import kornia
print("...success")