Compare commits

...

6 Commits

Author SHA1 Message Date
Ean Garvey
8f988517d4 Fix conditionals. 2023-11-17 20:09:09 +00:00
Ean Garvey
d78e86bd48 Fix cases where full URI is given 2023-11-17 19:44:49 +00:00
Ean Garvey
82128672fd Fix formatting. 2023-11-17 19:01:37 +00:00
Ean Garvey
2c5b79b09f Fix .mlir writes for some user-level permissions 2023-11-17 19:00:35 +00:00
Ean Garvey
2d1b6aa35f Update compile_utils.py 2023-11-17 12:43:37 -06:00
Ean Garvey
56d0e59c0f Move clean_device_info to compile_utils 2023-11-17 12:37:49 -06:00
4 changed files with 30 additions and 46 deletions

View File

@@ -118,7 +118,7 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
save_dir="",
debug=False,
generate_vmfb=True,
extra_args=None,

View File

@@ -6,6 +6,7 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from shark.iree_utils.compile_utils import clean_device_info
from datetime import datetime as dt
import json
import sys
@@ -132,27 +133,6 @@ def get_default_config():
c.split_into_layers()
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by LLM pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
model_vmfb_key = ""

View File

@@ -31,24 +31,7 @@ from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan", "rocm"]:
print(
f"Specific device selection only supported for vulkan and rocm."
f"Proceeding with {device} as device."
)
# device_uri can be device_num or device_path.
# assuming number of devices for a single driver will be not be >99
if len(device_uri[1]) <= 2:
# expected to be device index in range 0 - 99
device_num = int(device_uri[1])
else:
# expected to be device path
device_num = device_uri[1]
else:
device_num = 0
device, device_num = clean_device_info(device)
if "cpu" in device:
from shark.iree_utils.cpu_utils import get_iree_cpu_args
@@ -64,27 +47,48 @@ def get_iree_device_args(device, extra_args=[]):
+ stack_size_flag
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device_uri[0] == "vulkan":
if device == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
if device == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(extra_args=extra_args)
if device_uri[0] == "rocm":
if device == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id[0]) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:

View File

@@ -800,13 +800,13 @@ def save_mlir(
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
dir="",
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
dir = os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if frontend == "torch":