diff --git a/.gitignore b/.gitignore index 10c949c5..6280b98a 100644 --- a/.gitignore +++ b/.gitignore @@ -177,7 +177,7 @@ onnx_models/ generated_imgs/ # Custom model related artefacts -apps/stable_diffusion/src/utils/resources/variants.json +variants.json models/ # models folder diff --git a/apps/stable_diffusion/src/utils/__init__.py b/apps/stable_diffusion/src/utils/__init__.py index 4f783b82..c82e79dd 100644 --- a/apps/stable_diffusion/src/utils/__init__.py +++ b/apps/stable_diffusion/src/utils/__init__.py @@ -8,7 +8,6 @@ from apps.stable_diffusion.src.utils.resources import ( base_models, opt_flags, resource_path, - fetch_and_update_base_model_id, ) from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation from apps.stable_diffusion.src.utils.stable_args import args @@ -22,6 +21,7 @@ from apps.stable_diffusion.src.utils.utils import ( get_opt_flags, preprocessCKPT, fetch_or_delete_vmfbs, + fetch_and_update_base_model_id, get_path_to_diffusers_checkpoint, sanitize_seed, ) diff --git a/apps/stable_diffusion/src/utils/resources.py b/apps/stable_diffusion/src/utils/resources.py index 41a6f246..43504b82 100644 --- a/apps/stable_diffusion/src/utils/resources.py +++ b/apps/stable_diffusion/src/utils/resources.py @@ -35,28 +35,3 @@ base_models = get_json_file("resources/base_model.json") # Contains optimization flags for different models. opt_flags = get_json_file("resources/opt_flags.json") - - -# `fetch_and_update_base_model_id` is a resource utility function which -# helps maintaining mapping of the model to run with its base model. -# If `base_model` is "", then this function tries to fetch the base model -# info for the `model_to_run`. -def fetch_and_update_base_model_id(model_to_run, base_model=""): - path = "resources/variants.json" - loc_json = resource_path(path) - data = {model_to_run: base_model} - json_data = {} - if os.path.exists(loc_json): - with open(loc_json, "r", encoding="utf-8") as jsonFile: - json_data = json.load(jsonFile) - # Return with base_model's info if base_model is "". - if base_model == "": - if model_to_run in json_data: - base_model = json_data[model_to_run] - return base_model - elif base_model == "": - return base_model - # Update JSON data to contain an entry mapping model_to_run with base_model. - json_data.update(data) - with open(loc_json, "w", encoding="utf-8") as jsonFile: - json.dump(json_data, jsonFile) diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 1c578945..6acfa823 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -1,5 +1,6 @@ import os import gc +import json from pathlib import Path import numpy as np from random import randint @@ -435,6 +436,30 @@ def fetch_or_delete_vmfbs(basic_model_name, use_base_vae, precision="fp32"): return compiled_models +# `fetch_and_update_base_model_id` is a resource utility function which +# helps maintaining mapping of the model to run with its base model. +# If `base_model` is "", then this function tries to fetch the base model +# info for the `model_to_run`. +def fetch_and_update_base_model_id(model_to_run, base_model=""): + variants_path = os.path.join(os.getcwd(), "variants.json") + data = {model_to_run: base_model} + json_data = {} + if os.path.exists(variants_path): + with open(variants_path, "r", encoding="utf-8") as jsonFile: + json_data = json.load(jsonFile) + # Return with base_model's info if base_model is "". + if base_model == "": + if model_to_run in json_data: + base_model = json_data[model_to_run] + return base_model + elif base_model == "": + return base_model + # Update JSON data to contain an entry mapping model_to_run with base_model. + json_data.update(data) + with open(variants_path, "w", encoding="utf-8") as jsonFile: + json.dump(json_data, jsonFile) + + # Generate and return a new seed if the provided one is not in the supported range (including -1) def sanitize_seed(seed): uint32_info = np.iinfo(np.uint32)