Rest API support and cleanup

This commit is contained in:
Ean Garvey
2024-08-08 11:37:53 -05:00
parent d5f37eaf20
commit 4759e808f2
18 changed files with 450 additions and 211 deletions

View File

@@ -46,7 +46,7 @@ jobs:
draft: true
prerelease: true
- name: Build Package
- name: Build Package (api only)
shell: powershell
run: |
./setup_venv.ps1
@@ -54,10 +54,10 @@ jobs:
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip install -e .
pip freeze -l
pyinstaller .\apps\shark_studio\shark_studio.spec
pyinstaller .\apps\shark_studio\shark_studio_apionly.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1

View File

@@ -34,9 +34,9 @@ def imports():
action="ignore", category=UserWarning, module="huggingface-hub"
)
import gradio # noqa: F401
# import gradio # noqa: F401
startup_timer.record("import gradio")
# startup_timer.record("import gradio")
import apps.shark_studio.web.utils.globals as global_obj
@@ -56,9 +56,8 @@ def initialize():
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
config_tmp()
# clear_tmp_mlir()
clear_tmp_imgs()
# config_tmp()
# clear_tmp_imgs()
from apps.shark_studio.web.utils.file_utils import (
create_model_folders,
@@ -67,8 +66,6 @@ def initialize():
# Create custom models folders if they don't exist
create_model_folders()
import gradio as gr
# initialize_rest(reload_script_modules=False)

View File

@@ -14,7 +14,6 @@ from pathlib import Path
from random import randint
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
@@ -30,6 +29,7 @@ from apps.shark_studio.modules.img_processing import (
from subprocess import check_output
EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
@@ -114,11 +114,14 @@ class StableDiffusion:
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)
self.turbine_pipe = SharkSDXLPipeline
self.dynamic_steps = False
self.model_map = EMPTY_SDXL_MAP
else:
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sd_inference.sd_pipeline import (
SharkSDPipeline,
)
self.turbine_pipe = SharkSDPipeline
self.dynamic_steps = True
@@ -209,6 +212,7 @@ class StableDiffusion:
preprocessCKPT,
save_irpa,
)
custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
@@ -223,14 +227,20 @@ class StableDiffusion:
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(unet_weights_path, "unet.")
elif key in ["clip", "prompt_encoder"]:
if not self.is_sdxl:
if key in ["mmdit"]:
mmdit_weights_path = os.path.join(
diffusers_weights_path,
"mmdit",
"diffusion_pytorch_model_fp16.safetensors",
)
weights[key] = save_irpa(mmdit_weights_path, "mmdit.")
elif key in ["clip", "prompt_encoder", "text_encoder"]:
if not self.is_sdxl and not self.is_custom:
sd1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
else:
elif self.is_sdxl:
clip_1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
@@ -243,7 +253,27 @@ class StableDiffusion:
save_irpa(clip_1_path, "text_encoder_model_1."),
save_irpa(clip_2_path, "text_encoder_model_2."),
]
elif self.is_custom:
clip_g_path = os.path.join(
diffusers_weights_path,
"text_encoder",
"model.fp16.safetensors",
)
clip_l_path = os.path.join(
diffusers_weights_path,
"text_encoder_2",
"model.fp16.safetensors",
)
t5xxl_path = os.path.join(
diffusers_weights_path,
"text_encoder_3",
"model.fp16.safetensors",
)
weights[key] = [
save_irpa(clip_g_path, "clip_g.transformer."),
save_irpa(clip_l_path, "clip_l.transformer."),
save_irpa(t5xxl_path, "t5xxl.transformer."),
]
elif key in ["vae_decode"] and weights[key] is None:
vae_weights_path = os.path.join(
diffusers_weights_path,
@@ -251,6 +281,7 @@ class StableDiffusion:
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")
progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...")
vmfbs, weights = self.sd_pipe.check_prepared(
@@ -291,49 +322,6 @@ class StableDiffusion:
return img
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
print("\n[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
sd_kwargs[key] = int(sd_kwargs[key])
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs
def shark_sd_fn(
prompt,
negative_prompt,
@@ -359,7 +347,8 @@ def shark_sd_fn(
controlnets: dict,
embeddings: dict,
seed_increment: str | int = 1,
progress=gr.Progress(),
output_type: str = "png",
# progress=gr.Progress(),
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
@@ -464,8 +453,8 @@ def shark_sd_fn(
if submit_run_kwargs["seed"] in [-1, "-1"]:
submit_run_kwargs["seed"] = randint(0, 4294967295)
seed_increment = "random"
#print(f"\n[LOG] Random seed: {seed}")
progress(None, desc=f"Generating...")
# print(f"\n[LOG] Random seed: {seed}")
# progress(None, desc=f"Generating...")
for current_batch in range(batch_count):
start_time = time.time()
@@ -479,13 +468,14 @@ def shark_sd_fn(
# break
# else:
for batch in range(batch_size):
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
if output_type == "png":
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
@@ -495,13 +485,56 @@ def shark_sd_fn(
return (generated_imgs, "")
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
print("\n[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
sd_kwargs[key] = int(sd_kwargs[key])
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs
def get_next_seed(seed, seed_increment: str | int = 10):
if isinstance(seed_increment, int):
#print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
# print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
return int(seed + seed_increment)
elif seed_increment == "random":
seed = randint(0, 4294967295)
#print(f"\n[LOG] Random seed: {seed}")
# print(f"\n[LOG] Random seed: {seed}")
return seed

View File

@@ -63,9 +63,9 @@ _IREE_TARGET_MAP = {
}
def get_available_devices():
return ['rocm', 'cpu']
return ["rocm", "cpu"]
def get_devices_by_name(driver_name):
device_list = []
@@ -94,7 +94,7 @@ def get_available_devices():
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
#set_iree_runtime_flags()
# set_iree_runtime_flags()
available_devices = []
rocm_devices = get_devices_by_name("rocm")
@@ -140,17 +140,14 @@ def get_available_devices():
break
return available_devices
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
device = raw_device if "=>" not in raw_device else raw_device.split("=>")[1].strip()
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
@@ -162,6 +159,7 @@ def clean_device_info(raw_device):
device_id = 0
return device, device_id
def parse_device(device_str, target_override=""):
rt_driver, device_id = clean_device_info(device_str)
@@ -287,4 +285,4 @@ def get_all_devices(driver_name):
# # Due to lack of support for multi-reduce, we always collapse reduction
# # dims before dispatch formation right now.
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
# return iree_flags
# return iree_flags

View File

@@ -597,7 +597,7 @@ p.add_argument(
"--defaults",
default="sdxl-turbo.json",
type=str,
help="Path to the default API request .json file. Works for CLI and webui."
help="Path to the default API request .json file. Works for CLI and webui.",
)
p.add_argument(

View File

@@ -0,0 +1,45 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.shark_studio.studio_imports_apionly import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd3_server',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -22,30 +22,25 @@ datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("omegaconf")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += copy_metadata("scipy")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree", include_py_files=True)
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("shark-turbine", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("scipy", include_py_files=True)
datas += [
("web/ui/css/*", "ui/css"),
("web/ui/js/*", "ui/js"),
@@ -54,7 +49,7 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "apps"]
hiddenimports = ["apps", "shark-turbine"]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
blacklist = ["tests", "convert"]
@@ -65,4 +60,3 @@ hiddenimports += [
]
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
hiddenimports += ["iree._runtime"]
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]

View File

@@ -0,0 +1,46 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [
".",
]
# datafiles for pyinstaller
datas = []
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("iree", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
# hidden imports for pyinstaller
hiddenimports = ["apps", "shark-turbine"]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
hiddenimports += ["iree._runtime"]

View File

@@ -20,9 +20,6 @@ from fastapi.encoders import jsonable_encoder
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# from sdapi_v1 import shark_sd_api
from apps.shark_studio.api.llm import llm_chat_api
def decode_base64_to_image(encoding):
if encoding.startswith("http://") or encoding.startswith("https://"):
@@ -183,50 +180,8 @@ class ApiCompat:
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
# self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
# self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
# self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
# self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
# self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
# self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
# self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
# self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
# self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
# self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
# self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
# self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
# self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
# self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
# self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
# self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
# self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
# self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
# self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
# self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
# self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
# self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
self.add_api_route("/v1/chat/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/v1/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/chat/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/completions", llm_chat_api, methods=["POST"])
self.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["POST"]
)
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -234,27 +189,6 @@ class ApiCompat:
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
# def refresh_checkpoints(self):
# with self.queue_lock:
# studio_data.refresh_checkpoints()
# def refresh_vae(self):
# with self.queue_lock:
# studio_data.refresh_vae_list()
# def unloadapi(self):
# unload_model_weights()
# return {}
# def reloadapi(self):
# reload_model_weights()
# return {}
# def skip(self):
# studio.state.skip()
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(

View File

@@ -1 +1,115 @@
import base64
from fastapi import FastAPI
from io import BytesIO
from PIL import Image
from pydantic import BaseModel, Field
from fastapi.exceptions import HTTPException
from apps.shark_studio.api.sd import shark_sd_fn
sdapi = FastAPI()
class GenerationInputData(BaseModel):
prompt: list = [""]
negative_prompt: list = [""]
hf_model_id: str | None = None
height: int = Field(default=512, ge=128, le=1024, multiple_of=8)
width: int = Field(default=512, ge=128, le=1024, multiple_of=8)
sampler_name: str = "EulerDiscrete"
cfg_scale: float = Field(default=7.5, ge=1)
steps: int = Field(default=20, ge=1, le=100)
seed: int = Field(default=-1)
n_iter: int = Field(default=1)
config: dict = None
class GenerationResponseData(BaseModel):
images: list[str] = Field(description="Generated images, Base64 encoded")
properties: dict = {}
info: str
def encode_pil_to_base64(images: list[Image.Image]):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
image.save(output_bytes, format="PNG")
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
def decode_base64_to_image(encoding: str):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=400, detail="Invalid encoded image")
@sdapi.post(
"/v1/txt2img",
summary="Does text to image generation",
response_model=GenerationResponseData,
)
def txt2img_api(InputData: GenerationInputData):
model_id = (
InputData.hf_model_id or "stabilityai/stable-diffusion-3-medium-diffusers"
)
scheduler = "FlowEulerDiscrete"
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed},"
f"Model: {model_id}, "
f"Scheduler: {scheduler}. "
)
if not getattr(InputData, "config"):
InputData.config = {
"precision": "fp16",
"device": "rocm",
"target_triple": "gfx1150",
}
res = shark_sd_fn(
InputData.prompt,
InputData.negative_prompt,
None,
InputData.height,
InputData.width,
InputData.steps,
None,
InputData.cfg_scale,
InputData.seed,
custom_vae=None,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
base_model_id=model_id,
custom_weights=None,
precision=InputData.config["precision"],
device=InputData.config["device"],
target_triple=InputData.config["target_triple"],
output_type="pil",
ondemand=False,
compiled_pipeline=False,
resample_type=None,
controlnets=[],
embeddings=[],
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}

View File

@@ -32,13 +32,15 @@ def create_api(app):
def api_only():
from fastapi import FastAPI
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.api.sd import sdapi
from fastapi import FastAPI
initialize.initialize()
app = FastAPI()
initialize.setup_middleware(app)
app.mount("/sdapi/", sdapi)
api = create_api(app)
# from modules import script_callbacks
@@ -56,6 +58,7 @@ def api_only():
def launch_webui(address):
from tkinter import Tk
import webview
import gradio as gr
window = Tk()
@@ -83,7 +86,7 @@ def webui():
launch_api = cmd_opts.api
initialize.initialize()
#from ui.chat import chat_element
# from ui.chat import chat_element
from ui.sd import sd_element
from ui.outputgallery import outputgallery_element
@@ -216,7 +219,8 @@ def webui():
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if cmd_opts.webui == False:
api_only()
else:
webui()
api_only()
# if cmd_opts.webui == False:
# api_only()
# else:
# webui()

View File

@@ -51,7 +51,7 @@ sd_default_models = [
# "stabilityai/stable-diffusion-2-1-base",
# "stabilityai/stable-diffusion-2-1",
# "stabilityai/stable-diffusion-xl-base-1.0",
#"stabilityai/sdxl-turbo",
# "stabilityai/sdxl-turbo",
]
sd_default_models.extend(get_checkpoints(model_type="scripts"))
@@ -154,7 +154,9 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
elif os.path.exists(os.path.join(get_configs_path(), load_sd_config)):
config = os.path.join(get_configs_path(), load_sd_config)
else:
print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.")
print(
"Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config."
)
config = sd_json
new_sd_config = none_to_str_none(json.loads(view_json_file(config)))
if sd_json:
@@ -284,6 +286,7 @@ def base_model_changed(base_model_id):
new_steps,
]
init_config = global_obj.get_init_config()
init_config = none_to_str_none(json.loads(view_json_file(init_config)))
@@ -307,15 +310,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
show_copy_button=True,
)
with gr.Accordion(
label="\U0001F4D0\U0000FE0F Advanced Settings", open=False
label="\U0001F4D0\U0000FE0F Advanced Settings", open=False
):
with gr.Accordion(
label="Device Settings", open=False
):
with gr.Accordion(label="Device Settings", open=False):
device = gr.Dropdown(
elem_id="device",
label="Device",
value=init_config["device"] if init_config["device"] else "rocm",
value=(
init_config["device"]
if init_config["device"]
else "rocm"
),
choices=global_obj.get_device_list(),
allow_custom_value=True,
)
@@ -347,7 +352,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
value=512,
step=512,
label="\U00002195\U0000FE0F Height",
interactive=False, # DEMO
interactive=False, # DEMO
visible=False, # DEMO
)
width = gr.Slider(
@@ -356,10 +361,10 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
value=512,
step=512,
label="\U00002194\U0000FE0F Width",
interactive=False, # DEMO
interactive=False, # DEMO
visible=False, # DEMO
)
with gr.Accordion(
label="\U0001F9EA\U0000FE0F Input Image Processing",
open=False,
@@ -379,7 +384,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
allow_custom_value=True,
)
with gr.Row():
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
sd_model_info = (
f"Checkpoint Path: {str(get_checkpoints_path())}"
)
base_model_id = gr.Dropdown(
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
@@ -413,7 +420,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
)
guidance_scale = gr.Slider(
0,
5, #DEMO
5, # DEMO
value=4,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
@@ -444,9 +451,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
visible=False, # DEMO
)
with gr.Row(elem_classes=["fill"], visible=False):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
write_default_sd_configs(get_configs_path())
default_config_file = global_obj.get_init_config()
sd_json = gr.JSON(
@@ -463,9 +468,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
visible=False,
)
with gr.Row():
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
save_sd_config = gr.Button(value="Save Config", size="sm")
clear_sd_config = gr.ClearButton(
value="Clear Config",
size="sm",
@@ -514,7 +517,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value=init_config["embeddings"][0] if (len(init_config["embeddings"].keys()) > 1) else "None",
value=(
init_config["embeddings"][0]
if (len(init_config["embeddings"].keys()) > 1)
else "None"
),
multiselect=True,
choices=[] + get_checkpoints("lora"),
scale=2,

View File

@@ -100,6 +100,7 @@ def get_checkpoints(model_type="checkpoints"):
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_configs():
return sorted(
[

View File

@@ -3,17 +3,21 @@ from ...api.utils import get_available_devices
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import os
from apps.shark_studio.web.utils.file_utils import get_configs_path
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def _init():
global _sd_obj
global _llm_obj
@@ -95,6 +99,7 @@ def get_device_list():
global _devices
return _devices
def get_init_config():
global _init_config
if os.path.exists(cmd_opts.defaults):
@@ -102,10 +107,13 @@ def get_init_config():
elif os.path.exists(os.path.join(get_configs_path(), cmd_opts.defaults)):
_init_config = os.path.join(get_configs_path(), cmd_opts.defaults)
else:
print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.")
print(
"Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config."
)
_init_config = os.path.join(get_configs_path(), "sdxl-turbo.json")
return _init_config
def get_sd_status():
global _sd_obj
return _sd_obj.status

View File

@@ -1,3 +1,5 @@
-r https://raw.githubusercontent.com/llvm/torch-mlir/main/requirements.txt
-r https://raw.githubusercontent.com/llvm/torch-mlir/main/torchvision-requirements.txt
-f https://download.pytorch.org/whl/nightly/cpu
-f https://iree.dev/pip-release-links.html
--pre
@@ -5,40 +7,19 @@
setuptools
wheel
torch==2.3.0
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@merge_punet_sdxl#subdirectory=models
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
# SHARK Runner
tqdm
# SHARK Downloader
google-cloud-storage
# Testing
pytest
Pillow
parameterized
# Add transformers, diffusers and scipy since it most commonly used
#accelerate is now required for diffusers import from ckpt.
accelerate
scipy
transformers==4.37.1
torchsde # Required for Stable Diffusion SDE schedulers.
transformers==4.43.3
ftfy
gradio==4.29.0
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64
safetensors==0.3.1
safetensors
py-cpuinfo
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
mpmath==1.3.0
optimum
# Testing
pytest
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -0,0 +1,77 @@
import requests
from pydantic import BaseModel, Field
import json
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
# Define the URL of the REST API endpoint
api_url = "http://127.0.0.1:8080/sdapi/v1/txt2img/" # Replace with your actual API URL
class GenerationInputData(BaseModel):
prompt: list = [""]
negative_prompt: list = [""]
hf_model_id: str | None = None
height: int = Field(default=512, ge=128, le=1024, multiple_of=8)
width: int = Field(default=512, ge=128, le=1024, multiple_of=8)
sampler_name: str = "EulerDiscrete"
cfg_scale: float = Field(default=7.5, ge=1)
steps: int = Field(default=20, ge=1, le=100)
seed: int = Field(default=-1)
n_iter: int = Field(default=1)
config: dict = None
# Create an instance of GenerationInputData with example arguments
data = GenerationInputData(
prompt=[
"A phoenix made of diamond, black background, dream sequence, rising from coals"
],
negative_prompt=[
"cropped, cartoon, lowres, low quality, black and white, bad scan, pixelated"
],
hf_model_id="shark_sd3.py",
height=512,
width=512,
sampler_name="EulerDiscrete",
cfg_scale=7.5,
steps=20,
seed=-1,
n_iter=1,
config=json.loads(view_json_file("../configs/sd3_phoenix_npu.json")),
)
# Convert the data to a dictionary
data_dict = data.dict()
# Optional: Define headers if needed (e.g., for authentication)
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
def test_post_request(url, data, headers=None):
try:
# Send a POST request to the API endpoint
response = requests.post(url, json=data, headers=headers)
# Print the status code and response content
print(f"Status Code: {response.status_code}")
print("Response Content:")
# print(response.json()) # Print the JSON response
except requests.RequestException as e:
# Handle any exceptions that occur during the request
print(f"An error occurred: {e}")
# Run the test
test_post_request(api_url, data_dict, headers)

View File

@@ -87,9 +87,8 @@ if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip
pip install wheel
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240619.291/iree_compiler-20240619.291-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240619.291/iree_runtime-20240619.291-cp311-cp311-win_amd64.whl
pip install --pre -r requirements.txt
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl
pip install -e .
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

1
webui_requirements.txt Normal file
View File

@@ -0,0 +1 @@
gradio==4.29.0