mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Rest API support and cleanup
This commit is contained in:
6
.github/workflows/nightly.yml
vendored
6
.github/workflows/nightly.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
draft: true
|
||||
prerelease: true
|
||||
|
||||
- name: Build Package
|
||||
- name: Build Package (api only)
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
@@ -54,10 +54,10 @@ jobs:
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip install -e .
|
||||
pip freeze -l
|
||||
pyinstaller .\apps\shark_studio\shark_studio.spec
|
||||
pyinstaller .\apps\shark_studio\shark_studio_apionly.spec
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
|
||||
@@ -34,9 +34,9 @@ def imports():
|
||||
action="ignore", category=UserWarning, module="huggingface-hub"
|
||||
)
|
||||
|
||||
import gradio # noqa: F401
|
||||
# import gradio # noqa: F401
|
||||
|
||||
startup_timer.record("import gradio")
|
||||
# startup_timer.record("import gradio")
|
||||
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
@@ -56,9 +56,8 @@ def initialize():
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
|
||||
config_tmp()
|
||||
# clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
# config_tmp()
|
||||
# clear_tmp_imgs()
|
||||
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
create_model_folders,
|
||||
@@ -67,8 +66,6 @@ def initialize():
|
||||
# Create custom models folders if they don't exist
|
||||
create_model_folders()
|
||||
|
||||
import gradio as gr
|
||||
|
||||
# initialize_rest(reload_script_modules=False)
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from pathlib import Path
|
||||
from random import randint
|
||||
|
||||
|
||||
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
from apps.shark_studio.api.utils import parse_device
|
||||
from apps.shark_studio.web.utils.state import status_label
|
||||
@@ -30,6 +29,7 @@ from apps.shark_studio.modules.img_processing import (
|
||||
|
||||
|
||||
from subprocess import check_output
|
||||
|
||||
EMPTY_SD_MAP = {
|
||||
"clip": None,
|
||||
"scheduler": None,
|
||||
@@ -114,11 +114,14 @@ class StableDiffusion:
|
||||
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
|
||||
SharkSDXLPipeline,
|
||||
)
|
||||
|
||||
self.turbine_pipe = SharkSDXLPipeline
|
||||
self.dynamic_steps = False
|
||||
self.model_map = EMPTY_SDXL_MAP
|
||||
else:
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import (
|
||||
SharkSDPipeline,
|
||||
)
|
||||
|
||||
self.turbine_pipe = SharkSDPipeline
|
||||
self.dynamic_steps = True
|
||||
@@ -209,6 +212,7 @@ class StableDiffusion:
|
||||
preprocessCKPT,
|
||||
save_irpa,
|
||||
)
|
||||
|
||||
custom_weights = os.path.join(
|
||||
get_checkpoints_path("checkpoints"),
|
||||
safe_name(self.base_model_id.split("/")[-1]),
|
||||
@@ -223,14 +227,20 @@ class StableDiffusion:
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(unet_weights_path, "unet.")
|
||||
|
||||
elif key in ["clip", "prompt_encoder"]:
|
||||
if not self.is_sdxl:
|
||||
if key in ["mmdit"]:
|
||||
mmdit_weights_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"mmdit",
|
||||
"diffusion_pytorch_model_fp16.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(mmdit_weights_path, "mmdit.")
|
||||
elif key in ["clip", "prompt_encoder", "text_encoder"]:
|
||||
if not self.is_sdxl and not self.is_custom:
|
||||
sd1_path = os.path.join(
|
||||
diffusers_weights_path, "text_encoder", "model.safetensors"
|
||||
)
|
||||
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
|
||||
else:
|
||||
elif self.is_sdxl:
|
||||
clip_1_path = os.path.join(
|
||||
diffusers_weights_path, "text_encoder", "model.safetensors"
|
||||
)
|
||||
@@ -243,7 +253,27 @@ class StableDiffusion:
|
||||
save_irpa(clip_1_path, "text_encoder_model_1."),
|
||||
save_irpa(clip_2_path, "text_encoder_model_2."),
|
||||
]
|
||||
|
||||
elif self.is_custom:
|
||||
clip_g_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"text_encoder",
|
||||
"model.fp16.safetensors",
|
||||
)
|
||||
clip_l_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"text_encoder_2",
|
||||
"model.fp16.safetensors",
|
||||
)
|
||||
t5xxl_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
"text_encoder_3",
|
||||
"model.fp16.safetensors",
|
||||
)
|
||||
weights[key] = [
|
||||
save_irpa(clip_g_path, "clip_g.transformer."),
|
||||
save_irpa(clip_l_path, "clip_l.transformer."),
|
||||
save_irpa(t5xxl_path, "t5xxl.transformer."),
|
||||
]
|
||||
elif key in ["vae_decode"] and weights[key] is None:
|
||||
vae_weights_path = os.path.join(
|
||||
diffusers_weights_path,
|
||||
@@ -251,6 +281,7 @@ class StableDiffusion:
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(vae_weights_path, "vae.")
|
||||
|
||||
progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...")
|
||||
|
||||
vmfbs, weights = self.sd_pipe.check_prepared(
|
||||
@@ -291,49 +322,6 @@ class StableDiffusion:
|
||||
return img
|
||||
|
||||
|
||||
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
||||
print("\n[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
if sd_kwargs[key] in [None, []]:
|
||||
sd_kwargs[key] = None
|
||||
if sd_kwargs[key] in ["None"]:
|
||||
sd_kwargs[key] = ""
|
||||
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
|
||||
if not sd_kwargs["device"]:
|
||||
gr.Warning("No device specified. Please specify a device.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] not in [512, 1024]:
|
||||
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] != sd_kwargs["width"]:
|
||||
gr.Warning("Height and width must be the same. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
||||
if sd_kwargs["steps"] > 10:
|
||||
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
|
||||
return None, ""
|
||||
if sd_kwargs["guidance_scale"] > 3:
|
||||
gr.Warning(
|
||||
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
|
||||
)
|
||||
return None, ""
|
||||
if sd_kwargs["target_triple"] == "":
|
||||
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
|
||||
gr.Warning(
|
||||
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
||||
)
|
||||
return None, ""
|
||||
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
return generated_imgs
|
||||
|
||||
|
||||
def shark_sd_fn(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -359,7 +347,8 @@ def shark_sd_fn(
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
seed_increment: str | int = 1,
|
||||
progress=gr.Progress(),
|
||||
output_type: str = "png",
|
||||
# progress=gr.Progress(),
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
if not isinstance(sd_init_image, list):
|
||||
@@ -464,8 +453,8 @@ def shark_sd_fn(
|
||||
if submit_run_kwargs["seed"] in [-1, "-1"]:
|
||||
submit_run_kwargs["seed"] = randint(0, 4294967295)
|
||||
seed_increment = "random"
|
||||
#print(f"\n[LOG] Random seed: {seed}")
|
||||
progress(None, desc=f"Generating...")
|
||||
# print(f"\n[LOG] Random seed: {seed}")
|
||||
# progress(None, desc=f"Generating...")
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
@@ -479,13 +468,14 @@ def shark_sd_fn(
|
||||
# break
|
||||
# else:
|
||||
for batch in range(batch_size):
|
||||
save_output_img(
|
||||
out_imgs[batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
if output_type == "png":
|
||||
save_output_img(
|
||||
out_imgs[batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
|
||||
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
@@ -495,13 +485,56 @@ def shark_sd_fn(
|
||||
return (generated_imgs, "")
|
||||
|
||||
|
||||
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
||||
print("\n[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
if sd_kwargs[key] in [None, []]:
|
||||
sd_kwargs[key] = None
|
||||
if sd_kwargs[key] in ["None"]:
|
||||
sd_kwargs[key] = ""
|
||||
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
|
||||
if not sd_kwargs["device"]:
|
||||
gr.Warning("No device specified. Please specify a device.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] not in [512, 1024]:
|
||||
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] != sd_kwargs["width"]:
|
||||
gr.Warning("Height and width must be the same. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
||||
if sd_kwargs["steps"] > 10:
|
||||
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
|
||||
return None, ""
|
||||
if sd_kwargs["guidance_scale"] > 3:
|
||||
gr.Warning(
|
||||
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
|
||||
)
|
||||
return None, ""
|
||||
if sd_kwargs["target_triple"] == "":
|
||||
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
|
||||
gr.Warning(
|
||||
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
||||
)
|
||||
return None, ""
|
||||
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
return generated_imgs
|
||||
|
||||
|
||||
def get_next_seed(seed, seed_increment: str | int = 10):
|
||||
if isinstance(seed_increment, int):
|
||||
#print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
|
||||
# print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
|
||||
return int(seed + seed_increment)
|
||||
elif seed_increment == "random":
|
||||
seed = randint(0, 4294967295)
|
||||
#print(f"\n[LOG] Random seed: {seed}")
|
||||
# print(f"\n[LOG] Random seed: {seed}")
|
||||
return seed
|
||||
|
||||
|
||||
|
||||
@@ -63,9 +63,9 @@ _IREE_TARGET_MAP = {
|
||||
}
|
||||
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ['rocm', 'cpu']
|
||||
return ["rocm", "cpu"]
|
||||
|
||||
def get_devices_by_name(driver_name):
|
||||
|
||||
device_list = []
|
||||
@@ -94,7 +94,7 @@ def get_available_devices():
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
#set_iree_runtime_flags()
|
||||
# set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
@@ -140,17 +140,14 @@ def get_available_devices():
|
||||
break
|
||||
return available_devices
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by Studio pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
device = raw_device if "=>" not in raw_device else raw_device.split("=>")[1].strip()
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
if len(device_id) <= 2:
|
||||
@@ -162,6 +159,7 @@ def clean_device_info(raw_device):
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
|
||||
def parse_device(device_str, target_override=""):
|
||||
|
||||
rt_driver, device_id = clean_device_info(device_str)
|
||||
@@ -287,4 +285,4 @@ def get_all_devices(driver_name):
|
||||
# # Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# # dims before dispatch formation right now.
|
||||
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
# return iree_flags
|
||||
# return iree_flags
|
||||
|
||||
@@ -597,7 +597,7 @@ p.add_argument(
|
||||
"--defaults",
|
||||
default="sdxl-turbo.json",
|
||||
type=str,
|
||||
help="Path to the default API request .json file. Works for CLI and webui."
|
||||
help="Path to the default API request .json file. Works for CLI and webui.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
|
||||
45
apps/shark_studio/shark_studio_apionly.spec
Normal file
45
apps/shark_studio/shark_studio_apionly.spec
Normal file
@@ -0,0 +1,45 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from apps.shark_studio.studio_imports_apionly import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=pathex,
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_sd3_server',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
@@ -22,30 +22,25 @@ datas += copy_metadata("packaging")
|
||||
datas += copy_metadata("filelock")
|
||||
datas += copy_metadata("numpy")
|
||||
datas += copy_metadata("importlib_metadata")
|
||||
datas += copy_metadata("omegaconf")
|
||||
datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += copy_metadata("scipy")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree", include_py_files=True)
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("shark-turbine", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("scipy", include_py_files=True)
|
||||
datas += [
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/js/*", "ui/js"),
|
||||
@@ -54,7 +49,7 @@ datas += [
|
||||
|
||||
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "apps"]
|
||||
hiddenimports = ["apps", "shark-turbine"]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
|
||||
blacklist = ["tests", "convert"]
|
||||
@@ -65,4 +60,3 @@ hiddenimports += [
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]
|
||||
|
||||
46
apps/shark_studio/studio_imports_apionly.py
Normal file
46
apps/shark_studio/studio_imports_apionly.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
|
||||
import sys
|
||||
|
||||
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
# python path for pyinstaller
|
||||
pathex = [
|
||||
".",
|
||||
]
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
datas += copy_metadata("packaging")
|
||||
datas += copy_metadata("filelock")
|
||||
datas += copy_metadata("numpy")
|
||||
datas += copy_metadata("importlib_metadata")
|
||||
datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("iree", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
|
||||
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["apps", "shark-turbine"]
|
||||
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
@@ -20,9 +20,6 @@ from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
# from sdapi_v1 import shark_sd_api
|
||||
from apps.shark_studio.api.llm import llm_chat_api
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("http://") or encoding.startswith("https://"):
|
||||
@@ -183,50 +180,8 @@ class ApiCompat:
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||
# self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||
# self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||
# self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
# self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
# self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||
# self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||
# self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
||||
# self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||
# self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||
# self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||
# self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||
# self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||
# self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
# self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
# self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
# self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
# self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||
# self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||
# self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||
# self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||
# self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
# self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
|
||||
# chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
self.add_api_route("/v1/chat/completions", llm_chat_api, methods=["POST"])
|
||||
self.add_api_route("/v1/completions", llm_chat_api, methods=["POST"])
|
||||
self.add_api_route("/chat/completions", llm_chat_api, methods=["POST"])
|
||||
self.add_api_route("/completions", llm_chat_api, methods=["POST"])
|
||||
self.add_api_route(
|
||||
"/v1/engines/codegen/completions", llm_chat_api, methods=["POST"]
|
||||
)
|
||||
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
|
||||
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
@@ -234,27 +189,6 @@ class ApiCompat:
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
# def refresh_checkpoints(self):
|
||||
# with self.queue_lock:
|
||||
# studio_data.refresh_checkpoints()
|
||||
|
||||
# def refresh_vae(self):
|
||||
# with self.queue_lock:
|
||||
# studio_data.refresh_vae_list()
|
||||
|
||||
# def unloadapi(self):
|
||||
# unload_model_weights()
|
||||
|
||||
# return {}
|
||||
|
||||
# def reloadapi(self):
|
||||
# reload_model_weights()
|
||||
|
||||
# return {}
|
||||
|
||||
# def skip(self):
|
||||
# studio.state.skip()
|
||||
|
||||
def launch(self, server_name, port, root_path):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
|
||||
@@ -1 +1,115 @@
|
||||
import base64
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from apps.shark_studio.api.sd import shark_sd_fn
|
||||
|
||||
sdapi = FastAPI()
|
||||
|
||||
|
||||
class GenerationInputData(BaseModel):
|
||||
prompt: list = [""]
|
||||
negative_prompt: list = [""]
|
||||
hf_model_id: str | None = None
|
||||
height: int = Field(default=512, ge=128, le=1024, multiple_of=8)
|
||||
width: int = Field(default=512, ge=128, le=1024, multiple_of=8)
|
||||
sampler_name: str = "EulerDiscrete"
|
||||
cfg_scale: float = Field(default=7.5, ge=1)
|
||||
steps: int = Field(default=20, ge=1, le=100)
|
||||
seed: int = Field(default=-1)
|
||||
n_iter: int = Field(default=1)
|
||||
config: dict = None
|
||||
|
||||
|
||||
class GenerationResponseData(BaseModel):
|
||||
images: list[str] = Field(description="Generated images, Base64 encoded")
|
||||
properties: dict = {}
|
||||
info: str
|
||||
|
||||
|
||||
def encode_pil_to_base64(images: list[Image.Image]):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
image.save(output_bytes, format="PNG")
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding: str):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=400, detail="Invalid encoded image")
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/txt2img",
|
||||
summary="Does text to image generation",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def txt2img_api(InputData: GenerationInputData):
|
||||
model_id = (
|
||||
InputData.hf_model_id or "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
)
|
||||
scheduler = "FlowEulerDiscrete"
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed},"
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}. "
|
||||
)
|
||||
if not getattr(InputData, "config"):
|
||||
InputData.config = {
|
||||
"precision": "fp16",
|
||||
"device": "rocm",
|
||||
"target_triple": "gfx1150",
|
||||
}
|
||||
|
||||
res = shark_sd_fn(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
None,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
None,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
custom_vae=None,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
base_model_id=model_id,
|
||||
custom_weights=None,
|
||||
precision=InputData.config["precision"],
|
||||
device=InputData.config["device"],
|
||||
target_triple=InputData.config["target_triple"],
|
||||
output_type="pil",
|
||||
ondemand=False,
|
||||
compiled_pipeline=False,
|
||||
resample_type=None,
|
||||
controlnets=[],
|
||||
embeddings=[],
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
@@ -32,13 +32,15 @@ def create_api(app):
|
||||
|
||||
|
||||
def api_only():
|
||||
from fastapi import FastAPI
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.api.sd import sdapi
|
||||
from fastapi import FastAPI
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
app = FastAPI()
|
||||
initialize.setup_middleware(app)
|
||||
app.mount("/sdapi/", sdapi)
|
||||
api = create_api(app)
|
||||
|
||||
# from modules import script_callbacks
|
||||
@@ -56,6 +58,7 @@ def api_only():
|
||||
def launch_webui(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
import gradio as gr
|
||||
|
||||
window = Tk()
|
||||
|
||||
@@ -83,7 +86,7 @@ def webui():
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
|
||||
#from ui.chat import chat_element
|
||||
# from ui.chat import chat_element
|
||||
from ui.sd import sd_element
|
||||
from ui.outputgallery import outputgallery_element
|
||||
|
||||
@@ -216,7 +219,8 @@ def webui():
|
||||
if __name__ == "__main__":
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
if cmd_opts.webui == False:
|
||||
api_only()
|
||||
else:
|
||||
webui()
|
||||
api_only()
|
||||
# if cmd_opts.webui == False:
|
||||
# api_only()
|
||||
# else:
|
||||
# webui()
|
||||
|
||||
@@ -51,7 +51,7 @@ sd_default_models = [
|
||||
# "stabilityai/stable-diffusion-2-1-base",
|
||||
# "stabilityai/stable-diffusion-2-1",
|
||||
# "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
#"stabilityai/sdxl-turbo",
|
||||
# "stabilityai/sdxl-turbo",
|
||||
]
|
||||
sd_default_models.extend(get_checkpoints(model_type="scripts"))
|
||||
|
||||
@@ -154,7 +154,9 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
elif os.path.exists(os.path.join(get_configs_path(), load_sd_config)):
|
||||
config = os.path.join(get_configs_path(), load_sd_config)
|
||||
else:
|
||||
print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.")
|
||||
print(
|
||||
"Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config."
|
||||
)
|
||||
config = sd_json
|
||||
new_sd_config = none_to_str_none(json.loads(view_json_file(config)))
|
||||
if sd_json:
|
||||
@@ -284,6 +286,7 @@ def base_model_changed(base_model_id):
|
||||
new_steps,
|
||||
]
|
||||
|
||||
|
||||
init_config = global_obj.get_init_config()
|
||||
init_config = none_to_str_none(json.loads(view_json_file(init_config)))
|
||||
|
||||
@@ -307,15 +310,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
show_copy_button=True,
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="\U0001F4D0\U0000FE0F Advanced Settings", open=False
|
||||
label="\U0001F4D0\U0000FE0F Advanced Settings", open=False
|
||||
):
|
||||
with gr.Accordion(
|
||||
label="Device Settings", open=False
|
||||
):
|
||||
with gr.Accordion(label="Device Settings", open=False):
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=init_config["device"] if init_config["device"] else "rocm",
|
||||
value=(
|
||||
init_config["device"]
|
||||
if init_config["device"]
|
||||
else "rocm"
|
||||
),
|
||||
choices=global_obj.get_device_list(),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
@@ -347,7 +352,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
value=512,
|
||||
step=512,
|
||||
label="\U00002195\U0000FE0F Height",
|
||||
interactive=False, # DEMO
|
||||
interactive=False, # DEMO
|
||||
visible=False, # DEMO
|
||||
)
|
||||
width = gr.Slider(
|
||||
@@ -356,10 +361,10 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
value=512,
|
||||
step=512,
|
||||
label="\U00002194\U0000FE0F Width",
|
||||
interactive=False, # DEMO
|
||||
interactive=False, # DEMO
|
||||
visible=False, # DEMO
|
||||
)
|
||||
|
||||
|
||||
with gr.Accordion(
|
||||
label="\U0001F9EA\U0000FE0F Input Image Processing",
|
||||
open=False,
|
||||
@@ -379,7 +384,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
|
||||
sd_model_info = (
|
||||
f"Checkpoint Path: {str(get_checkpoints_path())}"
|
||||
)
|
||||
base_model_id = gr.Dropdown(
|
||||
label="\U000026F0\U0000FE0F Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
@@ -413,7 +420,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
5, #DEMO
|
||||
5, # DEMO
|
||||
value=4,
|
||||
step=0.1,
|
||||
label="\U0001F5C3\U0000FE0F CFG Scale",
|
||||
@@ -444,9 +451,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
visible=False, # DEMO
|
||||
)
|
||||
with gr.Row(elem_classes=["fill"], visible=False):
|
||||
Path(get_configs_path()).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
|
||||
write_default_sd_configs(get_configs_path())
|
||||
default_config_file = global_obj.get_init_config()
|
||||
sd_json = gr.JSON(
|
||||
@@ -463,9 +468,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
save_sd_config = gr.Button(value="Save Config", size="sm")
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config",
|
||||
size="sm",
|
||||
@@ -514,7 +517,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=sd_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value=init_config["embeddings"][0] if (len(init_config["embeddings"].keys()) > 1) else "None",
|
||||
value=(
|
||||
init_config["embeddings"][0]
|
||||
if (len(init_config["embeddings"].keys()) > 1)
|
||||
else "None"
|
||||
),
|
||||
multiselect=True,
|
||||
choices=[] + get_checkpoints("lora"),
|
||||
scale=2,
|
||||
|
||||
@@ -100,6 +100,7 @@ def get_checkpoints(model_type="checkpoints"):
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_configs():
|
||||
return sorted(
|
||||
[
|
||||
|
||||
@@ -3,17 +3,21 @@ from ...api.utils import get_available_devices
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import os
|
||||
from apps.shark_studio.web.utils.file_utils import get_configs_path
|
||||
|
||||
"""
|
||||
The global objects include SD pipeline and config.
|
||||
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
|
||||
Also we could avoid memory leak when switching models by clearing the cache.
|
||||
"""
|
||||
|
||||
|
||||
def view_json_file(file_path):
|
||||
content = ""
|
||||
with open(file_path, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
def _init():
|
||||
global _sd_obj
|
||||
global _llm_obj
|
||||
@@ -95,6 +99,7 @@ def get_device_list():
|
||||
global _devices
|
||||
return _devices
|
||||
|
||||
|
||||
def get_init_config():
|
||||
global _init_config
|
||||
if os.path.exists(cmd_opts.defaults):
|
||||
@@ -102,10 +107,13 @@ def get_init_config():
|
||||
elif os.path.exists(os.path.join(get_configs_path(), cmd_opts.defaults)):
|
||||
_init_config = os.path.join(get_configs_path(), cmd_opts.defaults)
|
||||
else:
|
||||
print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.")
|
||||
print(
|
||||
"Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config."
|
||||
)
|
||||
_init_config = os.path.join(get_configs_path(), "sdxl-turbo.json")
|
||||
return _init_config
|
||||
|
||||
|
||||
def get_sd_status():
|
||||
global _sd_obj
|
||||
return _sd_obj.status
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
-r https://raw.githubusercontent.com/llvm/torch-mlir/main/requirements.txt
|
||||
-r https://raw.githubusercontent.com/llvm/torch-mlir/main/torchvision-requirements.txt
|
||||
-f https://download.pytorch.org/whl/nightly/cpu
|
||||
-f https://iree.dev/pip-release-links.html
|
||||
--pre
|
||||
@@ -5,40 +7,19 @@
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
|
||||
torch==2.3.0
|
||||
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@merge_punet_sdxl#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
|
||||
|
||||
# SHARK Runner
|
||||
tqdm
|
||||
|
||||
# SHARK Downloader
|
||||
google-cloud-storage
|
||||
|
||||
# Testing
|
||||
pytest
|
||||
Pillow
|
||||
parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
#accelerate is now required for diffusers import from ckpt.
|
||||
accelerate
|
||||
scipy
|
||||
transformers==4.37.1
|
||||
torchsde # Required for Stable Diffusion SDE schedulers.
|
||||
transformers==4.43.3
|
||||
ftfy
|
||||
gradio==4.29.0
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
safetensors==0.3.1
|
||||
safetensors
|
||||
py-cpuinfo
|
||||
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
|
||||
mpmath==1.3.0
|
||||
optimum
|
||||
|
||||
# Testing
|
||||
pytest
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
77
rest_api_tests/sd3api_test.py
Normal file
77
rest_api_tests/sd3api_test.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
|
||||
|
||||
def view_json_file(file_path):
|
||||
content = ""
|
||||
with open(file_path, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
# Define the URL of the REST API endpoint
|
||||
api_url = "http://127.0.0.1:8080/sdapi/v1/txt2img/" # Replace with your actual API URL
|
||||
|
||||
|
||||
class GenerationInputData(BaseModel):
|
||||
prompt: list = [""]
|
||||
negative_prompt: list = [""]
|
||||
hf_model_id: str | None = None
|
||||
height: int = Field(default=512, ge=128, le=1024, multiple_of=8)
|
||||
width: int = Field(default=512, ge=128, le=1024, multiple_of=8)
|
||||
sampler_name: str = "EulerDiscrete"
|
||||
cfg_scale: float = Field(default=7.5, ge=1)
|
||||
steps: int = Field(default=20, ge=1, le=100)
|
||||
seed: int = Field(default=-1)
|
||||
n_iter: int = Field(default=1)
|
||||
config: dict = None
|
||||
|
||||
|
||||
# Create an instance of GenerationInputData with example arguments
|
||||
data = GenerationInputData(
|
||||
prompt=[
|
||||
"A phoenix made of diamond, black background, dream sequence, rising from coals"
|
||||
],
|
||||
negative_prompt=[
|
||||
"cropped, cartoon, lowres, low quality, black and white, bad scan, pixelated"
|
||||
],
|
||||
hf_model_id="shark_sd3.py",
|
||||
height=512,
|
||||
width=512,
|
||||
sampler_name="EulerDiscrete",
|
||||
cfg_scale=7.5,
|
||||
steps=20,
|
||||
seed=-1,
|
||||
n_iter=1,
|
||||
config=json.loads(view_json_file("../configs/sd3_phoenix_npu.json")),
|
||||
)
|
||||
|
||||
# Convert the data to a dictionary
|
||||
data_dict = data.dict()
|
||||
|
||||
# Optional: Define headers if needed (e.g., for authentication)
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
|
||||
def test_post_request(url, data, headers=None):
|
||||
try:
|
||||
# Send a POST request to the API endpoint
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
|
||||
# Print the status code and response content
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print("Response Content:")
|
||||
# print(response.json()) # Print the JSON response
|
||||
|
||||
except requests.RequestException as e:
|
||||
# Handle any exceptions that occur during the request
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
|
||||
# Run the test
|
||||
test_post_request(api_url, data_dict, headers)
|
||||
@@ -87,9 +87,8 @@ if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
|
||||
else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240619.291/iree_compiler-20240619.291-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240619.291/iree_runtime-20240619.291-cp311-cp311-win_amd64.whl
|
||||
pip install --pre -r requirements.txt
|
||||
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl
|
||||
pip install -e .
|
||||
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
1
webui_requirements.txt
Normal file
1
webui_requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
gradio==4.29.0
|
||||
Reference in New Issue
Block a user