Model loading

This commit is contained in:
Ean Garvey
2023-12-16 23:14:44 -06:00
parent 82a68ee6f6
commit 1f13facde4
10 changed files with 883 additions and 221 deletions

View File

@@ -1,14 +1,17 @@
import importlib
import logging
import os
import signal
import sys
import re
import warnings
import json
from threading import Thread
from apps.shark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
)
def imports():
@@ -21,6 +24,9 @@ def imports():
warnings.filterwarnings(
action="ignore", category=UserWarning, module="torchvision"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, message='.*is deprecated, please use.*', module="*torch*"
)
import gradio # noqa: F401
@@ -34,20 +40,27 @@ def imports():
from apps.shark_studio.modules import (
img_processing,
) # noqa: F401
from apps.shark_studio.modules.schedulers import scheduler_model_map
startup_timer.record("other imports")
def initialize():
configure_sigint_handler()
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# from apps.shark_studio.modules import modelloader
# modelloader.cleanup_models()
config_tmp()
clear_tmp_mlir()
clear_tmp_imgs()
# from apps.shark_studio.modules import sd_models
# sd_models.setup_model()
# startup_timer.record("setup SD model")
from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
)
# Create custom models folders if they don't exist
create_checkpoint_folders()
import gradio as gr
# initialize_rest(reload_script_modules=False)

View File

