mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
9 Commits
20230423.7
...
20230427.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bdb86637d | ||
|
|
fb6f26517f | ||
|
|
aa8ada9da9 | ||
|
|
1db906a373 | ||
|
|
9d1d1617d8 | ||
|
|
7112789cb8 | ||
|
|
d6b8be2849 | ||
|
|
822171277c | ||
|
|
a5ae9d9f02 |
@@ -29,6 +29,8 @@ datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('tkinter')
|
||||
datas += collect_data_files('webview')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
|
||||
from transformers import CLIPTextModel
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import safetensors.torch
|
||||
import traceback
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
@@ -91,10 +93,19 @@ class SharkifyStableDiffusionModel:
|
||||
self.custom_weights = custom_weights
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
if "civitai" in custom_weights:
|
||||
weights_id = custom_weights.split("/")[-1]
|
||||
# TODO: use model name and identify file type by civitai rest api
|
||||
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
if not os.path.isfile(weights_path):
|
||||
subprocess.run(["wget", custom_weights, "-O", weights_path])
|
||||
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
|
||||
self.custom_weights = weights_path
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
|
||||
@@ -493,6 +493,13 @@ p.add_argument(
|
||||
default="",
|
||||
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
|
||||
)
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--web_mode",
|
||||
type=str,
|
||||
default="app",
|
||||
help="any number of: [api, app, webui]. Currently api can't be run with others.",
|
||||
)
|
||||
|
||||
|
||||
p.add_argument(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import transformers
|
||||
@@ -10,8 +11,22 @@ if sys.platform == "darwin":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
tk = Tk()
|
||||
# size of the window where we show our website
|
||||
tk.geometry("1280x720")
|
||||
webview.create_window("SHARK", address)
|
||||
webview.start(private_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.api:
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.web_mode.split(","):
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_api,
|
||||
img2img_api,
|
||||
@@ -221,9 +236,14 @@ if __name__ == "__main__":
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
sd_web.queue()
|
||||
if "app" in args.web_mode.split(","):
|
||||
t = Process(
|
||||
target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
)
|
||||
t.start()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
inbrowser="webui" in args.web_mode.split(","),
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from apps.stable_diffusion.src import (
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
import numpy as np
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -41,7 +42,7 @@ init_import_mlir = args.import_mlir
|
||||
def img2img_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image,
|
||||
image_dict,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
@@ -84,9 +85,12 @@ def img2img_inf(
|
||||
args.img_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
if init_image is None:
|
||||
if image_dict is None:
|
||||
return None, "An Initial Image is required"
|
||||
image = init_image.convert("RGB")
|
||||
if use_stencil == "scribble":
|
||||
image = image_dict["mask"].convert("RGB")
|
||||
else:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
@@ -98,7 +102,10 @@ def img2img_inf(
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
@@ -292,7 +299,7 @@ def img2img_api(
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
init_image = decode_base64_to_image(InputData["image"])
|
||||
res = img2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
@@ -357,9 +364,9 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
@@ -386,7 +393,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
)
|
||||
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image", type="pil"
|
||||
label="Input Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
).style(height=300)
|
||||
|
||||
with gr.Accordion(label="Stencil Options", open=False):
|
||||
@@ -397,6 +407,57 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
if choice == "scribble":
|
||||
return (
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Button.update(visible=False),
|
||||
)
|
||||
|
||||
def create_canvas(w, h):
|
||||
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
||||
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
create_button = gr.Button(
|
||||
label="Start",
|
||||
value="Open drawing canvas!",
|
||||
visible=False,
|
||||
)
|
||||
create_button.click(
|
||||
fn=create_canvas,
|
||||
inputs=[canvas_width, canvas_height],
|
||||
outputs=[img2img_init_image],
|
||||
)
|
||||
use_stencil.change(
|
||||
fn=show_canvas,
|
||||
inputs=use_stencil,
|
||||
outputs=[canvas_width, canvas_height, create_button],
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
|
||||
@@ -91,7 +91,10 @@ def inpaint_inf(
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
@@ -315,9 +318,9 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
|
||||
@@ -93,7 +93,10 @@ def outpaint_inf(
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
@@ -326,9 +329,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
|
||||
@@ -85,7 +85,10 @@ def txt2img_inf(
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
@@ -290,9 +293,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
|
||||
@@ -89,7 +89,10 @@ def upscaler_inf(
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
@@ -318,9 +321,9 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
|
||||
@@ -26,6 +26,8 @@ safetensors
|
||||
opencv-python
|
||||
scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
tk
|
||||
pywebview
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -81,7 +81,7 @@ class SharkImporter:
|
||||
|
||||
# NOTE: The default function for torch is "forward" and tf-lite is "main".
|
||||
|
||||
def _torch_mlir(self, is_dynamic, tracing_required):
|
||||
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module
|
||||
|
||||
return get_torch_mlir_module(
|
||||
@@ -90,6 +90,7 @@ class SharkImporter:
|
||||
is_dynamic,
|
||||
tracing_required,
|
||||
self.return_str,
|
||||
mlir_type,
|
||||
)
|
||||
|
||||
def _tf_mlir(self, func_name, save_dir="."):
|
||||
@@ -120,6 +121,7 @@ class SharkImporter:
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
save_dir="./shark_tmp/",
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
if self.inputs == None:
|
||||
@@ -127,7 +129,10 @@ class SharkImporter:
|
||||
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
|
||||
)
|
||||
sys.exit(1)
|
||||
return self._torch_mlir(is_dynamic, tracing_required), func_name
|
||||
return (
|
||||
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
|
||||
func_name,
|
||||
)
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return self._tf_mlir(func_name, save_dir), func_name
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
|
||||
@@ -19,6 +19,12 @@ import tempfile
|
||||
from shark.parser import shark_args
|
||||
import io
|
||||
|
||||
mlir_type_mapping_dict = {
|
||||
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
"mhlo": torch_mlir.OutputType.STABLEHLO,
|
||||
"tosa": torch_mlir.OutputType.TOSA,
|
||||
}
|
||||
|
||||
|
||||
def get_module_name_for_asm_dump(module):
|
||||
"""Gets a name suitable for an assembly dump.
|
||||
@@ -57,6 +63,7 @@ def get_torch_mlir_module(
|
||||
dynamic: bool,
|
||||
jit_trace: bool,
|
||||
return_str: bool = False,
|
||||
mlir_type: str = "linalg",
|
||||
):
|
||||
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
|
||||
ignore_traced_shapes = False
|
||||
@@ -70,10 +77,11 @@ def get_torch_mlir_module(
|
||||
mlir_module = torch_mlir.compile(
|
||||
module,
|
||||
input,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
output_type=mlir_type_mapping_dict[mlir_type],
|
||||
use_tracing=jit_trace,
|
||||
ignore_traced_shapes=ignore_traced_shapes,
|
||||
)
|
||||
|
||||
if return_str:
|
||||
return mlir_module.operation.get_asm()
|
||||
bytecode_stream = io.BytesIO()
|
||||
|
||||
Reference in New Issue
Block a user