mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 14:27:58 -05:00
Migration to AMDShark (#2182)
Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
This commit is contained in:
committed by
GitHub
parent
dba2c8a567
commit
fe03539901
@@ -1,5 +1,5 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from apps.shark_studio.studio_imports import pathex, datas, hiddenimports
|
||||
from apps.amdshark_studio.studio_imports import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
@@ -32,7 +32,7 @@ exe = EXE(
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='nodai_shark_studio',
|
||||
name='nodai_amdshark_studio',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
@@ -2,7 +2,7 @@
|
||||
import os
|
||||
import PIL
|
||||
import numpy as np
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
)
|
||||
from datetime import datetime
|
||||
@@ -6,13 +6,13 @@ import warnings
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from apps.shark_studio.modules.timer import startup_timer
|
||||
from apps.amdshark_studio.modules.timer import startup_timer
|
||||
|
||||
from apps.shark_studio.web.utils.tmp_configs import (
|
||||
from apps.amdshark_studio.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
clear_tmp_mlir,
|
||||
clear_tmp_imgs,
|
||||
shark_tmp,
|
||||
amdshark_tmp,
|
||||
)
|
||||
|
||||
|
||||
@@ -30,12 +30,12 @@ def imports():
|
||||
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj._init()
|
||||
startup_timer.record("initialize globals")
|
||||
|
||||
from apps.shark_studio.modules import (
|
||||
from apps.amdshark_studio.modules import (
|
||||
img_processing,
|
||||
) # noqa: F401
|
||||
|
||||
@@ -44,7 +44,7 @@ def imports():
|
||||
|
||||
def initialize():
|
||||
configure_sigint_handler()
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# Setup to use amdshark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
|
||||
@@ -52,7 +52,7 @@ def initialize():
|
||||
# clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
create_model_folders,
|
||||
)
|
||||
|
||||
@@ -83,7 +83,7 @@ def dumpstacks():
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
|
||||
with open(os.path.join(amdshark_tmp, "stack_dump.log"), "w") as f:
|
||||
f.write("\n".join(code))
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ def setup_middleware(app):
|
||||
|
||||
def configure_cors_middleware(app):
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
cors_options = {
|
||||
"allow_methods": ["*"],
|
||||
@@ -2,13 +2,13 @@ from turbine_models.custom_models import stateless_llama
|
||||
from turbine_models.model_runner import vmfbRunner
|
||||
from turbine_models.gen_external_params.gen_external_params import gen_external_params
|
||||
import time
|
||||
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from amdshark.iree_utils.compile_utils import compile_module_to_flatbuffer
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_resource_path,
|
||||
get_checkpoints_path,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.api.utils import parse_device
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.api.utils import parse_device
|
||||
from urllib.request import urlopen
|
||||
import iree.runtime as ireert
|
||||
from itertools import chain
|
||||
@@ -366,7 +366,7 @@ def get_mfma_spec_path(target_chip, save_dir):
|
||||
def llm_chat_api(InputData: dict):
|
||||
from datetime import datetime as dt
|
||||
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
print(f"Input keys : {InputData.keys()}")
|
||||
|
||||
@@ -12,26 +12,26 @@ from tqdm.auto import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import AMDSharkSDPipeline
|
||||
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
|
||||
SharkSDXLPipeline,
|
||||
AMDSharkSDXLPipeline,
|
||||
)
|
||||
|
||||
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
from apps.shark_studio.api.utils import parse_device
|
||||
from apps.shark_studio.web.utils.state import status_label
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.api.controlnet import control_adapter_map
|
||||
from apps.amdshark_studio.api.utils import parse_device
|
||||
from apps.amdshark_studio.web.utils.state import status_label
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
safe_name,
|
||||
get_resource_path,
|
||||
get_checkpoints_path,
|
||||
)
|
||||
|
||||
from apps.shark_studio.modules.img_processing import (
|
||||
from apps.amdshark_studio.modules.img_processing import (
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.shark_studio.modules.ckpt_processing import (
|
||||
from apps.amdshark_studio.modules.ckpt_processing import (
|
||||
preprocessCKPT,
|
||||
save_irpa,
|
||||
)
|
||||
@@ -114,10 +114,10 @@ class StableDiffusion:
|
||||
self.turbine_pipe = custom_module.StudioPipeline
|
||||
self.model_map = custom_module.MODEL_MAP
|
||||
elif self.is_sdxl:
|
||||
self.turbine_pipe = SharkSDXLPipeline
|
||||
self.turbine_pipe = AMDSharkSDXLPipeline
|
||||
self.model_map = EMPTY_SDXL_MAP
|
||||
else:
|
||||
self.turbine_pipe = SharkSDPipeline
|
||||
self.turbine_pipe = AMDSharkSDPipeline
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
@@ -273,7 +273,7 @@ class StableDiffusion:
|
||||
return img
|
||||
|
||||
|
||||
def shark_sd_fn_dict_input(
|
||||
def amdshark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
print("\n[LOG] Submitting Request...")
|
||||
@@ -312,11 +312,11 @@ def shark_sd_fn_dict_input(
|
||||
)
|
||||
return None, ""
|
||||
|
||||
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
||||
generated_imgs = yield from amdshark_sd_fn(**sd_kwargs)
|
||||
return generated_imgs
|
||||
|
||||
|
||||
def shark_sd_fn(
|
||||
def amdshark_sd_fn(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image: list,
|
||||
@@ -346,8 +346,8 @@ def shark_sd_fn(
|
||||
sd_init_image = [sd_init_image]
|
||||
is_img2img = True if sd_init_image[0] is not None else False
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
adapters = {}
|
||||
is_controlled = False
|
||||
@@ -466,7 +466,7 @@ def shark_sd_fn(
|
||||
|
||||
def unload_sd():
|
||||
print("Unloading models.")
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
@@ -489,8 +489,8 @@ def safe_name(name):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
global_obj._init()
|
||||
|
||||
@@ -501,5 +501,5 @@ if __name__ == "__main__":
|
||||
for arg in vars(cmd_opts):
|
||||
if arg in sd_kwargs:
|
||||
sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
for i in shark_sd_fn_dict_input(sd_kwargs):
|
||||
for i in amdshark_sd_fn_dict_input(sd_kwargs):
|
||||
print(i)
|
||||
@@ -8,11 +8,11 @@ from random import (
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
# TODO: migrate these utils to studio
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
@@ -21,7 +21,7 @@ from shark.iree_utils.vulkan_utils import (
|
||||
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
@@ -59,7 +59,7 @@ def get_available_devices():
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
@@ -116,7 +116,7 @@ def set_init_device_flags():
|
||||
elif "metal" in cmd_opts.device:
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_metal_target_platform:
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from amdshark.iree_utils.metal_utils import get_metal_target_triple
|
||||
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
@@ -146,7 +146,7 @@ def set_iree_runtime_flags():
|
||||
|
||||
|
||||
def parse_device(device_str, target_override=""):
|
||||
from shark.iree_utils.compile_utils import (
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
iree_target_map,
|
||||
@@ -192,7 +192,7 @@ def get_rocm_target_chip(device_str):
|
||||
if key in device_str:
|
||||
return rocm_chip_map[key]
|
||||
raise AssertionError(
|
||||
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK-Studio/issues."
|
||||
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/AMDSHARK-Studio/issues."
|
||||
)
|
||||
|
||||
|
||||
@@ -225,7 +225,7 @@ def get_device_mapping(driver, key_combination=3):
|
||||
dict: map to possible device names user can input mapped to desired
|
||||
combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
@@ -256,7 +256,7 @@ def get_opt_flags(model, precision="fp16"):
|
||||
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
)
|
||||
if "rocm" in cmd_opts.device:
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
from amdshark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
@@ -301,7 +301,7 @@ def map_device_to_name_path(device, key_combination=3):
|
||||
return device_mapping
|
||||
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
from amdshark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
@@ -332,7 +332,7 @@ def map_device_to_name_path(device, key_combination=3):
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
from amdshark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
create_vae_diffusers_config,
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import safetensors
|
||||
from dataclasses import dataclass
|
||||
from safetensors.torch import load_file
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_checkpoint_pathfile,
|
||||
get_path_stem,
|
||||
)
|
||||
@@ -25,11 +25,11 @@ resampler_list = resamplers.keys()
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
@@ -30,8 +30,8 @@ def logger_test(x):
|
||||
|
||||
def read_sd_logs():
|
||||
sys.stdout.flush()
|
||||
with open("shark_tmp/sd.log", "r") as f:
|
||||
with open("amdshark_tmp/sd.log", "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
sys.stdout = Logger("shark_tmp/sd.log", filter="[LOG]")
|
||||
sys.stdout = Logger("amdshark_tmp/sd.log", filter="[LOG]")
|
||||
@@ -1,14 +1,14 @@
|
||||
from shark.iree_utils.compile_utils import (
|
||||
from amdshark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
load_vmfb_using_mmap,
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
)
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_checkpoints_path,
|
||||
get_resource_path,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import (
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import (
|
||||
cmd_opts,
|
||||
)
|
||||
from iree import runtime as ireert
|
||||
@@ -17,7 +17,7 @@ import gc
|
||||
import os
|
||||
|
||||
|
||||
class SharkPipelineBase:
|
||||
class AMDSharkPipelineBase:
|
||||
# This class is a lightweight base for managing an
|
||||
# inference API class. It should provide methods for:
|
||||
# - compiling a set (model map) of torch IR modules
|
||||
@@ -224,7 +224,7 @@ def get_unweighted_text_embeddings(
|
||||
text_embedding = text_embedding[:, 1:-1]
|
||||
|
||||
text_embeddings.append(text_embedding)
|
||||
# SHARK: Convert the result to tensor
|
||||
# AMDSHARK: Convert the result to tensor
|
||||
# text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
text_embeddings_np = np.concatenate(np.array(text_embeddings))
|
||||
text_embeddings = torch.from_numpy(text_embeddings_np)
|
||||
@@ -1,4 +1,4 @@
|
||||
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
|
||||
# from amdshark_turbine.turbine_models.schedulers import export_scheduler_model
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
@@ -2,7 +2,7 @@ import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apps.shark_studio.modules.img_processing import resampler_list
|
||||
from apps.amdshark_studio.modules.img_processing import resampler_list
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
@@ -299,8 +299,8 @@ p.add_argument(
|
||||
"--import_mlir",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Imports the model from torch module to shark_module otherwise "
|
||||
"downloads the model from shark_tank.",
|
||||
help="Imports the model from torch module to amdshark_module otherwise "
|
||||
"downloads the model from amdshark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -487,8 +487,8 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. "
|
||||
"If this is not set, the default is ~/.local/shark_tank/.",
|
||||
help="Specify where to save downloaded amdshark_tank artifacts. "
|
||||
"If this is not set, the default is ~/.local/amdshark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -562,7 +562,7 @@ p.add_argument(
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If import_mlir is True, saves mlir via the debug option "
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
"in amdshark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -615,7 +615,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--tmp_dir",
|
||||
type=str,
|
||||
default=os.path.join(os.getcwd(), "shark_tmp"),
|
||||
default=os.path.join(os.getcwd(), "amdshark_tmp"),
|
||||
help="Path to tmp directory",
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree", include_py_files=True)
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("amdshark", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
@@ -54,7 +54,7 @@ datas += [
|
||||
|
||||
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "apps"]
|
||||
hiddenimports = ["amdshark", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
|
||||
blacklist = ["tests", "convert"]
|
||||
@@ -8,14 +8,14 @@ import logging
|
||||
import unittest
|
||||
import json
|
||||
import gc
|
||||
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
|
||||
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
|
||||
from apps.shark_studio.web.utils.file_utils import get_resource_path
|
||||
from apps.amdshark_studio.api.llm import LanguageModel, llm_chat_api
|
||||
from apps.amdshark_studio.api.sd import amdshark_sd_fn_dict_input, view_json_file
|
||||
from apps.amdshark_studio.web.utils.file_utils import get_resource_path
|
||||
|
||||
# class SDAPITest(unittest.TestCase):
|
||||
# def testSDSimple(self):
|
||||
# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
# import apps.shark_studio.web.utils.globals as global_obj
|
||||
# from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
# import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
# global_obj._init()
|
||||
|
||||
@@ -24,7 +24,7 @@ from apps.shark_studio.web.utils.file_utils import get_resource_path
|
||||
# for arg in vars(cmd_opts):
|
||||
# if arg in sd_kwargs:
|
||||
# sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
# for i in shark_sd_fn_dict_input(sd_kwargs):
|
||||
# for i in amdshark_sd_fn_dict_input(sd_kwargs):
|
||||
# print(i)
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 347 KiB After Width: | Height: | Size: 347 KiB |
@@ -38,8 +38,8 @@ def llm_chat_test(verbose=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# "Exercises the chatbot REST API of Shark. Make sure "
|
||||
# "Shark is running in API mode on 127.0.0.1:8080 before running"
|
||||
# "Exercises the chatbot REST API of AMDShark. Make sure "
|
||||
# "AMDShark is running in API mode on 127.0.0.1:8080 before running"
|
||||
# "this script."
|
||||
|
||||
llm_chat_test(verbose=True)
|
||||
@@ -18,10 +18,10 @@ from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
# from sdapi_v1 import shark_sd_api
|
||||
from apps.shark_studio.api.llm import llm_chat_api
|
||||
# from sdapi_v1 import amdshark_sd_api
|
||||
from apps.amdshark_studio.api.llm import llm_chat_api
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -183,8 +183,8 @@ class ApiCompat:
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/txt2img", amdshark_sd_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/img2img", amdshark_sd_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
|
||||
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||
@@ -7,10 +7,10 @@ import os
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
import apps.shark_studio.api.initializers as initialize
|
||||
import apps.amdshark_studio.api.initializers as initialize
|
||||
|
||||
|
||||
from apps.shark_studio.modules import timer
|
||||
from apps.amdshark_studio.modules import timer
|
||||
|
||||
startup_timer = timer.startup_timer
|
||||
startup_timer.record("launcher")
|
||||
@@ -24,7 +24,7 @@ if sys.platform == "darwin":
|
||||
|
||||
|
||||
def create_api(app):
|
||||
from apps.shark_studio.web.api.compat import ApiCompat, FIFOLock
|
||||
from apps.amdshark_studio.web.api.compat import ApiCompat, FIFOLock
|
||||
|
||||
queue_lock = FIFOLock()
|
||||
api = ApiCompat(app, queue_lock)
|
||||
@@ -33,7 +33,7 @@ def create_api(app):
|
||||
|
||||
def api_only():
|
||||
from fastapi import FastAPI
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
@@ -64,7 +64,7 @@ def launch_webui(address):
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
"AMDSHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
@@ -74,8 +74,8 @@ def launch_webui(address):
|
||||
|
||||
|
||||
def webui():
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.web.ui.utils import (
|
||||
amdicon_loc,
|
||||
amdlogo_loc,
|
||||
)
|
||||
@@ -91,10 +91,10 @@ def webui():
|
||||
freeze_support()
|
||||
|
||||
# if args.api or "api" in args.ui.split(","):
|
||||
# from apps.shark_studio.api.llm import (
|
||||
# from apps.amdshark_studio.api.llm import (
|
||||
# chat,
|
||||
# )
|
||||
# from apps.shark_studio.web.api import sdapi
|
||||
# from apps.amdshark_studio.web.api import sdapi
|
||||
#
|
||||
# from fastapi import FastAPI, APIRouter
|
||||
# from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -144,7 +144,7 @@ def webui():
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
|
||||
|
||||
# from apps.shark_studio.web.ui import load_ui_from_script
|
||||
# from apps.amdshark_studio.web.ui import load_ui_from_script
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
@@ -170,7 +170,7 @@ def webui():
|
||||
css=dark_theme,
|
||||
js=gradio_workarounds,
|
||||
analytics_enabled=False,
|
||||
title="Shark Studio 2.0 Beta",
|
||||
title="AMDShark Studio 2.0 Beta",
|
||||
) as studio_web:
|
||||
amd_logo = Image.open(amdlogo_loc)
|
||||
gr.Image(
|
||||
@@ -214,7 +214,7 @@ def webui():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
if cmd_opts.webui == False:
|
||||
api_only()
|
||||
@@ -5,12 +5,12 @@ from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
from apps.shark_studio.api.llm import (
|
||||
from apps.amdshark_studio.api.llm import (
|
||||
llm_model_map,
|
||||
LanguageModel,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
@@ -129,7 +129,7 @@ with gr.Blocks(title="Chat") as chat_element:
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
label="Download vmfb from AMDShark tank if available",
|
||||
value=False,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
@@ -1,8 +1,8 @@
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
from apps.amdshark_studio.web.ui.utils import (
|
||||
HSLHue,
|
||||
hsl_color,
|
||||
)
|
||||
from apps.shark_studio.modules.embeddings import get_lora_metadata
|
||||
from apps.amdshark_studio.modules.embeddings import get_lora_metadata
|
||||
|
||||
|
||||
# Answers HTML to show the most frequent tags used when a LoRA was trained,
|
||||
@@ -100,7 +100,7 @@ Procedure to upgrade the dark theme:
|
||||
--input-border-width: 1px;
|
||||
}
|
||||
|
||||
/* SHARK theme */
|
||||
/* AMDSHARK theme */
|
||||
body {
|
||||
background-color: var(--background-fill-primary);
|
||||
}
|
||||
|
Before Width: | Height: | Size: 7.1 KiB After Width: | Height: | Size: 7.1 KiB |
|
Before Width: | Height: | Size: 7.4 KiB After Width: | Height: | Size: 7.4 KiB |
@@ -5,13 +5,13 @@ import subprocess
|
||||
import sys
|
||||
from PIL import Image
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.shark_studio.web.ui.utils import amdlogo_loc
|
||||
from apps.shark_studio.web.utils.metadata import displayable_metadata
|
||||
from apps.amdshark_studio.web.ui.utils import amdlogo_loc
|
||||
from apps.amdshark_studio.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
|
||||
@@ -9,40 +9,40 @@ from datetime import datetime as dt
|
||||
from gradio.components.image_editor import (
|
||||
EditorValue,
|
||||
)
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_generated_imgs_path,
|
||||
get_checkpoints_path,
|
||||
get_checkpoints,
|
||||
get_configs_path,
|
||||
write_default_sd_configs,
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
shark_sd_fn_dict_input,
|
||||
from apps.amdshark_studio.api.sd import (
|
||||
amdshark_sd_fn_dict_input,
|
||||
cancel_sd,
|
||||
unload_sd,
|
||||
)
|
||||
from apps.shark_studio.api.controlnet import (
|
||||
from apps.amdshark_studio.api.controlnet import (
|
||||
cnet_preview,
|
||||
)
|
||||
from apps.shark_studio.modules.schedulers import (
|
||||
from apps.amdshark_studio.modules.schedulers import (
|
||||
scheduler_model_map,
|
||||
)
|
||||
from apps.shark_studio.modules.img_processing import (
|
||||
from apps.amdshark_studio.modules.img_processing import (
|
||||
resampler_list,
|
||||
resize_stencil,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.web.ui.utils import (
|
||||
amdlogo_loc,
|
||||
none_to_str_none,
|
||||
str_none_to_none,
|
||||
)
|
||||
from apps.shark_studio.web.utils.state import (
|
||||
from apps.amdshark_studio.web.utils.state import (
|
||||
status_label,
|
||||
)
|
||||
from apps.shark_studio.web.ui.common_events import lora_changed
|
||||
from apps.shark_studio.modules import logger
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
from apps.amdshark_studio.web.ui.common_events import lora_changed
|
||||
from apps.amdshark_studio.modules import logger
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
|
||||
sd_default_models = [
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
@@ -758,7 +758,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
)
|
||||
|
||||
gen_kwargs = dict(
|
||||
fn=shark_sd_fn_dict_input,
|
||||
fn=amdshark_sd_fn_dict_input,
|
||||
inputs=[sd_json],
|
||||
outputs=[
|
||||
sd_gallery,
|
||||
@@ -4,14 +4,14 @@ import glob
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
checkpoints_filetypes = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
)
|
||||
|
||||
from apps.shark_studio.web.utils.default_configs import default_sd_configs
|
||||
from apps.amdshark_studio.web.utils.default_configs import default_sd_configs
|
||||
|
||||
|
||||
def write_default_sd_configs(path):
|
||||
@@ -1,4 +1,4 @@
|
||||
# As SHARK has evolved more columns have been added to images_details.csv. However, since
|
||||
# As AMDSHARK has evolved more columns have been added to images_details.csv. However, since
|
||||
# no version of the CSV has any headers (yet) we don't actually have anything within the
|
||||
# file that tells us which parameter each column is for. So this is a list of known patterns
|
||||
# indexed by length which is what we're going to have to use to guess which columns are the
|
||||
@@ -1,11 +1,11 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
from apps.amdshark_studio.web.utils.file_utils import (
|
||||
get_checkpoint_pathfile,
|
||||
)
|
||||
from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map
|
||||
from apps.amdshark_studio.api.sd import EMPTY_SD_MAP as sd_model_map
|
||||
|
||||
from apps.shark_studio.modules.schedulers import (
|
||||
from apps.amdshark_studio.modules.schedulers import (
|
||||
scheduler_model_map,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
import apps.amdshark_studio.web.utils.globals as global_obj
|
||||
import gc
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
import shutil
|
||||
from time import time
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
|
||||
amdshark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "amdshark_tmp/")
|
||||
|
||||
|
||||
def clear_tmp_mlir():
|
||||
@@ -12,20 +12,20 @@ def clear_tmp_mlir():
|
||||
print("Clearing .mlir temporary files from a prior run. This may take some time...")
|
||||
mlir_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
for filename in os.listdir(amdshark_tmp)
|
||||
if os.path.isfile(os.path.join(amdshark_tmp, filename))
|
||||
and filename.endswith(".mlir")
|
||||
]
|
||||
for filename in mlir_files:
|
||||
os.remove(os.path.join(shark_tmp, filename))
|
||||
os.remove(os.path.join(amdshark_tmp, filename))
|
||||
print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.")
|
||||
|
||||
|
||||
def clear_tmp_imgs():
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# tell gradio to use a directory under amdshark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
|
||||
os.environ["GRADIO_TEMP_DIR"] = os.path.join(amdshark_tmp, "gradio")
|
||||
|
||||
print(
|
||||
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
|
||||
@@ -43,22 +43,22 @@ def clear_tmp_imgs():
|
||||
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
# older SHARK versions had to workaround gradio bugs and stored things differently
|
||||
# older AMDSHARK versions had to workaround gradio bugs and stored things differently
|
||||
else:
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
for filename in os.listdir(amdshark_tmp)
|
||||
if os.path.isfile(os.path.join(amdshark_tmp, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
print(
|
||||
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
|
||||
"Clearing temporary image files of a prior run of a previous AMDSHARK version. This may take some time..."
|
||||
)
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
os.remove(amdshark_tmp + filename)
|
||||
print(
|
||||
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
@@ -67,9 +67,9 @@ def clear_tmp_imgs():
|
||||
|
||||
|
||||
def config_tmp():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
# create amdshark_tmp if it does not exist
|
||||
if not os.path.exists(amdshark_tmp):
|
||||
os.mkdir(amdshark_tmp)
|
||||
|
||||
clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
Reference in New Issue
Block a user