[SD-CLI] Make custom_model take highest priority for generating models if present

-- This commit makes `custom_model` take highest priority for generating models if present.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-01-17 17:09:29 +00:00
committed by Prashant Kumar
parent b9d947ce6f
commit df7eb80e5b
3 changed files with 39 additions and 21 deletions

View File

@@ -70,6 +70,7 @@ model_revision = {
version = args.version if args.variant == "stablediffusion" else "v1_4"
def get_configs():
model_id_key = f"{args.variant}/{version}"
revision_key = f"{args.variant}/{args.precision}"
@@ -83,23 +84,26 @@ def get_configs():
return model_id, revision
def get_clip_mlir(model_name="clip_text", extra_args=[]):
model_id, revision = get_configs()
class CLIPText(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
revision=revision,
)
self.text_encoder = None
if args.custom_model != "":
print("Getting custom CLIP")
self.text_encoder = CLIPTextModel.from_pretrained(
args.custom_model,
subfolder="text_encoder",
)
else:
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
revision=revision,
)
def forward(self, input):
return self.text_encoder(input)[0]
@@ -114,6 +118,7 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
)
return shark_clip
# We might not even need this function anymore! We just need to change
# the forward function.
def get_base_vae_mlir(model_name="vae", extra_args=[]):
@@ -159,18 +164,20 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
)
self.vae = None
if args.custom_model != "":
print("Getting custom VAE")
self.vae = AutoencoderKL.from_pretrained(
args.custom_model,
subfolder="vae",
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
)
def forward(self, input):
input = 1 / 0.18215 * input
@@ -207,18 +214,20 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="unet",
)
self.unet = None
if args.custom_model != "":
print("Getting custom UNET")
self.unet = UNet2DConditionModel.from_pretrained(
args.custom_model,
subfolder="unet",
)
else:
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="unet",
)
self.in_channels = self.unet.in_channels
self.train(False)

View File

@@ -72,7 +72,9 @@ def get_unet():
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
if not args.use_tuned and (args.import_mlir or args.custom_model != ""):
if args.custom_model != "":
return get_unet_mlir(model_name, iree_flags)
if not args.use_tuned and args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
@@ -91,7 +93,11 @@ def get_vae():
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
if not args.use_tuned and (args.import_mlir or args.custom_model != ""):
if args.custom_model != "":
if args.use_base_vae:
return get_base_vae_mlir(model_name, iree_flags)
return get_vae_mlir(model_name, iree_flags)
if not args.use_tuned and args.import_mlir:
if args.use_base_vae:
return get_base_vae_mlir(model_name, iree_flags)
return get_vae_mlir(model_name, iree_flags)

View File

@@ -25,10 +25,13 @@ def _compile_module(shark_module, model_name, extra_args=[]):
# .vmfb file.
# TODO: Have a better way of naming the vmfbs.
import re
custom_model_name = re.sub(r'\W+', '_', args.custom_model)
custom_model_name = re.sub(r"\W+", "_", args.custom_model)
if custom_model_name != "" and custom_model_name[0] == "_":
custom_model_name = custom_model_name[1:]
extended_name = "{}_{}_{}".format(model_name, device, custom_model_name)
extended_name = "{}_{}_{}".format(
model_name, device, custom_model_name
)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")