[SD] Release memory used by upscaler when not in use

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-03-16 18:52:23 +05:30
parent d8f0c4655d
commit 7ffe20b1c2
6 changed files with 58 additions and 56 deletions

View File

@@ -243,7 +243,7 @@ def img2img_inf(
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, generated_imgs[0], global_obj.get_sd_obj().log
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -172,7 +172,7 @@ def inpaint_inf(
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, generated_imgs[0], global_obj.get_sd_obj().log
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -181,7 +181,7 @@ def outpaint_inf(
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, generated_imgs[0], global_obj.get_sd_obj().log
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -173,7 +173,7 @@ def txt2img_inf(
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, generated_imgs[0], global_obj.get_sd_obj().log
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -1,8 +1,6 @@
import sys
import torch
import time
from PIL import Image
from dataclasses import dataclass
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
@@ -14,20 +12,6 @@ from apps.stable_diffusion.src import (
)
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
upscaler_obj = None
config_obj = None
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
@@ -58,8 +42,12 @@ def upscaler_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
global upscaler_obj
global config_obj
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
@@ -75,7 +63,6 @@ def upscaler_inf(
return None, "An Initial Image is required"
image = init_image.convert("RGB").resize((args.height, args.width))
# set ckpt_loc and hf_model_id.
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
@@ -101,6 +88,7 @@ def upscaler_inf(
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"upscaler",
args.hf_model_id,
args.ckpt_loc,
precision,
@@ -110,8 +98,12 @@ def upscaler_inf(
width,
device,
)
if not upscaler_obj or config_obj != new_config_obj:
config_obj = new_config_obj
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_size = batch_size
args.max_length = max_length
args.device = device.split("=>", 1)[1].strip()
@@ -126,27 +118,29 @@ def upscaler_inf(
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
upscaler_obj = UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
)
upscaler_obj.scheduler = schedulers[scheduler]
upscaler_obj.low_res_scheduler = schedulers["DDPM"]
global_obj.set_schedulers(schedulers[scheduler])
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
start_time = time.time()
upscaler_obj.log = ""
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
@@ -154,7 +148,7 @@ def upscaler_inf(
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = upscaler_obj.generate_images(
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
image,
@@ -173,7 +167,8 @@ def upscaler_inf(
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
upscaler_obj.log += "\n"
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
@@ -185,7 +180,7 @@ def upscaler_inf(
text_output += upscaler_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return generated_imgs, text_output
yield generated_imgs, text_output
if __name__ == "__main__":

View File

@@ -160,14 +160,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
choices=available_devices,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => -1",
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=150):
clear_queue = gr.Button("Clear Queue")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -222,6 +226,9 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
show_progress=args.progress_bar,
)
prompt.submit(**kwargs)
negative_prompt.submit(**kwargs)
stable_diffusion.click(**kwargs)
prompt_submit = prompt.submit(**kwargs)
neg_prompt_submit = negative_prompt.submit(**kwargs)
generate_click = stable_diffusion.click(**kwargs)
clear_queue.click(
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
)