Use clean_device_info() by default and don't write .mlir to /tmp/ (#1984)

* Move clean_device_info to compile_utils

* Update compile_utils.py

* Fix .mlir writes for some user-level permissions

* Fix cases where full URI is given

* Fix conditionals.

* Fix device path handling in vulkan utils.
This commit is contained in:
Ean Garvey
2023-11-20 13:10:31 -06:00
committed by GitHub
parent 1b11c82c9d
commit d051c3a4a7
5 changed files with 32 additions and 47 deletions

View File

@@ -118,7 +118,7 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
save_dir="",
debug=False,
generate_vmfb=True,
extra_args=None,

View File

@@ -6,6 +6,7 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from shark.iree_utils.compile_utils import clean_device_info
from datetime import datetime as dt
import json
import sys
@@ -132,27 +133,6 @@ def get_default_config():
c.split_into_layers()
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by LLM pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
model_vmfb_key = ""

View File

@@ -31,24 +31,7 @@ from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan", "rocm"]:
print(
f"Specific device selection only supported for vulkan and rocm."
f"Proceeding with {device} as device."
)
# device_uri can be device_num or device_path.
# assuming number of devices for a single driver will be not be >99
if len(device_uri[1]) <= 2:
# expected to be device index in range 0 - 99
device_num = int(device_uri[1])
else:
# expected to be device path
device_num = device_uri[1]
else:
device_num = 0
device, device_num = clean_device_info(device)
if "cpu" in device:
from shark.iree_utils.cpu_utils import get_iree_cpu_args
@@ -64,27 +47,49 @@ def get_iree_device_args(device, extra_args=[]):
+ stack_size_flag
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device_uri[0] == "vulkan":
if device == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
if device == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(extra_args=extra_args)
if device_uri[0] == "rocm":
if device == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:

View File

@@ -45,7 +45,7 @@ def get_vulkan_device_name(device_num=0):
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: {vulkaninfo_list[device_num]}")
print(f"Choosing device: vulkan://{device_num}")
return vulkaninfo_list[device_num]

View File

@@ -800,13 +800,13 @@ def save_mlir(
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
dir="",
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
dir = os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if frontend == "torch":