[WEB] Save vmfb and add live preview

This commit updates SD script to save the compiled module and also adds
live preview of generated images.

Signed-off-by: Gaurav Shukla<gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-10-26 19:56:23 +05:30
parent fbd77dc936
commit e52f533c16
4 changed files with 116 additions and 37 deletions

View File

@@ -135,6 +135,12 @@ with gr.Blocks() as shark_web:
precision
) = (
device
) = (
load_vmfb
) = (
save_vmfb
) = (
iree_vulkan_target_triple
) = debug = stable_diffusion = generated_img = std_output = None
with gr.Row():
with gr.Column(scale=1, min_width=600):
@@ -186,6 +192,13 @@ with gr.Blocks() as shark_web:
value="vulkan",
choices=["cpu", "cuda", "vulkan"],
)
load_vmfb = gr.Checkbox(label="Load vmfb", value=True)
save_vmfb = gr.Checkbox(label="Save vmfb", value=False)
iree_vulkan_target_triple = gr.Textbox(
value="",
max_lines=1,
label="IREE VULKAN TARGET TRIPLE",
)
debug = gr.Checkbox(label="DEBUG", value=False)
stable_diffusion = gr.Button("Generate image from prompt")
with gr.Column(scale=1, min_width=600):
@@ -216,8 +229,12 @@ with gr.Blocks() as shark_web:
seed,
precision,
device,
load_vmfb,
save_vmfb,
iree_vulkan_target_triple,
],
outputs=[generated_img, std_output],
)
shark_web.queue()
shark_web.launch(share=True, server_port=8080, enable_queue=True)

View File

