[WEB] Minor changes in the shark web (#454)

1. Default steps = 50.
2. Live preview will yield intermediate image at every 5 steps.
3. Add logs to .gitignore

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-11-01 02:59:00 +05:30
committed by GitHub
parent f4c91df1df
commit a3fdfc81db
7 changed files with 27 additions and 15 deletions

4
.gitignore vendored
View File

@@ -167,3 +167,7 @@ shark_tmp/
# ORT related artefacts
cache_models/
onnx_models/
#web logging
web/logs/
web/stored_results/stable_diffusion/

View File

@@ -192,17 +192,27 @@ with gr.Blocks() as shark_web:
)
with gr.Row():
steps = gr.Slider(
1, 100, value=20, step=1, label="Steps"
1, 100, value=50, step=1, label="Steps"
)
guidance = gr.Slider(
0, 50, value=7.5, step=0.1, label="Guidance Scale"
)
with gr.Row():
height = gr.Slider(
384, 768, value=512, step=64, label="Height"
384,
768,
value=512,
step=64,
label="Height",
interactive=False,
)
width = gr.Slider(
384, 768, value=512, step=64, label="Width"
384,
768,
value=512,
step=64,
label="Width",
interactive=False,
)
with gr.Row():
scheduler = gr.Radio(
@@ -278,4 +288,4 @@ with gr.Blocks() as shark_web:
)
shark_web.queue()
shark_web.launch(server_port=8080, enable_queue=True)
shark_web.launch(server_name="0.0.0.0", server_port=8080, enable_queue=True)

View File

@@ -245,8 +245,7 @@ def stable_diff_inf(
text_output = text_output + f"time={step_ms}ms"
latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
if live_preview:
time.sleep(0.1)
if live_preview and i % 5 == 0:
scaled_latents = 1 / 0.18215 * latents
latents_numpy = scaled_latents.detach().numpy()
image = vae.forward((latents_numpy,))
@@ -258,15 +257,14 @@ def stable_diff_inf(
yield out_img, text_output
# scale and decode the image latents with vae
if not live_preview:
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
image = vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
out_img = pil_images[0]
latents = 1 / 0.18215 * latents
latents_numpy = latents.detach().numpy()
image = vae.forward((latents_numpy,))
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
out_img = pil_images[0]
avg_ms = 1000 * avg_ms / args.steps
text_output = text_output + f"\nAverage step time: {avg_ms}ms/it"