Compare commits

..

9 Commits

Author SHA1 Message Date
Daniel Garvey
9bdb86637d add tkinter launch for webui (#1364) 2023-04-27 19:17:55 -05:00
jinchen62
fb6f26517f Fix webui note (#1367) 2023-04-27 16:14:43 -07:00
Chi_Liu
aa8ada9da9 Add support for torch to stablehlo and tosa in shark_importer (#1360) 2023-04-27 08:09:45 -07:00
powderluv
1db906a373 Revert "Add model manager tab for webui (#1359)" (#1362)
This reverts commit 9d1d1617d8.
2023-04-26 22:25:26 -07:00
jinchen62
9d1d1617d8 Add model manager tab for webui (#1359) 2023-04-26 13:38:18 -07:00
jinchen62
7112789cb8 Add support of using civitai model download url (#1357) 2023-04-25 23:39:52 -07:00
jinchen62
d6b8be2849 Add drawing canvas for img2img stencil scribble (#1355) 2023-04-25 14:41:01 -07:00
powderluv
822171277c Revert "[SD] Add FastChat as part of SD WebUI (#1349)" (#1350)
This reverts commit a5ae9d9f02.
2023-04-24 15:22:25 -07:00
Abhishek Varma
a5ae9d9f02 [SD] Add FastChat as part of SD WebUI (#1349)
-- This commit includes FastChat as part of SD WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-24 11:12:58 -07:00
12 changed files with 157 additions and 29 deletions

View File

@@ -29,6 +29,8 @@ datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),

View File

@@ -1,9 +1,11 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from collections import defaultdict
from pathlib import Path
import torch
import safetensors.torch
import traceback
import subprocess
import sys
import os
from apps.stable_diffusion.src.utils import (
@@ -91,10 +93,19 @@ class SharkifyStableDiffusionModel:
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":

View File

@@ -493,6 +493,13 @@ p.add_argument(
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--web_mode",
type=str,
default="app",
help="any number of: [api, app, webui]. Currently api can't be run with others.",
)
p.add_argument(

View File

@@ -1,3 +1,4 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers
@@ -10,8 +11,22 @@ if sys.platform == "darwin":
if args.clear_all:
clear_all()
def launch_app(address):
from tkinter import Tk
import webview
tk = Tk()
# size of the window where we show our website
tk.geometry("1280x720")
webview.create_window("SHARK", address)
webview.start(private_mode=False)
if __name__ == "__main__":
if args.api:
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.web_mode.split(","):
from apps.stable_diffusion.web.ui import (
txt2img_api,
img2img_api,
@@ -221,9 +236,14 @@ if __name__ == "__main__":
[outpaint_init_image, tabs],
)
sd_web.queue()
if "app" in args.web_mode.split(","):
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser=True,
inbrowser="webui" in args.web_mode.split(","),
server_name="0.0.0.0",
server_port=args.server_port,
)

View File

@@ -29,6 +29,7 @@ from apps.stable_diffusion.src import (
save_output_img,
)
from apps.stable_diffusion.src.utils import get_generation_text_info
import numpy as np
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
@@ -41,7 +42,7 @@ init_import_mlir = args.import_mlir
def img2img_inf(
prompt: str,
negative_prompt: str,
init_image,
image_dict,
height: int,
width: int,
steps: int,
@@ -84,9 +85,12 @@ def img2img_inf(
args.img_path = "not none"
args.ondemand = ondemand
if init_image is None:
if image_dict is None:
return None, "An Initial Image is required"
image = init_image.convert("RGB")
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
else:
image = image_dict["image"].convert("RGB")
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
@@ -98,7 +102,10 @@ def img2img_inf(
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -292,7 +299,7 @@ def img2img_api(
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
init_image = decode_base64_to_image(InputData["image"])
res = img2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
@@ -357,9 +364,9 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(
@@ -386,7 +393,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
img2img_init_image = gr.Image(
label="Input Image", type="pil"
label="Input Image",
source="upload",
tool="sketch",
type="pil",
).style(height=300)
with gr.Accordion(label="Stencil Options", open=False):
@@ -397,6 +407,57 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
value="None",
choices=["None", "canny", "openpose", "scribble"],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(

View File

@@ -91,7 +91,10 @@ def inpaint_inf(
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -315,9 +318,9 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(

View File

@@ -93,7 +93,10 @@ def outpaint_inf(
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -326,9 +329,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(

View File

@@ -85,7 +85,10 @@ def txt2img_inf(
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -290,9 +293,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(

View File

@@ -89,7 +89,10 @@ def upscaler_inf(
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
else:
@@ -318,9 +321,9 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID",
label="HuggingFace Model ID or Civitai model download URL",
lines=3,
)
custom_vae = gr.Dropdown(

View File

@@ -26,6 +26,8 @@ safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
tk
pywebview
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -81,7 +81,7 @@ class SharkImporter:
# NOTE: The default function for torch is "forward" and tf-lite is "main".
def _torch_mlir(self, is_dynamic, tracing_required):
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
@@ -90,6 +90,7 @@ class SharkImporter:
is_dynamic,
tracing_required,
self.return_str,
mlir_type,
)
def _tf_mlir(self, func_name, save_dir="."):
@@ -120,6 +121,7 @@ class SharkImporter:
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
if self.inputs == None:
@@ -127,7 +129,10 @@ class SharkImporter:
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
)
sys.exit(1)
return self._torch_mlir(is_dynamic, tracing_required), func_name
return (
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
func_name,
)
if self.frontend in ["tf", "tensorflow"]:
return self._tf_mlir(func_name, save_dir), func_name
if self.frontend in ["tflite", "tf-lite"]:

View File

@@ -19,6 +19,12 @@ import tempfile
from shark.parser import shark_args
import io
mlir_type_mapping_dict = {
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
"mhlo": torch_mlir.OutputType.STABLEHLO,
"tosa": torch_mlir.OutputType.TOSA,
}
def get_module_name_for_asm_dump(module):
"""Gets a name suitable for an assembly dump.
@@ -57,6 +63,7 @@ def get_torch_mlir_module(
dynamic: bool,
jit_trace: bool,
return_str: bool = False,
mlir_type: str = "linalg",
):
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
ignore_traced_shapes = False
@@ -70,10 +77,11 @@ def get_torch_mlir_module(
mlir_module = torch_mlir.compile(
module,
input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
output_type=mlir_type_mapping_dict[mlir_type],
use_tracing=jit_trace,
ignore_traced_shapes=ignore_traced_shapes,
)
if return_str:
return mlir_module.operation.get_asm()
bytecode_stream = io.BytesIO()