mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
506 lines
16 KiB
Python
506 lines
16 KiB
Python
import gc
|
|
import torch
|
|
import gradio as gr
|
|
import time
|
|
import os
|
|
import json
|
|
import numpy as np
|
|
import copy
|
|
import importlib.util
|
|
import sys
|
|
from tqdm.auto import tqdm
|
|
|
|
from pathlib import Path
|
|
from random import randint
|
|
from turbine_models.custom_models.sd_inference.sd_pipeline import AMDSharkSDPipeline
|
|
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
|
|
AMDSharkSDXLPipeline,
|
|
)
|
|
|
|
|
|
from apps.amdshark_studio.api.controlnet import control_adapter_map
|
|
from apps.amdshark_studio.api.utils import parse_device
|
|
from apps.amdshark_studio.web.utils.state import status_label
|
|
from apps.amdshark_studio.web.utils.file_utils import (
|
|
safe_name,
|
|
get_resource_path,
|
|
get_checkpoints_path,
|
|
)
|
|
|
|
from apps.amdshark_studio.modules.img_processing import (
|
|
save_output_img,
|
|
)
|
|
|
|
from apps.amdshark_studio.modules.ckpt_processing import (
|
|
preprocessCKPT,
|
|
save_irpa,
|
|
)
|
|
|
|
EMPTY_SD_MAP = {
|
|
"clip": None,
|
|
"scheduler": None,
|
|
"unet": None,
|
|
"vae_decode": None,
|
|
}
|
|
|
|
EMPTY_SDXL_MAP = {
|
|
"prompt_encoder": None,
|
|
"scheduled_unet": None,
|
|
"vae_decode": None,
|
|
"pipeline": None,
|
|
"full_pipeline": None,
|
|
}
|
|
|
|
EMPTY_FLAGS = {
|
|
"clip": None,
|
|
"unet": None,
|
|
"vae": None,
|
|
"pipeline": None,
|
|
}
|
|
|
|
|
|
def load_script(source, module_name):
|
|
"""
|
|
reads file source and loads it as a module
|
|
|
|
:param source: file to load
|
|
:param module_name: name of module to register in sys.modules
|
|
:return: loaded module
|
|
"""
|
|
|
|
spec = importlib.util.spec_from_file_location(module_name, source)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
|
|
return module
|
|
|
|
|
|
class StableDiffusion:
|
|
# This class is responsible for executing image generation and creating
|
|
# /managing a set of compiled modules to run Stable Diffusion. The init
|
|
# aims to be as general as possible, and the class will infer and compile
|
|
# a list of necessary modules or a combined "pipeline module" for a
|
|
# specified job based on the inference task.
|
|
|
|
def __init__(
|
|
self,
|
|
base_model_id,
|
|
height: int,
|
|
width: int,
|
|
batch_size: int,
|
|
steps: int,
|
|
scheduler: str,
|
|
precision: str,
|
|
device: str,
|
|
target_triple: str = None,
|
|
custom_vae: str = None,
|
|
num_loras: int = 0,
|
|
import_ir: bool = True,
|
|
is_controlled: bool = False,
|
|
external_weights: str = "safetensors",
|
|
):
|
|
self.precision = precision
|
|
self.compiled_pipeline = False
|
|
self.base_model_id = base_model_id
|
|
self.custom_vae = custom_vae
|
|
self.is_sdxl = "xl" in self.base_model_id.lower()
|
|
self.is_custom = ".py" in self.base_model_id.lower()
|
|
if self.is_custom:
|
|
custom_module = load_script(
|
|
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
|
|
"custom_pipeline",
|
|
)
|
|
self.turbine_pipe = custom_module.StudioPipeline
|
|
self.model_map = custom_module.MODEL_MAP
|
|
elif self.is_sdxl:
|
|
self.turbine_pipe = AMDSharkSDXLPipeline
|
|
self.model_map = EMPTY_SDXL_MAP
|
|
else:
|
|
self.turbine_pipe = AMDSharkSDPipeline
|
|
self.model_map = EMPTY_SD_MAP
|
|
max_length = 64
|
|
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
|
pipe_id_list = [
|
|
safe_name(base_model_id),
|
|
str(batch_size),
|
|
str(max_length),
|
|
f"{str(height)}x{str(width)}",
|
|
precision,
|
|
triple,
|
|
]
|
|
if num_loras > 0:
|
|
pipe_id_list.append(str(num_loras) + "lora")
|
|
if is_controlled:
|
|
pipe_id_list.append("controlled")
|
|
if custom_vae:
|
|
pipe_id_list.append(custom_vae)
|
|
self.pipe_id = "_".join(pipe_id_list)
|
|
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
|
|
self.weights_path = Path(
|
|
os.path.join(
|
|
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
|
|
)
|
|
)
|
|
if not os.path.exists(self.weights_path):
|
|
os.mkdir(self.weights_path)
|
|
|
|
decomp_attn = True
|
|
attn_spec = None
|
|
if triple in ["gfx940", "gfx942", "gfx90a"]:
|
|
decomp_attn = False
|
|
attn_spec = "mfma"
|
|
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
|
|
decomp_attn = False
|
|
attn_spec = "wmma"
|
|
if triple in ["gfx1103", "gfx1150"]:
|
|
# external weights have issues on igpu
|
|
external_weights = None
|
|
elif target_backend == "llvm-cpu":
|
|
decomp_attn = False
|
|
|
|
self.sd_pipe = self.turbine_pipe(
|
|
hf_model_name=base_model_id,
|
|
scheduler_id=scheduler,
|
|
height=height,
|
|
width=width,
|
|
precision=precision,
|
|
max_length=max_length,
|
|
batch_size=batch_size,
|
|
num_inference_steps=steps,
|
|
device=target_backend,
|
|
iree_target_triple=triple,
|
|
ireec_flags=EMPTY_FLAGS,
|
|
attn_spec=attn_spec,
|
|
decomp_attn=decomp_attn,
|
|
pipeline_dir=self.pipeline_dir,
|
|
external_weights_dir=self.weights_path,
|
|
external_weights=external_weights,
|
|
custom_vae=custom_vae,
|
|
)
|
|
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
|
gc.collect()
|
|
|
|
def prepare_pipe(
|
|
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
|
|
):
|
|
print(f"\n[LOG] Preparing pipeline...")
|
|
self.is_img2img = False
|
|
mlirs = copy.deepcopy(self.model_map)
|
|
vmfbs = copy.deepcopy(self.model_map)
|
|
weights = copy.deepcopy(self.model_map)
|
|
if not self.is_sdxl:
|
|
compiled_pipeline = False
|
|
self.compiled_pipeline = compiled_pipeline
|
|
|
|
if custom_weights:
|
|
custom_weights = os.path.join(
|
|
get_checkpoints_path("checkpoints"),
|
|
safe_name(self.base_model_id.split("/")[-1]),
|
|
custom_weights,
|
|
)
|
|
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
|
|
for key in weights:
|
|
if key in ["scheduled_unet", "unet"]:
|
|
unet_weights_path = os.path.join(
|
|
diffusers_weights_path,
|
|
"unet",
|
|
"diffusion_pytorch_model.safetensors",
|
|
)
|
|
weights[key] = save_irpa(unet_weights_path, "unet.")
|
|
|
|
elif key in ["clip", "prompt_encoder"]:
|
|
if not self.is_sdxl:
|
|
sd1_path = os.path.join(
|
|
diffusers_weights_path, "text_encoder", "model.safetensors"
|
|
)
|
|
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
|
|
else:
|
|
clip_1_path = os.path.join(
|
|
diffusers_weights_path, "text_encoder", "model.safetensors"
|
|
)
|
|
clip_2_path = os.path.join(
|
|
diffusers_weights_path,
|
|
"text_encoder_2",
|
|
"model.safetensors",
|
|
)
|
|
weights[key] = [
|
|
save_irpa(clip_1_path, "text_encoder_model_1."),
|
|
save_irpa(clip_2_path, "text_encoder_model_2."),
|
|
]
|
|
|
|
elif key in ["vae_decode"] and weights[key] is None:
|
|
vae_weights_path = os.path.join(
|
|
diffusers_weights_path,
|
|
"vae",
|
|
"diffusion_pytorch_model.safetensors",
|
|
)
|
|
weights[key] = save_irpa(vae_weights_path, "vae.")
|
|
|
|
vmfbs, weights = self.sd_pipe.check_prepared(
|
|
mlirs, vmfbs, weights, interactive=False
|
|
)
|
|
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
|
|
self.sd_pipe.load_pipeline(
|
|
vmfbs, weights, self.rt_device, self.compiled_pipeline
|
|
)
|
|
print(
|
|
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
|
|
)
|
|
return
|
|
|
|
def generate_images(
|
|
self,
|
|
prompt,
|
|
negative_prompt,
|
|
image,
|
|
strength,
|
|
guidance_scale,
|
|
seed,
|
|
ondemand,
|
|
resample_type,
|
|
control_mode,
|
|
hints,
|
|
):
|
|
img = self.sd_pipe.generate_images(
|
|
prompt,
|
|
negative_prompt,
|
|
1,
|
|
guidance_scale,
|
|
seed,
|
|
return_imgs=True,
|
|
)
|
|
return img
|
|
|
|
|
|
def amdshark_sd_fn_dict_input(
|
|
sd_kwargs: dict,
|
|
):
|
|
print("\n[LOG] Submitting Request...")
|
|
|
|
for key in sd_kwargs:
|
|
if sd_kwargs[key] in [None, []]:
|
|
sd_kwargs[key] = None
|
|
if sd_kwargs[key] in ["None"]:
|
|
sd_kwargs[key] = ""
|
|
if key == "seed":
|
|
sd_kwargs[key] = int(sd_kwargs[key])
|
|
|
|
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
|
|
if not sd_kwargs["device"]:
|
|
gr.Warning("No device specified. Please specify a device.")
|
|
return None, ""
|
|
if sd_kwargs["height"] not in [512, 1024]:
|
|
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
|
|
return None, ""
|
|
if sd_kwargs["height"] != sd_kwargs["width"]:
|
|
gr.Warning("Height and width must be the same. This is a temporary limitation.")
|
|
return None, ""
|
|
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
|
if sd_kwargs["steps"] > 10:
|
|
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
|
|
return None, ""
|
|
if sd_kwargs["guidance_scale"] > 3:
|
|
gr.Warning(
|
|
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
|
|
)
|
|
return None, ""
|
|
if sd_kwargs["target_triple"] == "":
|
|
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
|
|
gr.Warning(
|
|
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
|
)
|
|
return None, ""
|
|
|
|
generated_imgs = yield from amdshark_sd_fn(**sd_kwargs)
|
|
return generated_imgs
|
|
|
|
|
|
def amdshark_sd_fn(
|
|
prompt,
|
|
negative_prompt,
|
|
sd_init_image: list,
|
|
height: int,
|
|
width: int,
|
|
steps: int,
|
|
strength: float,
|
|
guidance_scale: float,
|
|
seed: list,
|
|
batch_count: int,
|
|
batch_size: int,
|
|
scheduler: str,
|
|
base_model_id: str,
|
|
custom_weights: str,
|
|
custom_vae: str,
|
|
precision: str,
|
|
device: str,
|
|
target_triple: str,
|
|
ondemand: bool,
|
|
compiled_pipeline: bool,
|
|
resample_type: str,
|
|
controlnets: dict,
|
|
embeddings: dict,
|
|
):
|
|
sd_kwargs = locals()
|
|
if not isinstance(sd_init_image, list):
|
|
sd_init_image = [sd_init_image]
|
|
is_img2img = True if sd_init_image[0] is not None else False
|
|
|
|
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
|
import apps.amdshark_studio.web.utils.globals as global_obj
|
|
|
|
adapters = {}
|
|
is_controlled = False
|
|
control_mode = None
|
|
hints = []
|
|
num_loras = 0
|
|
import_ir = True
|
|
for i in embeddings:
|
|
num_loras += 1 if embeddings[i] else 0
|
|
if "model" in controlnets:
|
|
for i, model in enumerate(controlnets["model"]):
|
|
if "xl" not in base_model_id.lower():
|
|
adapters[f"control_adapter_{model}"] = {
|
|
"hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
|
|
model
|
|
],
|
|
"strength": controlnets["strength"][i],
|
|
}
|
|
else:
|
|
adapters[f"control_adapter_{model}"] = {
|
|
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
|
|
model
|
|
],
|
|
"strength": controlnets["strength"][i],
|
|
}
|
|
if model is not None:
|
|
is_controlled = True
|
|
control_mode = controlnets["control_mode"]
|
|
for i in controlnets["hint"]:
|
|
hints.append[i]
|
|
|
|
submit_pipe_kwargs = {
|
|
"base_model_id": base_model_id,
|
|
"height": height,
|
|
"width": width,
|
|
"batch_size": batch_size,
|
|
"precision": precision,
|
|
"device": device,
|
|
"target_triple": target_triple,
|
|
"custom_vae": custom_vae,
|
|
"num_loras": num_loras,
|
|
"import_ir": import_ir,
|
|
"is_controlled": is_controlled,
|
|
"steps": steps,
|
|
"scheduler": scheduler,
|
|
}
|
|
submit_prep_kwargs = {
|
|
"custom_weights": custom_weights,
|
|
"adapters": adapters,
|
|
"embeddings": embeddings,
|
|
"is_img2img": is_img2img,
|
|
"compiled_pipeline": compiled_pipeline,
|
|
}
|
|
submit_run_kwargs = {
|
|
"prompt": prompt,
|
|
"negative_prompt": negative_prompt,
|
|
"image": sd_init_image,
|
|
"strength": strength,
|
|
"guidance_scale": guidance_scale,
|
|
"seed": seed,
|
|
"ondemand": ondemand,
|
|
"resample_type": resample_type,
|
|
"control_mode": control_mode,
|
|
"hints": hints,
|
|
}
|
|
if (
|
|
not global_obj.get_sd_obj()
|
|
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
|
):
|
|
print("\n[LOG] Initializing new pipeline...")
|
|
global_obj.clear_cache()
|
|
gc.collect()
|
|
|
|
# Initializes the pipeline and retrieves IR based on all
|
|
# parameters that are static in the turbine output format,
|
|
# which is currently MLIR in the torch dialect.
|
|
|
|
sd_pipe = StableDiffusion(
|
|
**submit_pipe_kwargs,
|
|
)
|
|
global_obj.set_sd_obj(sd_pipe)
|
|
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
|
|
if (
|
|
not global_obj.get_prep_kwargs()
|
|
or global_obj.get_prep_kwargs() != submit_prep_kwargs
|
|
):
|
|
global_obj.set_prep_kwargs(submit_prep_kwargs)
|
|
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
|
|
|
generated_imgs = []
|
|
for current_batch in range(batch_count):
|
|
start_time = time.time()
|
|
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
|
if not isinstance(out_imgs, list):
|
|
out_imgs = [out_imgs]
|
|
# total_time = time.time() - start_time
|
|
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
|
|
# print(f"\n[LOG] {text_output}")
|
|
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
|
# break
|
|
# else:
|
|
for batch in range(batch_size):
|
|
save_output_img(
|
|
out_imgs[batch],
|
|
seed,
|
|
sd_kwargs,
|
|
)
|
|
generated_imgs.extend(out_imgs)
|
|
# TODO: make seed changes over batch counts more configurable.
|
|
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
|
|
yield generated_imgs, status_label(
|
|
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
|
)
|
|
return (generated_imgs, "")
|
|
|
|
|
|
def unload_sd():
|
|
print("Unloading models.")
|
|
import apps.amdshark_studio.web.utils.globals as global_obj
|
|
|
|
global_obj.clear_cache()
|
|
gc.collect()
|
|
|
|
|
|
def cancel_sd():
|
|
print("Inject call to cancel longer API calls.")
|
|
return
|
|
|
|
|
|
def view_json_file(file_path):
|
|
content = ""
|
|
with open(file_path, "r") as fopen:
|
|
content = fopen.read()
|
|
return content
|
|
|
|
|
|
def safe_name(name):
|
|
return name.replace("/", "_").replace("\\", "_").replace(".", "_")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
|
import apps.amdshark_studio.web.utils.globals as global_obj
|
|
|
|
global_obj._init()
|
|
|
|
sd_json = view_json_file(
|
|
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
|
|
)
|
|
sd_kwargs = json.loads(sd_json)
|
|
for arg in vars(cmd_opts):
|
|
if arg in sd_kwargs:
|
|
sd_kwargs[arg] = getattr(cmd_opts, arg)
|
|
for i in amdshark_sd_fn_dict_input(sd_kwargs):
|
|
print(i)
|