@@ -1,125 +1,61 @@
import gc
from unittest import registerResult
import torch
import time
import os
import json
import numpy as np
from pathlib import Path
from turbine_models.custom_models.sd_inference import clip, unet, vae
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path
from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path, get_checkpoints_path
from apps.shark_studio.modules.pipeline import SharkPipelineBase
from apps.shark_studio.modules.schedulers import get_schedulers
from apps.shark_studio.modules.prompt_encoding import get_weighted_text_embeddings
from apps.shark_studio.modules.img_processing import (
resize_stencil,
save_output_img,
)
from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
process_custom_pipe_weights,
)
from transformers import CLIPTokenizer
from math import ceil
from PIL import Image
sd_model_map = {
"CompVis/stable-diffusion-v1-4": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"clip": {
"initializer": clip.export_clip_model,
"external_weight_file": None,
"ireec_flags": ["--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
},
"runwayml/stable-diffusion-v1-5": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"external_weight_file": None,
},
"stabilityai/stable-diffusion-2-1-base": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"ireec_flags": ["--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
"external_weight_file": None,
},
"stabilityai/stable_diffusion-xl-1.0": {
"clip_1": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"clip_2": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"external_weight_file": None,
},
}
def get_spec(custom_sd_map: dict, sd_embeds: dict):
spec = []
for key in custom_sd_map:
if "control" in key.split("_"):
spec.append("controlled")
elif key == "custom_vae":
spec.append(custom_sd_map[key]["custom_weights"].split(".")[0])
num_embeds = 0
embeddings_spec = None
for embed in sd_embeds:
if embed is not None:
num_embeds += 1
embeddings_spec = str(num_embeds) + "embeds"
if embeddings_spec:
spec.append(embeddings_spec)
return "_".join(spec)
class StableDiffusion(SharkPipelineBase):
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
@@ -132,39 +68,143 @@ class StableDiffusion(SharkPipelineBase):
# embeddings: a dict of embedding checkpoints or model IDs to use when
# initializing the compiled modules.
def __init__(
self,
base_model_id: str = "runwayml/stable-diffusion-v1-5",
height: int = 512,
width: int = 512,
precision: str = "fp16",
device: str = None,
custom_model_map: dict = {},
embeddings: dict = {},
base_model_id,
height: int,
width: int,
batch_size: int,
precision: str,
device: str,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
is_img2img: bool = False,
is_controlled: bool = False,
):
super().__init__(
sd_model_map[base_model_id], base_model_id, device, import_ir
)
self.model_max_length = 77
self.batch_size = batch_size
self.precision = precision
self.is_img2img = is_img2img
self.pipe_id = (
safe_name(base_model_id)
+ str(height)
+ str(width)
+ precision
+ device
+ get_spec(custom_model_map, embeddings)
self.scheduler_obj = {}
self.precision = precision
static_kwargs = {
"pipe": {},
"clip": {"hf_model_name": base_model_id},
"unet": {
"hf_model_name": base_model_id,
"unet_model": unet.UnetModel(hf_model_name=base_model_id, hf_auth_token=None),
"batch_size": batch_size,
#"is_controlled": is_controlled,
#"num_loras": num_loras,
"height": height,
"width": width,
},
"vae_encode": {
"hf_model_name": custom_vae if custom_vae else base_model_id,
"vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None),
"batch_size": batch_size,
"height": height,
"width": width,
},
"vae_decode": {
"hf_model_name": custom_vae,
"vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None),
"batch_size": batch_size,
"height": height,
"width": width,
},
}
super().__init__(
sd_model_map, base_model_id, static_kwargs, device, import_ir
)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}")
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
f"{str(height)}x{str(width)}",
precision,
]
if num_loras > 0:
pipe_id_list.append(str(num_loras)+"lora")
if is_controlled:
pipe_id_list.append("controlled")
if custom_vae:
pipe_id_list.append(custom_vae)
self.pipe_id = "_".join(pipe_id_list)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
del static_kwargs
gc.collect()
def prepare_pipe(self, scheduler, custom_model_map, embeddings):
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings):
print(
f"\n[LOG] Preparing pipeline with scheduler {scheduler}, custom map {json.dumps(custom_model_map)}, and embeddings {json.dumps(embeddings)}."
f"\n[LOG] Preparing pipeline with scheduler {scheduler}"
f"\n[LOG] Custom embeddings currently unsupported."
)
self.get_compiled_map(device=self.device, pipe_id=self.pipe_id)
return None
schedulers = get_schedulers(self.base_model_id)
self.weights_path = get_checkpoints_path(self.pipe_id)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
# accepting a list of schedulers in batched cases.
for i in scheduler:
self.scheduler_obj[i] = schedulers[i]
print(f"[LOG] Loaded scheduler: {i}")
for model in adapters:
self.model_map[model] = adapters[model]
if os.path.isfile(custom_weights):
for i in self.model_map:
self.model_map[i]["external_weights_file"] = None
elif custom_weights != "":
print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?")
self.static_kwargs["pipe"] = {
# "external_weight_path": self.weights_path,
# "external_weights": "safetensors",
}
self.get_compiled_map(pipe_id=self.pipe_id)
print("\n[LOG] Pipeline successfully prepared for runtime.")
return
def encode_prompts_weight(
self,
prompt,
negative_prompt,
do_classifier_free_guidance=True,
):
# Encodes the prompt into text encoder hidden states.
self.load_submodels(["clip"])
self.tokenizer = CLIPTokenizer.from_pretrained(
self.base_model_id,
subfolder="tokenizer",
)
clip_inf_start = time.time()
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt
if do_classifier_free_guidance
else None,
)
if do_classifier_free_guidance:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, 512 - text_embeddings.shape[1])
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}")
return text_embeddings.numpy().astype(np.float16)
def generate_images(
self,
@@ -181,11 +221,35 @@ class StableDiffusion(SharkPipelineBase):
hints,
):
print("\n[LOG] Generating images...")
batched_args=[
prompt,
negative_prompt,
steps,
strength,
guidance_scale,
seed,
resample_type,
control_mode,
hints,
]
for arg in batched_args:
if not isinstance(arg, list):
arg = [arg] * self.batch_size
if len(arg) < self.batch_size:
arg = arg * self.batch_size
else:
arg = [arg[i] for i in range(self.batch_size)]
text_embeddings = self.encode_prompts_weight(
prompt,
negative_prompt,
)
print(text_embeddings)
test_img = [
Image.open(
get_resource_path("../../tests/jupiter.png"), mode="r"
).convert("RGB")
]
] * self.batch_size
return test_img
@@ -257,29 +321,31 @@ def shark_sd_fn(
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
custom_model_map = {}
adapters = {}
is_controlled = False
control_mode = None
hints = []
if custom_weights != "None":
custom_model_map["unet"] = {"custom_weights": custom_weights}
if custom_vae != "None":
custom_model_map["vae"] = {"custom_weights": custom_vae}
num_loras = 0
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
if "model" in controlnets:
for i, model in enumerate(controlnets["model"]):
if "xl" not in base_model_id.lower():
custom_model_map[f"control_adapter_{model}"] = {
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map[
"runwayml/stable-diffusion-v1-5"
][model],
"strength": controlnets["strength"][i],
}
else:
custom_model_map[f"control_adapter_{model}"] = {
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map[
"stabilityai/stable-diffusion-xl-1.0"
][model],
"strength": controlnets["strength"][i],
}
if model is not None:
is_controlled=True
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
@@ -288,16 +354,19 @@ def shark_sd_fn(
"base_model_id": base_model_id,
"height": height,
"width": width,
"batch_size": batch_size,
"precision": precision,
"device": device,
"custom_model_map": custom_model_map,
"embeddings": embeddings,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": cmd_opts.import_mlir,
"is_img2img": is_img2img,
"is_controlled": is_controlled,
}
submit_prep_kwargs = {
"scheduler": scheduler,
"custom_model_map": custom_model_map,
"custom_weights": custom_weights,
"adapters": adapters,
"embeddings": embeddings,
}
submit_run_kwargs = {
@@ -313,6 +382,7 @@ def shark_sd_fn(
"control_mode": control_mode,
"hints": hints,
}
print(submit_pipe_kwargs)
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs

View File

@@ -92,6 +92,8 @@ def process_custom_pipe_weights(custom_weights):
custom_weights_tgt = get_path_to_diffusers_checkpoint(
custom_weights
)
custom_weights_params = custom_weights
return custom_weights_params, custom_weights_tgt
def get_civitai_checkpoint(url: str):

View File

@@ -1,8 +1,14 @@
from shark.iree_utils.compile_utils import get_iree_compiled_module
from msvcrt import kbhit
from shark.iree_utils.compile_utils import get_iree_compiled_module, load_vmfb_using_mmap
from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_resource_path,
)
from apps.shark_studio.modules.shared_cmd_opts import (
cmd_opts,
)
from iree import runtime as ireert
from pathlib import Path
import gc
import os
@@ -19,86 +25,152 @@ class SharkPipelineBase:
self,
model_map: dict,
base_model_id: str,
static_kwargs: dict,
device: str,
import_mlir: bool = True,
):
self.model_map = model_map
self.static_kwargs = static_kwargs
self.base_model_id = base_model_id
self.device = device
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tempfiles = {}
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# First checks whether we have .vmfbs precompiled, then populates the map
# with the precompiled executables and fetches executables for the rest of the map.
# The weights aren't static here anymore so this function should be a part of pipeline
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = pipe_id
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
print("\n[LOG] Checking for pre-compiled artifacts.")
if submodel == "None":
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
else:
self.get_precompiled(pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
if "vmfb" in self.iree_module_dict[submodel]:
print(f"[LOG] Found executable for {submodel} at {self.iree_module_dict[submodel]['vmfb']}...")
return
elif submodel not in self.tempfiles:
print(f"[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
if submodel in self.static_kwargs:
init_kwargs = self.static_kwargs[submodel]
for key in self.static_kwargs["pipe"]:
if key not in init_kwargs:
init_kwargs[key] = self.static_kwargs["pipe"][key]
self.import_torch_ir(
submodel, init_kwargs
)
self.get_compiled_map(pipe_id, submodel)
else:
ireec_flags = self.model_map[submodel]["ireec_flags"] if "ireec_flags" in self.model_map[submodel] else []
if "external_weights_file" in self.model_map[submodel]:
weights_path = self.model_map[submodel]["external_weights_file"]
else:
weights_path = None
self.iree_module_dict[submodel] = get_iree_compiled_module(
self.tempfiles[submodel],
device=self.device,
frontend="torch",
mmap=True,
external_weight_file=weights_path,
extra_args=ireec_flags,
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb")
)
return
def hijack_weights(self, weights_path, submodel="None"):
if submodel == "None":
for i in self.model_map:
self.hijack_weights(weights_path, i)
else:
if submodel in self.iree_module_dict:
self.model_map[submodel]["external_weights_file"] = weights_path
return
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
self.get_precompiled(pipe_id, model)
vmfbs = []
vmfb_matches = {}
vmfbs_path = self.pipe_vmfb_path
for dirpath, dirnames, filenames in os.walk(vmfbs_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if submodel in file:
print(f"Found existing .vmfb at {file}")
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
os.path.join(vmfbs_path, file),
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=self.model_map[submodel]['external_weight_file'],
)
return
def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]
return flat_args
def import_torch_ir(self, submodel, kwargs):
weights = (
submodel["custom_weights"] if submodel["custom_weights"] else None
)
torch_ir = self.model_map[submodel]["initializer"](
self.base_model_id, **kwargs, compile_to="torch"
**self.safe_dict(kwargs), compile_to="torch"
)
self.model_map[submodel]["tempfile_name"] = get_resource_path(
f"{submodel}.torch.tempfile"
)
with open(self.model_map[submodel]["tempfile_name"], "w+") as f:
if submodel == "clip":
# clip.export_clip_model returns (torch_ir, tokenizer)
torch_ir = torch_ir[0]
self.tempfiles[submodel] = get_resource_path(os.path.join(
"..", "shark_tmp", f"{submodel}.torch.tempfile"
))
with open(self.tempfiles[submodel], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
return
def load_vmfb(self, submodel):
if submodel in self.iree_module_dict:
print(
f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}"
)
elif self.model_map[submodel]["tempfile_name"]:
submodel["tempfile_name"]
return submodel["vmfb"]
def merge_custom_map(self, custom_model_map):
for submodel in custom_model_map:
for key in submodel:
self.model_map[submodel][key] = key
print(self.model_map)
def get_local_vmfbs(self, pipe_id):
for submodel in self.model_map:
vmfbs = []
vmfb_matches = {}
vmfbs_path = get_checkpoints_path("../vmfbs")
for dirpath, dirnames, filenames in os.walk(vmfbs_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if all(keys in file for keys in [submodel, pipe_id]):
print(f"Found existing .vmfb at {file}")
self.iree_module_dict[submodel] = {"vmfb": file}
def get_compiled_map(self, device, pipe_id) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
if not self.import_mlir:
self.get_local_vmfbs(pipe_id)
for submodel in self.model_map:
def load_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
if "vmfb" in self.iree_module_dict[submodel]:
continue
if "tempfile_name" not in self.model_map[submodel]:
sub_kwargs = (
self.model_map[submodel]["kwargs"]
if self.model_map[submodel]["kwargs"]
else {}
)
self.import_torch_ir(
submodel, self.base_model_id, **sub_kwargs
)
self.iree_module_dict[submodel] = get_iree_compiled_module(
submodel["tempfile_name"],
device=self.device,
frontend="torch",
external_weight_file=submodel["custom_weights"],
print(
f"\n[LOG] Loading .vmfb for {submodel} from {self.iree_module_dict[submodel]['vmfb']}"
)
# TODO: delete the temp file
else:
self.get_compiled_map(self.pipe_id, submodel)
return
def run(self, submodel, inputs):
return
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, inputs)]
return self.iree_module_dict[submodel]['vmfb']['main'](*inp)
def safe_name(name):
return name.replace("/", "_").replace("-", "_")

View File

@@ -0,0 +1,431 @@
from typing import List, Optional, Union
from iree import runtime as ireert
import re
import torch
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(
pipe, prompt: List[str], max_length: int
):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = 8
weights_length = (
max_length
if no_boseos_middle
else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = (
[bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
)
if no_boseos_middle:
weights[i] = (
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
)
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.run("clip", text_input_chunk)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
else:
text_embeddings = pipe.run("clip", text_input)[0]
# text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return torch.from_numpy(text_embeddings.to_host())
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = 8
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
print(text_input_chunk)
breakpoint()
text_embedding = pipe.run("clip", text_input_chunk)
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]
def get_weighted_text_embeddings(
pipe,
prompt: List[str],
uncond_prompt: List[str] = None,
max_embeddings_multiples: Optional[int] = 8,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(
prompt_weights, dtype=torch.float, device="cpu"
)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(
uncond_weights, dtype=torch.float, device="cpu"
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float()
.mean(axis=[-2, -1])
.to(text_embeddings.dtype)
)
text_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -1,4 +1,105 @@
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
def get_schedulers(model_id):
#TODO: switch over to turbine and run all on GPU
print(f"[LOG] Initializing schedulers from model id: {model_id}")
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
schedulers[
"DPMSolverMultistep++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"KDPM2AncestralDiscrete"
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
return schedulers
def export_scheduler_model(model):

View File

@@ -453,12 +453,6 @@ p.add_argument(
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
p.add_argument(
"--custom_model_map",
type=str,
default="",
help="path to custom model map to import. This should be a .json file",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################

View File

@@ -126,26 +126,6 @@ def webui():
#
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
)
from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
)
import gradio as gr
config_tmp()
clear_tmp_mlir()
clear_tmp_imgs()
# Create custom models folders if they don't exist
create_checkpoint_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""

View File

@@ -50,7 +50,7 @@ def get_generated_imgs_todays_subdir() -> str:
def create_checkpoint_folders():
dir = ["vae", "lora"]
dir = ["vae", "lora", "../vmfb"]
if not cmd_opts.ckpt_dir:
dir.insert(0, "models")
else:

View File

@@ -106,7 +106,6 @@ def get_iree_frontend_args(frontend):
# Common args to be used given any frontend or device.
def get_iree_common_args(debug=False):
common_args = [
"--iree-stream-resource-max-allocation-size=4294967295",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]