[WEB] Add Stable-Diffusion in the SHARK web (#366)

1. This commit adds stable-diffusion as a part of shark web.
2. The V-diffusion model has been disabled for now as it's not
   working(will raise a different patch with fix).
3. Add standard output in the web ui.
4. Add instructions to launch the shark-web.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-09-26 23:12:02 +05:30
committed by GitHub
parent c7b2d39ab2
commit d4eeff0a5d
7 changed files with 400 additions and 73 deletions

11
web/README.md Normal file
View File

@@ -0,0 +1,11 @@
In order to launch SHARK-web, from the root SHARK directory, run:
```shell
IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pip install diffusers scipy
cd web
wget -O models_mlir/stable_diffusion.mlir https://storage.googleapis.com/shark_tank/prashant_nod/stable_diff/stable_diff_torch.mlir
python index.py
```
This will launch a gradio server with a public URL.

View File

@@ -1,47 +1,102 @@
from models.resnet50 import resnet_inf
from models.albert_maskfill import albert_maskfill_inf
from models.diffusion.v_diffusion import vdiff_inf
from models.stable_diffusion import stable_diff_inf
# from models.diffusion.v_diffusion import vdiff_inf
import gradio as gr
shark_web = gr.Blocks()
with shark_web:
with gr.Blocks() as shark_web:
gr.Markdown("Shark Models Demo.")
with gr.Tabs():
with gr.TabItem("ResNet50"):
with gr.Group():
image = gr.Image(label="Image")
label = gr.Label(label="Output")
resnet = gr.Button("Recognize Image")
resnet.click(resnet_inf, inputs=image, outputs=label)
image = device = resnet = output = None
with gr.Row():
with gr.Column(scale=1, min_width=600):
image = gr.Image(label="Image")
device = gr.Textbox(label="Device", value="cpu")
resnet = gr.Button("Recognize Image").style(
full_width=True
)
with gr.Column(scale=1, min_width=600):
output = gr.Label(label="Output")
std_output = gr.Textbox(
label="Std Output", value="Nothing."
)
resnet.click(
resnet_inf,
inputs=[image, device],
outputs=[output, std_output],
)
with gr.TabItem("Albert MaskFill"):
with gr.Group():
masked_text = gr.Textbox(
label="Masked Text",
placeholder="Give me a sentence with [MASK] to fill",
)
decoded_res = gr.Label(label="Decoded Results")
albert_mask = gr.Button("Decode Mask")
albert_mask.click(
albert_maskfill_inf,
inputs=masked_text,
outputs=decoded_res,
)
with gr.TabItem("V-Diffusion"):
with gr.Group():
prompt = gr.Textbox(
label="Prompt", value="New York City, oil on canvas:5"
)
sample_count = gr.Number(label="Sample Count", value=1)
batch_size = gr.Number(label="Batch Size", value=1)
iters = gr.Number(label="Steps", value=2)
device = gr.Textbox(label="Device", value="gpu")
v_diffusion = gr.Button("Generate image from prompt")
generated_img = gr.Image(type="pil", shape=(100, 100))
v_diffusion.click(
vdiff_inf,
inputs=[prompt, sample_count, batch_size, iters, device],
outputs=generated_img,
)
masked_text = device = albert_mask = decoded_res = None
with gr.Row():
with gr.Column(scale=1, min_width=600):
masked_text = gr.Textbox(
label="Masked Text",
placeholder="Give me a sentence with [MASK] to fill",
)
device = gr.Textbox(label="Device", value="cpu")
albert_mask = gr.Button("Decode Mask")
with gr.Column(scale=1, min_width=600):
decoded_res = gr.Label(label="Decoded Results")
std_output = gr.Textbox(
label="Std Output", value="Nothing."
)
albert_mask.click(
albert_maskfill_inf,
inputs=[masked_text, device],
outputs=[decoded_res, std_output],
)
# with gr.TabItem("V-Diffusion"):
# prompt = sample_count = batch_size = iters = device = v_diffusion = generated_img = None
# with gr.Row():
# with gr.Column(scale=1, min_width=600):
# prompt = gr.Textbox(
# label="Prompt", value="New York City, oil on canvas:5"
# )
# sample_count = gr.Number(label="Sample Count", value=1)
# batch_size = gr.Number(label="Batch Size", value=1)
# iters = gr.Number(label="Steps", value=2)
# device = gr.Textbox(label="Device", value="gpu")
# v_diffusion = gr.Button("Generate image from prompt")
# with gr.Column(scale=1, min_width=600):
# generated_img = gr.Image(type="pil", shape=(100, 100))
# std_output = gr.Textbox(label="Std Output", value="Nothing.")
# v_diffusion.click(
# vdiff_inf,
# inputs=[prompt, sample_count, batch_size, iters, device],
# outputs=[generated_img, std_output]
# )
with gr.TabItem("Stable-Diffusion"):
prompt = (
iters
) = mlir_loc = device = stable_diffusion = generated_img = None
with gr.Row():
with gr.Column(scale=1, min_width=600):
prompt = gr.Textbox(
label="Prompt",
value="a photograph of an astronaut riding a horse",
)
iters = gr.Number(label="Steps", value=2)
mlir_loc = gr.Textbox(
label="Location of MLIR(Relative to SHARK/web/)",
value="./models_mlir/stable_diffusion.mlir",
)
device = gr.Textbox(label="Device", value="vulkan")
stable_diffusion = gr.Button("Generate image from prompt")
with gr.Column(scale=1, min_width=600):
generated_img = gr.Image(type="pil", shape=(100, 100))
std_output = gr.Textbox(
label="Std Output", value="Nothing."
)
stable_diffusion.click(
stable_diff_inf,
inputs=[prompt, iters, mlir_loc, device],
outputs=[generated_img, std_output],
)
shark_web.launch(share=True, server_port=8080, enable_queue=True)

View File

@@ -4,9 +4,7 @@ from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
import numpy as np
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
COMPILE_MODULE = None
################################## Albert Module #########################
class AlbertModule(torch.nn.Module):
@@ -21,18 +19,23 @@ class AlbertModule(torch.nn.Module):
).logits
################################## Preprocessing inputs and model ############
################################## Preprocessing inputs ####################
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
compiled_module = {}
compiled_module["tokenizer"] = AutoTokenizer.from_pretrained("albert-base-v2")
def preprocess_data(text):
global compiled_module
# Preparing Data
tokenizer = compiled_module["tokenizer"]
encoded_inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=512,
return_tensors="pt",
)
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
@@ -40,6 +43,10 @@ def preprocess_data(text):
def top5_possibilities(text, inputs, token_logits):
global compiled_module
tokenizer = compiled_module["tokenizer"]
mask_id = torch.where(inputs[0] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_id, :]
percentage = torch.nn.functional.softmax(mask_token_logits, dim=1)[0]
@@ -54,11 +61,12 @@ def top5_possibilities(text, inputs, token_logits):
##############################################################################
def albert_maskfill_inf(masked_text):
global COMPILE_MODULE
def albert_maskfill_inf(masked_text, device):
global compiled_module
inputs = preprocess_data(masked_text)
if COMPILE_MODULE == None:
print("module compiled")
if device not in compiled_module.keys():
mlir_importer = SharkImporter(
AlbertModule(),
inputs,
@@ -68,10 +76,10 @@ def albert_maskfill_inf(masked_text):
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, mlir_dialect="linalg", device="intel-gpu"
minilm_mlir, func_name, mlir_dialect="linalg", device=device
)
shark_module.compile()
COMPILE_MODULE = shark_module
compiled_module[device] = shark_module
token_logits = torch.tensor(COMPILE_MODULE.forward(inputs))
return top5_possibilities(masked_text, inputs, token_logits)
token_logits = torch.tensor(compiled_module[device].forward(inputs))
return top5_possibilities(masked_text, inputs, token_logits), "Testing.."

