mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Fix SD restart error in exe file (#975)
-- This commit fixes SD restart error in exe file by creating variants.json in CWD instead of a relative path. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com> Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -177,7 +177,7 @@ onnx_models/
|
||||
generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
apps/stable_diffusion/src/utils/resources/variants.json
|
||||
variants.json
|
||||
models/
|
||||
|
||||
# models folder
|
||||
|
||||
@@ -8,7 +8,6 @@ from apps.stable_diffusion.src.utils.resources import (
|
||||
base_models,
|
||||
opt_flags,
|
||||
resource_path,
|
||||
fetch_and_update_base_model_id,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
@@ -22,6 +21,7 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_opt_flags,
|
||||
preprocessCKPT,
|
||||
fetch_or_delete_vmfbs,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
)
|
||||
|
||||
@@ -35,28 +35,3 @@ base_models = get_json_file("resources/base_model.json")
|
||||
|
||||
# Contains optimization flags for different models.
|
||||
opt_flags = get_json_file("resources/opt_flags.json")
|
||||
|
||||
|
||||
# `fetch_and_update_base_model_id` is a resource utility function which
|
||||
# helps maintaining mapping of the model to run with its base model.
|
||||
# If `base_model` is "", then this function tries to fetch the base model
|
||||
# info for the `model_to_run`.
|
||||
def fetch_and_update_base_model_id(model_to_run, base_model=""):
|
||||
path = "resources/variants.json"
|
||||
loc_json = resource_path(path)
|
||||
data = {model_to_run: base_model}
|
||||
json_data = {}
|
||||
if os.path.exists(loc_json):
|
||||
with open(loc_json, "r", encoding="utf-8") as jsonFile:
|
||||
json_data = json.load(jsonFile)
|
||||
# Return with base_model's info if base_model is "".
|
||||
if base_model == "":
|
||||
if model_to_run in json_data:
|
||||
base_model = json_data[model_to_run]
|
||||
return base_model
|
||||
elif base_model == "":
|
||||
return base_model
|
||||
# Update JSON data to contain an entry mapping model_to_run with base_model.
|
||||
json_data.update(data)
|
||||
with open(loc_json, "w", encoding="utf-8") as jsonFile:
|
||||
json.dump(json_data, jsonFile)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import gc
|
||||
import json
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
@@ -435,6 +436,30 @@ def fetch_or_delete_vmfbs(basic_model_name, use_base_vae, precision="fp32"):
|
||||
return compiled_models
|
||||
|
||||
|
||||
# `fetch_and_update_base_model_id` is a resource utility function which
|
||||
# helps maintaining mapping of the model to run with its base model.
|
||||
# If `base_model` is "", then this function tries to fetch the base model
|
||||
# info for the `model_to_run`.
|
||||
def fetch_and_update_base_model_id(model_to_run, base_model=""):
|
||||
variants_path = os.path.join(os.getcwd(), "variants.json")
|
||||
data = {model_to_run: base_model}
|
||||
json_data = {}
|
||||
if os.path.exists(variants_path):
|
||||
with open(variants_path, "r", encoding="utf-8") as jsonFile:
|
||||
json_data = json.load(jsonFile)
|
||||
# Return with base_model's info if base_model is "".
|
||||
if base_model == "":
|
||||
if model_to_run in json_data:
|
||||
base_model = json_data[model_to_run]
|
||||
return base_model
|
||||
elif base_model == "":
|
||||
return base_model
|
||||
# Update JSON data to contain an entry mapping model_to_run with base_model.
|
||||
json_data.update(data)
|
||||
with open(variants_path, "w", encoding="utf-8") as jsonFile:
|
||||
json.dump(json_data, jsonFile)
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the supported range (including -1)
|
||||
def sanitize_seed(seed):
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
|
||||
Reference in New Issue
Block a user