This commit is contained in:
Ean Garvey
2023-12-10 01:09:07 -06:00
parent 3047d36df2
commit ab32bfbe61
10 changed files with 229 additions and 134 deletions

View File

@@ -13,32 +13,45 @@ from apps.shark_studio.modules.timer import startup_timer
def imports():
import torch # noqa: F401
startup_timer.record("import torch")
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="torch")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, module="torchvision"
)
import gradio # noqa: F401
startup_timer.record("import gradio")
#from apps.shark_studio.modules import shared_init
#shared_init.initialize()
#startup_timer.record("initialize shared")
# from apps.shark_studio.modules import shared_init
# shared_init.initialize()
# startup_timer.record("initialize shared")
from apps.shark_studio.modules import (
processing,
gradio_extensons,
ui,
) # noqa: F401
from apps.shark_studio.modules import processing, gradio_extensons, ui # noqa: F401
startup_timer.record("other imports")
def initialize():
configure_sigint_handler()
configure_opts_onchange()
#from apps.shark_studio.modules import modelloader
#modelloader.cleanup_models()
# from apps.shark_studio.modules import modelloader
# modelloader.cleanup_models()
#from apps.shark_studio.modules import sd_models
#sd_models.setup_model()
#startup_timer.record("setup SD model")
# from apps.shark_studio.modules import sd_models
# sd_models.setup_model()
# startup_timer.record("setup SD model")
# initialize_rest(reload_script_modules=False)
#initialize_rest(reload_script_modules=False)
def initialize_rest(*, reload_script_modules=False):
"""
@@ -46,6 +59,7 @@ def initialize_rest(*, reload_script_modules=False):
"""
# Keep this for adding reload options to the webUI.
def dumpstacks():
import threading
import traceback
@@ -65,12 +79,10 @@ def dumpstacks():
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
print(f"Interrupted with signal {sig} in {frame}")
dumpstacks()
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)

View File

@@ -1,20 +1,30 @@
#from shark_turbine.turbine_models.schedulers import export_scheduler_model
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
def export_scheduler_model(model):
return "None", "None"
scheduler_model_map = {
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model(
"EulerAncestralDiscreteScheduler"
),
"LCM": export_scheduler_model("LCMScheduler"),
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
"PNDM": export_scheduler_model("PNDMScheduler"),
"DDPM": export_scheduler_model("DDPMScheduler"),
"DDIM": export_scheduler_model("DDIMScheduler"),
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
"DPMSolverMultistep": export_scheduler_model(
"DPMSolverMultistepScheduler"
),
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
"DPMSolverSinglestep": export_scheduler_model(
"DPMSolverSingleStepScheduler"
),
"KDPM2AncestralDiscrete": export_scheduler_model(
"KDPM2AncestralDiscreteScheduler"
),
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}

View File

@@ -6,97 +6,128 @@ import gc
import torch
sd_model_map = {
"sd15": {
"base_model_id": "runwayml/stable-diffusion-v1-5"
"CompVis/stable-diffusion-v1-4": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 77,
}
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
}
},
"vae_decode": {
"initializer": vae.export_vae_model,,
}
}
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
},
"runwayml/stable-diffusion-v1-5": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
},
"stabilityai/stable-diffusion-2-1-base": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
},
"stabilityai/stable_diffusion-xl-1.0": {
"clip_1": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"clip_2": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
},
}
class SharkStableDiffusionPipeline:
def __init__(
self, model_name, , device=None, precision="fp32"
):
print(sd_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
class StableDiffusion(SharkPipelineBase):
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
#
# custom_model_ids: a dict of submodel + HF ID pairs for custom submodels.
# e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"}
#
# embeddings: a dict of embedding checkpoints or model IDs to use when
# initializing the compiled modules.
def __init__(
self,
base_model_id: str = "runwayml/stable-diffusion-v1-5",
height: int = 512,
width: int = 512,
precision: str = "fp16",
device: str = None,
custom_model_map: dict = {},
custom_weights_map: dict = {},
embeddings: dict = {},
import_ir: bool = True,
):
super().__init__(sd_model_map[base_model_id], device, import_ir)
self.base_model_id = base_model_id
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
self.get_compiled_map()
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
)
# TODO: delete the temp file
def generate_images(
self,
prompt,
):
history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)
history.append(token)
yield self.tokenizer.decode(history)
if token == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output
return result_output,
if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
sd = StableDiffusion(
"runwayml/stable-diffusion-v1-5",
device="vulkan",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)

View File

@@ -1,6 +1,7 @@
import torch
from safetensors.torch import load_file
def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
@@ -108,4 +109,3 @@ def update_lora_weight(model, use_lora, model_name):
return processLoRA(model, use_lora, "lora_te_")
except:
return None

View File

@@ -2,8 +2,24 @@ import sys
import gradio as gr
from modules import shared_cmd_options, shared_gradio, options, shared_items, sd_models_types
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules import (
shared_cmd_options,
shared_gradio,
options,
shared_items,
sd_models_types,
)
from modules.paths_internal import (
models_path,
script_path,
data_path,
sd_configs_path,
sd_default_config,
sd_model_file,
default_sd_model_file,
extensions_dir,
extensions_builtin_dir,
) # noqa: F401
from modules import util
cmd_opts = shared_cmd_options.cmd_opts

View File