View File

@@ -103,10 +103,12 @@ def cache_model():
def vdiff_inf(prompts: str, n, bs, steps, _device):
global device
global model
global checkpoint
global clip_model
args = {}
target_embeds = []
weights = []
@@ -197,14 +199,17 @@ def vdiff_inf(prompts: str, n, bs, steps, _device):
mlir_model, func_name, device=args["device"], mlir_dialect="linalg"
)
shark_module.compile()
return run_all(
x,
t,
args["steps"],
args["n"],
args["batch_size"],
side_x,
side_y,
shark_module,
args,
return (
run_all(
x,
t,
args["steps"],
args["n"],
args["batch_size"],
side_x,
side_y,
shark_module,
args,
),
"Testing..",
)

View File

@@ -5,9 +5,7 @@ from torchvision import transforms
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
################################## Preprocessing inputs and model ############
COMPILE_MODULE = None
################################## Preprocessing inputs and helper functions ########
def preprocess_image(img):
@@ -47,23 +45,31 @@ def top3_possibilities(res):
##############################################################################
compiled_module = {}
def resnet_inf(numpy_img):
img = preprocess_image(numpy_img)
## Can pass any img or input to the forward module.
global COMPILE_MODULE
if COMPILE_MODULE == None:
def resnet_inf(numpy_img, device):
global compiled_module
std_output = ""
if device not in compiled_module.keys():
std_output += "Compiling the Resnet50 module.\n"
mlir_model, func_name, inputs, golden_out = download_torch_model(
"resnet50"
)
shark_module = SharkInference(
mlir_model, func_name, device="intel-gpu", mlir_dialect="linalg"
mlir_model, func_name, device=device, mlir_dialect="linalg"
)
shark_module.compile()
COMPILE_MODULE = shark_module
std_output += "Compilation successful.\n"
compiled_module[device] = shark_module
result = COMPILE_MODULE.forward((img.detach().numpy(),))
img = preprocess_image(numpy_img)
result = compiled_module[device].forward((img.detach().numpy(),))
# print("The top 3 results obtained via shark_runner is:")
return top3_possibilities(torch.from_numpy(result))
std_output += "Retrieving top 3 possible outcomes.\n"
return top3_possibilities(torch.from_numpy(result)), std_output

View File

@@ -0,0 +1,242 @@
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
import torch
from PIL import Image
from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm
from shark.shark_inference import SharkInference
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import torch_mlir
import tempfile
import numpy as np
##############################################################################
def load_mlir(mlir_loc):
import os
if mlir_loc == None:
return None
print(f"Trying to load the model from {mlir_loc}.")
with open(os.path.join(mlir_loc)) as f:
mlir_module = f.read()
return mlir_module
def compile_through_fx(model, inputs, device, mlir_loc=None):
module = load_mlir(mlir_loc)
if mlir_loc == None:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
module = torch_mlir.compile(
ts_g,
inputs,
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()
return shark_module
##############################################################################
compiled_module = {}
def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
args = {}
args["prompt"] = [prompt]
args["steps"] = steps
args["device"] = device
args["mlir_loc"] = mlir_loc
global compiled_module
if args["device"] not in compiled_module.keys():
YOUR_TOKEN = "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"
# 1. Load the autoencoder model which will be used to decode the latents into image space.
compiled_module["vae"] = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
use_auth_token=YOUR_TOKEN,
)
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
compiled_module["tokenizer"] = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14"
)
compiled_module["text_encoder"] = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
# Wrap the unet model to return tuples.
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="unet",
use_auth_token=YOUR_TOKEN,
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(self, x, y, z):
return self.unet.forward(x, y, z, return_dict=False)[0]
# 3. The UNet model for generating the latents.
unet = UnetModel()
latent_model_input = torch.rand([2, 4, 64, 64])
text_embeddings = torch.rand([2, 77, 768])
shark_unet = compile_through_fx(
unet,
(latent_model_input, torch.tensor([1.0]), text_embeddings),
args["device"],
args["mlir_loc"],
)
compiled_module[args["device"]] = shark_unet
compiled_module["scheduler"] = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
compiled_module["unet"] = unet
shark_unet = compiled_module[args["device"]]
vae = compiled_module["vae"]
unet = compiled_module["unet"]
tokenizer = compiled_module["tokenizer"]
text_encoder = compiled_module["text_encoder"]
scheduler = compiled_module["scheduler"]
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
num_inference_steps = int(args["steps"]) # Number of denoising steps
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(
42
) # Seed generator to create the inital latent noise
batch_size = len(args["prompt"])
text_input = tokenizer(
args["prompt"],
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids)[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = text_encoder(uncond_input.input_ids)[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.sigmas[0]
for i, t in tqdm(enumerate(scheduler.timesteps)):
print(f"i = {i} t = {t}")
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
latent_model_input_numpy = latent_model_input.detach().numpy()
text_embeddings_numpy = text_embeddings.detach().numpy()
noise_pred = shark_unet.forward(
(
latent_model_input_numpy,
np.array([t]).astype(np.float32),
text_embeddings_numpy,
)
)
noise_pred = torch.from_numpy(noise_pred)
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
print(latents.shape)
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images[0], "Testing.."

View File