Sd3 initial support via turbine pipeline

This commit is contained in:
Ean Garvey
2024-06-17 14:17:08 -05:00
parent 70e4c31f44
commit 88db3457e2
6 changed files with 125 additions and 124 deletions

View File

@@ -39,17 +39,25 @@ EMPTY_SD_MAP = {
EMPTY_SDXL_MAP = {
"prompt_encoder": None,
"scheduled_unet": None,
"unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
"scheduler": None,
}
EMPTY_SD3_MAP = {
"clip": None,
"mmdit": None,
"vae": None,
"scheduler": None,
}
EMPTY_FLAGS = {
"clip": None,
"unet": None,
"mmdit": None,
"vae": None,
"pipeline": None,
"scheduler": None,
}
@@ -86,21 +94,50 @@ class StableDiffusion:
scheduler: str,
precision: str,
device: str,
clip_device: str = None,
vae_device: str = None,
target_triple: str = None,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
is_controlled: bool = False,
external_weights: str = "safetensors",
vae_precision: str = "fp16",
progress=gr.Progress(),
):
progress(0, desc="Initializing pipeline...")
self.ui_device = device
backend, target = parse_device(device, target_triple)
if clip_device:
clip_device, clip_target = parse_device(clip_device)
else:
clip_device, clip_target = backend, target
if vae_device:
vae_device, vae_target = parse_device(vae_device)
else:
vae_device, vae_target = backend, target
devices = {
"clip": clip_device,
"mmdit": backend,
"vae": vae_device,
}
targets = {
"clip": clip_target,
"mmdit": target,
"vae": vae_target,
}
pipe_device_id = backend
target_triple = target
for key in devices:
if devices[key] != backend:
pipe_device_id = "hybrid"
target_triple = "_".join([clip_target, target, vae_target])
self.precision = precision
self.compiled_pipeline = False
self.base_model_id = base_model_id
self.custom_vae = custom_vae
self.is_sdxl = "xl" in self.base_model_id.lower()
self.is_sd3 = "stable-diffusion-3" in self.base_model_id.lower()
self.is_custom = ".py" in self.base_model_id.lower()
if self.is_custom:
custom_module = load_script(
@@ -115,23 +152,32 @@ class StableDiffusion:
SharkSDXLPipeline,
)
self.turbine_pipe = SharkSDXLPipeline
self.dynamic_steps = False
self.dynamic_steps = True
self.model_map = EMPTY_SDXL_MAP
elif self.is_sd3:
from turbine_models.custom_models.sd3_inference.sd3_pipeline import SharkSD3Pipeline, empty_pipe_dict
self.turbine_pipe = SharkSD3Pipeline
self.dynamic_steps = True
self.model_map = EMPTY_SD3_MAP
else:
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
self.turbine_pipe = SharkSDPipeline
self.dynamic_steps = True
self.model_map = EMPTY_SD_MAP
# no multi-device yet
devices = backend
targets = target
max_length = 64
target_backend, self.rt_device, triple = parse_device(device, target_triple)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
str(max_length),
f"{str(height)}x{str(width)}",
precision,
triple,
target_triple,
]
if num_loras > 0:
pipe_id_list.append(str(num_loras) + "lora")
@@ -151,16 +197,16 @@ class StableDiffusion:
decomp_attn = True
attn_spec = None
if triple in ["gfx940", "gfx942", "gfx90a"]:
if target_triple in ["gfx940", "gfx942", "gfx90a"]:
decomp_attn = False
attn_spec = "mfma"
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
elif target in ["gfx1100", "gfx1103", "gfx1150"]:
decomp_attn = False
attn_spec = "wmma"
if triple in ["gfx1103", "gfx1150"]:
if target in ["gfx1103", "gfx1150"]:
# external weights have issues on igpu
external_weights = None
elif target_backend == "llvm-cpu":
elif backend == "llvm-cpu":
decomp_attn = False
progress(0.5, desc="Initializing pipeline...")
self.sd_pipe = self.turbine_pipe(
@@ -172,15 +218,15 @@ class StableDiffusion:
max_length=max_length,
batch_size=batch_size,
num_inference_steps=steps,
device=target_backend,
iree_target_triple=triple,
device=devices,
iree_target_triple=targets,
ireec_flags=EMPTY_FLAGS,
attn_spec=attn_spec,
decomp_attn=decomp_attn,
pipeline_dir=self.pipeline_dir,
external_weights_dir=self.weights_path,
external_weights=external_weights,
custom_vae=custom_vae,
vae_precision=vae_precision,
)
progress(1, desc="Pipeline initialized!...")
gc.collect()
@@ -191,15 +237,23 @@ class StableDiffusion:
adapters,
embeddings,
is_img2img,
compiled_pipeline,
compiled_pipeline = False,
cpu_scheduling=False,
progress=gr.Progress(),
):
progress(0, desc="Preparing models...")
pipe_map = copy.deepcopy(self.model_map)
if compiled_pipeline and self.is_sdxl:
pipe_map.pop("scheduler")
pipe_map.pop("unet")
pipe_map["scheduled_unet"] = None
pipe_map["full_pipeline"] = None
if cpu_scheduling:
pipe_map.pop("scheduler")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
mlirs = copy.deepcopy(pipe_map)
vmfbs = copy.deepcopy(pipe_map)
weights = copy.deepcopy(pipe_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline
@@ -260,7 +314,7 @@ class StableDiffusion:
progress(0.75, desc=f"Loading models and weights...")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
vmfbs, weights, self.compiled_pipeline
)
progress(1, desc="Pipeline loaded! Generating images...")
return
@@ -324,7 +378,7 @@ def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
)
return None, ""
if sd_kwargs["target_triple"] == "":
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[1]:
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
@@ -359,6 +413,9 @@ def shark_sd_fn(
controlnets: dict,
embeddings: dict,
seed_increment: str | int = 1,
clip_device: str = None,
vae_device: str = None,
vae_precision: str = None,
progress=gr.Progress(),
):
sd_kwargs = locals()
@@ -398,7 +455,8 @@ def shark_sd_fn(
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
if not vae_precision:
vae_precision = precision
submit_pipe_kwargs = {
"base_model_id": base_model_id,
"height": height,
@@ -406,6 +464,8 @@ def shark_sd_fn(
"batch_size": batch_size,
"precision": precision,
"device": device,
"clip_device": clip_device,
"vae_device": vae_device,
"target_triple": target_triple,
"custom_vae": custom_vae,
"num_loras": num_loras,
@@ -413,6 +473,7 @@ def shark_sd_fn(
"is_controlled": is_controlled,
"steps": steps,
"scheduler": scheduler,
"vae_precision": vae_precision,
}
submit_prep_kwargs = {
"custom_weights": custom_weights,
@@ -485,7 +546,7 @@ def shark_sd_fn(
sd_kwargs,
)
generated_imgs.extend(out_imgs)
breakpoint()
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)

View File

@@ -12,6 +12,18 @@ from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "metal",
"rocm": "rocm",
"hip": "hip",
"intel-gpu": "level_zero",
}
def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
@@ -31,25 +43,6 @@ def get_supported_device_list():
return list(_IREE_DEVICE_MAP.keys())
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "metal",
"rocm": "rocm",
"hip": "hip",
"intel-gpu": "level_zero",
}
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
@@ -63,9 +56,13 @@ _IREE_TARGET_MAP = {
}
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
def get_available_devices():
return ['rocm', 'cpu']
def get_devices_by_name(driver_name):
device_list = []
@@ -163,7 +160,6 @@ def clean_device_info(raw_device):
return device, device_id
def parse_device(device_str, target_override=""):
rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
@@ -174,19 +170,19 @@ def parse_device(device_str, target_override=""):
if target_override:
if "cpu" in device_str:
rt_device = "local-task"
return target_backend, rt_device, target_override
return target_backend, target_override
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
return target_backend, rt_device, triple
triple = None #get_iree_target_triple(device_str)
return target_backend, triple
case "rocm":
triple = get_rocm_target_chip(device_str)
return target_backend, rt_device, triple
return target_backend, triple
case "llvm-cpu":
if "Ryzen 9" in device_str:
return target_backend, "local-task", "znver4"
return target_backend, "znver4"
else:
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
return "llvm-cpu", "x86_64-linux-gnu"
def get_rocm_target_chip(device_str):
@@ -223,68 +219,4 @@ def get_all_devices(driver_name):
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
del driver
return device_list_src
# def get_device_mapping(driver, key_combination=3):
# """This method ensures consistent device ordering when choosing
# specific devices for execution
# Args:
# driver (str): execution driver (vulkan, cuda, rocm, etc)
# key_combination (int, optional): choice for mapping value for
# device name.
# 1 : path
# 2 : name
# 3 : (name, path)
# Defaults to 3.
# Returns:
# dict: map to possible device names user can input mapped to desired
# combination of name/path.
# """
# driver = iree_device_map(driver)
# device_list = get_all_devices(driver)
# device_map = dict()
# def get_output_value(dev_dict):
# if key_combination == 1:
# return f"{driver}://{dev_dict['path']}"
# if key_combination == 2:
# return dev_dict["name"]
# if key_combination == 3:
# return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# # mapping driver name to default device (driver://0)
# device_map[f"{driver}"] = get_output_value(device_list[0])
# for i, device in enumerate(device_list):
# # mapping with index
# device_map[f"{driver}://{i}"] = get_output_value(device)
# # mapping with full path
# device_map[f"{driver}://{device['path']}"] = get_output_value(device)
# return device_map
# def get_opt_flags(model, precision="fp16"):
# iree_flags = []
# if len(cmd_opts.iree_vulkan_target_triple) > 0:
# iree_flags.append(
# f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
# )
# if "rocm" in cmd_opts.device:
# from shark.iree_utils.gpu_utils import get_iree_rocm_args
# rocm_args = get_iree_rocm_args()
# iree_flags.extend(rocm_args)
# if cmd_opts.iree_constant_folding == False:
# iree_flags.append("--iree-opt-const-expr-hoisting=False")
# iree_flags.append(
# "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
# )
# if cmd_opts.data_tiling == False:
# iree_flags.append("--iree-opt-data-tiling=False")
# if "vae" not in model:
# # Due to lack of support for multi-reduce, we always collapse reduction
# # dims before dispatch formation right now.
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
# return iree_flags
return device_list_src

View File

@@ -89,8 +89,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
f"VAE: {img_vae}, "
f"LoRA: {img_loras}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if isinstance(output_img, list):
for i, img in enumerate(output_img):
out_img_path = Path(generated_imgs_path, f"{out_img_name}_{i}.png")
img.save(out_img_path, "PNG", pnginfo=pngInfo)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if cmd_opts.output_img_format not in ["png", "jpg"]:
print(

View File

@@ -52,6 +52,7 @@ sd_default_models = [
# "stabilityai/stable-diffusion-2-1",
# "stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-3-medium-diffusers",
]
sd_default_models.extend(get_checkpoints(model_type="scripts"))
@@ -315,7 +316,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
device = gr.Dropdown(
elem_id="device",
label="Device",
value=init_config["device"] if init_config["device"] else "rocm",
value=init_config["device"] if init_config["device"] else global_obj.get_device_list()[0],
choices=global_obj.get_device_list(),
allow_custom_value=True,
)
@@ -347,8 +348,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
value=512,
step=512,
label="\U00002195\U0000FE0F Height",
interactive=False, # DEMO
visible=False, # DEMO
interactive=True, # DEMO
visible=True, # DEMO
)
width = gr.Slider(
512,
@@ -356,8 +357,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
value=512,
step=512,
label="\U00002194\U0000FE0F Width",
interactive=False, # DEMO
visible=False, # DEMO
interactive=True, # DEMO
visible=True, # DEMO
)
with gr.Accordion(

View File

@@ -6,11 +6,13 @@ setuptools
wheel
torch==2.3.0
torch>=2.3.0
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
peft
sentencepiece
# SHARK Runner
tqdm

View File

@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install --pre -r requirements.txt
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240617.289/iree_compiler-20240617.289-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240617.289/iree_runtime-20240617.289-cp311-cp311-win_amd64.whl
pip install -e .
Write-Host "Source your venv with ./shark.venv/Scripts/activate"