mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-05 02:44:57 -05:00
cosmetic fixup to how the outputs are reported
This commit is contained in:
@@ -14,6 +14,8 @@ from math import sqrt, floor, ceil
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
# -------------------image generation utils-----
|
||||
|
||||
|
||||
class PngWriter:
|
||||
def __init__(self, outdir, prompt=None, batch_size=1):
|
||||
self.outdir = outdir
|
||||
@@ -23,18 +25,19 @@ class PngWriter:
|
||||
self.files_written = []
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
def write_image(self, image, seed):
|
||||
def write_image(self, image, seed, upscaled=False):
|
||||
self.filepath = self.unique_filename(
|
||||
seed, self.filepath
|
||||
seed, upscaled, self.filepath
|
||||
) # will increment name in some sensible way
|
||||
try:
|
||||
prompt = f'{self.prompt} -S{seed}'
|
||||
self.save_image_and_prompt_to_png(image, prompt, self.filepath)
|
||||
except IOError as e:
|
||||
print(e)
|
||||
self.files_written.append([self.filepath, seed])
|
||||
if not upscaled:
|
||||
self.files_written.append([self.filepath, seed])
|
||||
|
||||
def unique_filename(self, seed, previouspath=None):
|
||||
def unique_filename(self, seed, upscaled, previouspath=None):
|
||||
revision = 1
|
||||
|
||||
if previouspath is None:
|
||||
@@ -68,6 +71,8 @@ class PngWriter:
|
||||
if self.batch_size > 1 or os.path.exists(
|
||||
os.path.join(self.outdir, filename)
|
||||
):
|
||||
if upscaled:
|
||||
break
|
||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||
finished = not os.path.exists(
|
||||
os.path.join(self.outdir, filename)
|
||||
|
||||
165
ldm/gfpgan/gfpgan_tools.py
Normal file
165
ldm/gfpgan/gfpgan_tools.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import torch
|
||||
import warnings
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from scripts.dream import create_argv_parser
|
||||
|
||||
arg_parser = create_argv_parser()
|
||||
opt = arg_parser.parse_args()
|
||||
|
||||
|
||||
def _run_gfpgan(image, strength, prompt, seed, upsampler_scale=4):
|
||||
print(f'\n* GFPGAN - Restoring Faces: {prompt} : seed:{seed}')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||
if not os.path.isfile(model_path):
|
||||
raise Exception('GFPGAN model not found at path ' + model_path)
|
||||
|
||||
sys.path.append(os.path.abspath(opt.gfpgan_dir))
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
bg_upsampler = _load_gfpgan_bg_upsampler(
|
||||
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
||||
)
|
||||
|
||||
gfpgan = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=upsampler_scale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=bg_upsampler,
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print('Error loading GFPGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if gfpgan is None:
|
||||
print(
|
||||
f'GFPGAN not initialized, it must be loaded via the --gfpgan argument'
|
||||
)
|
||||
return image
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gfpgan = None
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
|
||||
if bg_upsampler == 'realesrgan':
|
||||
if not torch.cuda.is_available(): # CPU
|
||||
warnings.warn(
|
||||
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
|
||||
'If you really want to use it, please modify the corresponding codes.'
|
||||
)
|
||||
bg_upsampler = None
|
||||
else:
|
||||
model_path = {
|
||||
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
||||
}
|
||||
|
||||
if upsampler_scale not in model_path:
|
||||
return None
|
||||
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
if upsampler_scale == 4:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4,
|
||||
)
|
||||
if upsampler_scale == 2:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=2,
|
||||
)
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=upsampler_scale,
|
||||
model_path=model_path[upsampler_scale],
|
||||
model=model,
|
||||
tile=bg_tile,
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=True,
|
||||
) # need to set False in CPU mode
|
||||
else:
|
||||
bg_upsampler = None
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
|
||||
def real_esrgan_upscale(image, strength, upsampler_scale, prompt, seed):
|
||||
print(
|
||||
f'\n* Real-ESRGAN Upscaling: {prompt} : seed:{seed} : scale:{upsampler_scale}x'
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = _load_gfpgan_bg_upsampler(
|
||||
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print('Error loading Real-ESRGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
output, img_mode = upsampler.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
outscale=upsampler_scale,
|
||||
alpha_upsampler=opt.gfpgan_bg_upsampler,
|
||||
)
|
||||
|
||||
res = Image.fromarray(output)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if output.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
upsampler = None
|
||||
|
||||
return res
|
||||
112
ldm/simplet2i.py
112
ldm/simplet2i.py
@@ -7,7 +7,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
@@ -21,7 +20,6 @@ from contextlib import contextmanager, nullcontext
|
||||
import transformers
|
||||
import time
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
@@ -133,9 +131,9 @@ class T2I:
|
||||
full_precision=False,
|
||||
strength=0.75, # default in scripts/img2img.py
|
||||
embedding_path=None,
|
||||
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
|
||||
# just to keep track of this parameter when regenerating prompt
|
||||
latent_diffusion_weights=False,
|
||||
device='cuda',
|
||||
gfpgan=None,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.iterations = iterations
|
||||
@@ -157,7 +155,7 @@ class T2I:
|
||||
self.sampler = None
|
||||
self.latent_diffusion_weights = latent_diffusion_weights
|
||||
self.device = device
|
||||
self.gfpgan = gfpgan
|
||||
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated()
|
||||
if seed is None:
|
||||
self.seed = self._new_seed()
|
||||
@@ -176,7 +174,8 @@ class T2I:
|
||||
outdir, prompt, kwargs.get('batch_size', self.batch_size)
|
||||
)
|
||||
for r in results:
|
||||
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
|
||||
# gets written into the PNG
|
||||
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}'
|
||||
pngwriter.write_image(r[0], r[1])
|
||||
return pngwriter.files_written
|
||||
|
||||
@@ -210,6 +209,8 @@ class T2I:
|
||||
init_img=None,
|
||||
strength=None,
|
||||
gfpgan_strength=None,
|
||||
save_original=False,
|
||||
upscale=None,
|
||||
variants=None,
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
@@ -266,7 +267,7 @@ class T2I:
|
||||
f'Height and width must be multiples of 64. Resizing to {h}x{w}.'
|
||||
)
|
||||
height = h
|
||||
width = w
|
||||
width = w
|
||||
|
||||
scope = autocast if self.precision == 'autocast' else nullcontext
|
||||
|
||||
@@ -302,29 +303,47 @@ class T2I:
|
||||
)
|
||||
|
||||
with scope(self.device.type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc='Sampling'):
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
seed_everything(seed)
|
||||
iter_images = next(images_iterator)
|
||||
for image in iter_images:
|
||||
try:
|
||||
# if gfpgan strength is none or less than or equal to 0.0 then
|
||||
# don't even attempt to use GFPGAN.
|
||||
# if the user specified a value of -G that satisifies the condition and
|
||||
# --gfpgan wasn't specified, at startup then
|
||||
# the net result is a message gets printed - nothing else happens.
|
||||
if gfpgan_strength is not None and gfpgan_strength > 0.0:
|
||||
image = self._run_gfpgan(
|
||||
image, gfpgan_strength
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f'Error running GFPGAN - Your image was not enhanced.\n{e}'
|
||||
)
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed)
|
||||
seed = self._new_seed()
|
||||
|
||||
if upscale is not None or gfpgan_strength > 0:
|
||||
for result in results:
|
||||
image, seed = result
|
||||
try:
|
||||
if upscale is not None:
|
||||
from ldm.gfpgan.gfpgan_tools import (
|
||||
real_esrgan_upscale,
|
||||
)
|
||||
|
||||
image = real_esrgan_upscale(
|
||||
image,
|
||||
upscale[1],
|
||||
int(upscale[0]),
|
||||
prompt,
|
||||
seed,
|
||||
)
|
||||
if gfpgan_strength > 0:
|
||||
from ldm.gfpgan.gfpgan_tools import _run_gfpgan
|
||||
|
||||
image = _run_gfpgan(
|
||||
image, gfpgan_strength, prompt, seed, 1
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
|
||||
)
|
||||
if image_callback is not None:
|
||||
if save_original:
|
||||
image_callback(image, seed)
|
||||
else:
|
||||
image_callback(image, seed, upscaled=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print('*interrupted*')
|
||||
print(
|
||||
@@ -335,11 +354,21 @@ class T2I:
|
||||
print('Are you sure your system has an adequate NVIDIA GPU?')
|
||||
|
||||
toc = time.time()
|
||||
self.session_peakmem = max(self.session_peakmem,torch.cuda.max_memory_allocated() )
|
||||
self.session_peakmem = max(
|
||||
self.session_peakmem, torch.cuda.max_memory_allocated()
|
||||
)
|
||||
print('Usage stats:')
|
||||
print(f' {len(results)} image(s) generated in', '%4.2fs' % (toc - tic))
|
||||
print(f' Max VRAM used for this generation:', '%4.2fG' % (torch.cuda.max_memory_allocated()/1E9))
|
||||
print(f' Max VRAM used since script start: ', '%4.2fG' % (self.session_peakmem/1E9))
|
||||
print(
|
||||
f' {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
||||
)
|
||||
print(
|
||||
f' Max VRAM used for this generation:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
)
|
||||
print(
|
||||
f' Max VRAM used since script start: ',
|
||||
'%4.2fG' % (self.session_peakmem / 1e9),
|
||||
)
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -498,7 +527,9 @@ class T2I:
|
||||
self.device = self._get_device()
|
||||
model = self._load_model_from_config(config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(self.embedding_path, self.full_precision)
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.full_precision
|
||||
)
|
||||
self.model = model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
@@ -561,7 +592,7 @@ class T2I:
|
||||
|
||||
def _load_img(self, path):
|
||||
with Image.open(path) as img:
|
||||
image = img.convert("RGB")
|
||||
image = img.convert('RGB')
|
||||
|
||||
w, h = image.size
|
||||
print(f'loaded input image of size ({w}, {h}) from {path}')
|
||||
@@ -620,28 +651,3 @@ class T2I:
|
||||
weights.append(1.0)
|
||||
remaining = 0
|
||||
return prompts, weights
|
||||
|
||||
def _run_gfpgan(self, image, strength):
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
f'GFPGAN not initialized, it must be loaded via the --gfpgan argument'
|
||||
)
|
||||
return image
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
cropped_faces, restored_faces, restored_img = self.gfpgan.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user