Revert "[SD][WEB] Deduce vulkan-target-triple in the presence of multiple cards"

This reverts commit 35e623deaf.
This commit is contained in:
Gaurav Shukla
2022-12-17 04:26:43 +05:30
parent 35e623deaf
commit 72648aa9f2
3 changed files with 4 additions and 22 deletions

View File

@@ -90,8 +90,9 @@ def set_iree_runtime_flags():
def make_qualified_device_name():
# modify device name to be fully qualified device name of the format driver://path.
# supported for vulkan as of now.
# modify device name to be fully qualified device name
# of the format driver://path
# supported for vulkan as of now
if "vulkan" in args.device:
args.device = map_device_to_path(args.device)

View File

@@ -7,10 +7,7 @@ from diffusers import (
EulerDiscreteScheduler,
)
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
from models.stable_diffusion.utils import (
set_iree_runtime_flags,
make_qualified_device_name,
)
from models.stable_diffusion.utils import set_iree_runtime_flags
from models.stable_diffusion.stable_args import args
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
@@ -43,13 +40,6 @@ schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
subfolder="scheduler",
)
# update device name to fully-qualified device name.
make_qualified_device_name()
# use tuned version of unet in case of rdna3 cards.
if "rdna3" in get_vulkan_triple_flag(args.device):
args.use_tuned = True
# set iree-runtime flags
set_iree_runtime_flags()

View File

@@ -5,7 +5,6 @@ from shark.shark_inference import SharkInference
from models.stable_diffusion.stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
from shark.iree_utils._common import map_device_to_path
def _compile_module(shark_module, model_name, extra_args=[]):
@@ -87,11 +86,3 @@ def set_iree_runtime_flags():
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return
def make_qualified_device_name():
# modify device name to be fully qualified device name of the format driver://path.
# supported for vulkan as of now.
if "vulkan" in args.device:
args.device = map_device_to_path(args.device)