mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
Add igpu and custom triple support. (#2148)
This commit is contained in:
@@ -73,6 +73,7 @@ class StableDiffusion:
|
||||
scheduler: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
target_triple: str = None,
|
||||
custom_vae: str = None,
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
@@ -91,7 +92,7 @@ class StableDiffusion:
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
external_weights = "safetensors"
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device)
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
@@ -273,6 +274,7 @@ def shark_sd_fn(
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
target_triple: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
@@ -326,6 +328,7 @@ def shark_sd_fn(
|
||||
"batch_size": batch_size,
|
||||
"precision": precision,
|
||||
"device": device,
|
||||
"target_triple": target_triple,
|
||||
"custom_vae": custom_vae,
|
||||
"num_loras": num_loras,
|
||||
"import_ir": import_ir,
|
||||
|
||||
@@ -77,6 +77,21 @@ def get_available_devices():
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
print(available_devices)
|
||||
for idx, device_str in enumerate(available_devices):
|
||||
if "AMD Radeon(TM) Graphics =>" in device_str:
|
||||
igpu_id_candidates = [
|
||||
x.split("w/")[-1].split("=>")[0]
|
||||
for x in available_devices
|
||||
if "M Graphics" in x
|
||||
]
|
||||
for igpu_name in igpu_id_candidates:
|
||||
if igpu_name:
|
||||
print(f"Found iGPU: {igpu_name} for {device_str}")
|
||||
available_devices[idx] = device_str.replace(
|
||||
"AMD Radeon(TM) Graphics", f"AMD iGPU: {igpu_name}"
|
||||
)
|
||||
break
|
||||
return available_devices
|
||||
|
||||
|
||||
@@ -129,7 +144,7 @@ def set_iree_runtime_flags():
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def parse_device(device_str):
|
||||
def parse_device(device_str, target_override=""):
|
||||
from shark.iree_utils.compile_utils import (
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
@@ -143,6 +158,8 @@ def parse_device(device_str):
|
||||
else:
|
||||
rt_device = rt_driver
|
||||
|
||||
if target_override:
|
||||
return target_backend, rt_device, target_override
|
||||
match target_backend:
|
||||
case "vulkan-spirv":
|
||||
triple = get_iree_target_triple(device_str)
|
||||
@@ -168,6 +185,7 @@ def get_rocm_target_chip(device_str):
|
||||
"MI100": "gfx908",
|
||||
"MI50": "gfx906",
|
||||
"MI60": "gfx906",
|
||||
"780M": "gfx1103",
|
||||
}
|
||||
for key in rocm_chip_map:
|
||||
if key in device_str:
|
||||
|
||||
@@ -117,6 +117,7 @@ def pull_sd_configs(
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
@@ -175,6 +176,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json["custom_vae"],
|
||||
sd_json["precision"],
|
||||
sd_json["device"],
|
||||
sd_json["target_triple"],
|
||||
sd_json["ondemand"],
|
||||
sd_json["repeatable_seeds"],
|
||||
sd_json["resample_type"],
|
||||
@@ -253,6 +255,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
choices=global_obj.get_device_list(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
target_triple = gr.Textbox(
|
||||
elem_id="triple",
|
||||
label="Architecture",
|
||||
value="",
|
||||
)
|
||||
with gr.Row():
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
@@ -691,6 +698,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
@@ -730,6 +738,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
|
||||
Reference in New Issue
Block a user