mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Remove workarounds for gradio tempfile bugs (#1548)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import transformers # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -57,15 +58,19 @@ if __name__ == "__main__":
|
||||
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
|
||||
sys.exit(0)
|
||||
|
||||
import gradio as gr
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
config_gradio_tmp_imgs_folder,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
# Create custom models folders if they don't exist
|
||||
create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
|
||||
@@ -9,9 +9,6 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
gradio_tmp_galleries_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
@@ -63,19 +60,6 @@ def output_subdirs() -> list[str]:
|
||||
return result_paths
|
||||
|
||||
|
||||
# clear zero length temporary files that gradio 3.22.0 buggily creates
|
||||
# TODO: remove once gradio is upgraded to or past 3.32.0
|
||||
def clear_zero_length_temps():
|
||||
zero_length_temps = [
|
||||
os.path.join(root, file)
|
||||
for root, dirs, files in os.walk(gradio_tmp_galleries_folder)
|
||||
for file in files
|
||||
if os.path.getsize(os.path.join(root, file)) == 0
|
||||
]
|
||||
for file in zero_length_temps:
|
||||
os.remove(file)
|
||||
|
||||
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_web:
|
||||
@@ -105,7 +89,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
visible=False,
|
||||
show_label=True,
|
||||
).style(columns=4)
|
||||
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
@@ -179,7 +162,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
# --- Event handlers
|
||||
|
||||
def on_clear_gallery():
|
||||
clear_zero_length_temps()
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
value=[],
|
||||
@@ -247,7 +229,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
# only update if the current subdir is the most recent one as new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
clear_zero_length_temps()
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
|
||||
|
||||
|
||||
@@ -1,60 +1,54 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import gradio
|
||||
from time import time
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
def clear_gradio_tmp_imgs_folder():
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
return
|
||||
def config_gradio_tmp_imgs_folder():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
|
||||
|
||||
# clear all gradio tmp files created by generation galleries
|
||||
print(
|
||||
"Clearing gradio temporary image files from a prior run. This may take some time..."
|
||||
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
|
||||
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
|
||||
)
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(gradio_tmp_imgs_folder)
|
||||
if os.path.isfile(os.path.join(gradio_tmp_imgs_folder, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(gradio_tmp_imgs_folder + filename)
|
||||
print(
|
||||
f"Clearing generation temporary image files took {time() - cleanup_start:4f} seconds"
|
||||
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
|
||||
)
|
||||
else:
|
||||
print("no generation temporary files to clear")
|
||||
|
||||
# Clear all gradio tmp files created by output galleries
|
||||
if os.path.exists(gradio_tmp_galleries_folder):
|
||||
cleanup_start = time()
|
||||
shutil.rmtree(gradio_tmp_galleries_folder, ignore_errors=True)
|
||||
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
|
||||
print(
|
||||
f"Clearing output gallery temporary image files took {time() - cleanup_start:4f} seconds"
|
||||
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
# older SHARK versions had to workaround gradio bugs and stored things differently
|
||||
else:
|
||||
print("no output gallery temporary files to clear")
|
||||
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
os.mkdir(gradio_tmp_imgs_folder)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
|
||||
)
|
||||
pil_image.save(file_obj)
|
||||
return file_obj
|
||||
|
||||
|
||||
# Register save_pil_to_file override
|
||||
gradio.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
print(
|
||||
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
|
||||
)
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
Reference in New Issue
Block a user