mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Model loading
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import re
|
||||
import warnings
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from apps.shark_studio.modules.timer import startup_timer
|
||||
from apps.shark_studio.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
clear_tmp_mlir,
|
||||
clear_tmp_imgs,
|
||||
)
|
||||
|
||||
|
||||
def imports():
|
||||
@@ -21,6 +24,9 @@ def imports():
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=UserWarning, module="torchvision"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=UserWarning, message='.*is deprecated, please use.*', module="*torch*"
|
||||
)
|
||||
|
||||
import gradio # noqa: F401
|
||||
|
||||
@@ -34,20 +40,27 @@ def imports():
|
||||
from apps.shark_studio.modules import (
|
||||
img_processing,
|
||||
) # noqa: F401
|
||||
from apps.shark_studio.modules.schedulers import scheduler_model_map
|
||||
|
||||
startup_timer.record("other imports")
|
||||
|
||||
|
||||
def initialize():
|
||||
configure_sigint_handler()
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
|
||||
# from apps.shark_studio.modules import modelloader
|
||||
# modelloader.cleanup_models()
|
||||
config_tmp()
|
||||
clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
|
||||
# from apps.shark_studio.modules import sd_models
|
||||
# sd_models.setup_model()
|
||||
# startup_timer.record("setup SD model")
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
create_checkpoint_folders,
|
||||
)
|
||||
# Create custom models folders if they don't exist
|
||||
create_checkpoint_folders()
|
||||
|
||||
import gradio as gr
|
||||
|
||||
# initialize_rest(reload_script_modules=False)
|
||||
|
||||
|
||||
@@ -1,125 +1,61 @@
|
||||
import gc
|
||||
from unittest import registerResult
|
||||
import torch
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from turbine_models.custom_models.sd_inference import clip, unet, vae
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
from apps.shark_studio.web.utils.state import status_label
|
||||
from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path
|
||||
from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path, get_checkpoints_path
|
||||
from apps.shark_studio.modules.pipeline import SharkPipelineBase
|
||||
from apps.shark_studio.modules.schedulers import get_schedulers
|
||||
from apps.shark_studio.modules.prompt_encoding import get_weighted_text_embeddings
|
||||
from apps.shark_studio.modules.img_processing import (
|
||||
resize_stencil,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.shark_studio.modules.ckpt_processing import (
|
||||
preprocessCKPT,
|
||||
process_custom_pipe_weights,
|
||||
)
|
||||
from transformers import CLIPTokenizer
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
|
||||
sd_model_map = {
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"external_weight_file": None,
|
||||
"ireec_flags": ["--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
},
|
||||
"runwayml/stable-diffusion-v1-5": {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"external_weight_file": None,
|
||||
},
|
||||
"stabilityai/stable-diffusion-2-1-base": {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"ireec_flags": ["--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
"external_weight_file": None,
|
||||
},
|
||||
"stabilityai/stable_diffusion-xl-1.0": {
|
||||
"clip_1": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"clip_2": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"max_tokens": 64,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"external_weight_file": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_spec(custom_sd_map: dict, sd_embeds: dict):
|
||||
spec = []
|
||||
for key in custom_sd_map:
|
||||
if "control" in key.split("_"):
|
||||
spec.append("controlled")
|
||||
elif key == "custom_vae":
|
||||
spec.append(custom_sd_map[key]["custom_weights"].split(".")[0])
|
||||
num_embeds = 0
|
||||
embeddings_spec = None
|
||||
for embed in sd_embeds:
|
||||
if embed is not None:
|
||||
num_embeds += 1
|
||||
embeddings_spec = str(num_embeds) + "embeds"
|
||||
if embeddings_spec:
|
||||
spec.append(embeddings_spec)
|
||||
return "_".join(spec)
|
||||
|
||||
|
||||
class StableDiffusion(SharkPipelineBase):
|
||||
|
||||
# This class is responsible for executing image generation and creating
|
||||
# /managing a set of compiled modules to run Stable Diffusion. The init
|
||||
# aims to be as general as possible, and the class will infer and compile
|
||||
@@ -132,39 +68,143 @@ class StableDiffusion(SharkPipelineBase):
|
||||
# embeddings: a dict of embedding checkpoints or model IDs to use when
|
||||
# initializing the compiled modules.
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model_id: str = "runwayml/stable-diffusion-v1-5",
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
precision: str = "fp16",
|
||||
device: str = None,
|
||||
custom_model_map: dict = {},
|
||||
embeddings: dict = {},
|
||||
base_model_id,
|
||||
height: int,
|
||||
width: int,
|
||||
batch_size: int,
|
||||
precision: str,
|
||||
device: str,
|
||||
custom_vae: str = None,
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
is_img2img: bool = False,
|
||||
is_controlled: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
sd_model_map[base_model_id], base_model_id, device, import_ir
|
||||
)
|
||||
self.model_max_length = 77
|
||||
self.batch_size = batch_size
|
||||
self.precision = precision
|
||||
self.is_img2img = is_img2img
|
||||
self.pipe_id = (
|
||||
safe_name(base_model_id)
|
||||
+ str(height)
|
||||
+ str(width)
|
||||
+ precision
|
||||
+ device
|
||||
+ get_spec(custom_model_map, embeddings)
|
||||
self.scheduler_obj = {}
|
||||
self.precision = precision
|
||||
static_kwargs = {
|
||||
"pipe": {},
|
||||
"clip": {"hf_model_name": base_model_id},
|
||||
"unet": {
|
||||
"hf_model_name": base_model_id,
|
||||
"unet_model": unet.UnetModel(hf_model_name=base_model_id, hf_auth_token=None),
|
||||
"batch_size": batch_size,
|
||||
#"is_controlled": is_controlled,
|
||||
#"num_loras": num_loras,
|
||||
"height": height,
|
||||
"width": width,
|
||||
},
|
||||
"vae_encode": {
|
||||
"hf_model_name": custom_vae if custom_vae else base_model_id,
|
||||
"vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
"width": width,
|
||||
},
|
||||
"vae_decode": {
|
||||
"hf_model_name": custom_vae,
|
||||
"vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
"width": width,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
sd_model_map, base_model_id, static_kwargs, device, import_ir
|
||||
)
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}")
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
f"{str(height)}x{str(width)}",
|
||||
precision,
|
||||
]
|
||||
if num_loras > 0:
|
||||
pipe_id_list.append(str(num_loras)+"lora")
|
||||
if is_controlled:
|
||||
pipe_id_list.append("controlled")
|
||||
if custom_vae:
|
||||
pipe_id_list.append(custom_vae)
|
||||
self.pipe_id = "_".join(pipe_id_list)
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
del static_kwargs
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(self, scheduler, custom_model_map, embeddings):
|
||||
|
||||
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings):
|
||||
print(
|
||||
f"\n[LOG] Preparing pipeline with scheduler {scheduler}, custom map {json.dumps(custom_model_map)}, and embeddings {json.dumps(embeddings)}."
|
||||
f"\n[LOG] Preparing pipeline with scheduler {scheduler}"
|
||||
f"\n[LOG] Custom embeddings currently unsupported."
|
||||
)
|
||||
self.get_compiled_map(device=self.device, pipe_id=self.pipe_id)
|
||||
return None
|
||||
schedulers = get_schedulers(self.base_model_id)
|
||||
self.weights_path = get_checkpoints_path(self.pipe_id)
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
# accepting a list of schedulers in batched cases.
|
||||
for i in scheduler:
|
||||
self.scheduler_obj[i] = schedulers[i]
|
||||
print(f"[LOG] Loaded scheduler: {i}")
|
||||
for model in adapters:
|
||||
self.model_map[model] = adapters[model]
|
||||
if os.path.isfile(custom_weights):
|
||||
for i in self.model_map:
|
||||
self.model_map[i]["external_weights_file"] = None
|
||||
elif custom_weights != "":
|
||||
print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?")
|
||||
self.static_kwargs["pipe"] = {
|
||||
# "external_weight_path": self.weights_path,
|
||||
# "external_weights": "safetensors",
|
||||
}
|
||||
self.get_compiled_map(pipe_id=self.pipe_id)
|
||||
print("\n[LOG] Pipeline successfully prepared for runtime.")
|
||||
return
|
||||
|
||||
|
||||
def encode_prompts_weight(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
):
|
||||
# Encodes the prompt into text encoder hidden states.
|
||||
self.load_submodels(["clip"])
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
self.base_model_id,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
clip_inf_start = time.time()
|
||||
|
||||
|
||||
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
||||
pipe=self,
|
||||
prompt=prompt,
|
||||
uncond_prompt=negative_prompt
|
||||
if do_classifier_free_guidance
|
||||
else None,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
pad = (0, 0) * (len(text_embeddings.shape) - 2)
|
||||
pad = pad + (0, 512 - text_embeddings.shape[1])
|
||||
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
|
||||
|
||||
# SHARK: Report clip inference time
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}")
|
||||
|
||||
return text_embeddings.numpy().astype(np.float16)
|
||||
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
@@ -181,11 +221,35 @@ class StableDiffusion(SharkPipelineBase):
|
||||
hints,
|
||||
):
|
||||
print("\n[LOG] Generating images...")
|
||||
batched_args=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
]
|
||||
for arg in batched_args:
|
||||
if not isinstance(arg, list):
|
||||
arg = [arg] * self.batch_size
|
||||
if len(arg) < self.batch_size:
|
||||
arg = arg * self.batch_size
|
||||
else:
|
||||
arg = [arg[i] for i in range(self.batch_size)]
|
||||
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
)
|
||||
print(text_embeddings)
|
||||
test_img = [
|
||||
Image.open(
|
||||
get_resource_path("../../tests/jupiter.png"), mode="r"
|
||||
).convert("RGB")
|
||||
]
|
||||
] * self.batch_size
|
||||
return test_img
|
||||
|
||||
|
||||
@@ -257,29 +321,31 @@ def shark_sd_fn(
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
custom_model_map = {}
|
||||
adapters = {}
|
||||
is_controlled = False
|
||||
control_mode = None
|
||||
hints = []
|
||||
if custom_weights != "None":
|
||||
custom_model_map["unet"] = {"custom_weights": custom_weights}
|
||||
if custom_vae != "None":
|
||||
custom_model_map["vae"] = {"custom_weights": custom_vae}
|
||||
num_loras = 0
|
||||
for i in embeddings:
|
||||
num_loras += 1 if embeddings[i] else 0
|
||||
if "model" in controlnets:
|
||||
for i, model in enumerate(controlnets["model"]):
|
||||
if "xl" not in base_model_id.lower():
|
||||
custom_model_map[f"control_adapter_{model}"] = {
|
||||
adapters[f"control_adapter_{model}"] = {
|
||||
"hf_id": control_adapter_map[
|
||||
"runwayml/stable-diffusion-v1-5"
|
||||
][model],
|
||||
"strength": controlnets["strength"][i],
|
||||
}
|
||||
else:
|
||||
custom_model_map[f"control_adapter_{model}"] = {
|
||||
adapters[f"control_adapter_{model}"] = {
|
||||
"hf_id": control_adapter_map[
|
||||
"stabilityai/stable-diffusion-xl-1.0"
|
||||
][model],
|
||||
"strength": controlnets["strength"][i],
|
||||
}
|
||||
if model is not None:
|
||||
is_controlled=True
|
||||
control_mode = controlnets["control_mode"]
|
||||
for i in controlnets["hint"]:
|
||||
hints.append[i]
|
||||
@@ -288,16 +354,19 @@ def shark_sd_fn(
|
||||
"base_model_id": base_model_id,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"batch_size": batch_size,
|
||||
"precision": precision,
|
||||
"device": device,
|
||||
"custom_model_map": custom_model_map,
|
||||
"embeddings": embeddings,
|
||||
"custom_vae": custom_vae,
|
||||
"num_loras": num_loras,
|
||||
"import_ir": cmd_opts.import_mlir,
|
||||
"is_img2img": is_img2img,
|
||||
"is_controlled": is_controlled,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
"scheduler": scheduler,
|
||||
"custom_model_map": custom_model_map,
|
||||
"custom_weights": custom_weights,
|
||||
"adapters": adapters,
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
submit_run_kwargs = {
|
||||
@@ -313,6 +382,7 @@ def shark_sd_fn(
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
}
|
||||
print(submit_pipe_kwargs)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
|
||||
@@ -92,6 +92,8 @@ def process_custom_pipe_weights(custom_weights):
|
||||
custom_weights_tgt = get_path_to_diffusers_checkpoint(
|
||||
custom_weights
|
||||
)
|
||||
custom_weights_params = custom_weights
|
||||
return custom_weights_params, custom_weights_tgt
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module
|
||||
from msvcrt import kbhit
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module, load_vmfb_using_mmap
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
get_checkpoints_path,
|
||||
get_resource_path,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import (
|
||||
cmd_opts,
|
||||
)
|
||||
from iree import runtime as ireert
|
||||
from pathlib import Path
|
||||
import gc
|
||||
import os
|
||||
|
||||
@@ -19,86 +25,152 @@ class SharkPipelineBase:
|
||||
self,
|
||||
model_map: dict,
|
||||
base_model_id: str,
|
||||
static_kwargs: dict,
|
||||
device: str,
|
||||
import_mlir: bool = True,
|
||||
):
|
||||
self.model_map = model_map
|
||||
self.static_kwargs = static_kwargs
|
||||
self.base_model_id = base_model_id
|
||||
self.device = device
|
||||
self.import_mlir = import_mlir
|
||||
self.iree_module_dict = {}
|
||||
self.tempfiles = {}
|
||||
|
||||
|
||||
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
|
||||
# First checks whether we have .vmfbs precompiled, then populates the map
|
||||
# with the precompiled executables and fetches executables for the rest of the map.
|
||||
# The weights aren't static here anymore so this function should be a part of pipeline
|
||||
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
|
||||
# and your model map is populated with any IR - unique model IDs and their static params,
|
||||
# call this method to get the artifacts associated with your map.
|
||||
self.pipe_id = pipe_id
|
||||
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id))
|
||||
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
|
||||
print("\n[LOG] Checking for pre-compiled artifacts.")
|
||||
if submodel == "None":
|
||||
for key in self.model_map:
|
||||
self.get_compiled_map(pipe_id, submodel=key)
|
||||
else:
|
||||
self.get_precompiled(pipe_id, submodel)
|
||||
ireec_flags = []
|
||||
if submodel in self.iree_module_dict:
|
||||
if "vmfb" in self.iree_module_dict[submodel]:
|
||||
print(f"[LOG] Found executable for {submodel} at {self.iree_module_dict[submodel]['vmfb']}...")
|
||||
return
|
||||
elif submodel not in self.tempfiles:
|
||||
print(f"[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
|
||||
if submodel in self.static_kwargs:
|
||||
init_kwargs = self.static_kwargs[submodel]
|
||||
for key in self.static_kwargs["pipe"]:
|
||||
if key not in init_kwargs:
|
||||
init_kwargs[key] = self.static_kwargs["pipe"][key]
|
||||
self.import_torch_ir(
|
||||
submodel, init_kwargs
|
||||
)
|
||||
self.get_compiled_map(pipe_id, submodel)
|
||||
else:
|
||||
ireec_flags = self.model_map[submodel]["ireec_flags"] if "ireec_flags" in self.model_map[submodel] else []
|
||||
|
||||
if "external_weights_file" in self.model_map[submodel]:
|
||||
weights_path = self.model_map[submodel]["external_weights_file"]
|
||||
else:
|
||||
weights_path = None
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
self.tempfiles[submodel],
|
||||
device=self.device,
|
||||
frontend="torch",
|
||||
mmap=True,
|
||||
external_weight_file=weights_path,
|
||||
extra_args=ireec_flags,
|
||||
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb")
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def hijack_weights(self, weights_path, submodel="None"):
|
||||
if submodel == "None":
|
||||
for i in self.model_map:
|
||||
self.hijack_weights(weights_path, i)
|
||||
else:
|
||||
if submodel in self.iree_module_dict:
|
||||
self.model_map[submodel]["external_weights_file"] = weights_path
|
||||
return
|
||||
|
||||
|
||||
def get_precompiled(self, pipe_id, submodel="None"):
|
||||
if submodel == "None":
|
||||
for model in self.model_map:
|
||||
self.get_precompiled(pipe_id, model)
|
||||
vmfbs = []
|
||||
vmfb_matches = {}
|
||||
vmfbs_path = self.pipe_vmfb_path
|
||||
for dirpath, dirnames, filenames in os.walk(vmfbs_path):
|
||||
vmfbs.extend(filenames)
|
||||
break
|
||||
for file in vmfbs:
|
||||
if submodel in file:
|
||||
print(f"Found existing .vmfb at {file}")
|
||||
self.iree_module_dict[submodel] = {}
|
||||
(
|
||||
self.iree_module_dict[submodel]["vmfb"],
|
||||
self.iree_module_dict[submodel]["config"],
|
||||
self.iree_module_dict[submodel]["temp_file_to_unlink"],
|
||||
) = load_vmfb_using_mmap(
|
||||
os.path.join(vmfbs_path, file),
|
||||
self.device,
|
||||
device_idx=0,
|
||||
rt_flags=[],
|
||||
external_weight_file=self.model_map[submodel]['external_weight_file'],
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def safe_dict(self, kwargs: dict):
|
||||
flat_args = {}
|
||||
for i in kwargs:
|
||||
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
|
||||
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
|
||||
else:
|
||||
flat_args[i] = kwargs[i]
|
||||
|
||||
return flat_args
|
||||
|
||||
|
||||
def import_torch_ir(self, submodel, kwargs):
|
||||
weights = (
|
||||
submodel["custom_weights"] if submodel["custom_weights"] else None
|
||||
)
|
||||
torch_ir = self.model_map[submodel]["initializer"](
|
||||
self.base_model_id, **kwargs, compile_to="torch"
|
||||
**self.safe_dict(kwargs), compile_to="torch"
|
||||
)
|
||||
self.model_map[submodel]["tempfile_name"] = get_resource_path(
|
||||
f"{submodel}.torch.tempfile"
|
||||
)
|
||||
with open(self.model_map[submodel]["tempfile_name"], "w+") as f:
|
||||
if submodel == "clip":
|
||||
# clip.export_clip_model returns (torch_ir, tokenizer)
|
||||
torch_ir = torch_ir[0]
|
||||
self.tempfiles[submodel] = get_resource_path(os.path.join(
|
||||
"..", "shark_tmp", f"{submodel}.torch.tempfile"
|
||||
))
|
||||
|
||||
with open(self.tempfiles[submodel], "w+") as f:
|
||||
f.write(torch_ir)
|
||||
del torch_ir
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def load_vmfb(self, submodel):
|
||||
if submodel in self.iree_module_dict:
|
||||
print(
|
||||
f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}"
|
||||
)
|
||||
elif self.model_map[submodel]["tempfile_name"]:
|
||||
submodel["tempfile_name"]
|
||||
|
||||
return submodel["vmfb"]
|
||||
|
||||
def merge_custom_map(self, custom_model_map):
|
||||
for submodel in custom_model_map:
|
||||
for key in submodel:
|
||||
self.model_map[submodel][key] = key
|
||||
print(self.model_map)
|
||||
|
||||
def get_local_vmfbs(self, pipe_id):
|
||||
for submodel in self.model_map:
|
||||
vmfbs = []
|
||||
vmfb_matches = {}
|
||||
vmfbs_path = get_checkpoints_path("../vmfbs")
|
||||
for dirpath, dirnames, filenames in os.walk(vmfbs_path):
|
||||
vmfbs.extend(filenames)
|
||||
break
|
||||
for file in vmfbs:
|
||||
if all(keys in file for keys in [submodel, pipe_id]):
|
||||
print(f"Found existing .vmfb at {file}")
|
||||
self.iree_module_dict[submodel] = {"vmfb": file}
|
||||
|
||||
def get_compiled_map(self, device, pipe_id) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
if not self.import_mlir:
|
||||
self.get_local_vmfbs(pipe_id)
|
||||
for submodel in self.model_map:
|
||||
def load_submodels(self, submodels: list):
|
||||
for submodel in submodels:
|
||||
if submodel in self.iree_module_dict:
|
||||
if "vmfb" in self.iree_module_dict[submodel]:
|
||||
continue
|
||||
if "tempfile_name" not in self.model_map[submodel]:
|
||||
sub_kwargs = (
|
||||
self.model_map[submodel]["kwargs"]
|
||||
if self.model_map[submodel]["kwargs"]
|
||||
else {}
|
||||
)
|
||||
self.import_torch_ir(
|
||||
submodel, self.base_model_id, **sub_kwargs
|
||||
)
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
submodel["tempfile_name"],
|
||||
device=self.device,
|
||||
frontend="torch",
|
||||
external_weight_file=submodel["custom_weights"],
|
||||
print(
|
||||
f"\n[LOG] Loading .vmfb for {submodel} from {self.iree_module_dict[submodel]['vmfb']}"
|
||||
)
|
||||
# TODO: delete the temp file
|
||||
else:
|
||||
self.get_compiled_map(self.pipe_id, submodel)
|
||||
return
|
||||
|
||||
|
||||
def run(self, submodel, inputs):
|
||||
return
|
||||
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, inputs)]
|
||||
return self.iree_module_dict[submodel]['vmfb']['main'](*inp)
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("-", "_")
|
||||
|
||||
431
apps/shark_studio/modules/prompt_encoding.py
Normal file
431
apps/shark_studio/modules/prompt_encoding.py
Normal file
@@ -0,0 +1,431 @@
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from iree import runtime as ireert
|
||||
import re
|
||||
import torch
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs:
|
||||
text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith("\\"):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == "(":
|
||||
round_brackets.append(len(res))
|
||||
elif text == "[":
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ")" and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == "]" and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_prompts_with_weights(
|
||||
pipe, prompt: List[str], max_length: int
|
||||
):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||
No padding, starting or ending token is included.
|
||||
"""
|
||||
tokens = []
|
||||
weights = []
|
||||
truncated = False
|
||||
for text in prompt:
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
text_token = []
|
||||
text_weight = []
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||
text_token += token
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
# truncate
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
text_token = text_token[:max_length]
|
||||
text_weight = text_weight[:max_length]
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
print(
|
||||
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
|
||||
)
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def pad_tokens_and_weights(
|
||||
tokens,
|
||||
weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=True,
|
||||
chunk_length=77,
|
||||
):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
max_embeddings_multiples = 8
|
||||
weights_length = (
|
||||
max_length
|
||||
if no_boseos_middle
|
||||
else max_embeddings_multiples * chunk_length
|
||||
)
|
||||
for i in range(len(tokens)):
|
||||
tokens[i] = (
|
||||
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
||||
)
|
||||
if no_boseos_middle:
|
||||
weights[i] = (
|
||||
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||||
)
|
||||
else:
|
||||
w = []
|
||||
if len(weights[i]) == 0:
|
||||
w = [1.0] * weights_length
|
||||
else:
|
||||
for j in range(max_embeddings_multiples):
|
||||
w.append(1.0) # weight for starting token in this chunk
|
||||
w += weights[i][
|
||||
j
|
||||
* (chunk_length - 2) : min(
|
||||
len(weights[i]), (j + 1) * (chunk_length - 2)
|
||||
)
|
||||
]
|
||||
w.append(1.0) # weight for ending token in this chunk
|
||||
w += [1.0] * (weights_length - len(w))
|
||||
weights[i] = w[:]
|
||||
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
||||
if max_embeddings_multiples > 1:
|
||||
text_embeddings = []
|
||||
for i in range(max_embeddings_multiples):
|
||||
# extract the i-th chunk
|
||||
text_input_chunk = text_input[
|
||||
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
|
||||
].clone()
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
|
||||
text_embedding = pipe.run("clip", text_input_chunk)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
text_embedding = text_embedding[:, :-1]
|
||||
elif i == max_embeddings_multiples - 1:
|
||||
# discard the starting token
|
||||
text_embedding = text_embedding[:, 1:]
|
||||
else:
|
||||
# discard both starting and ending tokens
|
||||
text_embedding = text_embedding[:, 1:-1]
|
||||
|
||||
text_embeddings.append(text_embedding)
|
||||
# SHARK: Convert the result to tensor
|
||||
# text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
text_embeddings_np = np.concatenate(np.array(text_embeddings))
|
||||
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
|
||||
else:
|
||||
text_embeddings = pipe.run("clip", text_input)[0]
|
||||
# text_embeddings = torch.from_numpy(text_embeddings)[None, :]
|
||||
return torch.from_numpy(text_embeddings.to_host())
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
"""
|
||||
max_embeddings_multiples = 8
|
||||
text_embeddings = []
|
||||
for i in range(max_embeddings_multiples):
|
||||
# extract the i-th chunk
|
||||
text_input_chunk = text_input[
|
||||
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
|
||||
].clone()
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
||||
|
||||
print(text_input_chunk)
|
||||
breakpoint()
|
||||
text_embedding = pipe.run("clip", text_input_chunk)
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
text_embedding = text_embedding[:, :-1]
|
||||
elif i == max_embeddings_multiples - 1:
|
||||
# discard the starting token
|
||||
text_embedding = text_embedding[:, 1:]
|
||||
else:
|
||||
# discard both starting and ending tokens
|
||||
text_embedding = text_embedding[:, 1:-1]
|
||||
|
||||
text_embeddings.append(text_embedding)
|
||||
# SHARK: Convert the result to tensor
|
||||
# text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
text_embeddings_np = np.concatenate(np.array(text_embeddings))
|
||||
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
|
||||
return text_embeddings
|
||||
|
||||
|
||||
# This function deals with NoneType values occuring in tokens after padding
|
||||
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
|
||||
def filter_nonetype_tokens(tokens: List[List]):
|
||||
return [[49407 if token is None else token for token in tokens[0]]]
|
||||
|
||||
|
||||
def get_weighted_text_embeddings(
|
||||
pipe,
|
||||
prompt: List[str],
|
||||
uncond_prompt: List[str] = None,
|
||||
max_embeddings_multiples: Optional[int] = 8,
|
||||
no_boseos_middle: Optional[bool] = False,
|
||||
skip_parsing: Optional[bool] = False,
|
||||
skip_weighting: Optional[bool] = False,
|
||||
):
|
||||
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
|
||||
if not skip_parsing:
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(
|
||||
pipe, prompt, max_length - 2
|
||||
)
|
||||
if uncond_prompt is not None:
|
||||
uncond_tokens, uncond_weights = get_prompts_with_weights(
|
||||
pipe, uncond_prompt, max_length - 2
|
||||
)
|
||||
else:
|
||||
prompt_tokens = [
|
||||
token[1:-1]
|
||||
for token in pipe.tokenizer(
|
||||
prompt, max_length=max_length, truncation=True
|
||||
).input_ids
|
||||
]
|
||||
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
||||
if uncond_prompt is not None:
|
||||
if isinstance(uncond_prompt, str):
|
||||
uncond_prompt = [uncond_prompt]
|
||||
uncond_tokens = [
|
||||
token[1:-1]
|
||||
for token in pipe.tokenizer(
|
||||
uncond_prompt, max_length=max_length, truncation=True
|
||||
).input_ids
|
||||
]
|
||||
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
||||
|
||||
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
||||
max_length = max([len(token) for token in prompt_tokens])
|
||||
if uncond_prompt is not None:
|
||||
max_length = max(
|
||||
max_length, max([len(token) for token in uncond_tokens])
|
||||
)
|
||||
|
||||
max_embeddings_multiples = min(
|
||||
max_embeddings_multiples,
|
||||
(max_length - 1) // (pipe.model_max_length - 2) + 1,
|
||||
)
|
||||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
|
||||
# pad the length of tokens and weights
|
||||
bos = pipe.tokenizer.bos_token_id
|
||||
eos = pipe.tokenizer.eos_token_id
|
||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||
prompt_tokens,
|
||||
prompt_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.model_max_length,
|
||||
)
|
||||
|
||||
# FIXME: This is a hacky fix caused by tokenizer padding with None values
|
||||
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
|
||||
|
||||
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
||||
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
|
||||
if uncond_prompt is not None:
|
||||
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
||||
uncond_tokens,
|
||||
uncond_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.model_max_length,
|
||||
)
|
||||
|
||||
# FIXME: This is a hacky fix caused by tokenizer padding with None values
|
||||
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
|
||||
|
||||
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
||||
uncond_tokens = torch.tensor(
|
||||
uncond_tokens, dtype=torch.long, device="cpu"
|
||||
)
|
||||
|
||||
# get the embeddings
|
||||
text_embeddings = get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
prompt_tokens,
|
||||
pipe.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
||||
prompt_weights = torch.tensor(
|
||||
prompt_weights, dtype=torch.float, device="cpu"
|
||||
)
|
||||
if uncond_prompt is not None:
|
||||
uncond_embeddings = get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
uncond_tokens,
|
||||
pipe.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
||||
uncond_weights = torch.tensor(
|
||||
uncond_weights, dtype=torch.float, device="cpu"
|
||||
)
|
||||
|
||||
# assign weights to the prompts and normalize in the sense of mean
|
||||
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
||||
if (not skip_parsing) and (not skip_weighting):
|
||||
previous_mean = (
|
||||
text_embeddings.float()
|
||||
.mean(axis=[-2, -1])
|
||||
.to(text_embeddings.dtype)
|
||||
)
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
current_mean = (
|
||||
text_embeddings.float()
|
||||
.mean(axis=[-2, -1])
|
||||
.to(text_embeddings.dtype)
|
||||
)
|
||||
text_embeddings *= (
|
||||
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
)
|
||||
if uncond_prompt is not None:
|
||||
previous_mean = (
|
||||
uncond_embeddings.float()
|
||||
.mean(axis=[-2, -1])
|
||||
.to(uncond_embeddings.dtype)
|
||||
)
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
current_mean = (
|
||||
uncond_embeddings.float()
|
||||
.mean(axis=[-2, -1])
|
||||
.to(uncond_embeddings.dtype)
|
||||
)
|
||||
uncond_embeddings *= (
|
||||
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
)
|
||||
|
||||
if uncond_prompt is not None:
|
||||
return text_embeddings, uncond_embeddings
|
||||
return text_embeddings, None
|
||||
@@ -1,4 +1,105 @@
|
||||
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def get_schedulers(model_id):
|
||||
#TODO: switch over to turbine and run all on GPU
|
||||
print(f"[LOG] Initializing schedulers from model id: {model_id}")
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep++"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistepKarras"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistepKarras++"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"EulerAncestralDiscrete"
|
||||
] = EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverSinglestep"
|
||||
] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"KDPM2AncestralDiscrete"
|
||||
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
return schedulers
|
||||
|
||||
|
||||
def export_scheduler_model(model):
|
||||
|
||||
@@ -453,12 +453,6 @@ p.add_argument(
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--custom_model_map",
|
||||
type=str,
|
||||
default="",
|
||||
help="path to custom model map to import. This should be a .json file",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -126,26 +126,6 @@ def webui():
|
||||
#
|
||||
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
|
||||
# sys.exit(0)
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.shark_studio.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
clear_tmp_mlir,
|
||||
clear_tmp_imgs,
|
||||
)
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
create_checkpoint_folders,
|
||||
)
|
||||
|
||||
import gradio as gr
|
||||
|
||||
config_tmp()
|
||||
clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
create_checkpoint_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
|
||||
@@ -50,7 +50,7 @@ def get_generated_imgs_todays_subdir() -> str:
|
||||
|
||||
|
||||
def create_checkpoint_folders():
|
||||
dir = ["vae", "lora"]
|
||||
dir = ["vae", "lora", "../vmfb"]
|
||||
if not cmd_opts.ckpt_dir:
|
||||
dir.insert(0, "models")
|
||||
else:
|
||||
|
||||
@@ -106,7 +106,6 @@ def get_iree_frontend_args(frontend):
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-stream-resource-max-allocation-size=4294967295",
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user