mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 07:18:27 -05:00
Compare commits
2 Commits
diffusers-
...
20240529.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7bd54de23 | ||
|
|
1adfb3b24b |
@@ -73,10 +73,12 @@ class StableDiffusion:
|
||||
scheduler: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
target_triple: str = None,
|
||||
custom_vae: str = None,
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
is_controlled: bool = False,
|
||||
external_weights: str = "safetensors",
|
||||
):
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
@@ -89,9 +91,8 @@ class StableDiffusion:
|
||||
else:
|
||||
self.turbine_pipe = SharkSDPipeline
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
external_weights = "safetensors"
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device)
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
@@ -121,9 +122,12 @@ class StableDiffusion:
|
||||
if triple in ["gfx940", "gfx942", "gfx90a"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "mfma"
|
||||
elif triple in ["gfx1100", "gfx1103"]:
|
||||
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
|
||||
decomp_attn = False
|
||||
attn_spec = "wmma"
|
||||
if triple in ["gfx1103", "gfx1150"]:
|
||||
# external weights have issues on igpu
|
||||
external_weights = None
|
||||
elif target_backend == "llvm-cpu":
|
||||
decomp_attn = False
|
||||
|
||||
@@ -273,6 +277,7 @@ def shark_sd_fn(
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
target_triple: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
@@ -284,8 +289,6 @@ def shark_sd_fn(
|
||||
sd_init_image = [sd_init_image]
|
||||
is_img2img = True if sd_init_image[0] is not None else False
|
||||
|
||||
print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
@@ -326,6 +329,7 @@ def shark_sd_fn(
|
||||
"batch_size": batch_size,
|
||||
"precision": precision,
|
||||
"device": device,
|
||||
"target_triple": target_triple,
|
||||
"custom_vae": custom_vae,
|
||||
"num_loras": num_loras,
|
||||
"import_ir": import_ir,
|
||||
|
||||
@@ -52,6 +52,13 @@ def get_available_devices():
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
@@ -64,19 +71,27 @@ def get_available_devices():
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
hip_devices = get_devices_by_name("hip")
|
||||
available_devices.extend(hip_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
for idx, device_str in enumerate(available_devices):
|
||||
if "AMD Radeon(TM) Graphics =>" in device_str:
|
||||
igpu_id_candidates = [
|
||||
x.split("w/")[-1].split("=>")[0]
|
||||
for x in available_devices
|
||||
if "M Graphics" in x
|
||||
]
|
||||
for igpu_name in igpu_id_candidates:
|
||||
if igpu_name:
|
||||
available_devices[idx] = device_str.replace(
|
||||
"AMD Radeon(TM) Graphics", igpu_name
|
||||
)
|
||||
break
|
||||
return available_devices
|
||||
|
||||
|
||||
@@ -129,7 +144,7 @@ def set_iree_runtime_flags():
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def parse_device(device_str):
|
||||
def parse_device(device_str, target_override=""):
|
||||
from shark.iree_utils.compile_utils import (
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
@@ -143,6 +158,8 @@ def parse_device(device_str):
|
||||
else:
|
||||
rt_device = rt_driver
|
||||
|
||||
if target_override:
|
||||
return target_backend, rt_device, target_override
|
||||
match target_backend:
|
||||
case "vulkan-spirv":
|
||||
triple = get_iree_target_triple(device_str)
|
||||
@@ -168,6 +185,7 @@ def get_rocm_target_chip(device_str):
|
||||
"MI100": "gfx908",
|
||||
"MI50": "gfx906",
|
||||
"MI60": "gfx906",
|
||||
"780M": "gfx1103",
|
||||
}
|
||||
for key in rocm_chip_map:
|
||||
if key in device_str:
|
||||
|
||||
@@ -24,47 +24,47 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
)
|
||||
schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
)
|
||||
schedulers["DPMSolverMultistepKarras"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
)
|
||||
schedulers["DPMSolverMultistepKarras++"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
)
|
||||
# schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
# )
|
||||
# schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
# )
|
||||
# schedulers["DPMSolverMultistepKarras"] = (
|
||||
# DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# use_karras_sigmas=True,
|
||||
# )
|
||||
# )
|
||||
# schedulers["DPMSolverMultistepKarras++"] = (
|
||||
# DPMSolverMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# algorithm_type="dpmsolver++",
|
||||
# use_karras_sigmas=True,
|
||||
# )
|
||||
# )
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
@@ -75,24 +75,24 @@ def get_schedulers(model_id):
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2AncestralDiscrete"] = (
|
||||
KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
# schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# schedulers["KDPM2AncestralDiscrete"] = (
|
||||
# KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
# )
|
||||
# schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
# model_id,
|
||||
# subfolder="scheduler",
|
||||
# )
|
||||
return schedulers
|
||||
|
||||
|
||||
@@ -102,17 +102,17 @@ def export_scheduler_model(model):
|
||||
|
||||
scheduler_model_map = {
|
||||
"PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
"DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
|
||||
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
|
||||
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
|
||||
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
|
||||
"LCM": export_scheduler_model("LCMScheduler"),
|
||||
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
"DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
"DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
# "LCM": export_scheduler_model("LCMScheduler"),
|
||||
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
# "DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
# "DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
}
|
||||
|
||||
@@ -117,6 +117,7 @@ def pull_sd_configs(
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
@@ -175,6 +176,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json["custom_vae"],
|
||||
sd_json["precision"],
|
||||
sd_json["device"],
|
||||
sd_json["target_triple"],
|
||||
sd_json["ondemand"],
|
||||
sd_json["repeatable_seeds"],
|
||||
sd_json["resample_type"],
|
||||
@@ -253,6 +255,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
choices=global_obj.get_device_list(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
target_triple = gr.Textbox(
|
||||
elem_id="triple",
|
||||
label="Architecture",
|
||||
value="",
|
||||
)
|
||||
with gr.Row():
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
@@ -691,6 +698,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
@@ -730,6 +738,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
|
||||
Reference in New Issue
Block a user