Switch most compile flows to use ireec.compile_file. (#1863)

* Switch most compile flows to use ireec.compile_file.

* re-add input type to compile_str path.

* Check if mlir_module exists before checking if it's a path or pyobject.

* Fix some save_dir cases
This commit is contained in:
Ean Garvey
2023-10-06 23:04:43 -05:00
committed by GitHub
parent 8614a18474
commit caf6cc5d8f
21 changed files with 184 additions and 79 deletions

View File

@@ -710,8 +710,11 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
save_dir = ""
if self.debug:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["clip"]
)
os.makedirs(
save_dir,
exist_ok=True,

View File

@@ -18,7 +18,7 @@ import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
@@ -154,8 +154,8 @@ def compile_through_fx(
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
@@ -168,6 +168,14 @@ def compile_through_fx(
mlir_module, extended_model_name, base_model_id
)
if not os.path.isdir(save_dir):
save_dir = ""
mlir_module = save_mlir(
mlir_module,
model_name=extended_model_name,
dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=args.device if device is None else device,
@@ -179,7 +187,6 @@ def compile_through_fx(
mlir_module,
)
del mlir_module
gc.collect()