mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD-CLI] Make custom_model take highest priority for generating models if present
-- This commit makes `custom_model` take highest priority for generating models if present. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
committed by
Prashant Kumar
parent
b9d947ce6f
commit
df7eb80e5b
@@ -70,6 +70,7 @@ model_revision = {
|
||||
|
||||
version = args.version if args.variant == "stablediffusion" else "v1_4"
|
||||
|
||||
|
||||
def get_configs():
|
||||
model_id_key = f"{args.variant}/{version}"
|
||||
revision_key = f"{args.variant}/{args.precision}"
|
||||
@@ -83,23 +84,26 @@ def get_configs():
|
||||
|
||||
return model_id, revision
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
model_id, revision = get_configs()
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
self.text_encoder = None
|
||||
if args.custom_model != "":
|
||||
print("Getting custom CLIP")
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.custom_model,
|
||||
subfolder="text_encoder",
|
||||
)
|
||||
else:
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
@@ -114,6 +118,7 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
# We might not even need this function anymore! We just need to change
|
||||
# the forward function.
|
||||
def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
@@ -159,18 +164,20 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
)
|
||||
self.vae = None
|
||||
if args.custom_model != "":
|
||||
print("Getting custom VAE")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
args.custom_model,
|
||||
subfolder="vae",
|
||||
)
|
||||
else:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
input = 1 / 0.18215 * input
|
||||
@@ -207,18 +214,20 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="unet",
|
||||
)
|
||||
self.unet = None
|
||||
if args.custom_model != "":
|
||||
print("Getting custom UNET")
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
args.custom_model,
|
||||
subfolder="unet",
|
||||
)
|
||||
else:
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="unet",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
|
||||
@@ -72,7 +72,9 @@ def get_unet():
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "unet", is_tuned, args.precision
|
||||
)
|
||||
if not args.use_tuned and (args.import_mlir or args.custom_model != ""):
|
||||
if args.custom_model != "":
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
if not args.use_tuned and args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
@@ -91,7 +93,11 @@ def get_vae():
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
if not args.use_tuned and (args.import_mlir or args.custom_model != ""):
|
||||
if args.custom_model != "":
|
||||
if args.use_base_vae:
|
||||
return get_base_vae_mlir(model_name, iree_flags)
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
if not args.use_tuned and args.import_mlir:
|
||||
if args.use_base_vae:
|
||||
return get_base_vae_mlir(model_name, iree_flags)
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
|
||||
@@ -25,10 +25,13 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
# .vmfb file.
|
||||
# TODO: Have a better way of naming the vmfbs.
|
||||
import re
|
||||
custom_model_name = re.sub(r'\W+', '_', args.custom_model)
|
||||
|
||||
custom_model_name = re.sub(r"\W+", "_", args.custom_model)
|
||||
if custom_model_name != "" and custom_model_name[0] == "_":
|
||||
custom_model_name = custom_model_name[1:]
|
||||
extended_name = "{}_{}_{}".format(model_name, device, custom_model_name)
|
||||
extended_name = "{}_{}_{}".format(
|
||||
model_name, device, custom_model_name
|
||||
)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
|
||||
Reference in New Issue
Block a user