@@ -11,7 +11,9 @@ class TimerSubcategory:
def __enter__(self):
self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/"
self.timer.base_category = (
self.original_base_category + self.category + "/"
)
self.timer.subcategory_level += 1
if self.timer.print_log:
@@ -20,7 +22,10 @@ class TimerSubcategory:
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
self.timer.add_time_to_record(
self.original_base_category + self.category,
elapsed_for_subcategroy,
)
self.timer.subcategory_level -= 1
self.timer.record(self.category, disable_log=True)
@@ -30,7 +35,7 @@ class Timer:
self.start = time.time()
self.records = {}
self.total = 0
self.base_category = ''
self.base_category = ""
self.print_log = print_log
self.subcategory_level = 0
@@ -54,7 +59,9 @@ class Timer:
self.total += e + extra_time
if self.print_log and not disable_log:
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
print(
f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s"
)
def subcategory(self, name):
self.elapsed()
@@ -65,25 +72,38 @@ class Timer:
def summary(self):
res = f"{self.total:.1f}s"
additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
additions = [
(category, time_taken)
for category, time_taken in self.records.items()
if time_taken >= 0.1 and "/" not in category
]
if not additions:
return res
res += " ("
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
res += ", ".join(
[
f"{category}: {time_taken:.1f}s"
for category, time_taken in additions
]
)
res += ")"
return res
def dump(self):
return {'total': self.total, 'records': self.records}
return {"total": self.total, "records": self.records}
def reset(self):
self.__init__()
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
parser.add_argument(
"--log-startup",
action="store_true",
help="print a detailed log of what's happening at startup",
)
args = parser.parse_known_args()[0]
startup_timer = Timer(print_log=args.log_startup)

View File

@@ -0,0 +1 @@

View File

@@ -20,6 +20,7 @@ if sys.platform == "darwin":
# import before IREE to avoid MLIR library issues
import torch_mlir
def create_api(app):
from apps.shark_studio.api.compat import ApiCompat
from modules.call_queue import queue_lock
@@ -27,6 +28,7 @@ def create_api(app):
api = ApiCompat(app, queue_lock)
return api
def api_only():
from fastapi import FastAPI
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
@@ -37,17 +39,17 @@ def api_only():
initialize.setup_middleware(app)
api = create_api(app)
#from modules import script_callbacks
#script_callbacks.before_ui_callback()
#script_callbacks.app_started_callback(None, app)
# from modules import script_callbacks
# script_callbacks.before_ui_callback()
# script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
port=cmd_opts.port if cmd_opts.port else 8080,
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
)
def launch_webui(address):
from tkinter import Tk
@@ -68,11 +70,12 @@ def launch_webui(address):
)
webview.start(private_mode=False, storage_path=os.getcwd())
def webui():
from apps.shark_studio.shared_cmd_options import cmd_opts
logging.basicConfig(level=logging.DEBUG)
launch_api = cmd_opts.api
initialize.initialize()
@@ -81,7 +84,6 @@ def webui():
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.shark_studio.api.llm import (
# chat,
@@ -94,7 +96,7 @@ def webui():
#
# # init global sd pipeline and config
# global_obj._init()
#
#
# api = FastAPI()
# api.mount("/sdapi/", sdapi)
#
@@ -123,15 +125,15 @@ def webui():
# )
# else:
# print("API not configured for CORS")
#
#
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
# sys.exit(0)
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.shark_studio.web.initializers import (
config_gradio_tmp_imgs_folder,
create_custom_models_folders,
config_gradio_tmp_imgs_folder,
create_custom_models_folders,
)
config_gradio_tmp_imgs_folder()
@@ -161,6 +163,7 @@ def webui():
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
@@ -191,7 +194,6 @@ def webui():
with gr.TabItem(label="Chat Bot", id=2):
chat_element.render()
studio_web.queue()
# if args.ui == "app":
# t = Process(
@@ -204,6 +206,8 @@ def webui():
server_name="0.0.0.0",
server_port=11911, # args.server_port,
)
if __name__ == "__main__":
from apps.shark_studio.shared_cmd_options import cmd_opts

View File

@@ -24,13 +24,12 @@ from apps.shark_studio.api.utils import (
)
from apps.shark_studio.api.sd import (
sd_model_map,
SharkStableDiffusionPipeline,
StableDiffusion,
)
from apps.shark_studio.api.schedulers import (
scheduler_model_map,
)
from apps.shark_studio.api.controlnet import (
resampler_list,
preprocessor_model_map,
control_adapter_model_map,
PreprocessorModel,
@@ -70,8 +69,8 @@ def shark_sd_fn(
custom_vae: str,
precision: str,
device: str,
lora_weights: str,
lora_hf_id: str,
lora_weights: str | list,
lora_hf_ids: str | list,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
@@ -108,14 +107,22 @@ def shark_sd_fn(
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
submit_pipe_kwargs = {
"base_model_id": base_model_id,
"custom_vae": custom_vae,
"import_mlir": cmd_opts.import_mlir,
"":
base_model_id: base_model_id,
height: height,
width: width,
precision: precision,
device: device,
extra_model_ids: extra_model_ids,
embeddings: lora_hf_ids,
import_ir: cmd_opts.import_ir,
}
submit_prep_kwargs = {
global sd_pipe
global sd_pipe_kwargs
for key in
if sd_pipe is None:
@@ -127,17 +134,10 @@ def shark_sd_fn(
# which is currently MLIR in the torch dialect.
sd_pipe = SharkStableDiffusionPipeline(
base_model_id = base_model_id,
custom_vae = custom_vae,
import_mlir = import_mlir,
device = device.split("=>", 1)[1].strip(),
precision = precision,
max_length = 512,
height = height,
width = width,
**submit_pipe_kwargs
)
sd_pipe.queue_compile()
#
for prompt, msg, exec_time in progress.tqdm(
sd_pipe.generate_images(
prompt,

View File

@@ -1,9 +1,10 @@
def nodlogo_loc():
return "foo"
def get_checkpoints_path(model_type: str = None):
return "foo"
def get_checkpoints():
return "foo"