Enable tuned models for inpainting (#1102)

This commit is contained in:
yzhang93
2023-02-27 16:46:57 -08:00
committed by GitHub
parent 1344c0659a
commit c6c8ec36a1
2 changed files with 17 additions and 7 deletions

View File

@@ -82,14 +82,20 @@ def load_lower_configs():
fetch_and_update_base_model_id,
)
base_model_id = args.hf_model_id
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
if base_model_id == "runwayml/stable-diffusion-v1-5":
base_model_id = "CompVis/stable-diffusion-v1-4"
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
variant, version = get_variant_version(base_model_id)
if version == "inpaint_v1":
version = "v1_4"
elif version == "inpaint_v2":
version = "v2_1base"
config_bucket = "gs://shark_tank/sd_tuned_configs/"
device, device_spec_args = get_device_args()

View File

@@ -239,13 +239,15 @@ def set_init_device_flags():
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
base_model_id = args.hf_model_id
if args.ckpt_loc != "":
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
else:
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
if base_model_id == "":
base_model_id = args.hf_model_id
if (
"inpainting" in args.hf_model_id
or args.precision != "fp16"
args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.batch_size != 1
@@ -253,7 +255,7 @@ def set_init_device_flags():
):
args.use_tuned = False
elif args.ckpt_loc != "" and base_model_id not in [
elif base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
@@ -262,6 +264,8 @@ def set_init_device_flags():
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
]:
args.use_tuned = False