@@ -20,6 +20,7 @@ VAE_FP32 = "vae_fp32"
UNET_FP16 = "unet_fp16"
UNET_FP32 = "unet_fp32"
IREE_EXTRA_ARGS = []
args = None
DEBUG = False
@@ -38,6 +39,9 @@ class Arguments:
seed: int,
precision: str,
device: str,
load_vmfb: bool,
save_vmfb: bool,
iree_vulkan_target_triple: str,
import_mlir: bool = False,
max_length: int = 77,
):
@@ -52,34 +56,47 @@ class Arguments:
self.seed = seed
self.precision = precision
self.device = device
self.load_vmfb = load_vmfb
self.save_vmfb = save_vmfb
self.iree_vulkan_target_triple = iree_vulkan_target_triple
self.import_mlir = import_mlir
self.max_length = max_length
def get_models():
global IREE_EXTRA_ARGS
global args
if args.precision == "fp16":
IREE_EXTRA_ARGS += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-iterator-space-fusion",
]
if args.import_mlir == True:
return get_vae16(args), get_unet16_wrapped(args)
return get_shark_model(args, GCLOUD_BUCKET, VAE_FP16), get_shark_model(
args, GCLOUD_BUCKET, UNET_FP16
)
return get_vae16(args, model_name=VAE_FP16), get_unet16_wrapped(
args, model_name=UNET_FP16
)
return get_shark_model(
args, GCLOUD_BUCKET, VAE_FP16, IREE_EXTRA_ARGS
), get_shark_model(args, GCLOUD_BUCKET, UNET_FP16, IREE_EXTRA_ARGS)
elif args.precision == "fp32":
IREE_EXTRA_ARGS += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir == True:
return (get_vae32(args), get_unet32_wrapped(args))
return get_shark_model(args, GCLOUD_BUCKET, VAE_FP32), get_shark_model(
args,
GCLOUD_BUCKET,
UNET_FP32,
[
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
],
)
return (
get_vae32(args, model_name=VAE_FP32),
get_unet32_wrapped(args, model_name=UNET_FP32),
)
return get_shark_model(
args, GCLOUD_BUCKET, VAE_FP32, IREE_EXTRA_ARGS
), get_shark_model(args, GCLOUD_BUCKET, UNET_FP32, IREE_EXTRA_ARGS)
return None, None
@@ -95,8 +112,12 @@ def stable_diff_inf(
seed: str,
precision: str,
device: str,
load_vmfb: bool,
save_vmfb: bool,
iree_vulkan_target_triple: str,
):
global IREE_EXTRA_ARGS
global args
global DEBUG
@@ -155,8 +176,15 @@ def stable_diff_inf(
seed,
precision,
device,
load_vmfb,
save_vmfb,
iree_vulkan_target_triple,
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
if len(args.iree_vulkan_target_triple) > 0:
IREE_EXTRA_ARGS.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
num_inference_steps = int(args.steps) # Number of denoising steps
generator = torch.manual_seed(
args.seed
@@ -191,17 +219,20 @@ def stable_diff_inf(
latents = torch.randn(
(batch_size, 4, args.height // 8, args.width // 8),
generator=generator,
dtype=dtype,
)
dtype=torch.float32,
).to(dtype)
scheduler.set_timesteps(num_inference_steps)
scheduler.is_scale_input_called = True
latents = latents * scheduler.sigmas[0]
text_embeddings_numpy = text_embeddings.detach().numpy()
avg_ms = 0
pil_images = []
for i, t in tqdm(enumerate(scheduler.timesteps)):
time.sleep(0.1)
if DEBUG:
log_write.write(f"\ni = {i} t = {t} ")
step_start = time.time()
@@ -219,19 +250,20 @@ def stable_diff_inf(
if DEBUG:
log_write.write(f"time={step_ms}ms")
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
image = vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
yield pil_images[0], ""
avg_ms = 1000 * avg_ms / args.steps
if DEBUG:
log_write.write(f"\nAverage step time: {avg_ms}ms/it")
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
image = vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
print("total images:", len(pil_images))
output = pil_images[0]
# save the output image with the prompt name.
@@ -241,4 +273,4 @@ def stable_diff_inf(
std_output = ""
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
std_output = log_read.read()
return pil_images[0], std_output
yield output, std_output

View File

@@ -5,7 +5,7 @@ import torch
YOUR_TOKEN = "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"
def get_vae32(args):
def get_vae32(args, model_name="vae_fp32"):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -25,11 +25,12 @@ def get_vae32(args):
args,
vae,
(vae_input,),
model_name,
)
return shark_vae
def get_vae16(args):
def get_vae16(args, model_name="vae_fp16"):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -51,11 +52,12 @@ def get_vae16(args):
args,
vae,
(vae_input,),
model_name,
)
return shark_vae
def get_unet32(args):
def get_unet32(args, model_name="unet_fp32"):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -77,11 +79,12 @@ def get_unet32(args):
args,
unet,
(latent_model_input, torch.tensor([1.0]), text_embeddings),
model_name,
)
return shark_unet
def get_unet16(args):
def get_unet16(args, model_name="unet_fp16"):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -109,11 +112,12 @@ def get_unet16(args):
torch.tensor([1.0]).half().cuda(),
text_embeddings,
),
model_name,
)
return shark_unet
def get_unet16_wrapped(args):
def get_unet16_wrapped(args, model_name="unet_fp16_wrapped"):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -154,11 +158,12 @@ def get_unet16_wrapped(args):
text_embeddings,
sigma,
),
model_name,
)
return shark_unet
def get_unet32_wrapped(args):
def get_unet32_wrapped(args, model_name="unet_fp32_wrapped"):
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -191,5 +196,6 @@ def get_unet32_wrapped(args):
args,
unet,
(latent_model_input, torch.tensor([1.0]), text_embeddings, sigma),
model_name,
)
return shark_unet

View File

@@ -3,6 +3,32 @@ from shark.shark_inference import SharkInference
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import torch_mlir
import os
def _compile_module(args, shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
extended_name = "{}_{}".format(model_name, args.device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print("Loading flatbuffer from {}".format(vmfb_path))
shark_module.load_module(vmfb_path)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
@@ -15,12 +41,11 @@ def get_shark_model(args, tank_url, model_name, extra_args=[]):
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
)
shark_module.compile(extra_args)
return shark_module
return _compile_module(args, shark_module, model_name, extra_args)
# Converts the torch-module into shark_module.
def compile_through_fx(args, model, inputs, extra_args=[]):
def compile_through_fx(args, model, inputs, model_name, extra_args=[]):
fx_g = make_fx(
model,
@@ -74,6 +99,5 @@ def compile_through_fx(args, model, inputs, extra_args=[]):
device=args.device,
mlir_dialect="linalg",
)
shark_module.compile(extra_args)
return shark_module
return _compile_module(args, shark_module, model_name, extra_args)