mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
[WEB] Add SD2.1 web support
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -99,6 +99,11 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
step=0.1,
|
||||
label="Guidance Scale",
|
||||
)
|
||||
version = gr.Radio(
|
||||
label="Version",
|
||||
value="v1.4",
|
||||
choices=["v1.4", "v2.1base"],
|
||||
)
|
||||
with gr.Row():
|
||||
scheduler_key = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
@@ -108,6 +113,7 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
@@ -151,6 +157,7 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
guidance,
|
||||
seed,
|
||||
scheduler_key,
|
||||
version,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
@@ -162,6 +169,7 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
guidance,
|
||||
seed,
|
||||
scheduler_key,
|
||||
version,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
|
||||
from models.stable_diffusion.utils import set_iree_runtime_flags
|
||||
@@ -27,9 +28,38 @@ schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
schedulers2 = dict()
|
||||
schedulers2["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers2["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers2["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers2[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers2["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
# set iree-runtime flags
|
||||
set_iree_runtime_flags(args)
|
||||
args.version = "v1.4"
|
||||
|
||||
cache_obj = dict()
|
||||
|
||||
@@ -44,3 +74,16 @@ cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
cache_obj["unet"],
|
||||
cache_obj["clip"],
|
||||
) = (get_vae(args), get_unet(args), get_clip(args))
|
||||
|
||||
args.version = "v2.1base"
|
||||
# cache tokenizer
|
||||
cache_obj["tokenizer2"] = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
# cache vae, unet and clip.
|
||||
(
|
||||
cache_obj["vae2"],
|
||||
cache_obj["unet2"],
|
||||
cache_obj["clip2"],
|
||||
) = (get_vae(args), get_unet(args), get_clip(args))
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from models.stable_diffusion.cache_objects import cache_obj, schedulers
|
||||
from models.stable_diffusion.cache_objects import (
|
||||
cache_obj,
|
||||
schedulers,
|
||||
schedulers2,
|
||||
)
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from random import randint
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
|
||||
def set_ui_params(prompt, steps, guidance, seed, scheduler_key, version):
|
||||
args.prompt = [prompt]
|
||||
args.steps = steps
|
||||
args.guidance = guidance
|
||||
args.seed = seed
|
||||
args.scheduler = scheduler_key
|
||||
args.version = version
|
||||
|
||||
|
||||
def stable_diff_inf(
|
||||
@@ -22,6 +27,7 @@ def stable_diff_inf(
|
||||
guidance: float,
|
||||
seed: int,
|
||||
scheduler_key: str,
|
||||
version: str,
|
||||
):
|
||||
|
||||
# Handle out of range seeds.
|
||||
@@ -30,20 +36,29 @@ def stable_diff_inf(
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
|
||||
set_ui_params(prompt, steps, guidance, seed, scheduler_key, version)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
guidance_scale = torch.tensor(args.guidance).to(torch.float32)
|
||||
# Initialize vae and unet models.
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae"],
|
||||
cache_obj["unet"],
|
||||
cache_obj["clip"],
|
||||
cache_obj["tokenizer"],
|
||||
)
|
||||
scheduler = schedulers[args.scheduler]
|
||||
if args.version == "v2.1base":
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae2"],
|
||||
cache_obj["unet2"],
|
||||
cache_obj["clip2"],
|
||||
cache_obj["tokenizer2"],
|
||||
)
|
||||
scheduler = schedulers2[args.scheduler]
|
||||
else:
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae"],
|
||||
cache_obj["unet"],
|
||||
cache_obj["clip"],
|
||||
cache_obj["tokenizer"],
|
||||
)
|
||||
scheduler = schedulers[args.scheduler]
|
||||
|
||||
start = time.time()
|
||||
text_input = tokenizer(
|
||||
@@ -125,7 +140,7 @@ def stable_diff_inf(
|
||||
total_time = time.time() - start
|
||||
|
||||
text_output = f"prompt={args.prompt}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}, version={args.version}"
|
||||
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
|
||||
print(f"\nAverage step time: {avg_ms}ms/it")
|
||||
text_output += "\nTotal image generation time: {0:.2f}sec".format(
|
||||
|
||||
@@ -3,7 +3,6 @@ from transformers import CLIPTextModel
|
||||
from models.stable_diffusion.utils import compile_through_fx
|
||||
import torch
|
||||
|
||||
|
||||
model_config = {
|
||||
"v2": "stabilityai/stable-diffusion-2",
|
||||
"v1.4": "CompVis/stable-diffusion-v1-4",
|
||||
@@ -34,9 +33,14 @@ model_input = {
|
||||
|
||||
|
||||
def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.version == "v2":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@@ -58,13 +62,16 @@ def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
|
||||
|
||||
|
||||
def get_vae_mlir(args, model_name="vae", extra_args=[]):
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = "fp16" if args.precision == "fp16" else "main"
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version],
|
||||
subfolder="vae",
|
||||
revision="fp16",
|
||||
revision=model_revision,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -72,10 +79,17 @@ def get_vae_mlir(args, model_name="vae", extra_args=[]):
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = VaeModel()
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input[args.version]["vae"]]
|
||||
)
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
args,
|
||||
vae,
|
||||
@@ -116,13 +130,15 @@ def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]):
|
||||
|
||||
|
||||
def get_unet_mlir(args, model_name="unet", extra_args=[]):
|
||||
model_revision = "fp16" if args.precision == "fp16" else "main"
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_config[args.version],
|
||||
subfolder="unet",
|
||||
revision="fp16",
|
||||
revision=model_revision,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
@@ -140,13 +156,16 @@ def get_unet_mlir(args, model_name="unet", extra_args=[]):
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
shark_unet = compile_through_fx(
|
||||
args,
|
||||
unet,
|
||||
|
||||
@@ -23,6 +23,8 @@ def get_unet(args):
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "unet_1dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "unet2base_8dec_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
@@ -55,6 +57,8 @@ def get_vae(args):
|
||||
if args.precision == "fp16":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_1dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "vae2base_8dec_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
@@ -116,6 +120,8 @@ def get_clip(args):
|
||||
)
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "clip_1dec_fp32"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "clip2base_8dec_fp32"
|
||||
iree_flags += [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
|
||||
@@ -12,6 +12,7 @@ def set_iree_runtime_flags(args):
|
||||
]
|
||||
if "vulkan" in args.device:
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user