[WEB] Add SD2.1 web support

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-12-13 11:21:38 +05:30
parent 48ec11c514
commit 5f5e0766dd
6 changed files with 117 additions and 25 deletions

View File

@@ -99,6 +99,11 @@ with gr.Blocks(css=demo_css) as shark_web:
step=0.1,
label="Guidance Scale",
)
version = gr.Radio(
label="Version",
value="v1.4",
choices=["v1.4", "v2.1base"],
)
with gr.Row():
scheduler_key = gr.Dropdown(
label="Scheduler",
@@ -108,6 +113,7 @@ with gr.Blocks(css=demo_css) as shark_web:
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
],
)
with gr.Group():
@@ -151,6 +157,7 @@ with gr.Blocks(css=demo_css) as shark_web:
guidance,
seed,
scheduler_key,
version,
],
outputs=[generated_img, std_output],
)
@@ -162,6 +169,7 @@ with gr.Blocks(css=demo_css) as shark_web:
guidance,
seed,
scheduler_key,
version,
],
outputs=[generated_img, std_output],
)

View File

@@ -4,6 +4,7 @@ from diffusers import (
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
from models.stable_diffusion.utils import set_iree_runtime_flags
@@ -27,9 +28,38 @@ schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="scheduler",
)
schedulers2 = dict()
schedulers2["PNDM"] = PNDMScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
schedulers2["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
schedulers2["DDIM"] = DDIMScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
schedulers2[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
schedulers2["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler",
)
# set iree-runtime flags
set_iree_runtime_flags(args)
args.version = "v1.4"
cache_obj = dict()
@@ -44,3 +74,16 @@ cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
cache_obj["unet"],
cache_obj["clip"],
) = (get_vae(args), get_unet(args), get_clip(args))
args.version = "v2.1base"
# cache tokenizer
cache_obj["tokenizer2"] = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
)
# cache vae, unet and clip.
(
cache_obj["vae2"],
cache_obj["unet2"],
cache_obj["clip2"],
) = (get_vae(args), get_unet(args), get_clip(args))

View File

@@ -1,19 +1,24 @@
import torch
from PIL import Image
from tqdm.auto import tqdm
from models.stable_diffusion.cache_objects import cache_obj, schedulers
from models.stable_diffusion.cache_objects import (
cache_obj,
schedulers,
schedulers2,
)
from models.stable_diffusion.stable_args import args
from random import randint
import numpy as np
import time
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
def set_ui_params(prompt, steps, guidance, seed, scheduler_key, version):
args.prompt = [prompt]
args.steps = steps
args.guidance = guidance
args.seed = seed
args.scheduler = scheduler_key
args.version = version
def stable_diff_inf(
@@ -22,6 +27,7 @@ def stable_diff_inf(
guidance: float,
seed: int,
scheduler_key: str,
version: str,
):
# Handle out of range seeds.
@@ -30,20 +36,29 @@ def stable_diff_inf(
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
set_ui_params(prompt, steps, guidance, seed, scheduler_key, version)
dtype = torch.float32 if args.precision == "fp32" else torch.half
generator = torch.manual_seed(
args.seed
) # Seed generator to create the inital latent noise
guidance_scale = torch.tensor(args.guidance).to(torch.float32)
# Initialize vae and unet models.
vae, unet, clip, tokenizer = (
cache_obj["vae"],
cache_obj["unet"],
cache_obj["clip"],
cache_obj["tokenizer"],
)
scheduler = schedulers[args.scheduler]
if args.version == "v2.1base":
vae, unet, clip, tokenizer = (
cache_obj["vae2"],
cache_obj["unet2"],
cache_obj["clip2"],
cache_obj["tokenizer2"],
)
scheduler = schedulers2[args.scheduler]
else:
vae, unet, clip, tokenizer = (
cache_obj["vae"],
cache_obj["unet"],
cache_obj["clip"],
cache_obj["tokenizer"],
)
scheduler = schedulers[args.scheduler]
start = time.time()
text_input = tokenizer(
@@ -125,7 +140,7 @@ def stable_diff_inf(
total_time = time.time() - start
text_output = f"prompt={args.prompt}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}, version={args.version}"
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
print(f"\nAverage step time: {avg_ms}ms/it")
text_output += "\nTotal image generation time: {0:.2f}sec".format(

View File

@@ -3,7 +3,6 @@ from transformers import CLIPTextModel
from models.stable_diffusion.utils import compile_through_fx
import torch
model_config = {
"v2": "stabilityai/stable-diffusion-2",
"v1.4": "CompVis/stable-diffusion-v1-4",
@@ -34,9 +33,14 @@ model_input = {
def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.version == "v2":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
class CLIPText(torch.nn.Module):
def __init__(self):
@@ -58,13 +62,16 @@ def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
def get_vae_mlir(args, model_name="vae", extra_args=[]):
# revision param for from_pretrained defaults to "main" => fp32
model_revision = "fp16" if args.precision == "fp16" else "main"
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
subfolder="vae",
revision="fp16",
revision=model_revision,
)
def forward(self, input):
@@ -72,10 +79,17 @@ def get_vae_mlir(args, model_name="vae", extra_args=[]):
return (x / 2 + 0.5).clamp(0, 1)
vae = VaeModel()
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input[args.version]["vae"]]
)
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
shark_vae = compile_through_fx(
args,
vae,
@@ -116,13 +130,15 @@ def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]):
def get_unet_mlir(args, model_name="unet", extra_args=[]):
model_revision = "fp16" if args.precision == "fp16" else "main"
class UnetModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version],
subfolder="unet",
revision="fp16",
revision=model_revision,
)
self.in_channels = self.unet.in_channels
self.train(False)
@@ -140,13 +156,16 @@ def get_unet_mlir(args, model_name="unet", extra_args=[]):
return noise_pred
unet = UnetModel()
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
else:
inputs = model_input[args.version]["unet"]
shark_unet = compile_through_fx(
args,
unet,

View File

@@ -23,6 +23,8 @@ def get_unet(args):
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp16"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
@@ -55,6 +57,8 @@ def get_vae(args):
if args.precision == "fp16":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
@@ -116,6 +120,8 @@ def get_clip(args):
)
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_1dec_fp32"
if args.version == "v2.1base":
model_name = "clip2base_8dec_fp32"
iree_flags += [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",

View File

@@ -12,6 +12,7 @@ def set_iree_runtime_flags(args):
]
if "vulkan" in args.device:
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return