mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[WEB] Save vmfb and add live preview
This commit updates SD script to save the compiled module and also adds live preview of generated images. Signed-off-by: Gaurav Shukla<gaurav@nod-labs.com>
This commit is contained in:
17
web/index.py
17
web/index.py
@@ -135,6 +135,12 @@ with gr.Blocks() as shark_web:
|
||||
precision
|
||||
) = (
|
||||
device
|
||||
) = (
|
||||
load_vmfb
|
||||
) = (
|
||||
save_vmfb
|
||||
) = (
|
||||
iree_vulkan_target_triple
|
||||
) = debug = stable_diffusion = generated_img = std_output = None
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
@@ -186,6 +192,13 @@ with gr.Blocks() as shark_web:
|
||||
value="vulkan",
|
||||
choices=["cpu", "cuda", "vulkan"],
|
||||
)
|
||||
load_vmfb = gr.Checkbox(label="Load vmfb", value=True)
|
||||
save_vmfb = gr.Checkbox(label="Save vmfb", value=False)
|
||||
iree_vulkan_target_triple = gr.Textbox(
|
||||
value="",
|
||||
max_lines=1,
|
||||
label="IREE VULKAN TARGET TRIPLE",
|
||||
)
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
stable_diffusion = gr.Button("Generate image from prompt")
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
@@ -216,8 +229,12 @@ with gr.Blocks() as shark_web:
|
||||
seed,
|
||||
precision,
|
||||
device,
|
||||
load_vmfb,
|
||||
save_vmfb,
|
||||
iree_vulkan_target_triple,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
|
||||
shark_web.queue()
|
||||
shark_web.launch(share=True, server_port=8080, enable_queue=True)
|
||||
|
||||
@@ -20,6 +20,7 @@ VAE_FP32 = "vae_fp32"
|
||||
UNET_FP16 = "unet_fp16"
|
||||
UNET_FP32 = "unet_fp32"
|
||||
|
||||
IREE_EXTRA_ARGS = []
|
||||
args = None
|
||||
DEBUG = False
|
||||
|
||||
@@ -38,6 +39,9 @@ class Arguments:
|
||||
seed: int,
|
||||
precision: str,
|
||||
device: str,
|
||||
load_vmfb: bool,
|
||||
save_vmfb: bool,
|
||||
iree_vulkan_target_triple: str,
|
||||
import_mlir: bool = False,
|
||||
max_length: int = 77,
|
||||
):
|
||||
@@ -52,34 +56,47 @@ class Arguments:
|
||||
self.seed = seed
|
||||
self.precision = precision
|
||||
self.device = device
|
||||
self.load_vmfb = load_vmfb
|
||||
self.save_vmfb = save_vmfb
|
||||
self.iree_vulkan_target_triple = iree_vulkan_target_triple
|
||||
self.import_mlir = import_mlir
|
||||
self.max_length = max_length
|
||||
|
||||
|
||||
def get_models():
|
||||
|
||||
global IREE_EXTRA_ARGS
|
||||
global args
|
||||
|
||||
if args.precision == "fp16":
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-iterator-space-fusion",
|
||||
]
|
||||
if args.import_mlir == True:
|
||||
return get_vae16(args), get_unet16_wrapped(args)
|
||||
return get_shark_model(args, GCLOUD_BUCKET, VAE_FP16), get_shark_model(
|
||||
args, GCLOUD_BUCKET, UNET_FP16
|
||||
)
|
||||
return get_vae16(args, model_name=VAE_FP16), get_unet16_wrapped(
|
||||
args, model_name=UNET_FP16
|
||||
)
|
||||
return get_shark_model(
|
||||
args, GCLOUD_BUCKET, VAE_FP16, IREE_EXTRA_ARGS
|
||||
), get_shark_model(args, GCLOUD_BUCKET, UNET_FP16, IREE_EXTRA_ARGS)
|
||||
|
||||
elif args.precision == "fp32":
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir == True:
|
||||
return (get_vae32(args), get_unet32_wrapped(args))
|
||||
return get_shark_model(args, GCLOUD_BUCKET, VAE_FP32), get_shark_model(
|
||||
args,
|
||||
GCLOUD_BUCKET,
|
||||
UNET_FP32,
|
||||
[
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
],
|
||||
)
|
||||
return (
|
||||
get_vae32(args, model_name=VAE_FP32),
|
||||
get_unet32_wrapped(args, model_name=UNET_FP32),
|
||||
)
|
||||
return get_shark_model(
|
||||
args, GCLOUD_BUCKET, VAE_FP32, IREE_EXTRA_ARGS
|
||||
), get_shark_model(args, GCLOUD_BUCKET, UNET_FP32, IREE_EXTRA_ARGS)
|
||||
return None, None
|
||||
|
||||
|
||||
@@ -95,8 +112,12 @@ def stable_diff_inf(
|
||||
seed: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
load_vmfb: bool,
|
||||
save_vmfb: bool,
|
||||
iree_vulkan_target_triple: str,
|
||||
):
|
||||
|
||||
global IREE_EXTRA_ARGS
|
||||
global args
|
||||
global DEBUG
|
||||
|
||||
@@ -155,8 +176,15 @@ def stable_diff_inf(
|
||||
seed,
|
||||
precision,
|
||||
device,
|
||||
load_vmfb,
|
||||
save_vmfb,
|
||||
iree_vulkan_target_triple,
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
IREE_EXTRA_ARGS.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
num_inference_steps = int(args.steps) # Number of denoising steps
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
@@ -191,17 +219,20 @@ def stable_diff_inf(
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, args.height // 8, args.width // 8),
|
||||
generator=generator,
|
||||
dtype=dtype,
|
||||
)
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
|
||||
latents = latents * scheduler.sigmas[0]
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
avg_ms = 0
|
||||
pil_images = []
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
time.sleep(0.1)
|
||||
if DEBUG:
|
||||
log_write.write(f"\ni = {i} t = {t} ")
|
||||
step_start = time.time()
|
||||
@@ -219,19 +250,20 @@ def stable_diff_inf(
|
||||
if DEBUG:
|
||||
log_write.write(f"time={step_ms}ms")
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
yield pil_images[0], ""
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
if DEBUG:
|
||||
log_write.write(f"\nAverage step time: {avg_ms}ms/it")
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
print("total images:", len(pil_images))
|
||||
output = pil_images[0]
|
||||
# save the output image with the prompt name.
|
||||
@@ -241,4 +273,4 @@ def stable_diff_inf(
|
||||
std_output = ""
|
||||
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
|
||||
std_output = log_read.read()
|
||||
return pil_images[0], std_output
|
||||
yield output, std_output
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
|
||||
def get_vae32(args):
|
||||
def get_vae32(args, model_name="vae_fp32"):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -25,11 +25,12 @@ def get_vae32(args):
|
||||
args,
|
||||
vae,
|
||||
(vae_input,),
|
||||
model_name,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_vae16(args):
|
||||
def get_vae16(args, model_name="vae_fp16"):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -51,11 +52,12 @@ def get_vae16(args):
|
||||
args,
|
||||
vae,
|
||||
(vae_input,),
|
||||
model_name,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet32(args):
|
||||
def get_unet32(args, model_name="unet_fp32"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -77,11 +79,12 @@ def get_unet32(args):
|
||||
args,
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings),
|
||||
model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet16(args):
|
||||
def get_unet16(args, model_name="unet_fp16"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -109,11 +112,12 @@ def get_unet16(args):
|
||||
torch.tensor([1.0]).half().cuda(),
|
||||
text_embeddings,
|
||||
),
|
||||
model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet16_wrapped(args):
|
||||
def get_unet16_wrapped(args, model_name="unet_fp16_wrapped"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -154,11 +158,12 @@ def get_unet16_wrapped(args):
|
||||
text_embeddings,
|
||||
sigma,
|
||||
),
|
||||
model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
|
||||
def get_unet32_wrapped(args):
|
||||
def get_unet32_wrapped(args, model_name="unet_fp32_wrapped"):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -191,5 +196,6 @@ def get_unet32_wrapped(args):
|
||||
args,
|
||||
unet,
|
||||
(latent_model_input, torch.tensor([1.0]), text_embeddings, sigma),
|
||||
model_name,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
@@ -3,6 +3,32 @@ from shark.shark_inference import SharkInference
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
import os
|
||||
|
||||
|
||||
def _compile_module(args, shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
extended_name = "{}_{}".format(model_name, args.device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print("Loading flatbuffer from {}".format(vmfb_path))
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
@@ -15,12 +41,11 @@ def get_shark_model(args, tank_url, model_name, extra_args=[]):
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
return _compile_module(args, shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into shark_module.
|
||||
def compile_through_fx(args, model, inputs, extra_args=[]):
|
||||
def compile_through_fx(args, model, inputs, model_name, extra_args=[]):
|
||||
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
@@ -74,6 +99,5 @@ def compile_through_fx(args, model, inputs, extra_args=[]):
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
shark_module.compile(extra_args)
|
||||
|
||||
return shark_module
|
||||
return _compile_module(args, shark_module, model_name, extra_args)
|
||||
|
||||
Reference in New Issue
Block a user