diff --git a/cpp/.gitignore b/cpp/.gitignore new file mode 100644 index 00000000..cab16f61 --- /dev/null +++ b/cpp/.gitignore @@ -0,0 +1,3 @@ +*.mlir +*.vmfb +*.ini diff --git a/shark/examples/shark_inference/stable_diffusion/.gitignore b/shark/examples/shark_inference/stable_diffusion/.gitignore new file mode 100644 index 00000000..e01596af --- /dev/null +++ b/shark/examples/shark_inference/stable_diffusion/.gitignore @@ -0,0 +1,2 @@ +*.vmfb +*.jpg diff --git a/shark/examples/shark_inference/stable_diffusion/main.py b/shark/examples/shark_inference/stable_diffusion/main.py index 649e13cf..8abbff5a 100644 --- a/shark/examples/shark_inference/stable_diffusion/main.py +++ b/shark/examples/shark_inference/stable_diffusion/main.py @@ -24,7 +24,9 @@ UNET_FP32 = "unet_fp32" def get_models(): if args.precision == "fp16": if args.import_mlir == True: - return get_vae16(), get_unet16_wrapped() + return get_vae16(model_name=VAE_FP16), get_unet16_wrapped( + model_name=UNET_FP16 + ) else: return get_shark_model( GCLOUD_BUCKET, @@ -46,7 +48,9 @@ def get_models(): elif args.precision == "fp32": if args.import_mlir == True: - return get_vae32(), get_unet32_wrapped() + return get_vae32(model_name=VAE_FP32), get_unet32_wrapped( + model_name=UNET_FP32 + ) else: return get_shark_model( GCLOUD_BUCKET, @@ -87,6 +91,7 @@ if __name__ == "__main__": batch_size = len(prompt) vae, unet = get_models() + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14" @@ -128,6 +133,7 @@ if __name__ == "__main__": ) scheduler.set_timesteps(num_inference_steps) + scheduler.is_scale_input_called = True latents = latents * scheduler.sigmas[0] text_embeddings_numpy = text_embeddings.detach().numpy() diff --git a/shark/examples/shark_inference/stable_diffusion/model_wrappers.py b/shark/examples/shark_inference/stable_diffusion/model_wrappers.py index f2c3a275..d8603ef2 100644 --- a/shark/examples/shark_inference/stable_diffusion/model_wrappers.py +++ b/shark/examples/shark_inference/stable_diffusion/model_wrappers.py @@ -6,7 +6,7 @@ import torch YOUR_TOKEN = "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" -def get_vae32(): +def get_vae32(model_name="vae_fp32"): class VaeModel(torch.nn.Module): def __init__(self): super().__init__() @@ -25,11 +25,12 @@ def get_vae32(): shark_vae = compile_through_fx( vae, (vae_input,), + model_name, ) return shark_vae -def get_vae16(): +def get_vae16(model_name="vae_fp16"): class VaeModel(torch.nn.Module): def __init__(self): super().__init__() @@ -50,11 +51,12 @@ def get_vae16(): shark_vae = compile_through_fx( vae, (vae_input,), + model_name, ) return shark_vae -def get_unet32(): +def get_unet32(model_name="unet_fp32"): class UnetModel(torch.nn.Module): def __init__(self): super().__init__() @@ -75,11 +77,12 @@ def get_unet32(): shark_unet = compile_through_fx( unet, (latent_model_input, torch.tensor([1.0]), text_embeddings), + model_name, ) return shark_unet -def get_unet16(): +def get_unet16(model_name="unet_fp16"): class UnetModel(torch.nn.Module): def __init__(self): super().__init__() @@ -105,12 +108,13 @@ def get_unet16(): latent_model_input, torch.tensor([1.0]).half().cuda(), text_embeddings, + model_name, ), ) return shark_unet -def get_unet16_wrapped(guidance_scale=7.5): +def get_unet16_wrapped(guidance_scale=7.5, model_name="unet_fp16_wrapped"): class UnetModel(torch.nn.Module): def __init__(self, guidance_scale=guidance_scale): super().__init__() @@ -149,12 +153,13 @@ def get_unet16_wrapped(guidance_scale=7.5): torch.tensor([1.0]).half().cuda(), text_embeddings, sigma, + model_name, ), ) return shark_unet -def get_unet32_wrapped(guidance_scale=7.5): +def get_unet32_wrapped(guidance_scale=7.5, model_name="unet_fp32_wrapped"): class UnetModel(torch.nn.Module): def __init__(self, guidance_scale=guidance_scale): super().__init__() @@ -186,5 +191,6 @@ def get_unet32_wrapped(guidance_scale=7.5): shark_unet = compile_through_fx( unet, (latent_model_input, torch.tensor([1.0]), text_embeddings, sigma), + model_name, ) return shark_unet diff --git a/shark/examples/shark_inference/stable_diffusion/stable_args.py b/shark/examples/shark_inference/stable_diffusion/stable_args.py index 1cf65998..68daaabc 100644 --- a/shark/examples/shark_inference/stable_diffusion/stable_args.py +++ b/shark/examples/shark_inference/stable_diffusion/stable_args.py @@ -50,5 +50,19 @@ p.add_argument( help="max length of the tokenizer output.", ) +p.add_argument( + "--load_vmfb", + default=True, + action=argparse.BooleanOptionalAction, + help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.", +) + +p.add_argument( + "--save_vmfb", + default=False, + action=argparse.BooleanOptionalAction, + help="saves the compiled flatbuffer to the local directory", +) + args = p.parse_args() diff --git a/shark/examples/shark_inference/stable_diffusion/utils.py b/shark/examples/shark_inference/stable_diffusion/utils.py index 7361c2ed..9429ac18 100644 --- a/shark/examples/shark_inference/stable_diffusion/utils.py +++ b/shark/examples/shark_inference/stable_diffusion/utils.py @@ -1,3 +1,5 @@ +import os + import torch from shark.shark_inference import SharkInference from torch.fx.experimental.proxy_tensor import make_fx @@ -6,6 +8,31 @@ from torch._decomp import get_decompositions import torch_mlir +def _compile_module(shark_module, model_name, extra_args=[]): + if args.load_vmfb or args.save_vmfb: + extended_name = "{}_{}".format(model_name, args.device) + vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb") + if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb: + print("Loading flatbuffer from {}".format(vmfb_path)) + shark_module.load_module(vmfb_path) + else: + if args.save_vmfb: + print("Saving to {}".format(vmfb_path)) + else: + print( + "No vmfb found. Compiling and saving to {}".format( + vmfb_path + ) + ) + path = shark_module.save_module( + os.getcwd(), extended_name, extra_args + ) + shark_module.load_module(path) + else: + shark_module.compile(extra_args) + return shark_module + + # Downloads the model from shark_tank and returns the shark_module. def get_shark_model(tank_url, model_name, extra_args=[]): from shark.shark_downloader import download_torch_model @@ -16,12 +43,11 @@ def get_shark_model(tank_url, model_name, extra_args=[]): shark_module = SharkInference( mlir_model, func_name, device=args.device, mlir_dialect="linalg" ) - shark_module.compile(extra_args) - return shark_module + return _compile_module(shark_module, model_name, extra_args) # Converts the torch-module into shark_module. -def compile_through_fx(model, inputs, extra_args=[]): +def compile_through_fx(model, inputs, model_name, extra_args=[]): fx_g = make_fx( model, @@ -75,6 +101,5 @@ def compile_through_fx(model, inputs, extra_args=[]): device=args.device, mlir_dialect="linalg", ) - shark_module.compile(extra_args) - return shark_module + return _compile_module(shark_module, model_name, extra_args) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 4fe56549..2acae3b9 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -148,13 +148,15 @@ def export_iree_module_to_vmfb( mlir_dialect: str = "linalg", func_name: str = "forward", model_config_path: str = None, + module_name: str = None, extra_args: list = [], ): # Compiles the module given specs and saves it as .vmfb file. flatbuffer_blob = compile_module_to_flatbuffer( module, device, mlir_dialect, func_name, model_config_path, extra_args ) - module_name = f"{mlir_dialect}_{func_name}_{device}" + if module_name is None: + module_name = f"{mlir_dialect}_{func_name}_{device}" filename = os.path.join(directory, module_name + ".vmfb") print(f"Saved vmfb in {filename}.") with open(filename, "wb") as f: diff --git a/shark/shark_inference.py b/shark/shark_inference.py index 0fd67284..4f3a53bc 100644 --- a/shark/shark_inference.py +++ b/shark/shark_inference.py @@ -146,13 +146,15 @@ class SharkInference: # TODO: Instead of passing directory and having names decided by the module # , user may want to save the module with manual names. - def save_module(self, dir=os.getcwd()): + def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]): return export_iree_module_to_vmfb( self.mlir_module, self.device, dir, self.mlir_dialect, self.function_name, + module_name=module_name, + extra_args=extra_args, ) # load and return the module.