Filesystem cleanup and custom model fixes (#2127)

* Initial filesystem cleanup

* More filesystem cleanup

* Fix some formatting issues

* Address comments
This commit is contained in:
gpetters-amd
2024-04-30 12:18:33 -04:00
committed by GitHub
parent 81d6e059ac
commit 7db1612a5c
9 changed files with 92 additions and 58 deletions

View File

@@ -53,11 +53,11 @@ def initialize():
clear_tmp_imgs()
from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
create_model_folders,
)
# Create custom models folders if they don't exist
create_checkpoint_folders()
create_model_folders()
import gradio as gr

View File

@@ -602,7 +602,9 @@ if __name__ == "__main__":
global_obj._init()
sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
sd_json = view_json_file(
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
)
sd_kwargs = json.loads(sd_json)
for arg in vars(cmd_opts):
if arg in sd_kwargs:

View File

@@ -6,6 +6,7 @@ from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
@@ -87,6 +88,7 @@ def process_custom_pipe_weights(custom_weights):
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights
return custom_weights_params, custom_weights_tgt
@@ -98,7 +100,7 @@ def get_civitai_checkpoint(url: str):
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
# we don't have this model downloaded yet
if not destination_path.is_file():

View File

@@ -41,7 +41,7 @@ class SharkPipelineBase:
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
@@ -55,9 +55,7 @@ class SharkPipelineBase:
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = self.safe_name(pipe_id)
self.pipe_vmfb_path = Path(
os.path.join(get_checkpoints_path(".."), self.pipe_id)
)
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")

View File

@@ -339,7 +339,7 @@ p.add_argument(
p.add_argument(
"--output_dir",
type=str,
default=None,
default=os.path.join(os.getcwd(), "generated_imgs"),
help="Directory path to save the output images and json.",
)
@@ -613,12 +613,27 @@ p.add_argument(
)
p.add_argument(
"--ckpt_dir",
"--tmp_dir",
type=str,
default="../models",
default=os.path.join(os.getcwd(), "shark_tmp"),
help="Path to tmp directory",
)
p.add_argument(
"--config_dir",
type=str,
default=os.path.join(os.getcwd(), "configs"),
help="Path to config directory",
)
p.add_argument(
"--model_dir",
type=str,
default=os.path.join(os.getcwd(), "models"),
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",

View File

@@ -231,9 +231,14 @@ def import_original(original_img, width, height):
def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
ckpt_path = Path(
os.path.join(
cmd_opts.model_dir, "checkpoints", os.path.basename(str(base_model_id))
)
)
ckpt_path.mkdir(parents=True, exist_ok=True)
new_choices = get_checkpoints(ckpt_path) + get_checkpoints(model_type="checkpoints")
return gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
@@ -581,21 +586,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
object_fit="fit",
preview=True,
)
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Row():
batch_count = gr.Slider(
1,
@@ -631,19 +621,18 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
stop_batch = gr.Button("Stop")
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_config(default_config_file)
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_config(default_config_file)
sd_json = gr.JSON(
label="SD Config",
elem_classes=["fill"],
value=view_json_file(default_config_file),
render=False,
)
with gr.Row():
with gr.Column(scale=3):
load_sd_config = gr.FileExplorer(
@@ -706,11 +695,30 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Row(elem_classes=["fill"]):
sd_json.render()
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Tab(label="Automation", id=104) as sd_tab_automation:
pass
pull_kwargs = dict(
fn=pull_sd_configs,

View File

@@ -66,33 +66,39 @@ def get_resource_path(path):
def get_configs_path() -> Path:
configs = get_resource_path(os.path.join("..", "configs"))
configs = get_resource_path(cmd_opts.config_dir)
if not os.path.exists(configs):
os.mkdir(configs)
return Path(get_resource_path("../configs"))
return Path(configs)
def get_generated_imgs_path() -> Path:
return Path(
cmd_opts.output_dir
if cmd_opts.output_dir
else get_resource_path("../generated_imgs")
)
outputs = get_resource_path(cmd_opts.output_dir)
if not os.path.exists(outputs):
os.mkdir(outputs)
return Path(outputs)
def get_tmp_path() -> Path:
tmpdir = get_resource_path(cmd_opts.model_dir)
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
return Path(tmpdir)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
def create_checkpoint_folders():
def create_model_folders():
dir = ["checkpoints", "vae", "lora", "vmfb"]
if not os.path.isdir(cmd_opts.ckpt_dir):
if not os.path.isdir(cmd_opts.model_dir):
try:
os.makedirs(cmd_opts.ckpt_dir)
os.makedirs(cmd_opts.model_dir)
except OSError:
sys.exit(
f"Invalid --ckpt_dir argument, "
f"{cmd_opts.ckpt_dir} folder does not exist, and cannot be created."
f"Invalid --model_dir argument, "
f"{cmd_opts.model_dir} folder does not exist, and cannot be created."
)
for root in dir:
@@ -100,7 +106,7 @@ def create_checkpoint_folders():
def get_checkpoints_path(model_type=""):
return get_resource_path(os.path.join(cmd_opts.ckpt_dir, model_type))
return get_resource_path(os.path.join(cmd_opts.model_dir, model_type))
def get_checkpoints(model_type="checkpoints"):

View File

@@ -2,7 +2,9 @@ import os
import shutil
from time import time
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
def clear_tmp_mlir():

View File

@@ -6,6 +6,7 @@ import tempfile
import os
import hashlib
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
def create_hash(file_name):
with open(file_name, "rb") as f:
@@ -120,7 +121,7 @@ class SharkImporter:
is_dynamic=False,
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
save_dir=cmd_opts.tmp_dir, #"./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
@@ -806,7 +807,7 @@ def save_mlir(
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = os.path.join(".", "shark_tmp")
dir = cmd_opts.tmp_dir, #os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if not os.path.exists(dir):