mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-05 19:15:28 -05:00
correctly handle upscaling in webUI, including displaying status messages during GFPGAN/ESRGAN postprocessing
This commit is contained in:
@@ -68,15 +68,12 @@ class PngWriter:
|
||||
while not finished:
|
||||
series += 1
|
||||
filename = f'{basecount:06}.{seed}.png'
|
||||
if self.batch_size > 1 or os.path.exists(
|
||||
os.path.join(self.outdir, filename)
|
||||
):
|
||||
path = os.path.join(self.outdir, filename)
|
||||
if self.batch_size > 1 or os.path.exists(path):
|
||||
if upscaled:
|
||||
break
|
||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||
finished = not os.path.exists(
|
||||
os.path.join(self.outdir, filename)
|
||||
)
|
||||
finished = not os.path.exists(path)
|
||||
return os.path.join(self.outdir, filename)
|
||||
|
||||
def save_image_and_prompt_to_png(self, image, prompt, path):
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
|
||||
class DreamServer(BaseHTTPRequestHandler):
|
||||
model = None
|
||||
@@ -52,11 +53,63 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
|
||||
|
||||
print(f"Request to generate with prompt: {prompt}")
|
||||
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||
# across images generated by each call to prompt2img(), so we define it in
|
||||
# the outer scope of image_done()
|
||||
config = post_data.copy() # Shallow copy
|
||||
config['initimg'] = ''
|
||||
|
||||
images_generated = 0 # helps keep track of when upscaling is started
|
||||
images_upscaled = 0 # helps keep track of when upscaling is completed
|
||||
pngwriter = PngWriter(
|
||||
"./outputs/img-samples/", config['prompt'], 1
|
||||
)
|
||||
|
||||
# if upscaling is requested, then this will be called twice, once when
|
||||
# the images are first generated, and then again when after upscaling
|
||||
# is complete. The upscaling replaces the original file, so the second
|
||||
# entry should not be inserted into the image list.
|
||||
def image_done(image, seed, upscaled=False):
|
||||
pngwriter.write_image(image, seed, upscaled)
|
||||
|
||||
# Append post_data to log, but only once!
|
||||
if not upscaled:
|
||||
current_image = pngwriter.files_written[-1]
|
||||
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
||||
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event':'result', 'files':current_image, 'config':config}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
# control state of the "postprocessing..." message
|
||||
upscaling_requested = upscale or gfpgan_strength>0
|
||||
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
if upscaled:
|
||||
images_upscaled += 1
|
||||
else:
|
||||
images_generated +=1
|
||||
if upscaling_requested:
|
||||
action = None
|
||||
if images_generated >= iterations:
|
||||
if images_upscaled < iterations:
|
||||
action = 'upscaling-started'
|
||||
else:
|
||||
action = 'upscaling-done'
|
||||
if action:
|
||||
x = images_upscaled+1
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
def image_progress(image, step):
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event':'step', 'step':step}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
outputs = []
|
||||
if initimg is None:
|
||||
# Run txt2img
|
||||
outputs = self.model.txt2img(prompt,
|
||||
self.model.prompt2image(prompt,
|
||||
iterations=iterations,
|
||||
cfg_scale = cfgscale,
|
||||
width = width,
|
||||
@@ -64,8 +117,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
seed = seed,
|
||||
steps = steps,
|
||||
gfpgan_strength = gfpgan_strength,
|
||||
upscale = upscale
|
||||
)
|
||||
upscale = upscale,
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done)
|
||||
else:
|
||||
# Decode initimg as base64 to temp file
|
||||
with open("./img2img-tmp.png", "wb") as f:
|
||||
@@ -73,30 +127,21 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
f.write(base64.b64decode(initimg))
|
||||
|
||||
# Run img2img
|
||||
outputs = self.model.img2img(prompt,
|
||||
init_img = "./img2img-tmp.png",
|
||||
iterations = iterations,
|
||||
cfg_scale = cfgscale,
|
||||
seed = seed,
|
||||
gfpgan_strength=gfpgan_strength,
|
||||
upscale = upscale,
|
||||
steps = steps
|
||||
)
|
||||
self.model.prompt2image(prompt,
|
||||
init_img = "./img2img-tmp.png",
|
||||
iterations = iterations,
|
||||
cfg_scale = cfgscale,
|
||||
seed = seed,
|
||||
steps = steps,
|
||||
gfpgan_strength=gfpgan_strength,
|
||||
upscale = upscale,
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done)
|
||||
|
||||
# Remove the temp file
|
||||
os.remove("./img2img-tmp.png")
|
||||
|
||||
print(f"Prompt generated with output: {outputs}")
|
||||
|
||||
post_data['initimg'] = '' # Don't send init image back
|
||||
|
||||
# Append post_data to log
|
||||
with open("./outputs/img-samples/dream_web_log.txt", "a", encoding="utf-8") as log:
|
||||
for output in outputs:
|
||||
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
|
||||
|
||||
outputs = [x + [post_data] for x in outputs] # Append config to each output
|
||||
result = {'outputs': outputs}
|
||||
self.wfile.write(bytes(json.dumps(result), "utf-8"))
|
||||
print(f"Prompt generated!")
|
||||
|
||||
|
||||
class ThreadingDreamServer(ThreadingHTTPServer):
|
||||
|
||||
@@ -61,6 +61,9 @@ class KSampler(object):
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
def route_callback(k_callback_values):
|
||||
if img_callback is not None:
|
||||
img_callback(k_callback_values['x'], k_callback_values['i'])
|
||||
|
||||
sigmas = self.model.get_sigmas(S)
|
||||
if x_T:
|
||||
@@ -78,7 +81,8 @@ class KSampler(object):
|
||||
}
|
||||
return (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -201,6 +201,7 @@ class T2I:
|
||||
ddim_eta=None,
|
||||
skip_normalize=False,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
# these are specific to txt2img
|
||||
width=None,
|
||||
height=None,
|
||||
@@ -230,9 +231,14 @@ class T2I:
|
||||
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
||||
step_callback // a function or method that will be called each step
|
||||
image_callback // a function or method that will be called each time an image is generated
|
||||
|
||||
To use the callback, define a function of method that receives two arguments, an Image object
|
||||
To use the step callback, define a function that receives two arguments:
|
||||
- Image GPU data
|
||||
- The step number
|
||||
|
||||
To use the image callback, define a function of method that receives two arguments, an Image object
|
||||
and the seed. You can then do whatever you like with the image, including converting it to
|
||||
different formats and manipulating it. For example:
|
||||
|
||||
@@ -292,6 +298,7 @@ class T2I:
|
||||
skip_normalize=skip_normalize,
|
||||
init_img=init_img,
|
||||
strength=strength,
|
||||
callback=step_callback,
|
||||
)
|
||||
else:
|
||||
images_iterator = self._txt2img(
|
||||
@@ -304,6 +311,7 @@ class T2I:
|
||||
skip_normalize=skip_normalize,
|
||||
width=width,
|
||||
height=height,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
with scope(self.device.type), self.model.ema_scope():
|
||||
@@ -390,6 +398,7 @@ class T2I:
|
||||
skip_normalize,
|
||||
width,
|
||||
height,
|
||||
callback,
|
||||
):
|
||||
"""
|
||||
An infinite iterator of images from the prompt.
|
||||
@@ -413,6 +422,7 @@ class T2I:
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
img_callback=callback
|
||||
)
|
||||
yield self._samples_to_images(samples)
|
||||
|
||||
@@ -428,6 +438,7 @@ class T2I:
|
||||
skip_normalize,
|
||||
init_img,
|
||||
strength,
|
||||
callback, # Currently not implemented for img2img
|
||||
):
|
||||
"""
|
||||
An infinite iterator of images from the prompt and the initial image
|
||||
|
||||
Reference in New Issue
Block a user