Migration to AMDShark (#2182)

Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
This commit is contained in:
pdhirajkumarprasad
2025-11-20 12:52:07 +05:30
committed by GitHub
parent dba2c8a567
commit fe03539901
232 changed files with 1719 additions and 1719 deletions

View File

@@ -1,5 +1,5 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.shark_studio.studio_imports import pathex, datas, hiddenimports
from apps.amdshark_studio.studio_imports import pathex, datas, hiddenimports
binaries = []
@@ -32,7 +32,7 @@ exe = EXE(
a.zipfiles,
a.datas,
[],
name='nodai_shark_studio',
name='nodai_amdshark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,

View File

@@ -2,7 +2,7 @@
import os
import PIL
import numpy as np
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from datetime import datetime

View File

@@ -6,13 +6,13 @@ import warnings
import json
from threading import Thread
from apps.shark_studio.modules.timer import startup_timer
from apps.amdshark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
from apps.amdshark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
shark_tmp,
amdshark_tmp,
)
@@ -30,12 +30,12 @@ def imports():
startup_timer.record("import gradio")
import apps.shark_studio.web.utils.globals as global_obj
import apps.amdshark_studio.web.utils.globals as global_obj
global_obj._init()
startup_timer.record("initialize globals")
from apps.shark_studio.modules import (
from apps.amdshark_studio.modules import (
img_processing,
) # noqa: F401
@@ -44,7 +44,7 @@ def imports():
def initialize():
configure_sigint_handler()
# Setup to use shark_tmp for gradio's temporary image files and clear any
# Setup to use amdshark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
@@ -52,7 +52,7 @@ def initialize():
# clear_tmp_mlir()
clear_tmp_imgs()
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.web.utils.file_utils import (
create_model_folders,
)
@@ -83,7 +83,7 @@ def dumpstacks():
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
with open(os.path.join(amdshark_tmp, "stack_dump.log"), "w") as f:
f.write("\n".join(code))
@@ -100,7 +100,7 @@ def setup_middleware(app):
def configure_cors_middleware(app):
from starlette.middleware.cors import CORSMiddleware
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
cors_options = {
"allow_methods": ["*"],

View File

@@ -2,13 +2,13 @@ from turbine_models.custom_models import stateless_llama
from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import (
from amdshark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.amdshark_studio.web.utils.file_utils import (
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
@@ -366,7 +366,7 @@ def get_mfma_spec_path(target_chip, save_dir):
def llm_chat_api(InputData: dict):
from datetime import datetime as dt
import apps.shark_studio.web.utils.globals as global_obj
import apps.amdshark_studio.web.utils.globals as global_obj
print(f"Input keys : {InputData.keys()}")

View File

@@ -12,26 +12,26 @@ from tqdm.auto import tqdm
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sd_inference.sd_pipeline import AMDSharkSDPipeline
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
AMDSharkSDXLPipeline,
)
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.api.controlnet import control_adapter_map
from apps.amdshark_studio.api.utils import parse_device
from apps.amdshark_studio.web.utils.state import status_label
from apps.amdshark_studio.web.utils.file_utils import (
safe_name,
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.img_processing import (
from apps.amdshark_studio.modules.img_processing import (
save_output_img,
)
from apps.shark_studio.modules.ckpt_processing import (
from apps.amdshark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)
@@ -114,10 +114,10 @@ class StableDiffusion:
self.turbine_pipe = custom_module.StudioPipeline
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
self.turbine_pipe = SharkSDXLPipeline
self.turbine_pipe = AMDSharkSDXLPipeline
self.model_map = EMPTY_SDXL_MAP
else:
self.turbine_pipe = SharkSDPipeline
self.turbine_pipe = AMDSharkSDPipeline
self.model_map = EMPTY_SD_MAP
max_length = 64
target_backend, self.rt_device, triple = parse_device(device, target_triple)
@@ -273,7 +273,7 @@ class StableDiffusion:
return img
def shark_sd_fn_dict_input(
def amdshark_sd_fn_dict_input(
sd_kwargs: dict,
):
print("\n[LOG] Submitting Request...")
@@ -312,11 +312,11 @@ def shark_sd_fn_dict_input(
)
return None, ""
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
generated_imgs = yield from amdshark_sd_fn(**sd_kwargs)
return generated_imgs
def shark_sd_fn(
def amdshark_sd_fn(
prompt,
negative_prompt,
sd_init_image: list,
@@ -346,8 +346,8 @@ def shark_sd_fn(
sd_init_image = [sd_init_image]
is_img2img = True if sd_init_image[0] is not None else False
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
import apps.amdshark_studio.web.utils.globals as global_obj
adapters = {}
is_controlled = False
@@ -466,7 +466,7 @@ def shark_sd_fn(
def unload_sd():
print("Unloading models.")
import apps.shark_studio.web.utils.globals as global_obj
import apps.amdshark_studio.web.utils.globals as global_obj
global_obj.clear_cache()
gc.collect()
@@ -489,8 +489,8 @@ def safe_name(name):
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
import apps.amdshark_studio.web.utils.globals as global_obj
global_obj._init()
@@ -501,5 +501,5 @@ if __name__ == "__main__":
for arg in vars(cmd_opts):
if arg in sd_kwargs:
sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
for i in amdshark_sd_fn_dict_input(sd_kwargs):
print(i)

View File

@@ -8,11 +8,11 @@ from random import (
)
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
from amdshark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
@@ -21,7 +21,7 @@ from shark.iree_utils.vulkan_utils import (
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
from amdshark.iree_utils._common import iree_device_map
device_list = []
try:
@@ -59,7 +59,7 @@ def get_available_devices():
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
from shark.iree_utils.vulkan_utils import (
from amdshark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
@@ -116,7 +116,7 @@ def set_init_device_flags():
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple
from amdshark.iree_utils.metal_utils import get_metal_target_triple
triple = get_metal_target_triple(device_name)
if triple is not None:
@@ -146,7 +146,7 @@ def set_iree_runtime_flags():
def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
from amdshark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
iree_target_map,
@@ -192,7 +192,7 @@ def get_rocm_target_chip(device_str):
if key in device_str:
return rocm_chip_map[key]
raise AssertionError(
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK-Studio/issues."
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/AMDSHARK-Studio/issues."
)
@@ -225,7 +225,7 @@ def get_device_mapping(driver, key_combination=3):
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
from amdshark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
@@ -256,7 +256,7 @@ def get_opt_flags(model, precision="fp16"):
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
)
if "rocm" in cmd_opts.device:
from shark.iree_utils.gpu_utils import get_iree_rocm_args
from amdshark.iree_utils.gpu_utils import get_iree_rocm_args
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
@@ -301,7 +301,7 @@ def map_device_to_name_path(device, key_combination=3):
return device_mapping
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
from amdshark.iree_utils._common import iree_device_map
device_list = []
try:
@@ -332,7 +332,7 @@ def map_device_to_name_path(device, key_combination=3):
set_iree_runtime_flags()
available_devices = []
from shark.iree_utils.vulkan_utils import (
from amdshark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)

View File

@@ -12,7 +12,7 @@ from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,

View File

@@ -5,7 +5,7 @@ import json
import safetensors
from dataclasses import dataclass
from safetensors.torch import load_file
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
get_path_stem,
)

View File

@@ -25,11 +25,11 @@ resampler_list = resamplers.keys()
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info=None):
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
if extra_info is None:
extra_info = {}

View File

@@ -30,8 +30,8 @@ def logger_test(x):
def read_sd_logs():
sys.stdout.flush()
with open("shark_tmp/sd.log", "r") as f:
with open("amdshark_tmp/sd.log", "r") as f:
return f.read()
sys.stdout = Logger("shark_tmp/sd.log", filter="[LOG]")
sys.stdout = Logger("amdshark_tmp/sd.log", filter="[LOG]")

View File

@@ -1,14 +1,14 @@
from shark.iree_utils.compile_utils import (
from amdshark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_resource_path,
)
from apps.shark_studio.modules.shared_cmd_opts import (
from apps.amdshark_studio.modules.shared_cmd_opts import (
cmd_opts,
)
from iree import runtime as ireert
@@ -17,7 +17,7 @@ import gc
import os
class SharkPipelineBase:
class AMDSharkPipelineBase:
# This class is a lightweight base for managing an
# inference API class. It should provide methods for:
# - compiling a set (model map) of torch IR modules

View File

@@ -224,7 +224,7 @@ def get_unweighted_text_embeddings(
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# AMDSHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)

View File

@@ -1,4 +1,4 @@
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
# from amdshark_turbine.turbine_models.schedulers import export_scheduler_model
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,

View File

@@ -2,7 +2,7 @@ import argparse
import os
from pathlib import Path
from apps.shark_studio.modules.img_processing import resampler_list
from apps.amdshark_studio.modules.img_processing import resampler_list
def path_expand(s):
@@ -299,8 +299,8 @@ p.add_argument(
"--import_mlir",
default=True,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
help="Imports the model from torch module to amdshark_module otherwise "
"downloads the model from amdshark_tank.",
)
p.add_argument(
@@ -487,8 +487,8 @@ p.add_argument(
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
help="Specify where to save downloaded amdshark_tank artifacts. "
"If this is not set, the default is ~/.local/amdshark_tank/.",
)
p.add_argument(
@@ -562,7 +562,7 @@ p.add_argument(
default=False,
action=argparse.BooleanOptionalAction,
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
"in amdshark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
@@ -615,7 +615,7 @@ p.add_argument(
p.add_argument(
"--tmp_dir",
type=str,
default=os.path.join(os.getcwd(), "shark_tmp"),
default=os.path.join(os.getcwd(), "amdshark_tmp"),
help="Path to tmp directory",
)

View File

@@ -38,7 +38,7 @@ datas += collect_data_files("transformers")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree", include_py_files=True)
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("amdshark", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("sentencepiece")
@@ -54,7 +54,7 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "apps"]
hiddenimports = ["amdshark", "apps"]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
blacklist = ["tests", "convert"]

View File

@@ -8,14 +8,14 @@ import logging
import unittest
import json
import gc
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
from apps.shark_studio.web.utils.file_utils import get_resource_path
from apps.amdshark_studio.api.llm import LanguageModel, llm_chat_api
from apps.amdshark_studio.api.sd import amdshark_sd_fn_dict_input, view_json_file
from apps.amdshark_studio.web.utils.file_utils import get_resource_path
# class SDAPITest(unittest.TestCase):
# def testSDSimple(self):
# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# import apps.shark_studio.web.utils.globals as global_obj
# from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
# import apps.amdshark_studio.web.utils.globals as global_obj
# global_obj._init()
@@ -24,7 +24,7 @@ from apps.shark_studio.web.utils.file_utils import get_resource_path
# for arg in vars(cmd_opts):
# if arg in sd_kwargs:
# sd_kwargs[arg] = getattr(cmd_opts, arg)
# for i in shark_sd_fn_dict_input(sd_kwargs):
# for i in amdshark_sd_fn_dict_input(sd_kwargs):
# print(i)

View File

Before

Width:  |  Height:  |  Size: 347 KiB

After

Width:  |  Height:  |  Size: 347 KiB

View File

@@ -38,8 +38,8 @@ def llm_chat_test(verbose=False):
if __name__ == "__main__":
# "Exercises the chatbot REST API of Shark. Make sure "
# "Shark is running in API mode on 127.0.0.1:8080 before running"
# "Exercises the chatbot REST API of AMDShark. Make sure "
# "AMDShark is running in API mode on 127.0.0.1:8080 before running"
# "this script."
llm_chat_test(verbose=True)

View File

@@ -18,10 +18,10 @@ from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
# from sdapi_v1 import shark_sd_api
from apps.shark_studio.api.llm import llm_chat_api
# from sdapi_v1 import amdshark_sd_api
from apps.amdshark_studio.api.llm import llm_chat_api
def decode_base64_to_image(encoding):
@@ -183,8 +183,8 @@ class ApiCompat:
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/txt2img", amdshark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/img2img", amdshark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)

View File

@@ -7,10 +7,10 @@ import os
import time
import sys
import logging
import apps.shark_studio.api.initializers as initialize
import apps.amdshark_studio.api.initializers as initialize
from apps.shark_studio.modules import timer
from apps.amdshark_studio.modules import timer
startup_timer = timer.startup_timer
startup_timer.record("launcher")
@@ -24,7 +24,7 @@ if sys.platform == "darwin":
def create_api(app):
from apps.shark_studio.web.api.compat import ApiCompat, FIFOLock
from apps.amdshark_studio.web.api.compat import ApiCompat, FIFOLock
queue_lock = FIFOLock()
api = ApiCompat(app, queue_lock)
@@ -33,7 +33,7 @@ def create_api(app):
def api_only():
from fastapi import FastAPI
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
initialize.initialize()
@@ -64,7 +64,7 @@ def launch_webui(address):
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
"AMDSHARK AI Studio",
url=address,
width=width,
height=height,
@@ -74,8 +74,8 @@ def launch_webui(address):
def webui():
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.web.ui.utils import (
amdicon_loc,
amdlogo_loc,
)
@@ -91,10 +91,10 @@ def webui():
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.shark_studio.api.llm import (
# from apps.amdshark_studio.api.llm import (
# chat,
# )
# from apps.shark_studio.web.api import sdapi
# from apps.amdshark_studio.web.api import sdapi
#
# from fastapi import FastAPI, APIRouter
# from fastapi.middleware.cors import CORSMiddleware
@@ -144,7 +144,7 @@ def webui():
dark_theme = resource_path("ui/css/sd_dark_theme.css")
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
# from apps.shark_studio.web.ui import load_ui_from_script
# from apps.amdshark_studio.web.ui import load_ui_from_script
def register_button_click(button, selectedid, inputs, outputs):
button.click(
@@ -170,7 +170,7 @@ def webui():
css=dark_theme,
js=gradio_workarounds,
analytics_enabled=False,
title="Shark Studio 2.0 Beta",
title="AMDShark Studio 2.0 Beta",
) as studio_web:
amd_logo = Image.open(amdlogo_loc)
gr.Image(
@@ -214,7 +214,7 @@ def webui():
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
if cmd_opts.webui == False:
api_only()

View File

@@ -5,12 +5,12 @@ from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.llm import (
from apps.amdshark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
import apps.amdshark_studio.web.utils.globals as global_obj
B_SYS, E_SYS = "<s>", "</s>"
@@ -129,7 +129,7 @@ with gr.Blocks(title="Chat") as chat_element:
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
label="Download vmfb from AMDShark tank if available",
value=False,
interactive=True,
visible=False,

View File

@@ -1,8 +1,8 @@
from apps.shark_studio.web.ui.utils import (
from apps.amdshark_studio.web.ui.utils import (
HSLHue,
hsl_color,
)
from apps.shark_studio.modules.embeddings import get_lora_metadata
from apps.amdshark_studio.modules.embeddings import get_lora_metadata
# Answers HTML to show the most frequent tags used when a LoRA was trained,

View File

@@ -100,7 +100,7 @@ Procedure to upgrade the dark theme:
--input-border-width: 1px;
}
/* SHARK theme */
/* AMDSHARK theme */
body {
background-color: var(--background-fill-primary);
}

View File

Before

Width:  |  Height:  |  Size: 7.1 KiB

After

Width:  |  Height:  |  Size: 7.1 KiB

View File

Before

Width:  |  Height:  |  Size: 7.4 KiB

After

Width:  |  Height:  |  Size: 7.4 KiB

View File

@@ -5,13 +5,13 @@ import subprocess
import sys
from PIL import Image
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.web.ui.utils import amdlogo_loc
from apps.shark_studio.web.utils.metadata import displayable_metadata
from apps.amdshark_studio.web.ui.utils import amdlogo_loc
from apps.amdshark_studio.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying

View File

@@ -9,40 +9,40 @@ from datetime import datetime as dt
from gradio.components.image_editor import (
EditorValue,
)
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_checkpoints_path,
get_checkpoints,
get_configs_path,
write_default_sd_configs,
)
from apps.shark_studio.api.sd import (
shark_sd_fn_dict_input,
from apps.amdshark_studio.api.sd import (
amdshark_sd_fn_dict_input,
cancel_sd,
unload_sd,
)
from apps.shark_studio.api.controlnet import (
from apps.amdshark_studio.api.controlnet import (
cnet_preview,
)
from apps.shark_studio.modules.schedulers import (
from apps.amdshark_studio.modules.schedulers import (
scheduler_model_map,
)
from apps.shark_studio.modules.img_processing import (
from apps.amdshark_studio.modules.img_processing import (
resampler_list,
resize_stencil,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.web.ui.utils import (
amdlogo_loc,
none_to_str_none,
str_none_to_none,
)
from apps.shark_studio.web.utils.state import (
from apps.amdshark_studio.web.utils.state import (
status_label,
)
from apps.shark_studio.web.ui.common_events import lora_changed
from apps.shark_studio.modules import logger
import apps.shark_studio.web.utils.globals as global_obj
from apps.amdshark_studio.web.ui.common_events import lora_changed
from apps.amdshark_studio.modules import logger
import apps.amdshark_studio.web.utils.globals as global_obj
sd_default_models = [
"runwayml/stable-diffusion-v1-5",
@@ -758,7 +758,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
)
gen_kwargs = dict(
fn=shark_sd_fn_dict_input,
fn=amdshark_sd_fn_dict_input,
inputs=[sd_json],
outputs=[
sd_gallery,

View File

@@ -4,14 +4,14 @@ import glob
from datetime import datetime as dt
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
checkpoints_filetypes = (
"*.ckpt",
"*.safetensors",
)
from apps.shark_studio.web.utils.default_configs import default_sd_configs
from apps.amdshark_studio.web.utils.default_configs import default_sd_configs
def write_default_sd_configs(path):

View File

@@ -1,4 +1,4 @@
# As SHARK has evolved more columns have been added to images_details.csv. However, since
# As AMDSHARK has evolved more columns have been added to images_details.csv. However, since
# no version of the CSV has any headers (yet) we don't actually have anything within the
# file that tells us which parameter each column is for. So this is a list of known patterns
# indexed by length which is what we're going to have to use to guess which columns are the

View File

@@ -1,11 +1,11 @@
import re
from pathlib import Path
from apps.shark_studio.web.utils.file_utils import (
from apps.amdshark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
)
from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map
from apps.amdshark_studio.api.sd import EMPTY_SD_MAP as sd_model_map
from apps.shark_studio.modules.schedulers import (
from apps.amdshark_studio.modules.schedulers import (
scheduler_model_map,
)

View File

@@ -1,4 +1,4 @@
import apps.shark_studio.web.utils.globals as global_obj
import apps.amdshark_studio.web.utils.globals as global_obj
import gc

View File

@@ -2,9 +2,9 @@ import os
import shutil
from time import time
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
amdshark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "amdshark_tmp/")
def clear_tmp_mlir():
@@ -12,20 +12,20 @@ def clear_tmp_mlir():
print("Clearing .mlir temporary files from a prior run. This may take some time...")
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
for filename in os.listdir(amdshark_tmp)
if os.path.isfile(os.path.join(amdshark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(os.path.join(shark_tmp, filename))
os.remove(os.path.join(amdshark_tmp, filename))
print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.")
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# tell gradio to use a directory under amdshark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
os.environ["GRADIO_TEMP_DIR"] = os.path.join(amdshark_tmp, "gradio")
print(
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
@@ -43,22 +43,22 @@ def clear_tmp_imgs():
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
# older AMDSHARK versions had to workaround gradio bugs and stored things differently
else:
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
for filename in os.listdir(amdshark_tmp)
if os.path.isfile(os.path.join(amdshark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
"Clearing temporary image files of a prior run of a previous AMDSHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
os.remove(amdshark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
@@ -67,9 +67,9 @@ def clear_tmp_imgs():
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
# create amdshark_tmp if it does not exist
if not os.path.exists(amdshark_tmp):
os.mkdir(amdshark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()