mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SHARK][SD] Add support for negative prompts
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from stable_args import args
|
||||
from utils import get_shark_model, set_iree_runtime_flags
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
import time
|
||||
import sys
|
||||
from model_wrappers import get_vae_mlir
|
||||
from shark.iree_utils.compile_utils import dump_isas
|
||||
|
||||
@@ -39,6 +40,7 @@ if __name__ == "__main__":
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
|
||||
prompt = args.prompts
|
||||
neg_prompt = args.negative_prompts
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
if args.version == "v2":
|
||||
@@ -54,7 +56,12 @@ if __name__ == "__main__":
|
||||
args.seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
# TODO: Add support for batch_size > 1.
|
||||
batch_size = len(prompt)
|
||||
if batch_size != 1:
|
||||
sys.exit("More than one prompt is not supported yet.")
|
||||
if batch_size != len(neg_prompt):
|
||||
sys.exit("prompts and negative prompts must be of same length")
|
||||
|
||||
set_iree_runtime_flags()
|
||||
unet = get_unet()
|
||||
@@ -103,9 +110,10 @@ if __name__ == "__main__":
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
neg_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_clip_inf_start = time.time()
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
from stable_args import args
|
||||
import torch
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
|
||||
model_config = {
|
||||
"v2": "stabilityai/stable-diffusion-2",
|
||||
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
|
||||
@@ -7,12 +7,21 @@ p = argparse.ArgumentParser(
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["a photograph of an astronaut riding a horse"],
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
nargs="+",
|
||||
default=["trees, green"],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="cpu", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
@@ -33,6 +42,7 @@ p.add_argument(
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
|
||||
10
web/index.py
10
web/index.py
@@ -78,7 +78,13 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value="A photograph of an astronaut riding a horse",
|
||||
value="cyberpunk forest by Salvador Dali",
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value="trees, green",
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
@@ -148,6 +154,7 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
stable_diff_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
guidance,
|
||||
seed,
|
||||
@@ -159,6 +166,7 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
stable_diff_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
guidance,
|
||||
seed,
|
||||
|
||||
@@ -11,8 +11,11 @@ import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
|
||||
args.prompt = [prompt]
|
||||
def set_ui_params(
|
||||
prompt, negative_prompt, steps, guidance, seed, scheduler_key
|
||||
):
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.steps = steps
|
||||
args.guidance = guidance
|
||||
args.seed = seed
|
||||
@@ -21,6 +24,7 @@ def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
|
||||
|
||||
def stable_diff_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed: int,
|
||||
@@ -33,7 +37,9 @@ def stable_diff_inf(
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
|
||||
set_ui_params(
|
||||
prompt, negative_prompt, steps, guidance, seed, scheduler_key
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
@@ -50,7 +56,7 @@ def stable_diff_inf(
|
||||
|
||||
start = time.time()
|
||||
text_input = tokenizer(
|
||||
args.prompt,
|
||||
args.prompts,
|
||||
padding="max_length",
|
||||
max_length=args.max_length,
|
||||
truncation=True,
|
||||
@@ -64,9 +70,10 @@ def stable_diff_inf(
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
|
||||
uncond_input = tokenizer(
|
||||
[""],
|
||||
args.negative_prompts,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_clip_inf_start = time.time()
|
||||
@@ -127,7 +134,8 @@ def stable_diff_inf(
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
total_time = time.time() - start
|
||||
|
||||
text_output = f"prompt={args.prompt}"
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}, version={args.version}"
|
||||
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
|
||||
print(f"\nAverage step time: {avg_ms}ms/it")
|
||||
|
||||
@@ -7,10 +7,17 @@ p = argparse.ArgumentParser(
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["a photograph of an astronaut riding a horse"],
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
nargs="+",
|
||||
default=["trees, green"],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user