mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[WEB] Launch only one SD version at a time
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -99,11 +99,6 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
step=0.1,
|
||||
label="Guidance Scale",
|
||||
)
|
||||
version = gr.Radio(
|
||||
label="Version",
|
||||
value="v2.1base",
|
||||
choices=["v1.4", "v2.1base"],
|
||||
)
|
||||
with gr.Row():
|
||||
scheduler_key = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
@@ -157,7 +152,6 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
guidance,
|
||||
seed,
|
||||
scheduler_key,
|
||||
version,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
@@ -169,7 +163,6 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
guidance,
|
||||
seed,
|
||||
scheduler_key,
|
||||
version,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
|
||||
@@ -11,63 +11,39 @@ from models.stable_diffusion.utils import set_iree_runtime_flags
|
||||
from models.stable_diffusion.stable_args import args
|
||||
|
||||
|
||||
model_config = {
|
||||
"v2": "stabilityai/stable-diffusion-2",
|
||||
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1.4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
model_config[args.version],
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
model_config[args.version],
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
model_config[args.version],
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
model_config[args.version],
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
model_config[args.version],
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
schedulers2 = dict()
|
||||
schedulers2["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers2["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers2["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers2[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers2["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
# set iree-runtime flags
|
||||
set_iree_runtime_flags(args)
|
||||
args.version = "v1.4"
|
||||
|
||||
cache_obj = dict()
|
||||
|
||||
# cache tokenizer
|
||||
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
# cache vae, unet and clip.
|
||||
(
|
||||
cache_obj["vae"],
|
||||
@@ -75,15 +51,11 @@ cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
cache_obj["clip"],
|
||||
) = (get_vae(args), get_unet(args), get_clip(args))
|
||||
|
||||
args.version = "v2.1base"
|
||||
# cache tokenizer
|
||||
cache_obj["tokenizer2"] = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
|
||||
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
# cache vae, unet and clip.
|
||||
(
|
||||
cache_obj["vae2"],
|
||||
cache_obj["unet2"],
|
||||
cache_obj["clip2"],
|
||||
) = (get_vae(args), get_unet(args), get_clip(args))
|
||||
if args.version == "v2.1base":
|
||||
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from tqdm.auto import tqdm
|
||||
from models.stable_diffusion.cache_objects import (
|
||||
cache_obj,
|
||||
schedulers,
|
||||
schedulers2,
|
||||
)
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from random import randint
|
||||
@@ -12,13 +11,12 @@ import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def set_ui_params(prompt, steps, guidance, seed, scheduler_key, version):
|
||||
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
|
||||
args.prompt = [prompt]
|
||||
args.steps = steps
|
||||
args.guidance = guidance
|
||||
args.seed = seed
|
||||
args.scheduler = scheduler_key
|
||||
args.version = version
|
||||
|
||||
|
||||
def stable_diff_inf(
|
||||
@@ -27,7 +25,6 @@ def stable_diff_inf(
|
||||
guidance: float,
|
||||
seed: int,
|
||||
scheduler_key: str,
|
||||
version: str,
|
||||
):
|
||||
|
||||
# Handle out of range seeds.
|
||||
@@ -36,29 +33,20 @@ def stable_diff_inf(
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
set_ui_params(prompt, steps, guidance, seed, scheduler_key, version)
|
||||
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
guidance_scale = torch.tensor(args.guidance).to(torch.float32)
|
||||
# Initialize vae and unet models.
|
||||
if args.version == "v2.1base":
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae2"],
|
||||
cache_obj["unet2"],
|
||||
cache_obj["clip2"],
|
||||
cache_obj["tokenizer2"],
|
||||
)
|
||||
scheduler = schedulers2[args.scheduler]
|
||||
else:
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae"],
|
||||
cache_obj["unet"],
|
||||
cache_obj["clip"],
|
||||
cache_obj["tokenizer"],
|
||||
)
|
||||
scheduler = schedulers[args.scheduler]
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae"],
|
||||
cache_obj["unet"],
|
||||
cache_obj["clip"],
|
||||
cache_obj["tokenizer"],
|
||||
)
|
||||
scheduler = schedulers[args.scheduler]
|
||||
|
||||
start = time.time()
|
||||
text_input = tokenizer(
|
||||
|
||||
Reference in New Issue
Block a user