mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
[WEB] Update stable diffusion UI and enable live preview (#447)
This commit enables live preview feature and also updates stable diffusion web UI. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com> Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -48,10 +48,8 @@ def get_vulkan_triple_flag(extra_args=[]):
|
||||
elif all(x in vulkan_device for x in ("RTX", "3090")):
|
||||
print(f"Found {vulkan_device} Device. Using ampere-rtx3090-linux")
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3090-linux"
|
||||
elif any(x in vulkan_device for x in ("Radeon", "AMD")):
|
||||
print(
|
||||
"Found AMD Radeon RX 6000 series device. Using rdna2-unknown-linux"
|
||||
)
|
||||
elif "AMD" in vulkan_device:
|
||||
print("Found AMD device. Using rdna2-unknown-linux")
|
||||
return "-iree-vulkan-target-triple=rdna2-unknown-linux"
|
||||
else:
|
||||
print(
|
||||
|
||||
130
web/index.py
130
web/index.py
@@ -141,67 +141,96 @@ with gr.Blocks() as shark_web:
|
||||
save_vmfb
|
||||
) = (
|
||||
iree_vulkan_target_triple
|
||||
) = (
|
||||
live_preview
|
||||
) = debug = stable_diffusion = generated_img = std_output = None
|
||||
examples = [
|
||||
["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
[
|
||||
"A small cabin on top of a snowy mountain in the style of Disney, artstation"
|
||||
],
|
||||
]
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value="a photograph of an astronaut riding a horse",
|
||||
lines=2,
|
||||
)
|
||||
scheduler = gr.Radio(
|
||||
label="Scheduler",
|
||||
value="LMS",
|
||||
choices=["PNDM", "LMS", "DDIM"],
|
||||
visible=False,
|
||||
)
|
||||
iters_count = gr.Slider(
|
||||
1,
|
||||
24,
|
||||
value=1,
|
||||
step=1,
|
||||
label="Iteration Count",
|
||||
visible=False,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=1,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
visible=False,
|
||||
)
|
||||
steps = gr.Slider(1, 100, value=20, step=1, label="Steps")
|
||||
guidance = gr.Slider(
|
||||
0, 50, value=7.5, step=0.1, label="Guidance Scale"
|
||||
)
|
||||
height = gr.Slider(
|
||||
384, 768, value=512, step=64, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 768, value=512, step=64, label="Width"
|
||||
)
|
||||
seed = gr.Textbox(value="42", max_lines=1, label="Seed")
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp32",
|
||||
choices=["fp16", "fp32"],
|
||||
)
|
||||
device = gr.Radio(
|
||||
label="Device",
|
||||
value="vulkan",
|
||||
choices=["cpu", "cuda", "vulkan"],
|
||||
)
|
||||
with gr.Group():
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value="a photograph of an astronaut riding a horse",
|
||||
)
|
||||
ex = gr.Examples(
|
||||
examples=examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.Row():
|
||||
iters_count = gr.Slider(
|
||||
1,
|
||||
24,
|
||||
value=1,
|
||||
step=1,
|
||||
label="Iteration Count",
|
||||
visible=False,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=1,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=20, step=1, label="Steps"
|
||||
)
|
||||
guidance = gr.Slider(
|
||||
0, 50, value=7.5, step=0.1, label="Guidance Scale"
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 768, value=512, step=64, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 768, value=512, step=64, label="Width"
|
||||
)
|
||||
with gr.Row():
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp32",
|
||||
choices=["fp16", "fp32"],
|
||||
)
|
||||
device = gr.Radio(
|
||||
label="Device",
|
||||
value="vulkan",
|
||||
choices=["cpu", "cuda", "vulkan"],
|
||||
)
|
||||
with gr.Row():
|
||||
scheduler = gr.Radio(
|
||||
label="Scheduler",
|
||||
value="LMS",
|
||||
choices=["PNDM", "LMS", "DDIM"],
|
||||
interactive=False,
|
||||
)
|
||||
seed = gr.Textbox(
|
||||
value="42", max_lines=1, label="Seed"
|
||||
)
|
||||
with gr.Row():
|
||||
load_vmfb = gr.Checkbox(label="Load vmfb", value=True)
|
||||
save_vmfb = gr.Checkbox(label="Save vmfb", value=False)
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
live_preview = gr.Checkbox(
|
||||
label="live preview", value=False
|
||||
)
|
||||
iree_vulkan_target_triple = gr.Textbox(
|
||||
value="",
|
||||
max_lines=1,
|
||||
label="IREE VULKAN TARGET TRIPLE",
|
||||
visible=False,
|
||||
)
|
||||
debug = gr.Checkbox(label="DEBUG", value=False)
|
||||
stable_diffusion = gr.Button("Generate image from prompt")
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
generated_img = gr.Image(type="pil", shape=(100, 100))
|
||||
@@ -211,6 +240,7 @@ with gr.Blocks() as shark_web:
|
||||
lines=10,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
debug.change(
|
||||
debug_event,
|
||||
inputs=[debug],
|
||||
@@ -234,8 +264,10 @@ with gr.Blocks() as shark_web:
|
||||
load_vmfb,
|
||||
save_vmfb,
|
||||
iree_vulkan_target_triple,
|
||||
live_preview,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
|
||||
shark_web.queue()
|
||||
shark_web.launch(share=True, server_port=8080, enable_queue=True)
|
||||
|
||||
@@ -42,6 +42,7 @@ class Arguments:
|
||||
load_vmfb: bool,
|
||||
save_vmfb: bool,
|
||||
iree_vulkan_target_triple: str,
|
||||
live_preview: bool,
|
||||
import_mlir: bool = False,
|
||||
max_length: int = 77,
|
||||
):
|
||||
@@ -59,6 +60,7 @@ class Arguments:
|
||||
self.load_vmfb = load_vmfb
|
||||
self.save_vmfb = save_vmfb
|
||||
self.iree_vulkan_target_triple = iree_vulkan_target_triple
|
||||
self.live_preview = live_preview
|
||||
self.import_mlir = import_mlir
|
||||
self.max_length = max_length
|
||||
|
||||
@@ -114,6 +116,7 @@ def stable_diff_inf(
|
||||
load_vmfb: bool,
|
||||
save_vmfb: bool,
|
||||
iree_vulkan_target_triple: str,
|
||||
live_preview: bool,
|
||||
):
|
||||
|
||||
global IREE_EXTRA_ARGS
|
||||
@@ -178,6 +181,7 @@ def stable_diff_inf(
|
||||
load_vmfb,
|
||||
save_vmfb,
|
||||
iree_vulkan_target_triple,
|
||||
live_preview,
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
@@ -228,6 +232,7 @@ def stable_diff_inf(
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
avg_ms = 0
|
||||
out_img = None
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
||||
|
||||
if DEBUG:
|
||||
@@ -248,25 +253,38 @@ def stable_diff_inf(
|
||||
log_write.write(f"time={step_ms}ms")
|
||||
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
|
||||
if live_preview:
|
||||
time.sleep(0.1)
|
||||
scaled_latents = 1 / 0.18215 * latents
|
||||
latents_numpy = scaled_latents.detach().numpy()
|
||||
image = vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
out_img = pil_images[0]
|
||||
yield out_img, ""
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
if not live_preview:
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents.detach().numpy()
|
||||
image = vae.forward((latents_numpy,))
|
||||
image = torch.from_numpy(image)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
out_img = pil_images[0]
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
if DEBUG:
|
||||
log_write.write(f"\nAverage step time: {avg_ms}ms/it")
|
||||
|
||||
print("total images:", len(pil_images))
|
||||
output = pil_images[0]
|
||||
# save the output image with the prompt name.
|
||||
output.save(os.path.join(output_loc))
|
||||
out_img.save(os.path.join(output_loc))
|
||||
log_write.close()
|
||||
|
||||
std_output = ""
|
||||
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
|
||||
std_output = log_read.read()
|
||||
return output, std_output
|
||||
yield out_img, std_output
|
||||
|
||||
Reference in New Issue
Block a user