[SHARK][SD] Add support for negative prompts

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-12-14 04:08:39 +05:30
parent 0eee7616b9
commit 986c126a5c
6 changed files with 52 additions and 13 deletions

View File

@@ -14,6 +14,7 @@ from stable_args import args
from utils import get_shark_model, set_iree_runtime_flags
from opt_params import get_unet, get_vae, get_clip
import time
import sys
from model_wrappers import get_vae_mlir
from shark.iree_utils.compile_utils import dump_isas
@@ -39,6 +40,7 @@ if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
prompt = args.prompts
neg_prompt = args.negative_prompts
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
if args.version == "v2":
@@ -54,7 +56,12 @@ if __name__ == "__main__":
args.seed
) # Seed generator to create the inital latent noise
# TODO: Add support for batch_size > 1.
batch_size = len(prompt)
if batch_size != 1:
sys.exit("More than one prompt is not supported yet.")
if batch_size != len(neg_prompt):
sys.exit("prompts and negative prompts must be of same length")
set_iree_runtime_flags()
unet = get_unet()
@@ -103,9 +110,10 @@ if __name__ == "__main__":
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size,
neg_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_clip_inf_start = time.time()

View File

@@ -1,11 +1,9 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from utils import compile_through_fx
from stable_args import args
import torch
BATCH_SIZE = len(args.prompts)
model_config = {
"v2": "stabilityai/stable-diffusion-2",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",

View File

@@ -7,12 +7,21 @@ p = argparse.ArgumentParser(
p.add_argument(
"--prompts",
nargs="+",
default=["a photograph of an astronaut riding a horse"],
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
)
p.add_argument(
"--negative-prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--device", type=str, default="cpu", help="device to run the model."
)
p.add_argument(
"--steps",
type=int,
@@ -33,6 +42,7 @@ p.add_argument(
default=42,
help="the seed to use.",
)
p.add_argument(
"--guidance_scale",
type=float,

View File

@@ -78,7 +78,13 @@ with gr.Blocks(css=demo_css) as shark_web:
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value="A photograph of an astronaut riding a horse",
value="cyberpunk forest by Salvador Dali",
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="trees, green",
lines=1,
elem_id="prompt_box",
)
@@ -148,6 +154,7 @@ with gr.Blocks(css=demo_css) as shark_web:
stable_diff_inf,
inputs=[
prompt,
negative_prompt,
steps,
guidance,
seed,
@@ -159,6 +166,7 @@ with gr.Blocks(css=demo_css) as shark_web:
stable_diff_inf,
inputs=[
prompt,
negative_prompt,
steps,
guidance,
seed,

View File

@@ -11,8 +11,11 @@ import numpy as np
import time
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
args.prompt = [prompt]
def set_ui_params(
prompt, negative_prompt, steps, guidance, seed, scheduler_key
):
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.steps = steps
args.guidance = guidance
args.seed = seed
@@ -21,6 +24,7 @@ def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
def stable_diff_inf(
prompt: str,
negative_prompt: str,
steps: int,
guidance: float,
seed: int,
@@ -33,7 +37,9 @@ def stable_diff_inf(
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
set_ui_params(
prompt, negative_prompt, steps, guidance, seed, scheduler_key
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
generator = torch.manual_seed(
args.seed
@@ -50,7 +56,7 @@ def stable_diff_inf(
start = time.time()
text_input = tokenizer(
args.prompt,
args.prompts,
padding="max_length",
max_length=args.max_length,
truncation=True,
@@ -64,9 +70,10 @@ def stable_diff_inf(
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""],
args.negative_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_clip_inf_start = time.time()
@@ -127,7 +134,8 @@ def stable_diff_inf(
avg_ms = 1000 * avg_ms / args.steps
total_time = time.time() - start
text_output = f"prompt={args.prompt}"
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}, version={args.version}"
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
print(f"\nAverage step time: {avg_ms}ms/it")

View File

@@ -7,10 +7,17 @@ p = argparse.ArgumentParser(
p.add_argument(
"--prompts",
nargs="+",
default=["a photograph of an astronaut riding a horse"],
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
)
p.add_argument(
"--negative-prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)