mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Revert "[SD][WEB] Deduce vulkan-target-triple in the presence of multiple cards"
This reverts commit 35e623deaf.
This commit is contained in:
@@ -90,8 +90,9 @@ def set_iree_runtime_flags():
|
||||
|
||||
|
||||
def make_qualified_device_name():
|
||||
# modify device name to be fully qualified device name of the format driver://path.
|
||||
# supported for vulkan as of now.
|
||||
# modify device name to be fully qualified device name
|
||||
# of the format driver://path
|
||||
# supported for vulkan as of now
|
||||
|
||||
if "vulkan" in args.device:
|
||||
args.device = map_device_to_path(args.device)
|
||||
|
||||
@@ -7,10 +7,7 @@ from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
|
||||
from models.stable_diffusion.utils import (
|
||||
set_iree_runtime_flags,
|
||||
make_qualified_device_name,
|
||||
)
|
||||
from models.stable_diffusion.utils import set_iree_runtime_flags
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
|
||||
@@ -43,13 +40,6 @@ schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
# update device name to fully-qualified device name.
|
||||
make_qualified_device_name()
|
||||
|
||||
# use tuned version of unet in case of rdna3 cards.
|
||||
if "rdna3" in get_vulkan_triple_flag(args.device):
|
||||
args.use_tuned = True
|
||||
|
||||
# set iree-runtime flags
|
||||
set_iree_runtime_flags()
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from shark.shark_inference import SharkInference
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
|
||||
from shark.iree_utils._common import map_device_to_path
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
@@ -87,11 +86,3 @@ def set_iree_runtime_flags():
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def make_qualified_device_name():
|
||||
# modify device name to be fully qualified device name of the format driver://path.
|
||||
# supported for vulkan as of now.
|
||||
|
||||
if "vulkan" in args.device:
|
||||
args.device = map_device_to_path(args.device)
|
||||
|
||||
Reference in New Issue
Block a user