mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Commit 2
This commit is contained in:
@@ -13,32 +13,45 @@ from apps.shark_studio.modules.timer import startup_timer
|
||||
|
||||
def imports():
|
||||
import torch # noqa: F401
|
||||
|
||||
startup_timer.record("import torch")
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="torch")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=DeprecationWarning, module="torch"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=UserWarning, module="torchvision"
|
||||
)
|
||||
|
||||
import gradio # noqa: F401
|
||||
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
#from apps.shark_studio.modules import shared_init
|
||||
#shared_init.initialize()
|
||||
#startup_timer.record("initialize shared")
|
||||
# from apps.shark_studio.modules import shared_init
|
||||
# shared_init.initialize()
|
||||
# startup_timer.record("initialize shared")
|
||||
|
||||
from apps.shark_studio.modules import (
|
||||
processing,
|
||||
gradio_extensons,
|
||||
ui,
|
||||
) # noqa: F401
|
||||
|
||||
from apps.shark_studio.modules import processing, gradio_extensons, ui # noqa: F401
|
||||
startup_timer.record("other imports")
|
||||
|
||||
|
||||
def initialize():
|
||||
configure_sigint_handler()
|
||||
configure_opts_onchange()
|
||||
|
||||
#from apps.shark_studio.modules import modelloader
|
||||
#modelloader.cleanup_models()
|
||||
# from apps.shark_studio.modules import modelloader
|
||||
# modelloader.cleanup_models()
|
||||
|
||||
#from apps.shark_studio.modules import sd_models
|
||||
#sd_models.setup_model()
|
||||
#startup_timer.record("setup SD model")
|
||||
# from apps.shark_studio.modules import sd_models
|
||||
# sd_models.setup_model()
|
||||
# startup_timer.record("setup SD model")
|
||||
|
||||
# initialize_rest(reload_script_modules=False)
|
||||
|
||||
#initialize_rest(reload_script_modules=False)
|
||||
|
||||
def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
@@ -46,6 +59,7 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
# Keep this for adding reload options to the webUI.
|
||||
|
||||
|
||||
def dumpstacks():
|
||||
import threading
|
||||
import traceback
|
||||
@@ -65,12 +79,10 @@ def dumpstacks():
|
||||
def configure_sigint_handler():
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
def sigint_handler(sig, frame):
|
||||
print(f'Interrupted with signal {sig} in {frame}')
|
||||
print(f"Interrupted with signal {sig} in {frame}")
|
||||
|
||||
dumpstacks()
|
||||
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,30 @@
|
||||
#from shark_turbine.turbine_models.schedulers import export_scheduler_model
|
||||
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
|
||||
|
||||
|
||||
def export_scheduler_model(model):
|
||||
return "None", "None"
|
||||
|
||||
|
||||
scheduler_model_map = {
|
||||
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
|
||||
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
|
||||
"EulerAncestralDiscrete": export_scheduler_model(
|
||||
"EulerAncestralDiscreteScheduler"
|
||||
),
|
||||
"LCM": export_scheduler_model("LCMScheduler"),
|
||||
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
"PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
"DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
"DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
"DPMSolverMultistep": export_scheduler_model(
|
||||
"DPMSolverMultistepScheduler"
|
||||
),
|
||||
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
"DPMSolverSinglestep": export_scheduler_model(
|
||||
"DPMSolverSingleStepScheduler"
|
||||
),
|
||||
"KDPM2AncestralDiscrete": export_scheduler_model(
|
||||
"KDPM2AncestralDiscreteScheduler"
|
||||
),
|
||||
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
}
|
||||
|
||||
@@ -6,97 +6,128 @@ import gc
|
||||
import torch
|
||||
|
||||
sd_model_map = {
|
||||
"sd15": {
|
||||
"base_model_id": "runwayml/stable-diffusion-v1-5"
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 77,
|
||||
}
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"max_tokens": 512,
|
||||
}
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,,
|
||||
}
|
||||
}
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
},
|
||||
"runwayml/stable-diffusion-v1-5": {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
},
|
||||
"stabilityai/stable-diffusion-2-1-base": {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
},
|
||||
"stabilityai/stable_diffusion-xl-1.0": {
|
||||
"clip_1": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"clip_2": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SharkStableDiffusionPipeline:
|
||||
def __init__(
|
||||
self, model_name, , device=None, precision="fp32"
|
||||
):
|
||||
print(sd_model_map[model_name])
|
||||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
|
||||
self.torch_ir, self.tokenizer = llm_model_map[model_name][
|
||||
"initializer"
|
||||
](self.hf_model_name, hf_auth_token, compile_to="torch")
|
||||
self.tempfile_name = get_resource_path("llm.torch.tempfile")
|
||||
with open(self.tempfile_name, "w+") as f:
|
||||
f.write(self.torch_ir)
|
||||
del self.torch_ir
|
||||
gc.collect()
|
||||
class StableDiffusion(SharkPipelineBase):
|
||||
|
||||
# This class is responsible for executing image generation and creating
|
||||
# /managing a set of compiled modules to run Stable Diffusion. The init
|
||||
# aims to be as general as possible, and the class will infer and compile
|
||||
# a list of necessary modules or a combined "pipeline module" for a
|
||||
# specified job based on the inference task.
|
||||
#
|
||||
# custom_model_ids: a dict of submodel + HF ID pairs for custom submodels.
|
||||
# e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"}
|
||||
#
|
||||
# embeddings: a dict of embedding checkpoints or model IDs to use when
|
||||
# initializing the compiled modules.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model_id: str = "runwayml/stable-diffusion-v1-5",
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
precision: str = "fp16",
|
||||
device: str = None,
|
||||
custom_model_map: dict = {},
|
||||
custom_weights_map: dict = {},
|
||||
embeddings: dict = {},
|
||||
import_ir: bool = True,
|
||||
):
|
||||
super().__init__(sd_model_map[base_model_id], device, import_ir)
|
||||
self.base_model_id = base_model_id
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.max_tokens = llm_model_map[model_name]["max_tokens"]
|
||||
self.iree_module_dict = None
|
||||
self.compile()
|
||||
self.get_compiled_map()
|
||||
|
||||
def compile(self) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
self.iree_module_dict = get_iree_compiled_module(
|
||||
self.tempfile_name, device=self.device, frontend="torch"
|
||||
)
|
||||
# TODO: delete the temp file
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
):
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
input_tensor = self.tokenizer(
|
||||
prompt, return_tensors="pt"
|
||||
).input_ids
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(
|
||||
self.iree_module_dict["config"], input_tensor
|
||||
)
|
||||
]
|
||||
if iter == 0:
|
||||
token = torch.tensor(
|
||||
self.iree_module_dict["vmfb"]["run_initialize"](
|
||||
*device_inputs
|
||||
).to_host()[0][0]
|
||||
)
|
||||
else:
|
||||
token = torch.tensor(
|
||||
self.iree_module_dict["vmfb"]["run_forward"](
|
||||
*device_inputs
|
||||
).to_host()[0][0]
|
||||
)
|
||||
|
||||
history.append(token)
|
||||
yield self.tokenizer.decode(history)
|
||||
|
||||
if token == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
break
|
||||
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
yield result_output
|
||||
|
||||
return result_output,
|
||||
|
||||
if __name__ == "__main__":
|
||||
lm = LanguageModel(
|
||||
"llama2_7b",
|
||||
hf_auth_token="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",
|
||||
device="cpu-task",
|
||||
sd = StableDiffusion(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
device="vulkan",
|
||||
)
|
||||
print("model loaded")
|
||||
for i in lm.chat("Hello, I am a robot."):
|
||||
print(i)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
@@ -108,4 +109,3 @@ def update_lora_weight(model, use_lora, model_name):
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
@@ -2,8 +2,24 @@ import sys
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared_cmd_options, shared_gradio, options, shared_items, sd_models_types
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from modules import (
|
||||
shared_cmd_options,
|
||||
shared_gradio,
|
||||
options,
|
||||
shared_items,
|
||||
sd_models_types,
|
||||
)
|
||||
from modules.paths_internal import (
|
||||
models_path,
|
||||
script_path,
|
||||
data_path,
|
||||
sd_configs_path,
|
||||
sd_default_config,
|
||||
sd_model_file,
|
||||
default_sd_model_file,
|
||||
extensions_dir,
|
||||
extensions_builtin_dir,
|
||||
) # noqa: F401
|
||||
from modules import util
|
||||
|
||||
cmd_opts = shared_cmd_options.cmd_opts
|
||||
|
||||
@@ -11,7 +11,9 @@ class TimerSubcategory:
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
self.timer.base_category = self.original_base_category + self.category + "/"
|
||||
self.timer.base_category = (
|
||||
self.original_base_category + self.category + "/"
|
||||
)
|
||||
self.timer.subcategory_level += 1
|
||||
|
||||
if self.timer.print_log:
|
||||
@@ -20,7 +22,10 @@ class TimerSubcategory:
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
elapsed_for_subcategroy = time.time() - self.start
|
||||
self.timer.base_category = self.original_base_category
|
||||
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
|
||||
self.timer.add_time_to_record(
|
||||
self.original_base_category + self.category,
|
||||
elapsed_for_subcategroy,
|
||||
)
|
||||
self.timer.subcategory_level -= 1
|
||||
self.timer.record(self.category, disable_log=True)
|
||||
|
||||
@@ -30,7 +35,7 @@ class Timer:
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
self.base_category = ''
|
||||
self.base_category = ""
|
||||
self.print_log = print_log
|
||||
self.subcategory_level = 0
|
||||
|
||||
@@ -54,7 +59,9 @@ class Timer:
|
||||
self.total += e + extra_time
|
||||
|
||||
if self.print_log and not disable_log:
|
||||
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
|
||||
print(
|
||||
f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s"
|
||||
)
|
||||
|
||||
def subcategory(self, name):
|
||||
self.elapsed()
|
||||
@@ -65,25 +72,38 @@ class Timer:
|
||||
def summary(self):
|
||||
res = f"{self.total:.1f}s"
|
||||
|
||||
additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
|
||||
additions = [
|
||||
(category, time_taken)
|
||||
for category, time_taken in self.records.items()
|
||||
if time_taken >= 0.1 and "/" not in category
|
||||
]
|
||||
if not additions:
|
||||
return res
|
||||
|
||||
res += " ("
|
||||
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
||||
res += ", ".join(
|
||||
[
|
||||
f"{category}: {time_taken:.1f}s"
|
||||
for category, time_taken in additions
|
||||
]
|
||||
)
|
||||
res += ")"
|
||||
|
||||
return res
|
||||
|
||||
def dump(self):
|
||||
return {'total': self.total, 'records': self.records}
|
||||
return {"total": self.total, "records": self.records}
|
||||
|
||||
def reset(self):
|
||||
self.__init__()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
|
||||
parser.add_argument(
|
||||
"--log-startup",
|
||||
action="store_true",
|
||||
help="print a detailed log of what's happening at startup",
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
|
||||
startup_timer = Timer(print_log=args.log_startup)
|
||||
|
||||
1
apps/shark_studio/web/api/sd.py
Normal file
1
apps/shark_studio/web/api/sd.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -20,6 +20,7 @@ if sys.platform == "darwin":
|
||||
# import before IREE to avoid MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def create_api(app):
|
||||
from apps.shark_studio.api.compat import ApiCompat
|
||||
from modules.call_queue import queue_lock
|
||||
@@ -27,6 +28,7 @@ def create_api(app):
|
||||
api = ApiCompat(app, queue_lock)
|
||||
return api
|
||||
|
||||
|
||||
def api_only():
|
||||
from fastapi import FastAPI
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
@@ -37,17 +39,17 @@ def api_only():
|
||||
initialize.setup_middleware(app)
|
||||
api = create_api(app)
|
||||
|
||||
#from modules import script_callbacks
|
||||
#script_callbacks.before_ui_callback()
|
||||
#script_callbacks.app_started_callback(None, app)
|
||||
# from modules import script_callbacks
|
||||
# script_callbacks.before_ui_callback()
|
||||
# script_callbacks.app_started_callback(None, app)
|
||||
|
||||
print(f"Startup time: {startup_timer.summary()}.")
|
||||
api.launch(
|
||||
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
|
||||
port=cmd_opts.port if cmd_opts.port else 8080,
|
||||
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
|
||||
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
|
||||
)
|
||||
|
||||
|
||||
|
||||
def launch_webui(address):
|
||||
from tkinter import Tk
|
||||
@@ -68,11 +70,12 @@ def launch_webui(address):
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
def webui():
|
||||
from apps.shark_studio.shared_cmd_options import cmd_opts
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
|
||||
@@ -81,7 +84,6 @@ def webui():
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
|
||||
|
||||
# if args.api or "api" in args.ui.split(","):
|
||||
# from apps.shark_studio.api.llm import (
|
||||
# chat,
|
||||
@@ -94,7 +96,7 @@ def webui():
|
||||
#
|
||||
# # init global sd pipeline and config
|
||||
# global_obj._init()
|
||||
#
|
||||
#
|
||||
# api = FastAPI()
|
||||
# api.mount("/sdapi/", sdapi)
|
||||
#
|
||||
@@ -123,15 +125,15 @@ def webui():
|
||||
# )
|
||||
# else:
|
||||
# print("API not configured for CORS")
|
||||
#
|
||||
#
|
||||
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
|
||||
# sys.exit(0)
|
||||
# sys.exit(0)
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.shark_studio.web.initializers import (
|
||||
config_gradio_tmp_imgs_folder,
|
||||
create_custom_models_folders,
|
||||
config_gradio_tmp_imgs_folder,
|
||||
create_custom_models_folders,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
@@ -161,6 +163,7 @@ def webui():
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
@@ -191,7 +194,6 @@ def webui():
|
||||
with gr.TabItem(label="Chat Bot", id=2):
|
||||
chat_element.render()
|
||||
|
||||
|
||||
studio_web.queue()
|
||||
# if args.ui == "app":
|
||||
# t = Process(
|
||||
@@ -204,6 +206,8 @@ def webui():
|
||||
server_name="0.0.0.0",
|
||||
server_port=11911, # args.server_port,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from apps.shark_studio.shared_cmd_options import cmd_opts
|
||||
|
||||
|
||||
@@ -24,13 +24,12 @@ from apps.shark_studio.api.utils import (
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
sd_model_map,
|
||||
SharkStableDiffusionPipeline,
|
||||
StableDiffusion,
|
||||
)
|
||||
from apps.shark_studio.api.schedulers import (
|
||||
scheduler_model_map,
|
||||
)
|
||||
from apps.shark_studio.api.controlnet import (
|
||||
resampler_list,
|
||||
preprocessor_model_map,
|
||||
control_adapter_model_map,
|
||||
PreprocessorModel,
|
||||
@@ -70,8 +69,8 @@ def shark_sd_fn(
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_weights: str | list,
|
||||
lora_hf_ids: str | list,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
@@ -108,14 +107,22 @@ def shark_sd_fn(
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
submit_pipe_kwargs = {
|
||||
"base_model_id": base_model_id,
|
||||
"custom_vae": custom_vae,
|
||||
"import_mlir": cmd_opts.import_mlir,
|
||||
"":
|
||||
base_model_id: base_model_id,
|
||||
height: height,
|
||||
width: width,
|
||||
precision: precision,
|
||||
device: device,
|
||||
extra_model_ids: extra_model_ids,
|
||||
embeddings: lora_hf_ids,
|
||||
import_ir: cmd_opts.import_ir,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
|
||||
|
||||
|
||||
global sd_pipe
|
||||
global sd_pipe_kwargs
|
||||
|
||||
|
||||
for key in
|
||||
|
||||
if sd_pipe is None:
|
||||
@@ -127,17 +134,10 @@ def shark_sd_fn(
|
||||
# which is currently MLIR in the torch dialect.
|
||||
|
||||
sd_pipe = SharkStableDiffusionPipeline(
|
||||
base_model_id = base_model_id,
|
||||
custom_vae = custom_vae,
|
||||
import_mlir = import_mlir,
|
||||
device = device.split("=>", 1)[1].strip(),
|
||||
precision = precision,
|
||||
max_length = 512,
|
||||
height = height,
|
||||
width = width,
|
||||
**submit_pipe_kwargs
|
||||
)
|
||||
sd_pipe.queue_compile()
|
||||
|
||||
#
|
||||
for prompt, msg, exec_time in progress.tqdm(
|
||||
sd_pipe.generate_images(
|
||||
prompt,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
def nodlogo_loc():
|
||||
return "foo"
|
||||
|
||||
|
||||
def get_checkpoints_path(model_type: str = None):
|
||||
return "foo"
|
||||
|
||||
|
||||
def get_checkpoints():
|
||||
return "foo"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user