Use tuned configs on custom models with ckpt_loc (#1038)

This commit is contained in:
yzhang93
2023-02-16 17:06:21 -08:00
committed by GitHub
parent c96d25c3e2
commit cf126e4839
2 changed files with 24 additions and 3 deletions

View File

@@ -78,8 +78,15 @@ def load_winograd_configs():
def load_lower_configs():
from apps.stable_diffusion.src.models import get_variant_version
from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
)
variant, version = get_variant_version(args.hf_model_id)
base_model_id = args.hf_model_id
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
variant, version = get_variant_version(base_model_id)
config_bucket = "gs://shark_tank/sd_tuned_configs/"

View File

@@ -239,13 +239,16 @@ def set_init_device_flags():
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
base_model_id = args.hf_model_id
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
if (
args.hf_model_id
in [
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
]
or args.ckpt_loc != ""
or args.precision != "fp16"
or args.height != 512
or args.width != 512
@@ -254,6 +257,17 @@ def set_init_device_flags():
):
args.use_tuned = False
elif args.ckpt_loc != "" and base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
elif "vulkan" in args.device and not any(
x in args.iree_vulkan_target_triple for x in ["rdna2", "rdna3"]
):
@@ -269,7 +283,7 @@ def set_init_device_flags():
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {args.hf_model_id}/fp16/{args.device}.")
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
else:
print("Tuned models are currently not supported for this setting.")