Merge branch 'development' into development

This commit is contained in:
Peter Baylies
2022-09-21 03:10:49 -04:00
committed by GitHub
29 changed files with 1038 additions and 484 deletions

View File

@@ -23,14 +23,51 @@ from PIL import Image, ImageOps
from torch import nn
from pytorch_lightning import seed_everything, logging
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter
from ldm.dream.image_util import InitImageResizer
from ldm.dream.devices import choose_torch_device, choose_precision
from ldm.dream.conditioning import get_uc_and_c
from ldm.dream.pngwriter import PngWriter
from ldm.dream.args import metadata_from_png
from ldm.dream.image_util import InitImageResizer
from ldm.dream.devices import choose_torch_device, choose_precision
from ldm.dream.conditioning import get_uc_and_c
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
@@ -134,6 +171,9 @@ class Generate:
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
gfpgan=None,
codeformer=None,
esrgan=None
):
models = OmegaConf.load(conf)
mconfig = models[model]
@@ -157,6 +197,9 @@ class Generate:
self.generators = {}
self.base_generator = None
self.seed = None
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@@ -237,6 +280,7 @@ class Generate:
# these are specific to embiggen (which also relies on img2img args)
embiggen = None,
embiggen_tiles = None,
out_direction = None,
# these are specific to GFPGAN/ESRGAN
facetool = None,
gfpgan_strength = 0,
@@ -287,16 +331,17 @@ class Generate:
write the prompt into the PNG metadata.
"""
# TODO: convert this into a getattr() loop
steps = steps or self.steps
width = width or self.width
height = height or self.height
seamless = seamless or self.seamless
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
iterations = iterations or self.iterations
strength = strength or self.strength
self.seed = seed
steps = steps or self.steps
width = width or self.width
height = height or self.height
seamless = seamless or self.seamless
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
iterations = iterations or self.iterations
strength = strength or self.strength
self.seed = seed
self.log_tokenization = log_tokenization
self.step_callback = step_callback
with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache
@@ -305,20 +350,21 @@ class Generate:
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert threshold >= 0.0, '--threshold must be >=0.0'
assert (
0.0 < strength < 1.0
), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
assert (
0.0 <= variation_amount <= 1.0
0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]'
assert (
0.0 <= perlin <= 1.0
), '-v --perlin must be in [0.0, 1.0]'
assert (
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
(embiggen == None and embiggen_tiles == None) or (
(embiggen != None or embiggen_tiles != None) and init_img != None)
), 'Embiggen requires an init/input image to be specified'
if len(with_variations) > 0 or variation_amount > 1.0:
@@ -340,9 +386,9 @@ class Generate:
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
results = list()
init_image = None
mask_image = None
results = list()
init_image = None
mask_image = None
try:
uc, c = get_uc_and_c(
@@ -351,8 +397,14 @@ class Generate:
log_tokens =self.log_tokenization
)
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
init_image,mask_image = self._make_images(
init_img,
init_mask,
width,
height,
fit=fit,
out_direction=out_direction,
)
if (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint()
elif (embiggen != None or embiggen_tiles != None):
@@ -362,28 +414,29 @@ class Generate:
else:
generator = self._make_txt2img()
generator.set_variation(self.seed, variation_amount, with_variations)
generator.set_variation(
self.seed, variation_amount, with_variations)
results = generator.generate(
prompt,
iterations = iterations,
seed = self.seed,
sampler = self.sampler,
steps = steps,
cfg_scale = cfg_scale,
conditioning = (uc,c),
ddim_eta = ddim_eta,
image_callback = image_callback, # called after the final image is generated
step_callback = step_callback, # called after each intermediate image is generated
width = width,
height = height,
init_img = init_img, # embiggen needs to manipulate from the unmodified init_img
init_image = init_image, # notice that init_image is different from init_img
mask_image = mask_image,
strength = strength,
threshold = threshold,
perlin = perlin,
embiggen = embiggen,
embiggen_tiles = embiggen_tiles,
iterations=iterations,
seed=self.seed,
sampler=self.sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
threshold=threshold,
perlin=perlin,
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
)
if init_color:
@@ -412,7 +465,8 @@ class Generate:
toc = time.time()
print('>> Usage stats:')
print(
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
f'>> {len(results)} image(s) generated in', '%4.2fs' % (
toc - tic)
)
if self._has_cuda():
print(
@@ -431,30 +485,163 @@ class Generate:
)
return results
def _make_images(self, img_path, mask_path, width, height, fit=False):
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
# in a nice harmonized call signature. For now we have a bunch of if/elses!
def apply_postprocessor(
self,
image_path,
tool = 'gfpgan', # one of 'upscale', 'gfpgan', 'codeformer', 'outpaint', or 'embiggen'
gfpgan_strength = 0.0,
codeformer_fidelity = 0.75,
upscale = None,
out_direction = None,
save_original = True, # to get new name
callback = None,
opt = None,
):
# retrieve the seed from the image;
# note that we will try both the new way and the old way, since not all files have the
# metadata (yet)
seed = None
image_metadata = None
prompt = None
try:
args = metadata_from_png(image_path)
if len(args) > 1:
print("* Can't postprocess a grid")
return
seed = args[0].seed
prompt = args[0].prompt
print(f'>> retrieved seed {seed} and prompt "{prompt}" from {image_path}')
except:
m = re.search('(\d+)\.png$',image_path)
if m:
seed = m.group(1)
if not seed:
print('* Could not recover seed for image. Replacing with 42. This will not affect image quality')
seed = 42
# face fixers and esrgan take an Image, but embiggen takes a path
image = Image.open(image_path)
# Note that we need to adopt a uniform API for the postprocessors.
# This is completely ad hoc ATCM
if tool in ('gfpgan','codeformer','upscale'):
if tool == 'gfpgan':
facetool = 'gfpgan'
elif tool == 'codeformer':
facetool = 'codeformer'
elif tool == 'upscale':
facetool = 'gfpgan' # but won't be run
gfpgan_strength = 0
return self.upscale_and_reconstruct(
[[image,seed]],
facetool = facetool,
strength = gfpgan_strength,
codeformer_fidelity = codeformer_fidelity,
save_original = save_original,
upscale = upscale,
image_callback = callback,
)
elif tool == 'embiggen':
# fetch the metadata from the image
generator = self._make_embiggen()
uc, c = get_uc_and_c(
prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization
)
opt.strength = 0.40
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
# embiggen takes a image path (sigh)
generator.generate(
prompt,
sampler = self.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.ddim_eta,
conditioning= (uc, c),
init_img = image_path, # not the Image! (sigh)
init_image = image, # embiggen wants both! (sigh)
strength = opt.strength,
width = opt.width,
height = opt.height,
embiggen = opt.embiggen,
embiggen_tiles = opt.embiggen_tiles,
image_callback = callback,
)
elif tool == 'outpaint':
oldargs = metadata_from_png(image_path)
opt.strength = 0.83
opt.init_img = image_path
return self.prompt2image(
oldargs.prompt,
out_direction = opt.out_direction,
sampler = self.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.ddim_eta,
conditioning= get_uc_and_c(
oldargs.prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization
),
width = opt.width,
height = opt.height,
init_img = image_path, # not the Image! (sigh)
strength = opt.strength,
image_callback = callback,
)
else:
print(f'* postprocessing tool {tool} is not yet supported')
return None
def _make_images(
self,
img_path,
mask_path,
width,
height,
fit=False,
out_direction=None,
):
init_image = None
init_mask = None
if not img_path:
return None,None
return None, None
image = self._load_img(img_path, width, height, fit=fit) # this returns an Image
image = self._load_img(
img_path,
width,
height,
fit=fit
) # this returns an Image
if out_direction:
image = self._create_outpaint_image(image, out_direction)
init_image = self._create_init_image(image) # this returns a torch tensor
if self._has_transparency(image) and not mask_path: # if image has a transparent area and no mask was provided, then try to generate mask
print('>> Initial image has transparent areas. Will inpaint in these regions.')
# if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image) and not mask_path:
print(
'>> Initial image has transparent areas. Will inpaint in these regions.')
if self._check_for_erasure(image):
print(
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
)
init_mask = self._create_init_mask(image) # this returns a torch tensor
# this returns a torch tensor
init_mask = self._create_init_mask(image)
if mask_path:
mask_image = self._load_img(mask_path, width, height, fit=fit) # this returns an Image
init_mask = self._create_init_mask(mask_image)
mask_image = self._load_img(
mask_path, width, height, fit=fit) # this returns an Image
init_mask = self._create_init_mask(mask_image)
return init_image,init_mask
return init_image, init_mask
def _make_img2img(self):
if not self.generators.get('img2img'):
@@ -536,38 +723,26 @@ class Generate:
codeformer_fidelity = 0.75,
save_original = False,
image_callback = None):
try:
if upscale is not None:
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
if strength > 0:
if facetool == 'codeformer':
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
else:
from ldm.gfpgan.gfpgan_tools import run_gfpgan
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return
for r in image_list:
image, seed = r
try:
if upscale is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = real_esrgan_upscale(
image,
upscale[1],
int(upscale[0]),
seed,
)
if strength > 0:
if facetool == 'codeformer':
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]))
else:
image = run_gfpgan(
image, strength, seed, 1
)
print(">> ESRGAN is disabled. Image not upscaled.")
if strength > 0:
if self.gfpgan is not None and self.codeformer is not None:
if facetool == 'codeformer':
image = self.codeformer.process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
else:
image = self.gfpgan.process(image, strength, seed)
else:
print(">> Face Restoration is disabled.")
except Exception as e:
print(
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
@@ -579,10 +754,10 @@ class Generate:
r[0] = image
# to help WebGUI - front end to generator util function
def sample_to_image(self,samples):
def sample_to_image(self, samples):
return self._sample_to_image(samples)
def _sample_to_image(self,samples):
def _sample_to_image(self, samples):
if not self.base_generator:
from ldm.dream.generator import Generator
self.base_generator = Generator(self.model)
@@ -625,7 +800,7 @@ class Generate:
# for usage statistics
device_type = choose_torch_device()
if device_type == 'cuda':
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_peak_memory_stats()
tic = time.time()
# this does the work
@@ -640,10 +815,10 @@ class Generate:
m, u = model.load_state_dict(sd, strict=False)
if self.precision == 'float16':
print('Using faster float16 precision')
print('>> Using faster float16 precision')
model.to(torch.float16)
else:
print('Using more accurate float32 precision')
print('>> Using more accurate float32 precision')
model.to(self.device)
model.eval()
@@ -664,6 +839,7 @@ class Generate:
return model
def _load_img(self, path, width, height, fit=False):
print(f'DEBUG: path = {path}')
assert os.path.exists(path), f'>> {path}: File not found'
# with Image.open(path) as img:
@@ -673,12 +849,12 @@ class Generate:
f'>> loaded input image of size {image.width}x{image.height} from {path}'
)
if fit:
image = self._fit_image(image,(width,height))
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
return image
def _create_init_image(self,image):
def _create_init_image(self, image):
image = image.convert('RGB')
# print(
# f'>> DEBUG: writing the image to img.png'
@@ -687,16 +863,77 @@ class Generate:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
image = 2.0 * image - 1.0
return image.to(self.device)
# TODO: outpainting is a post-processing application and should be made to behave
# like the other ones.
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.'
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == 'left':
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == 'bottom':
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height//2 if pixels is None else int(pixels)
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size'
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height//10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == 'left':
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == 'bottom':
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img
def _create_init_mask(self, image):
# convert into a black/white mask
image = self._image_to_mask(image)
image = image.convert('RGB')
# BUG: We need to use the model's downsample factor rather than hardcoding "8"
from ldm.dream.generator.base import downsampling
image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.LANCZOS)
# print(
# f'>> DEBUG: writing the mask to mask.png'
# )
@@ -718,7 +955,7 @@ class Generate:
mask = ImageOps.invert(mask)
return mask
def _has_transparency(self,image):
def _has_transparency(self, image):
if image.info.get("transparency", None) is not None:
return True
if image.mode == "P":
@@ -732,11 +969,10 @@ class Generate:
return True
return False
def _check_for_erasure(self,image):
def _check_for_erasure(self, image):
width, height = image.size
pixdata = image.load()
colored = 0
pixdata = image.load()
colored = 0
for y in range(height):
for x in range(width):
if pixdata[x, y][3] == 0:
@@ -746,28 +982,28 @@ class Generate:
colored += 1
return colored == 0
def _squeeze_image(self,image):
x,y,resize_needed = self._resolution_check(image.width,image.height)
def _squeeze_image(self, image):
x, y, resize_needed = self._resolution_check(image.width, image.height)
if resize_needed:
return InitImageResizer(image).resize(x,y)
return InitImageResizer(image).resize(x, y)
return image
def _fit_image(self,image,max_dimensions):
w,h = max_dimensions
def _fit_image(self, image, max_dimensions):
w, h = max_dimensions
print(
f'>> image will be resized to fit inside a box {w}x{h} in size.'
)
if image.width > image.height:
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
elif image.height > image.width:
w = None # ditto for w
w = None # ditto for w
else:
pass
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally
# note that InitImageResizer does the multiple of 64 truncation internally
image = InitImageResizer(image).resize(w, h)
print(
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
)
)
return image
def _resolution_check(self, width, height, log=False):
@@ -781,7 +1017,7 @@ class Generate:
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
)
height = h
width = w
width = w
resize_needed = True
if (width * height) > (self.width * self.height):