mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
589 lines
19 KiB
Python
589 lines
19 KiB
Python
import gc
|
|
import torch
|
|
import gradio as gr
|
|
import time
|
|
import os
|
|
import json
|
|
import numpy as np
|
|
import copy
|
|
import importlib.util
|
|
import sys
|
|
from tqdm.auto import tqdm
|
|
|
|
from pathlib import Path
|
|
from random import randint
|
|
|
|
|
|
|
|
from apps.shark_studio.api.controlnet import control_adapter_map
|
|
from apps.shark_studio.api.utils import parse_device
|
|
from apps.shark_studio.web.utils.state import status_label
|
|
from apps.shark_studio.web.utils.file_utils import (
|
|
safe_name,
|
|
get_resource_path,
|
|
get_checkpoints_path,
|
|
)
|
|
|
|
from apps.shark_studio.modules.img_processing import (
|
|
save_output_img,
|
|
)
|
|
|
|
|
|
from subprocess import check_output
|
|
EMPTY_SD_MAP = {
|
|
"clip": None,
|
|
"unet": None,
|
|
"vae_decode": None,
|
|
}
|
|
|
|
EMPTY_SDXL_MAP = {
|
|
"prompt_encoder": None,
|
|
"unet": None,
|
|
"vae_decode": None,
|
|
}
|
|
|
|
EMPTY_SD3_MAP = {
|
|
"clip": None,
|
|
"mmdit": None,
|
|
"vae": None,
|
|
}
|
|
|
|
EMPTY_FLAGS = {
|
|
"clip": None,
|
|
"unet": None,
|
|
"mmdit": None,
|
|
"vae": None,
|
|
"pipeline": None,
|
|
"scheduler": None,
|
|
}
|
|
|
|
|
|
def load_script(source, module_name):
|
|
"""
|
|
reads file source and loads it as a module
|
|
|
|
:param source: file to load
|
|
:param module_name: name of module to register in sys.modules
|
|
:return: loaded module
|
|
"""
|
|
spec = importlib.util.spec_from_file_location(module_name, source)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
|
|
return module
|
|
|
|
|
|
class StableDiffusion:
|
|
# This class is responsible for executing image generation and creating
|
|
# /managing a set of compiled modules to run Stable Diffusion. The init
|
|
# aims to be as general as possible, and the class will infer and compile
|
|
# a list of necessary modules or a combined "pipeline module" for a
|
|
# specified job based on the inference task.
|
|
|
|
def __init__(
|
|
self,
|
|
base_model_id,
|
|
height: int,
|
|
width: int,
|
|
batch_size: int,
|
|
precision: str,
|
|
device: str,
|
|
steps: int = 50,
|
|
scheduler_id: str = None,
|
|
clip_device: str = None,
|
|
vae_device: str = None,
|
|
target_triple: str = None,
|
|
custom_vae: str = None,
|
|
num_loras: int = 0,
|
|
import_ir: bool = True,
|
|
is_controlled: bool = False,
|
|
external_weights: str = "safetensors",
|
|
vae_precision: str = "fp16",
|
|
cpu_scheduling: bool = False,
|
|
progress=gr.Progress(),
|
|
):
|
|
progress(0, desc="Initializing pipeline...")
|
|
self.ui_device = device
|
|
backend, target = parse_device(device, target_triple)
|
|
if clip_device:
|
|
clip_device, clip_target = parse_device(clip_device)
|
|
else:
|
|
clip_device, clip_target = backend, target
|
|
if vae_device:
|
|
vae_device, vae_target = parse_device(vae_device)
|
|
else:
|
|
vae_device, vae_target = backend, target
|
|
devices = {
|
|
"clip": clip_device,
|
|
"mmdit": backend,
|
|
"unet": backend,
|
|
"vae": vae_device,
|
|
}
|
|
targets = {
|
|
"clip": clip_target,
|
|
"mmdit": target,
|
|
"unet": target,
|
|
"vae": vae_target,
|
|
}
|
|
self.precision = precision
|
|
self.compiled_pipeline = False
|
|
self.base_model_id = base_model_id
|
|
self.custom_vae = custom_vae
|
|
self.is_sdxl = "xl" in self.base_model_id.lower()
|
|
self.is_sd3 = "stable-diffusion-3" in self.base_model_id.lower()
|
|
self.is_custom = ".py" in self.base_model_id.lower()
|
|
if self.is_custom:
|
|
custom_module = load_script(
|
|
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
|
|
"custom_pipeline",
|
|
)
|
|
self.turbine_pipe = custom_module.StudioPipeline
|
|
self.dynamic_steps = False
|
|
self.model_map = custom_module.MODEL_MAP
|
|
elif self.is_sdxl:
|
|
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
|
|
SharkSDXLPipeline,
|
|
)
|
|
self.turbine_pipe = SharkSDXLPipeline
|
|
self.dynamic_steps = True
|
|
self.model_map = EMPTY_SDXL_MAP
|
|
elif self.is_sd3:
|
|
from turbine_models.custom_models.sd3_inference.sd3_pipeline import SharkSD3Pipeline, empty_pipe_dict
|
|
|
|
self.turbine_pipe = SharkSD3Pipeline
|
|
self.dynamic_steps = True
|
|
self.model_map = EMPTY_SD3_MAP
|
|
else:
|
|
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
|
|
|
|
self.turbine_pipe = SharkSDPipeline
|
|
self.dynamic_steps = True
|
|
self.model_map = EMPTY_SD_MAP
|
|
# no multi-device yet
|
|
devices = backend
|
|
targets = target
|
|
max_length = 64
|
|
|
|
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), "vmfbs"))
|
|
self.weights_path = Path(os.path.join(get_checkpoints_path(), "weights"))
|
|
if not os.path.exists(self.pipeline_dir):
|
|
os.mkdir(self.pipeline_dir)
|
|
if not os.path.exists(self.weights_path):
|
|
os.mkdir(self.weights_path)
|
|
|
|
decomp_attn = True
|
|
attn_spec = None
|
|
if target_triple in ["gfx940", "gfx942", "gfx90a"]:
|
|
decomp_attn = False
|
|
attn_spec = "mfma"
|
|
elif target in ["gfx1100", "gfx1103", "gfx1150"]:
|
|
decomp_attn = False
|
|
attn_spec = "wmma"
|
|
if target in ["gfx1103", "gfx1150"]:
|
|
# external weights have issues on igpu
|
|
external_weights = None
|
|
elif backend == "llvm-cpu":
|
|
decomp_attn = False
|
|
progress(0.5, desc="Initializing pipeline...")
|
|
self.sd_pipe = self.turbine_pipe(
|
|
hf_model_name=base_model_id,
|
|
scheduler_id=scheduler_id,
|
|
height=height,
|
|
width=width,
|
|
precision=precision,
|
|
max_length=max_length,
|
|
batch_size=batch_size,
|
|
num_inference_steps=steps,
|
|
device=devices,
|
|
iree_target_triple=targets,
|
|
ireec_flags=EMPTY_FLAGS,
|
|
attn_spec=attn_spec,
|
|
decomp_attn=decomp_attn,
|
|
pipeline_dir=self.pipeline_dir,
|
|
external_weights_dir=self.weights_path,
|
|
external_weights=external_weights,
|
|
vae_precision=vae_precision,
|
|
cpu_scheduling=cpu_scheduling,
|
|
)
|
|
progress(1, desc="Pipeline initialized!...")
|
|
gc.collect()
|
|
|
|
def prepare_pipe(
|
|
self,
|
|
custom_weights,
|
|
adapters,
|
|
embeddings,
|
|
is_img2img,
|
|
compiled_pipeline = False,
|
|
cpu_scheduling=False,
|
|
progress=gr.Progress(),
|
|
):
|
|
progress(0, desc="Preparing models...")
|
|
pipe_map = copy.deepcopy(self.model_map)
|
|
if compiled_pipeline and self.is_sdxl:
|
|
pipe_map.pop("scheduler")
|
|
pipe_map.pop("unet")
|
|
pipe_map["scheduled_unet"] = None
|
|
pipe_map["full_pipeline"] = None
|
|
if cpu_scheduling:
|
|
pipe_map.pop("scheduler")
|
|
self.is_img2img = False
|
|
mlirs = copy.deepcopy(pipe_map)
|
|
vmfbs = copy.deepcopy(pipe_map)
|
|
weights = copy.deepcopy(pipe_map)
|
|
if not self.is_sdxl:
|
|
compiled_pipeline = False
|
|
self.compiled_pipeline = compiled_pipeline
|
|
|
|
if custom_weights:
|
|
from apps.shark_studio.modules.ckpt_processing import (
|
|
preprocessCKPT,
|
|
save_irpa,
|
|
)
|
|
custom_weights = os.path.join(
|
|
get_checkpoints_path("checkpoints"),
|
|
safe_name(self.base_model_id.split("/")[-1]),
|
|
custom_weights,
|
|
)
|
|
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
|
|
for key in weights:
|
|
if key in ["scheduled_unet", "unet"]:
|
|
unet_weights_path = os.path.join(
|
|
diffusers_weights_path,
|
|
"unet",
|
|
"diffusion_pytorch_model.safetensors",
|
|
)
|
|
weights[key] = save_irpa(unet_weights_path, "unet.")
|
|
|
|
elif key in ["clip", "prompt_encoder"]:
|
|
if not self.is_sdxl:
|
|
sd1_path = os.path.join(
|
|
diffusers_weights_path, "text_encoder", "model.safetensors"
|
|
)
|
|
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
|
|
else:
|
|
clip_1_path = os.path.join(
|
|
diffusers_weights_path, "text_encoder", "model.safetensors"
|
|
)
|
|
clip_2_path = os.path.join(
|
|
diffusers_weights_path,
|
|
"text_encoder_2",
|
|
"model.safetensors",
|
|
)
|
|
weights[key] = [
|
|
save_irpa(clip_1_path, "text_encoder_model_1."),
|
|
save_irpa(clip_2_path, "text_encoder_model_2."),
|
|
]
|
|
|
|
elif key in ["vae_decode"] and weights[key] is None:
|
|
vae_weights_path = os.path.join(
|
|
diffusers_weights_path,
|
|
"vae",
|
|
"diffusion_pytorch_model.safetensors",
|
|
)
|
|
weights[key] = save_irpa(vae_weights_path, "vae.")
|
|
progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...")
|
|
|
|
vmfbs, weights = self.sd_pipe.check_prepared(
|
|
mlirs, vmfbs, weights, interactive=False
|
|
)
|
|
progress(0.5, desc=f"Artifacts ready!")
|
|
progress(0.75, desc=f"Loading models and weights...")
|
|
|
|
self.sd_pipe.load_pipeline(
|
|
vmfbs, weights, self.compiled_pipeline
|
|
)
|
|
progress(1, desc="Pipeline loaded! Generating images...")
|
|
return
|
|
|
|
def generate_images(
|
|
self,
|
|
prompt,
|
|
negative_prompt,
|
|
image,
|
|
strength,
|
|
guidance_scale,
|
|
seed,
|
|
ondemand,
|
|
resample_type,
|
|
control_mode,
|
|
hints,
|
|
steps=None,
|
|
cpu_scheduling=False,
|
|
scheduler_id=None,
|
|
progress=gr.Progress(track_tqdm=True),
|
|
):
|
|
|
|
img = self.sd_pipe.generate_images(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
batch_count=1,
|
|
guidance_scale=guidance_scale,
|
|
seed=seed,
|
|
return_imgs=True,
|
|
steps=steps,
|
|
cpu_scheduling=cpu_scheduling,
|
|
scheduler_id=scheduler_id,
|
|
progress=gr.Progress(track_tqdm=True),
|
|
)
|
|
return img
|
|
|
|
|
|
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
|
print("\n[LOG] Submitting Request...")
|
|
|
|
for key in sd_kwargs:
|
|
if sd_kwargs[key] in [None, []]:
|
|
sd_kwargs[key] = None
|
|
if sd_kwargs[key] in ["None"]:
|
|
sd_kwargs[key] = ""
|
|
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
|
|
sd_kwargs[key] = int(sd_kwargs[key])
|
|
if key == "seed":
|
|
sd_kwargs[key] = int(sd_kwargs[key])
|
|
|
|
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
|
|
if not sd_kwargs["device"]:
|
|
gr.Warning("No device specified. Please specify a device.")
|
|
return None, ""
|
|
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
|
if sd_kwargs["steps"] > 10:
|
|
gr.Warning("1 to 4 steps are recommended for sdxl-turbo, unless you are using a custom checkpoint.")
|
|
if sd_kwargs["guidance_scale"] > 3:
|
|
gr.Warning(
|
|
"sdxl-turbo CFG scale should be between 1.0 and 2.0 if using negative prompt, 0 otherwise."
|
|
)
|
|
if sd_kwargs["target_triple"] == "":
|
|
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[1]:
|
|
gr.Warning(
|
|
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
|
)
|
|
return None, ""
|
|
|
|
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
|
|
return generated_imgs
|
|
|
|
|
|
def shark_sd_fn(
|
|
prompt,
|
|
negative_prompt,
|
|
sd_init_image: list,
|
|
height: int,
|
|
width: int,
|
|
steps: int,
|
|
strength: float,
|
|
guidance_scale: float,
|
|
seed: list,
|
|
batch_count: int,
|
|
batch_size: int,
|
|
scheduler: str,
|
|
base_model_id: str,
|
|
custom_weights: str,
|
|
custom_vae: str,
|
|
precision: str,
|
|
device: str,
|
|
target_triple: str,
|
|
ondemand: bool,
|
|
compiled_pipeline: bool,
|
|
resample_type: str,
|
|
controlnets: dict,
|
|
embeddings: dict,
|
|
seed_increment: str | int = 1,
|
|
clip_device: str = None,
|
|
vae_device: str = None,
|
|
vae_precision: str = None,
|
|
cpu_scheduling: bool = False,
|
|
progress=gr.Progress(),
|
|
):
|
|
sd_kwargs = locals()
|
|
if not isinstance(sd_init_image, list):
|
|
sd_init_image = [sd_init_image]
|
|
is_img2img = True if sd_init_image[0] is not None else False
|
|
|
|
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
|
import apps.shark_studio.web.utils.globals as global_obj
|
|
|
|
adapters = {}
|
|
is_controlled = False
|
|
control_mode = None
|
|
hints = []
|
|
num_loras = 0
|
|
import_ir = True
|
|
for i in embeddings:
|
|
num_loras += 1 if embeddings[i] else 0
|
|
if "model" in controlnets:
|
|
for i, model in enumerate(controlnets["model"]):
|
|
if "xl" not in base_model_id.lower():
|
|
adapters[f"control_adapter_{model}"] = {
|
|
"hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
|
|
model
|
|
],
|
|
"strength": controlnets["strength"][i],
|
|
}
|
|
else:
|
|
adapters[f"control_adapter_{model}"] = {
|
|
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
|
|
model
|
|
],
|
|
"strength": controlnets["strength"][i],
|
|
}
|
|
if model is not None:
|
|
is_controlled = True
|
|
control_mode = controlnets["control_mode"]
|
|
for i in controlnets["hint"]:
|
|
hints.append[i]
|
|
if not vae_precision:
|
|
vae_precision = precision
|
|
submit_pipe_kwargs = {
|
|
"base_model_id": base_model_id,
|
|
"height": height,
|
|
"width": width,
|
|
"batch_size": batch_size,
|
|
"precision": precision,
|
|
"device": device,
|
|
"clip_device": clip_device,
|
|
"vae_device": vae_device,
|
|
"target_triple": target_triple,
|
|
"custom_vae": custom_vae,
|
|
"num_loras": num_loras,
|
|
"import_ir": import_ir,
|
|
"is_controlled": is_controlled,
|
|
"vae_precision": vae_precision,
|
|
}
|
|
submit_prep_kwargs = {
|
|
"custom_weights": custom_weights,
|
|
"adapters": adapters,
|
|
"embeddings": embeddings,
|
|
"is_img2img": is_img2img,
|
|
"compiled_pipeline": compiled_pipeline,
|
|
}
|
|
submit_run_kwargs = {
|
|
"prompt": prompt,
|
|
"negative_prompt": negative_prompt,
|
|
"image": sd_init_image,
|
|
"steps": steps,
|
|
"strength": strength,
|
|
"guidance_scale": guidance_scale,
|
|
"seed": seed,
|
|
"ondemand": ondemand,
|
|
"resample_type": resample_type,
|
|
"control_mode": control_mode,
|
|
"hints": hints,
|
|
"cpu_scheduling": cpu_scheduling,
|
|
"scheduler_id": scheduler,
|
|
}
|
|
if compiled_pipeline:
|
|
submit_pipe_kwargs["steps"] = submit_run_kwargs["steps"]
|
|
submit_pipe_kwargs["scheduler_id"] = submit_run_kwargs["scheduler_id"]
|
|
if (
|
|
not global_obj.get_sd_obj()
|
|
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
|
):
|
|
print("\n[LOG] Initializing new pipeline...")
|
|
global_obj.clear_cache()
|
|
gc.collect()
|
|
|
|
# Initializes the pipeline and retrieves IR based on all
|
|
# parameters that are static in the turbine output format,
|
|
# which is currently MLIR in the torch dialect.
|
|
|
|
sd_pipe = StableDiffusion(
|
|
**submit_pipe_kwargs,
|
|
)
|
|
global_obj.set_sd_obj(sd_pipe)
|
|
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
|
|
if (
|
|
not global_obj.get_prep_kwargs()
|
|
or global_obj.get_prep_kwargs() != submit_prep_kwargs
|
|
):
|
|
global_obj.set_prep_kwargs(submit_prep_kwargs)
|
|
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
|
|
|
generated_imgs = []
|
|
if submit_run_kwargs["seed"] in [-1, "-1"]:
|
|
submit_run_kwargs["seed"] = randint(0, 4294967295)
|
|
seed_increment = "random"
|
|
#print(f"\n[LOG] Random seed: {seed}")
|
|
progress(None, desc=f"Generating...")
|
|
|
|
for current_batch in range(batch_count):
|
|
start_time = time.time()
|
|
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
|
if not isinstance(out_imgs, list):
|
|
out_imgs = [out_imgs]
|
|
# total_time = time.time() - start_time
|
|
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
|
|
# print(f"\n[LOG] {text_output}")
|
|
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
|
# break
|
|
# else:
|
|
for batch in range(batch_size):
|
|
save_output_img(
|
|
out_imgs[batch],
|
|
seed,
|
|
sd_kwargs,
|
|
)
|
|
generated_imgs.extend(out_imgs[batch])
|
|
yield generated_imgs, status_label(
|
|
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
|
)
|
|
if batch_count > 1:
|
|
submit_run_kwargs["seed"] = get_next_seed(seed, seed_increment)
|
|
return (generated_imgs, "")
|
|
|
|
|
|
def get_next_seed(seed, seed_increment: str | int = 10):
|
|
if isinstance(seed_increment, int):
|
|
#print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
|
|
return int(seed + seed_increment)
|
|
elif seed_increment == "random":
|
|
seed = randint(0, 4294967295)
|
|
#print(f"\n[LOG] Random seed: {seed}")
|
|
return seed
|
|
|
|
|
|
def unload_sd():
|
|
print("Unloading models.")
|
|
import apps.shark_studio.web.utils.globals as global_obj
|
|
|
|
global_obj.clear_cache()
|
|
gc.collect()
|
|
|
|
|
|
def cancel_sd():
|
|
import apps.shark_studio.web.utils.globals as global_obj
|
|
print("Cancelling...")
|
|
global_obj.get_sd_obj()._interrupt = True
|
|
while global_obj.get_sd_obj()._interrupt:
|
|
time.sleep(0.1)
|
|
return
|
|
|
|
|
|
def view_json_file(file_path):
|
|
content = ""
|
|
with open(file_path, "r") as fopen:
|
|
content = fopen.read()
|
|
return content
|
|
|
|
|
|
def safe_name(name):
|
|
return name.replace("/", "_").replace("\\", "_").replace(".", "_")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
|
import apps.shark_studio.web.utils.globals as global_obj
|
|
|
|
global_obj._init()
|
|
|
|
sd_json = view_json_file(
|
|
get_resource_path(os.path.join(cmd_opts.config_dir, cmd_opts.default_config))
|
|
)
|
|
sd_kwargs = json.loads(sd_json)
|
|
# for arg in vars(cmd_opts):
|
|
# if arg in sd_kwargs:
|
|
# sd_kwargs[arg] = getattr(cmd_opts, arg)
|
|
for i in shark_sd_fn_dict_input(sd_kwargs):
|
|
print(i)
|