[WEB] Update stable diffusion UI and enable live preview (#447)

This commit enables live preview feature and also updates stable
diffusion web UI.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-10-31 16:40:15 +05:30
committed by GitHub
parent 024c5e153a
commit 25931d48a3
3 changed files with 112 additions and 64 deletions

View File

@@ -48,10 +48,8 @@ def get_vulkan_triple_flag(extra_args=[]):
elif all(x in vulkan_device for x in ("RTX", "3090")):
print(f"Found {vulkan_device} Device. Using ampere-rtx3090-linux")
return "-iree-vulkan-target-triple=ampere-rtx3090-linux"
elif any(x in vulkan_device for x in ("Radeon", "AMD")):
print(
"Found AMD Radeon RX 6000 series device. Using rdna2-unknown-linux"
)
elif "AMD" in vulkan_device:
print("Found AMD device. Using rdna2-unknown-linux")
return "-iree-vulkan-target-triple=rdna2-unknown-linux"
else:
print(

View File

@@ -141,67 +141,96 @@ with gr.Blocks() as shark_web:
save_vmfb
) = (
iree_vulkan_target_triple
) = (
live_preview
) = debug = stable_diffusion = generated_img = std_output = None
examples = [
["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
[
"A small cabin on top of a snowy mountain in the style of Disney, artstation"
],
]
with gr.Row():
with gr.Column(scale=1, min_width=600):
prompt = gr.Textbox(
label="Prompt",
value="a photograph of an astronaut riding a horse",
lines=2,
)
scheduler = gr.Radio(
label="Scheduler",
value="LMS",
choices=["PNDM", "LMS", "DDIM"],
visible=False,
)
iters_count = gr.Slider(
1,
24,
value=1,
step=1,
label="Iteration Count",
visible=False,
)
batch_size = gr.Slider(
1,
4,
value=1,
step=1,
label="Batch Size",
visible=False,
)
steps = gr.Slider(1, 100, value=20, step=1, label="Steps")
guidance = gr.Slider(
0, 50, value=7.5, step=0.1, label="Guidance Scale"
)
height = gr.Slider(
384, 768, value=512, step=64, label="Height"
)
width = gr.Slider(
384, 768, value=512, step=64, label="Width"
)
seed = gr.Textbox(value="42", max_lines=1, label="Seed")
precision = gr.Radio(
label="Precision",
value="fp32",
choices=["fp16", "fp32"],
)
device = gr.Radio(
label="Device",
value="vulkan",
choices=["cpu", "cuda", "vulkan"],
)
with gr.Group():
prompt = gr.Textbox(
label="Prompt",
value="a photograph of an astronaut riding a horse",
)
ex = gr.Examples(
examples=examples,
inputs=prompt,
cache_examples=False,
)
with gr.Row():
iters_count = gr.Slider(
1,
24,
value=1,
step=1,
label="Iteration Count",
visible=False,
)
batch_size = gr.Slider(
1,
4,
value=1,
step=1,
label="Batch Size",
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=20, step=1, label="Steps"
)
guidance = gr.Slider(
0, 50, value=7.5, step=0.1, label="Guidance Scale"
)
with gr.Row():
height = gr.Slider(
384, 768, value=512, step=64, label="Height"
)
width = gr.Slider(
384, 768, value=512, step=64, label="Width"
)
with gr.Row():
precision = gr.Radio(
label="Precision",
value="fp32",
choices=["fp16", "fp32"],
)
device = gr.Radio(
label="Device",
value="vulkan",
choices=["cpu", "cuda", "vulkan"],
)
with gr.Row():
scheduler = gr.Radio(
label="Scheduler",
value="LMS",
choices=["PNDM", "LMS", "DDIM"],
interactive=False,
)
seed = gr.Textbox(
value="42", max_lines=1, label="Seed"
)
with gr.Row():
load_vmfb = gr.Checkbox(label="Load vmfb", value=True)
save_vmfb = gr.Checkbox(label="Save vmfb", value=False)
debug = gr.Checkbox(label="DEBUG", value=False)
live_preview = gr.Checkbox(
label="live preview", value=False
)
iree_vulkan_target_triple = gr.Textbox(
value="",
max_lines=1,
label="IREE VULKAN TARGET TRIPLE",
visible=False,
)
debug = gr.Checkbox(label="DEBUG", value=False)
stable_diffusion = gr.Button("Generate image from prompt")
with gr.Column(scale=1, min_width=600):
generated_img = gr.Image(type="pil", shape=(100, 100))
@@ -211,6 +240,7 @@ with gr.Blocks() as shark_web:
lines=10,
visible=False,
)
debug.change(
debug_event,
inputs=[debug],
@@ -234,8 +264,10 @@ with gr.Blocks() as shark_web:
load_vmfb,
save_vmfb,
iree_vulkan_target_triple,
live_preview,
],
outputs=[generated_img, std_output],
)
shark_web.queue()
shark_web.launch(share=True, server_port=8080, enable_queue=True)

View File

@@ -42,6 +42,7 @@ class Arguments:
load_vmfb: bool,
save_vmfb: bool,
iree_vulkan_target_triple: str,
live_preview: bool,
import_mlir: bool = False,
max_length: int = 77,
):
@@ -59,6 +60,7 @@ class Arguments:
self.load_vmfb = load_vmfb
self.save_vmfb = save_vmfb
self.iree_vulkan_target_triple = iree_vulkan_target_triple
self.live_preview = live_preview
self.import_mlir = import_mlir
self.max_length = max_length
@@ -114,6 +116,7 @@ def stable_diff_inf(
load_vmfb: bool,
save_vmfb: bool,
iree_vulkan_target_triple: str,
live_preview: bool,
):
global IREE_EXTRA_ARGS
@@ -178,6 +181,7 @@ def stable_diff_inf(
load_vmfb,
save_vmfb,
iree_vulkan_target_triple,
live_preview,
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
if len(args.iree_vulkan_target_triple) > 0:
@@ -228,6 +232,7 @@ def stable_diff_inf(
text_embeddings_numpy = text_embeddings.detach().numpy()
avg_ms = 0
out_img = None
for i, t in tqdm(enumerate(scheduler.timesteps)):
if DEBUG:
@@ -248,25 +253,38 @@ def stable_diff_inf(
log_write.write(f"time={step_ms}ms")
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
if live_preview:
time.sleep(0.1)
scaled_latents = 1 / 0.18215 * latents
latents_numpy = scaled_latents.detach().numpy()
image = vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
out_img = pil_images[0]
yield out_img, ""
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
image = vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
if not live_preview:
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
image = vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
out_img = pil_images[0]
avg_ms = 1000 * avg_ms / args.steps
if DEBUG:
log_write.write(f"\nAverage step time: {avg_ms}ms/it")
print("total images:", len(pil_images))
output = pil_images[0]
# save the output image with the prompt name.
output.save(os.path.join(output_loc))
out_img.save(os.path.join(output_loc))
log_write.close()
std_output = ""
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
std_output = log_read.read()
return output, std_output
yield out_img, std_output