mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Release memory used by upscaler when not in use
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -243,7 +243,7 @@ def img2img_inf(
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, generated_imgs[0], global_obj.get_sd_obj().log
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -172,7 +172,7 @@ def inpaint_inf(
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, generated_imgs[0], global_obj.get_sd_obj().log
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -181,7 +181,7 @@ def outpaint_inf(
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, generated_imgs[0], global_obj.get_sd_obj().log
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -173,7 +173,7 @@ def txt2img_inf(
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, generated_imgs[0], global_obj.get_sd_obj().log
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
@@ -14,20 +12,6 @@ from apps.stable_diffusion.src import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
upscaler_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -58,8 +42,12 @@ def upscaler_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global upscaler_obj
|
||||
global config_obj
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
@@ -75,7 +63,6 @@ def upscaler_inf(
|
||||
return None, "An Initial Image is required"
|
||||
image = init_image.convert("RGB").resize((args.height, args.width))
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
@@ -101,6 +88,7 @@ def upscaler_inf(
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"upscaler",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
@@ -110,8 +98,12 @@ def upscaler_inf(
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if not upscaler_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
@@ -126,27 +118,29 @@ def upscaler_inf(
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
upscaler_obj = UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
global_obj.set_sd_obj(
|
||||
UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
)
|
||||
|
||||
upscaler_obj.scheduler = schedulers[scheduler]
|
||||
upscaler_obj.low_res_scheduler = schedulers["DDPM"]
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
|
||||
|
||||
start_time = time.time()
|
||||
upscaler_obj.log = ""
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
@@ -154,7 +148,7 @@ def upscaler_inf(
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = upscaler_obj.generate_images(
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
@@ -173,7 +167,8 @@ def upscaler_inf(
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
upscaler_obj.log += "\n"
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
@@ -185,7 +180,7 @@ def upscaler_inf(
|
||||
text_output += upscaler_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
yield generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -160,14 +160,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=1, min_width=150):
|
||||
clear_queue = gr.Button("Clear Queue")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -222,6 +226,9 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
clear_queue.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user