Add option to save and load precompiled flatbuffer (#425)

This commit is contained in:
Quinn Dawkins
2022-10-23 19:24:09 -04:00
committed by GitHub
parent a48eaaed20
commit 1d33913d48
8 changed files with 75 additions and 15 deletions

3
cpp/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
*.mlir
*.vmfb
*.ini

View File

@@ -0,0 +1,2 @@
*.vmfb
*.jpg

View File

@@ -24,7 +24,9 @@ UNET_FP32 = "unet_fp32"
def get_models():
if args.precision == "fp16":
if args.import_mlir == True:
return get_vae16(), get_unet16_wrapped()
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
model_name=UNET_FP16
)
else:
return get_shark_model(
GCLOUD_BUCKET,
@@ -46,7 +48,9 @@ def get_models():
elif args.precision == "fp32":
if args.import_mlir == True:
return get_vae32(), get_unet32_wrapped()
return get_vae32(model_name=VAE_FP32), get_unet32_wrapped(
model_name=UNET_FP32
)
else:
return get_shark_model(
GCLOUD_BUCKET,
@@ -87,6 +91,7 @@ if __name__ == "__main__":
batch_size = len(prompt)
vae, unet = get_models()
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
@@ -128,6 +133,7 @@ if __name__ == "__main__":
)
scheduler.set_timesteps(num_inference_steps)
scheduler.is_scale_input_called = True
latents = latents * scheduler.sigmas[0]
text_embeddings_numpy = text_embeddings.detach().numpy()

View File

@@ -6,7 +6,7 @@ import torch
YOUR_TOKEN = "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"
def get_vae32():
def get_vae32(model_name="vae_fp32"):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -25,11 +25,12 @@ def get_vae32():
shark_vae = compile_through_fx(
vae,
(vae_input,),
model_name,
)
return shark_vae
def get_vae16():
def get_vae16(model_name="vae_fp16"):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -50,11 +51,12 @@ def get_vae16():
shark_vae = compile_through_fx(
vae,
(vae_input,),
model_name,
)
return shark_vae
def get_unet32():
def get_unet32(model_name="unet_fp32"):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -75,11 +77,12 @@ def get_unet32():
shark_unet = compile_through_fx(
unet,
(latent_model_input, torch.tensor([1.0]), text_embeddings),
model_name,
)
return shark_unet
def get_unet16():
def get_unet16(model_name="unet_fp16"):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -105,12 +108,13 @@ def get_unet16():
latent_model_input,
torch.tensor([1.0]).half().cuda(),
text_embeddings,
model_name,
),
)
return shark_unet
def get_unet16_wrapped(guidance_scale=7.5):
def get_unet16_wrapped(guidance_scale=7.5, model_name="unet_fp16_wrapped"):
class UnetModel(torch.nn.Module):
def __init__(self, guidance_scale=guidance_scale):
super().__init__()
@@ -149,12 +153,13 @@ def get_unet16_wrapped(guidance_scale=7.5):
torch.tensor([1.0]).half().cuda(),
text_embeddings,
sigma,
model_name,
),
)
return shark_unet
def get_unet32_wrapped(guidance_scale=7.5):
def get_unet32_wrapped(guidance_scale=7.5, model_name="unet_fp32_wrapped"):
class UnetModel(torch.nn.Module):
def __init__(self, guidance_scale=guidance_scale):
super().__init__()
@@ -186,5 +191,6 @@ def get_unet32_wrapped(guidance_scale=7.5):
shark_unet = compile_through_fx(
unet,
(latent_model_input, torch.tensor([1.0]), text_embeddings, sigma),
model_name,
)
return shark_unet

View File

@@ -50,5 +50,19 @@ p.add_argument(
help="max length of the tokenizer output.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
args = p.parse_args()

View File

@@ -1,3 +1,5 @@
import os
import torch
from shark.shark_inference import SharkInference
from torch.fx.experimental.proxy_tensor import make_fx
@@ -6,6 +8,31 @@ from torch._decomp import get_decompositions
import torch_mlir
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
extended_name = "{}_{}".format(model_name, args.device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print("Loading flatbuffer from {}".format(vmfb_path))
shark_module.load_module(vmfb_path)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_torch_model
@@ -16,12 +43,11 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
)
shark_module.compile(extra_args)
return shark_module
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into shark_module.
def compile_through_fx(model, inputs, extra_args=[]):
def compile_through_fx(model, inputs, model_name, extra_args=[]):
fx_g = make_fx(
model,
@@ -75,6 +101,5 @@ def compile_through_fx(model, inputs, extra_args=[]):
device=args.device,
mlir_dialect="linalg",
)
shark_module.compile(extra_args)
return shark_module
return _compile_module(shark_module, model_name, extra_args)

View File

@@ -148,13 +148,15 @@ def export_iree_module_to_vmfb(
mlir_dialect: str = "linalg",
func_name: str = "forward",
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, func_name, model_config_path, extra_args
)
module_name = f"{mlir_dialect}_{func_name}_{device}"
if module_name is None:
module_name = f"{mlir_dialect}_{func_name}_{device}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:

View File

@@ -146,13 +146,15 @@ class SharkInference:
# TODO: Instead of passing directory and having names decided by the module
# , user may want to save the module with manual names.
def save_module(self, dir=os.getcwd()):
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
return export_iree_module_to_vmfb(
self.mlir_module,
self.device,
dir,
self.mlir_dialect,
self.function_name,
module_name=module_name,
extra_args=extra_args,
)
# load and return the module.