mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
Add option to save and load precompiled flatbuffer (#425)
This commit is contained in:
3
cpp/.gitignore
vendored
Normal file
3
cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*.mlir
|
||||
*.vmfb
|
||||
*.ini
|
||||
2
shark/examples/shark_inference/stable_diffusion/.gitignore
vendored
Normal file
2
shark/examples/shark_inference/stable_diffusion/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.vmfb
|
||||
*.jpg
|
||||
@@ -24,7 +24,9 @@ UNET_FP32 = "unet_fp32"
|
||||
def get_models():
|
||||
if args.precision == "fp16":
|
||||
if args.import_mlir == True:
|
||||
return get_vae16(), get_unet16_wrapped()
|
||||
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
|
||||
model_name=UNET_FP16
|
||||
)
|
||||
else:
|
||||
return get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
@@ -46,7 +48,9 @@ def get_models():
|
||||
|
||||
elif args.precision == "fp32":
|
||||
if args.import_mlir == True:
|
||||
return get_vae32(), get_unet32_wrapped()
|
||||
return get_vae32(model_name=VAE_FP32), get_unet32_wrapped(
|
||||
model_name=UNET_FP32
|
||||
)
|
||||
else:
|
||||
return get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
@@ -87,6 +91,7 @@ if __name__ == "__main__":
|
||||
batch_size = len(prompt)
|
||||
|
||||
vae, unet = get_models()
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
@@ -128,6 +133,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
|
||||
def get_vae32():
|
||||
def get_vae32(model_name="vae_fp32"):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -25,11 +25,12 @@ def get_vae32():
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
(vae_input,),
|
||||
model_name,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_vae16():
|
||||
def get_vae16(model_name="vae_fp16"):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -50,11 +51,12 @@ def get_vae16():
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
(vae_input,),
|
||||
model_name,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet32():
|
||||
def get_unet32(model_name="unet_fp32"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -75,11 +77,12 @@ def get_unet32():
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||
model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet16():
|
||||
def get_unet16(model_name="unet_fp16"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -105,12 +108,13 @@ def get_unet16():
|
||||
latent_model_input,
|
||||
torch.tensor([1.0]).half().cuda(),
|
||||
text_embeddings,
|
||||
model_name,
|
||||
),
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet16_wrapped(guidance_scale=7.5):
|
||||
def get_unet16_wrapped(guidance_scale=7.5, model_name="unet_fp16_wrapped"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, guidance_scale=guidance_scale):
|
||||
super().__init__()
|
||||
@@ -149,12 +153,13 @@ def get_unet16_wrapped(guidance_scale=7.5):
|
||||
torch.tensor([1.0]).half().cuda(),
|
||||
text_embeddings,
|
||||
sigma,
|
||||
model_name,
|
||||
),
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet32_wrapped(guidance_scale=7.5):
|
||||
def get_unet32_wrapped(guidance_scale=7.5, model_name="unet_fp32_wrapped"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, guidance_scale=guidance_scale):
|
||||
super().__init__()
|
||||
@@ -186,5 +191,6 @@ def get_unet32_wrapped(guidance_scale=7.5):
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings, sigma),
|
||||
model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
@@ -50,5 +50,19 @@ p.add_argument(
|
||||
help="max length of the tokenizer output.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@@ -6,6 +8,31 @@ from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
extended_name = "{}_{}".format(model_name, args.device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print("Loading flatbuffer from {}".format(vmfb_path))
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_torch_model
|
||||
@@ -16,12 +43,11 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into shark_module.
|
||||
def compile_through_fx(model, inputs, extra_args=[]):
|
||||
def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
@@ -75,6 +101,5 @@ def compile_through_fx(model, inputs, extra_args=[]):
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
shark_module.compile(extra_args)
|
||||
|
||||
return shark_module
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
@@ -148,13 +148,15 @@ def export_iree_module_to_vmfb(
|
||||
mlir_dialect: str = "linalg",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, func_name, model_config_path, extra_args
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device}"
|
||||
if module_name is None:
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
|
||||
@@ -146,13 +146,15 @@ class SharkInference:
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd()):
|
||||
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
dir,
|
||||
self.mlir_dialect,
|
||||
self.function_name,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
|
||||
Reference in New Issue
Block a user