[Web] Stop images (#1212)

This commit is contained in:
m68k-fr
2023-03-19 22:37:30 +01:00
committed by GitHub
parent 650b2ada58
commit 4a622532e5
13 changed files with 122 additions and 63 deletions

View File

@@ -12,6 +12,7 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
@@ -79,6 +80,9 @@ def img2img_inf(
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
global schedulers
@@ -164,6 +168,7 @@ def img2img_inf(
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -232,6 +237,7 @@ def img2img_inf(
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
text_output = ""
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
@@ -252,23 +258,20 @@ def img2img_inf(
cpu_scheduling,
use_stencil=use_stencil,
)
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":

View File

@@ -10,6 +10,7 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
@@ -50,6 +51,9 @@ def inpaint_inf(
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
global schedulers
@@ -114,6 +118,7 @@ def inpaint_inf(
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -159,6 +164,7 @@ def inpaint_inf(
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
@@ -180,23 +186,20 @@ def inpaint_inf(
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":

View File

@@ -10,6 +10,7 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
@@ -53,6 +54,9 @@ def outpaint_inf(
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
global schedulers
@@ -116,6 +120,7 @@ def outpaint_inf(
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -163,6 +168,7 @@ def outpaint_inf(
top = True if "up" in directions else False
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
@@ -189,23 +195,20 @@ def outpaint_inf(
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":

View File

@@ -9,7 +9,7 @@ from apps.stable_diffusion.src import (
clear_all,
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
@@ -46,6 +46,9 @@ def txt2img_inf(
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
global schedulers
@@ -108,6 +111,7 @@ def txt2img_inf(
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
@@ -152,6 +156,7 @@ def txt2img_inf(
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
@@ -169,25 +174,20 @@ def txt2img_inf(
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += (
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
)
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
# text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output
yield generated_imgs, text_output
return generated_imgs, text_output
if __name__ == "__main__":

View File

@@ -31,6 +31,9 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
class StableDiffusionPipeline:
def __init__(
@@ -58,6 +61,7 @@ class StableDiffusionPipeline:
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -226,6 +230,7 @@ class StableDiffusionPipeline:
masked_image_latents=None,
return_all_latents=False,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
@@ -275,6 +280,9 @@ class StableDiffusionPipeline:
# )
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"

View File

@@ -32,4 +32,5 @@ from apps.stable_diffusion.src.utils.utils import (
get_extended_name,
clear_all,
save_output_img,
get_generation_text_info,
)

View File

@@ -629,3 +629,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
return text_output

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_models,
cancel_sd,
)
@@ -255,5 +256,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_paint_models,
cancel_sd,
)
@@ -257,5 +258,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list,
predefined_paint_models,
cancel_sd,
)
@@ -277,5 +278,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
scheduler_list_txt2img,
predefined_models,
cancel_sd,
)
with gr.Blocks(title="Text-to-Image") as txt2img_web:
@@ -249,7 +250,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
stop_batch.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
from apps.stable_diffusion.web.utils.png_metadata import (

View File

@@ -5,6 +5,10 @@ import glob
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
@dataclass
@@ -89,5 +93,13 @@ def get_custom_model_files():
return sorted(ckpt_files, key=str.casefold)
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
global_obj.set_sd_status(SD_STATE_CANCEL)
except Exception:
pass
nodlogo_loc = resource_path("logos/nod-logo.png")
available_devices = get_available_devices()

View File

@@ -38,6 +38,16 @@ def get_cfg_obj():
return config_obj
def set_sd_status(value):
global sd_obj
sd_obj.status = value
def get_sd_status():
global sd_obj
return sd_obj.status
def clear_cache():
global sd_obj
global config_obj