From 88db3457e27bfb01d3efd153e05da361cd1f0df2 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 17 Jun 2024 14:17:08 -0500 Subject: [PATCH] Sd3 initial support via turbine pipeline --- apps/shark_studio/api/sd.py | 105 ++++++++++++++---- apps/shark_studio/api/utils.py | 116 ++++---------------- apps/shark_studio/modules/img_processing.py | 9 +- apps/shark_studio/web/ui/sd.py | 11 +- requirements.txt | 6 +- setup_venv.ps1 | 2 +- 6 files changed, 125 insertions(+), 124 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 502b2905..8656cb39 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -39,17 +39,25 @@ EMPTY_SD_MAP = { EMPTY_SDXL_MAP = { "prompt_encoder": None, - "scheduled_unet": None, + "unet": None, "vae_decode": None, - "pipeline": None, - "full_pipeline": None, + "scheduler": None, +} + +EMPTY_SD3_MAP = { + "clip": None, + "mmdit": None, + "vae": None, + "scheduler": None, } EMPTY_FLAGS = { "clip": None, "unet": None, + "mmdit": None, "vae": None, "pipeline": None, + "scheduler": None, } @@ -86,21 +94,50 @@ class StableDiffusion: scheduler: str, precision: str, device: str, + clip_device: str = None, + vae_device: str = None, target_triple: str = None, custom_vae: str = None, num_loras: int = 0, import_ir: bool = True, is_controlled: bool = False, external_weights: str = "safetensors", + vae_precision: str = "fp16", progress=gr.Progress(), ): progress(0, desc="Initializing pipeline...") self.ui_device = device + backend, target = parse_device(device, target_triple) + if clip_device: + clip_device, clip_target = parse_device(clip_device) + else: + clip_device, clip_target = backend, target + if vae_device: + vae_device, vae_target = parse_device(vae_device) + else: + vae_device, vae_target = backend, target + devices = { + "clip": clip_device, + "mmdit": backend, + "vae": vae_device, + } + targets = { + "clip": clip_target, + "mmdit": target, + "vae": vae_target, + } + pipe_device_id = backend + target_triple = target + for key in devices: + if devices[key] != backend: + pipe_device_id = "hybrid" + target_triple = "_".join([clip_target, target, vae_target]) self.precision = precision self.compiled_pipeline = False self.base_model_id = base_model_id self.custom_vae = custom_vae self.is_sdxl = "xl" in self.base_model_id.lower() + self.is_sd3 = "stable-diffusion-3" in self.base_model_id.lower() self.is_custom = ".py" in self.base_model_id.lower() if self.is_custom: custom_module = load_script( @@ -115,23 +152,32 @@ class StableDiffusion: SharkSDXLPipeline, ) self.turbine_pipe = SharkSDXLPipeline - self.dynamic_steps = False + self.dynamic_steps = True self.model_map = EMPTY_SDXL_MAP + elif self.is_sd3: + from turbine_models.custom_models.sd3_inference.sd3_pipeline import SharkSD3Pipeline, empty_pipe_dict + + self.turbine_pipe = SharkSD3Pipeline + self.dynamic_steps = True + self.model_map = EMPTY_SD3_MAP else: from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline self.turbine_pipe = SharkSDPipeline self.dynamic_steps = True self.model_map = EMPTY_SD_MAP + # no multi-device yet + devices = backend + targets = target max_length = 64 - target_backend, self.rt_device, triple = parse_device(device, target_triple) + pipe_id_list = [ safe_name(base_model_id), str(batch_size), str(max_length), f"{str(height)}x{str(width)}", precision, - triple, + target_triple, ] if num_loras > 0: pipe_id_list.append(str(num_loras) + "lora") @@ -151,16 +197,16 @@ class StableDiffusion: decomp_attn = True attn_spec = None - if triple in ["gfx940", "gfx942", "gfx90a"]: + if target_triple in ["gfx940", "gfx942", "gfx90a"]: decomp_attn = False attn_spec = "mfma" - elif triple in ["gfx1100", "gfx1103", "gfx1150"]: + elif target in ["gfx1100", "gfx1103", "gfx1150"]: decomp_attn = False attn_spec = "wmma" - if triple in ["gfx1103", "gfx1150"]: + if target in ["gfx1103", "gfx1150"]: # external weights have issues on igpu external_weights = None - elif target_backend == "llvm-cpu": + elif backend == "llvm-cpu": decomp_attn = False progress(0.5, desc="Initializing pipeline...") self.sd_pipe = self.turbine_pipe( @@ -172,15 +218,15 @@ class StableDiffusion: max_length=max_length, batch_size=batch_size, num_inference_steps=steps, - device=target_backend, - iree_target_triple=triple, + device=devices, + iree_target_triple=targets, ireec_flags=EMPTY_FLAGS, attn_spec=attn_spec, decomp_attn=decomp_attn, pipeline_dir=self.pipeline_dir, external_weights_dir=self.weights_path, external_weights=external_weights, - custom_vae=custom_vae, + vae_precision=vae_precision, ) progress(1, desc="Pipeline initialized!...") gc.collect() @@ -191,15 +237,23 @@ class StableDiffusion: adapters, embeddings, is_img2img, - compiled_pipeline, + compiled_pipeline = False, + cpu_scheduling=False, progress=gr.Progress(), ): progress(0, desc="Preparing models...") - + pipe_map = copy.deepcopy(self.model_map) + if compiled_pipeline and self.is_sdxl: + pipe_map.pop("scheduler") + pipe_map.pop("unet") + pipe_map["scheduled_unet"] = None + pipe_map["full_pipeline"] = None + if cpu_scheduling: + pipe_map.pop("scheduler") self.is_img2img = False - mlirs = copy.deepcopy(self.model_map) - vmfbs = copy.deepcopy(self.model_map) - weights = copy.deepcopy(self.model_map) + mlirs = copy.deepcopy(pipe_map) + vmfbs = copy.deepcopy(pipe_map) + weights = copy.deepcopy(pipe_map) if not self.is_sdxl: compiled_pipeline = False self.compiled_pipeline = compiled_pipeline @@ -260,7 +314,7 @@ class StableDiffusion: progress(0.75, desc=f"Loading models and weights...") self.sd_pipe.load_pipeline( - vmfbs, weights, self.rt_device, self.compiled_pipeline + vmfbs, weights, self.compiled_pipeline ) progress(1, desc="Pipeline loaded! Generating images...") return @@ -324,7 +378,7 @@ def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()): ) return None, "" if sd_kwargs["target_triple"] == "": - if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]: + if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[1]: gr.Warning( "Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx." ) @@ -359,6 +413,9 @@ def shark_sd_fn( controlnets: dict, embeddings: dict, seed_increment: str | int = 1, + clip_device: str = None, + vae_device: str = None, + vae_precision: str = None, progress=gr.Progress(), ): sd_kwargs = locals() @@ -398,7 +455,8 @@ def shark_sd_fn( control_mode = controlnets["control_mode"] for i in controlnets["hint"]: hints.append[i] - + if not vae_precision: + vae_precision = precision submit_pipe_kwargs = { "base_model_id": base_model_id, "height": height, @@ -406,6 +464,8 @@ def shark_sd_fn( "batch_size": batch_size, "precision": precision, "device": device, + "clip_device": clip_device, + "vae_device": vae_device, "target_triple": target_triple, "custom_vae": custom_vae, "num_loras": num_loras, @@ -413,6 +473,7 @@ def shark_sd_fn( "is_controlled": is_controlled, "steps": steps, "scheduler": scheduler, + "vae_precision": vae_precision, } submit_prep_kwargs = { "custom_weights": custom_weights, @@ -485,7 +546,7 @@ def shark_sd_fn( sd_kwargs, ) generated_imgs.extend(out_imgs) - + breakpoint() yield generated_imgs, status_label( "Stable Diffusion", current_batch + 1, batch_count, batch_size ) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 85a59ada..7ad398e6 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -12,6 +12,18 @@ from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from cpuinfo import get_cpu_info +_IREE_DEVICE_MAP = { + "cpu": "local-task", + "cpu-task": "local-task", + "cpu-sync": "local-sync", + "cuda": "cuda", + "vulkan": "vulkan", + "metal": "metal", + "rocm": "rocm", + "hip": "hip", + "intel-gpu": "level_zero", +} + def iree_device_map(device): uri_parts = device.split("://", 2) iree_driver = ( @@ -31,25 +43,6 @@ def get_supported_device_list(): return list(_IREE_DEVICE_MAP.keys()) -_IREE_DEVICE_MAP = { - "cpu": "local-task", - "cpu-task": "local-task", - "cpu-sync": "local-sync", - "cuda": "cuda", - "vulkan": "vulkan", - "metal": "metal", - "rocm": "rocm", - "hip": "hip", - "intel-gpu": "level_zero", -} - - -def iree_target_map(device): - if "://" in device: - device = device.split("://")[0] - return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device - - _IREE_TARGET_MAP = { "cpu": "llvm-cpu", "cpu-task": "llvm-cpu", @@ -63,9 +56,13 @@ _IREE_TARGET_MAP = { } +def iree_target_map(device): + if "://" in device: + device = device.split("://")[0] + return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device + def get_available_devices(): - return ['rocm', 'cpu'] def get_devices_by_name(driver_name): device_list = [] @@ -163,7 +160,6 @@ def clean_device_info(raw_device): return device, device_id def parse_device(device_str, target_override=""): - rt_driver, device_id = clean_device_info(device_str) target_backend = iree_target_map(rt_driver) if device_id: @@ -174,19 +170,19 @@ def parse_device(device_str, target_override=""): if target_override: if "cpu" in device_str: rt_device = "local-task" - return target_backend, rt_device, target_override + return target_backend, target_override match target_backend: case "vulkan-spirv": - triple = get_iree_target_triple(device_str) - return target_backend, rt_device, triple + triple = None #get_iree_target_triple(device_str) + return target_backend, triple case "rocm": triple = get_rocm_target_chip(device_str) - return target_backend, rt_device, triple + return target_backend, triple case "llvm-cpu": if "Ryzen 9" in device_str: - return target_backend, "local-task", "znver4" + return target_backend, "znver4" else: - return "llvm-cpu", "local-task", "x86_64-linux-gnu" + return "llvm-cpu", "x86_64-linux-gnu" def get_rocm_target_chip(device_str): @@ -223,68 +219,4 @@ def get_all_devices(driver_name): device_list_src = driver.query_available_devices() device_list_src.sort(key=lambda d: d["path"]) del driver - return device_list_src - - -# def get_device_mapping(driver, key_combination=3): -# """This method ensures consistent device ordering when choosing -# specific devices for execution -# Args: -# driver (str): execution driver (vulkan, cuda, rocm, etc) -# key_combination (int, optional): choice for mapping value for -# device name. -# 1 : path -# 2 : name -# 3 : (name, path) -# Defaults to 3. -# Returns: -# dict: map to possible device names user can input mapped to desired -# combination of name/path. -# """ - -# driver = iree_device_map(driver) -# device_list = get_all_devices(driver) -# device_map = dict() - -# def get_output_value(dev_dict): -# if key_combination == 1: -# return f"{driver}://{dev_dict['path']}" -# if key_combination == 2: -# return dev_dict["name"] -# if key_combination == 3: -# return dev_dict["name"], f"{driver}://{dev_dict['path']}" - -# # mapping driver name to default device (driver://0) -# device_map[f"{driver}"] = get_output_value(device_list[0]) -# for i, device in enumerate(device_list): -# # mapping with index -# device_map[f"{driver}://{i}"] = get_output_value(device) -# # mapping with full path -# device_map[f"{driver}://{device['path']}"] = get_output_value(device) -# return device_map - - -# def get_opt_flags(model, precision="fp16"): -# iree_flags = [] -# if len(cmd_opts.iree_vulkan_target_triple) > 0: -# iree_flags.append( -# f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" -# ) -# if "rocm" in cmd_opts.device: -# from shark.iree_utils.gpu_utils import get_iree_rocm_args - -# rocm_args = get_iree_rocm_args() -# iree_flags.extend(rocm_args) -# if cmd_opts.iree_constant_folding == False: -# iree_flags.append("--iree-opt-const-expr-hoisting=False") -# iree_flags.append( -# "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" -# ) -# if cmd_opts.data_tiling == False: -# iree_flags.append("--iree-opt-data-tiling=False") - -# if "vae" not in model: -# # Due to lack of support for multi-reduce, we always collapse reduction -# # dims before dispatch formation right now. -# iree_flags += ["--iree-flow-collapse-reduction-dims"] -# return iree_flags \ No newline at end of file + return device_list_src \ No newline at end of file diff --git a/apps/shark_studio/modules/img_processing.py b/apps/shark_studio/modules/img_processing.py index 04676520..eebf47c2 100644 --- a/apps/shark_studio/modules/img_processing.py +++ b/apps/shark_studio/modules/img_processing.py @@ -89,8 +89,13 @@ def save_output_img(output_img, img_seed, extra_info=None): f"VAE: {img_vae}, " f"LoRA: {img_loras}", ) - - output_img.save(out_img_path, "PNG", pnginfo=pngInfo) + if isinstance(output_img, list): + for i, img in enumerate(output_img): + out_img_path = Path(generated_imgs_path, f"{out_img_name}_{i}.png") + img.save(out_img_path, "PNG", pnginfo=pngInfo) + else: + out_img_path = Path(generated_imgs_path, f"{out_img_name}.png") + output_img.save(out_img_path, "PNG", pnginfo=pngInfo) if cmd_opts.output_img_format not in ["png", "jpg"]: print( diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 60fd487d..bb56d466 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -52,6 +52,7 @@ sd_default_models = [ # "stabilityai/stable-diffusion-2-1", # "stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", + "stabilityai/stable-diffusion-3-medium-diffusers", ] sd_default_models.extend(get_checkpoints(model_type="scripts")) @@ -315,7 +316,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element: device = gr.Dropdown( elem_id="device", label="Device", - value=init_config["device"] if init_config["device"] else "rocm", + value=init_config["device"] if init_config["device"] else global_obj.get_device_list()[0], choices=global_obj.get_device_list(), allow_custom_value=True, ) @@ -347,8 +348,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element: value=512, step=512, label="\U00002195\U0000FE0F Height", - interactive=False, # DEMO - visible=False, # DEMO + interactive=True, # DEMO + visible=True, # DEMO ) width = gr.Slider( 512, @@ -356,8 +357,8 @@ with gr.Blocks(title="Stable Diffusion") as sd_element: value=512, step=512, label="\U00002194\U0000FE0F Width", - interactive=False, # DEMO - visible=False, # DEMO + interactive=True, # DEMO + visible=True, # DEMO ) with gr.Accordion( diff --git a/requirements.txt b/requirements.txt index 407263ff..16d6712a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,11 +6,13 @@ setuptools wheel -torch==2.3.0 +torch>=2.3.0 shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models -diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release +diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b +peft +sentencepiece # SHARK Runner tqdm diff --git a/setup_venv.ps1 b/setup_venv.ps1 index 651b1942..2ef307a4 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\} python -m pip install --upgrade pip pip install wheel pip install --pre -r requirements.txt -pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl +pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240617.289/iree_compiler-20240617.289-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240617.289/iree_runtime-20240617.289-cp311-cp311-win_amd64.whl pip install -e . Write-Host "Source your venv with ./shark.venv/Scripts/activate"