mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
Add option to save and load precompiled flatbuffer (#425)
This commit is contained in:
3
cpp/.gitignore
vendored
Normal file
3
cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
*.mlir
|
||||||
|
*.vmfb
|
||||||
|
*.ini
|
||||||
2
shark/examples/shark_inference/stable_diffusion/.gitignore
vendored
Normal file
2
shark/examples/shark_inference/stable_diffusion/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
*.vmfb
|
||||||
|
*.jpg
|
||||||
@@ -24,7 +24,9 @@ UNET_FP32 = "unet_fp32"
|
|||||||
def get_models():
|
def get_models():
|
||||||
if args.precision == "fp16":
|
if args.precision == "fp16":
|
||||||
if args.import_mlir == True:
|
if args.import_mlir == True:
|
||||||
return get_vae16(), get_unet16_wrapped()
|
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
|
||||||
|
model_name=UNET_FP16
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return get_shark_model(
|
return get_shark_model(
|
||||||
GCLOUD_BUCKET,
|
GCLOUD_BUCKET,
|
||||||
@@ -46,7 +48,9 @@ def get_models():
|
|||||||
|
|
||||||
elif args.precision == "fp32":
|
elif args.precision == "fp32":
|
||||||
if args.import_mlir == True:
|
if args.import_mlir == True:
|
||||||
return get_vae32(), get_unet32_wrapped()
|
return get_vae32(model_name=VAE_FP32), get_unet32_wrapped(
|
||||||
|
model_name=UNET_FP32
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return get_shark_model(
|
return get_shark_model(
|
||||||
GCLOUD_BUCKET,
|
GCLOUD_BUCKET,
|
||||||
@@ -87,6 +91,7 @@ if __name__ == "__main__":
|
|||||||
batch_size = len(prompt)
|
batch_size = len(prompt)
|
||||||
|
|
||||||
vae, unet = get_models()
|
vae, unet = get_models()
|
||||||
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14"
|
"openai/clip-vit-large-patch14"
|
||||||
@@ -128,6 +133,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
scheduler.set_timesteps(num_inference_steps)
|
scheduler.set_timesteps(num_inference_steps)
|
||||||
|
scheduler.is_scale_input_called = True
|
||||||
|
|
||||||
latents = latents * scheduler.sigmas[0]
|
latents = latents * scheduler.sigmas[0]
|
||||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import torch
|
|||||||
YOUR_TOKEN = "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"
|
YOUR_TOKEN = "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"
|
||||||
|
|
||||||
|
|
||||||
def get_vae32():
|
def get_vae32(model_name="vae_fp32"):
|
||||||
class VaeModel(torch.nn.Module):
|
class VaeModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -25,11 +25,12 @@ def get_vae32():
|
|||||||
shark_vae = compile_through_fx(
|
shark_vae = compile_through_fx(
|
||||||
vae,
|
vae,
|
||||||
(vae_input,),
|
(vae_input,),
|
||||||
|
model_name,
|
||||||
)
|
)
|
||||||
return shark_vae
|
return shark_vae
|
||||||
|
|
||||||
|
|
||||||
def get_vae16():
|
def get_vae16(model_name="vae_fp16"):
|
||||||
class VaeModel(torch.nn.Module):
|
class VaeModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -50,11 +51,12 @@ def get_vae16():
|
|||||||
shark_vae = compile_through_fx(
|
shark_vae = compile_through_fx(
|
||||||
vae,
|
vae,
|
||||||
(vae_input,),
|
(vae_input,),
|
||||||
|
model_name,
|
||||||
)
|
)
|
||||||
return shark_vae
|
return shark_vae
|
||||||
|
|
||||||
|
|
||||||
def get_unet32():
|
def get_unet32(model_name="unet_fp32"):
|
||||||
class UnetModel(torch.nn.Module):
|
class UnetModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -75,11 +77,12 @@ def get_unet32():
|
|||||||
shark_unet = compile_through_fx(
|
shark_unet = compile_through_fx(
|
||||||
unet,
|
unet,
|
||||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||||
|
model_name,
|
||||||
)
|
)
|
||||||
return shark_unet
|
return shark_unet
|
||||||
|
|
||||||
|
|
||||||
def get_unet16():
|
def get_unet16(model_name="unet_fp16"):
|
||||||
class UnetModel(torch.nn.Module):
|
class UnetModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -105,12 +108,13 @@ def get_unet16():
|
|||||||
latent_model_input,
|
latent_model_input,
|
||||||
torch.tensor([1.0]).half().cuda(),
|
torch.tensor([1.0]).half().cuda(),
|
||||||
text_embeddings,
|
text_embeddings,
|
||||||
|
model_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return shark_unet
|
return shark_unet
|
||||||
|
|
||||||
|
|
||||||
def get_unet16_wrapped(guidance_scale=7.5):
|
def get_unet16_wrapped(guidance_scale=7.5, model_name="unet_fp16_wrapped"):
|
||||||
class UnetModel(torch.nn.Module):
|
class UnetModel(torch.nn.Module):
|
||||||
def __init__(self, guidance_scale=guidance_scale):
|
def __init__(self, guidance_scale=guidance_scale):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -149,12 +153,13 @@ def get_unet16_wrapped(guidance_scale=7.5):
|
|||||||
torch.tensor([1.0]).half().cuda(),
|
torch.tensor([1.0]).half().cuda(),
|
||||||
text_embeddings,
|
text_embeddings,
|
||||||
sigma,
|
sigma,
|
||||||
|
model_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return shark_unet
|
return shark_unet
|
||||||
|
|
||||||
|
|
||||||
def get_unet32_wrapped(guidance_scale=7.5):
|
def get_unet32_wrapped(guidance_scale=7.5, model_name="unet_fp32_wrapped"):
|
||||||
class UnetModel(torch.nn.Module):
|
class UnetModel(torch.nn.Module):
|
||||||
def __init__(self, guidance_scale=guidance_scale):
|
def __init__(self, guidance_scale=guidance_scale):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -186,5 +191,6 @@ def get_unet32_wrapped(guidance_scale=7.5):
|
|||||||
shark_unet = compile_through_fx(
|
shark_unet = compile_through_fx(
|
||||||
unet,
|
unet,
|
||||||
(latent_model_input, torch.tensor([1.0]), text_embeddings, sigma),
|
(latent_model_input, torch.tensor([1.0]), text_embeddings, sigma),
|
||||||
|
model_name,
|
||||||
)
|
)
|
||||||
return shark_unet
|
return shark_unet
|
||||||
|
|||||||
@@ -50,5 +50,19 @@ p.add_argument(
|
|||||||
help="max length of the tokenizer output.",
|
help="max length of the tokenizer output.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
p.add_argument(
|
||||||
|
"--load_vmfb",
|
||||||
|
default=True,
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||||
|
)
|
||||||
|
|
||||||
|
p.add_argument(
|
||||||
|
"--save_vmfb",
|
||||||
|
default=False,
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help="saves the compiled flatbuffer to the local directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from shark.shark_inference import SharkInference
|
from shark.shark_inference import SharkInference
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
@@ -6,6 +8,31 @@ from torch._decomp import get_decompositions
|
|||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||||
|
if args.load_vmfb or args.save_vmfb:
|
||||||
|
extended_name = "{}_{}".format(model_name, args.device)
|
||||||
|
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||||
|
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||||
|
print("Loading flatbuffer from {}".format(vmfb_path))
|
||||||
|
shark_module.load_module(vmfb_path)
|
||||||
|
else:
|
||||||
|
if args.save_vmfb:
|
||||||
|
print("Saving to {}".format(vmfb_path))
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"No vmfb found. Compiling and saving to {}".format(
|
||||||
|
vmfb_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
path = shark_module.save_module(
|
||||||
|
os.getcwd(), extended_name, extra_args
|
||||||
|
)
|
||||||
|
shark_module.load_module(path)
|
||||||
|
else:
|
||||||
|
shark_module.compile(extra_args)
|
||||||
|
return shark_module
|
||||||
|
|
||||||
|
|
||||||
# Downloads the model from shark_tank and returns the shark_module.
|
# Downloads the model from shark_tank and returns the shark_module.
|
||||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||||
from shark.shark_downloader import download_torch_model
|
from shark.shark_downloader import download_torch_model
|
||||||
@@ -16,12 +43,11 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
|||||||
shark_module = SharkInference(
|
shark_module = SharkInference(
|
||||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||||
)
|
)
|
||||||
shark_module.compile(extra_args)
|
return _compile_module(shark_module, model_name, extra_args)
|
||||||
return shark_module
|
|
||||||
|
|
||||||
|
|
||||||
# Converts the torch-module into shark_module.
|
# Converts the torch-module into shark_module.
|
||||||
def compile_through_fx(model, inputs, extra_args=[]):
|
def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||||
|
|
||||||
fx_g = make_fx(
|
fx_g = make_fx(
|
||||||
model,
|
model,
|
||||||
@@ -75,6 +101,5 @@ def compile_through_fx(model, inputs, extra_args=[]):
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
mlir_dialect="linalg",
|
mlir_dialect="linalg",
|
||||||
)
|
)
|
||||||
shark_module.compile(extra_args)
|
|
||||||
|
|
||||||
return shark_module
|
return _compile_module(shark_module, model_name, extra_args)
|
||||||
|
|||||||
@@ -148,12 +148,14 @@ def export_iree_module_to_vmfb(
|
|||||||
mlir_dialect: str = "linalg",
|
mlir_dialect: str = "linalg",
|
||||||
func_name: str = "forward",
|
func_name: str = "forward",
|
||||||
model_config_path: str = None,
|
model_config_path: str = None,
|
||||||
|
module_name: str = None,
|
||||||
extra_args: list = [],
|
extra_args: list = [],
|
||||||
):
|
):
|
||||||
# Compiles the module given specs and saves it as .vmfb file.
|
# Compiles the module given specs and saves it as .vmfb file.
|
||||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||||
module, device, mlir_dialect, func_name, model_config_path, extra_args
|
module, device, mlir_dialect, func_name, model_config_path, extra_args
|
||||||
)
|
)
|
||||||
|
if module_name is None:
|
||||||
module_name = f"{mlir_dialect}_{func_name}_{device}"
|
module_name = f"{mlir_dialect}_{func_name}_{device}"
|
||||||
filename = os.path.join(directory, module_name + ".vmfb")
|
filename = os.path.join(directory, module_name + ".vmfb")
|
||||||
print(f"Saved vmfb in {filename}.")
|
print(f"Saved vmfb in {filename}.")
|
||||||
|
|||||||
@@ -146,13 +146,15 @@ class SharkInference:
|
|||||||
|
|
||||||
# TODO: Instead of passing directory and having names decided by the module
|
# TODO: Instead of passing directory and having names decided by the module
|
||||||
# , user may want to save the module with manual names.
|
# , user may want to save the module with manual names.
|
||||||
def save_module(self, dir=os.getcwd()):
|
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||||
return export_iree_module_to_vmfb(
|
return export_iree_module_to_vmfb(
|
||||||
self.mlir_module,
|
self.mlir_module,
|
||||||
self.device,
|
self.device,
|
||||||
dir,
|
dir,
|
||||||
self.mlir_dialect,
|
self.mlir_dialect,
|
||||||
self.function_name,
|
self.function_name,
|
||||||
|
module_name=module_name,
|
||||||
|
extra_args=extra_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load and return the module.
|
# load and return the module.
|
||||||
|
|||||||
Reference in New Issue
Block a user