mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
* Initial filesystem cleanup * More filesystem cleanup * Fix some formatting issues * Address comments
206 lines
7.8 KiB
Python
206 lines
7.8 KiB
Python
from shark.iree_utils.compile_utils import (
|
|
get_iree_compiled_module,
|
|
load_vmfb_using_mmap,
|
|
clean_device_info,
|
|
get_iree_target_triple,
|
|
)
|
|
from apps.shark_studio.web.utils.file_utils import (
|
|
get_checkpoints_path,
|
|
get_resource_path,
|
|
)
|
|
from apps.shark_studio.modules.shared_cmd_opts import (
|
|
cmd_opts,
|
|
)
|
|
from iree import runtime as ireert
|
|
from pathlib import Path
|
|
import gc
|
|
import os
|
|
|
|
|
|
class SharkPipelineBase:
|
|
# This class is a lightweight base for managing an
|
|
# inference API class. It should provide methods for:
|
|
# - compiling a set (model map) of torch IR modules
|
|
# - preparing weights for an inference job
|
|
# - loading weights for an inference job
|
|
# - utilites like benchmarks, tests
|
|
|
|
def __init__(
|
|
self,
|
|
model_map: dict,
|
|
base_model_id: str,
|
|
static_kwargs: dict,
|
|
device: str,
|
|
import_mlir: bool = True,
|
|
):
|
|
self.model_map = model_map
|
|
self.pipe_map = {}
|
|
self.static_kwargs = static_kwargs
|
|
self.base_model_id = base_model_id
|
|
self.triple = get_iree_target_triple(device)
|
|
self.device, self.device_id = clean_device_info(device)
|
|
self.import_mlir = import_mlir
|
|
self.iree_module_dict = {}
|
|
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
|
|
if not os.path.exists(self.tmp_dir):
|
|
os.mkdir(self.tmp_dir)
|
|
self.tempfiles = {}
|
|
self.pipe_vmfb_path = ""
|
|
|
|
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
|
|
# First checks whether we have .vmfbs precompiled, then populates the map
|
|
# with the precompiled executables and fetches executables for the rest of the map.
|
|
# The weights aren't static here anymore so this function should be a part of pipeline
|
|
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
|
|
# and your model map is populated with any IR - unique model IDs and their static params,
|
|
# call this method to get the artifacts associated with your map.
|
|
self.pipe_id = self.safe_name(pipe_id)
|
|
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
|
|
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
|
|
if submodel == "None":
|
|
print("\n[LOG] Gathering any pre-compiled artifacts....")
|
|
for key in self.model_map:
|
|
self.get_compiled_map(pipe_id, submodel=key)
|
|
else:
|
|
self.pipe_map[submodel] = {}
|
|
self.get_precompiled(self.pipe_id, submodel)
|
|
ireec_flags = []
|
|
if submodel in self.iree_module_dict:
|
|
return
|
|
elif "vmfb_path" in self.pipe_map[submodel]:
|
|
return
|
|
elif submodel not in self.tempfiles:
|
|
print(
|
|
f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..."
|
|
)
|
|
if submodel in self.static_kwargs:
|
|
init_kwargs = self.static_kwargs[submodel]
|
|
for key in self.static_kwargs["pipe"]:
|
|
if key not in init_kwargs:
|
|
init_kwargs[key] = self.static_kwargs["pipe"][key]
|
|
self.import_torch_ir(submodel, init_kwargs)
|
|
self.get_compiled_map(pipe_id, submodel)
|
|
else:
|
|
ireec_flags = (
|
|
self.model_map[submodel]["ireec_flags"]
|
|
if "ireec_flags" in self.model_map[submodel]
|
|
else []
|
|
)
|
|
|
|
weights_path = self.get_io_params(submodel)
|
|
if weights_path:
|
|
ireec_flags.append("--iree-opt-const-eval=False")
|
|
|
|
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
|
self.tempfiles[submodel],
|
|
device=self.device,
|
|
frontend="torch",
|
|
mmap=True,
|
|
external_weight_file=weights_path,
|
|
extra_args=ireec_flags,
|
|
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
|
|
)
|
|
return
|
|
|
|
def get_io_params(self, submodel):
|
|
if "external_weight_file" in self.static_kwargs[submodel]:
|
|
# we are using custom weights
|
|
weights_path = self.static_kwargs[submodel]["external_weight_file"]
|
|
elif "external_weight_path" in self.static_kwargs[submodel]:
|
|
# we are using the default weights for the HF model
|
|
weights_path = self.static_kwargs[submodel]["external_weight_path"]
|
|
else:
|
|
# assume the torch IR contains the weights.
|
|
weights_path = None
|
|
return weights_path
|
|
|
|
def get_precompiled(self, pipe_id, submodel="None"):
|
|
if submodel == "None":
|
|
for model in self.model_map:
|
|
self.get_precompiled(pipe_id, model)
|
|
vmfbs = []
|
|
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
|
|
vmfbs.extend(filenames)
|
|
break
|
|
for file in vmfbs:
|
|
if submodel in file:
|
|
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
|
|
self.pipe_vmfb_path, file
|
|
)
|
|
return
|
|
|
|
def import_torch_ir(self, submodel, kwargs):
|
|
torch_ir = self.model_map[submodel]["initializer"](
|
|
**self.safe_dict(kwargs), compile_to="torch"
|
|
)
|
|
if submodel == "clip":
|
|
# clip.export_clip_model returns (torch_ir, tokenizer)
|
|
torch_ir = torch_ir[0]
|
|
|
|
self.tempfiles[submodel] = os.path.join(
|
|
self.tmp_dir, f"{submodel}.torch.tempfile"
|
|
)
|
|
|
|
with open(self.tempfiles[submodel], "w+") as f:
|
|
f.write(torch_ir)
|
|
del torch_ir
|
|
gc.collect()
|
|
return
|
|
|
|
def load_submodels(self, submodels: list):
|
|
for submodel in submodels:
|
|
if submodel in self.iree_module_dict:
|
|
print(f"\n[LOG] {submodel} is ready for inference.")
|
|
continue
|
|
if "vmfb_path" in self.pipe_map[submodel]:
|
|
weights_path = self.get_io_params(submodel)
|
|
# print(
|
|
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
|
|
# )
|
|
self.iree_module_dict[submodel] = {}
|
|
(
|
|
self.iree_module_dict[submodel]["vmfb"],
|
|
self.iree_module_dict[submodel]["config"],
|
|
self.iree_module_dict[submodel]["temp_file_to_unlink"],
|
|
) = load_vmfb_using_mmap(
|
|
self.pipe_map[submodel]["vmfb_path"],
|
|
self.device,
|
|
device_idx=0,
|
|
rt_flags=[],
|
|
external_weight_file=weights_path,
|
|
)
|
|
else:
|
|
self.get_compiled_map(self.pipe_id, submodel)
|
|
return
|
|
|
|
def unload_submodels(self, submodels: list):
|
|
for submodel in submodels:
|
|
if submodel in self.iree_module_dict:
|
|
del self.iree_module_dict[submodel]
|
|
gc.collect()
|
|
return
|
|
|
|
def run(self, submodel, inputs):
|
|
if not isinstance(inputs, list):
|
|
inputs = [inputs]
|
|
inp = [
|
|
ireert.asdevicearray(
|
|
self.iree_module_dict[submodel]["config"].device, input
|
|
)
|
|
for input in inputs
|
|
]
|
|
return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
|
|
|
|
def safe_name(self, name):
|
|
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
|
|
|
|
def safe_dict(self, kwargs: dict):
|
|
flat_args = {}
|
|
for i in kwargs:
|
|
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
|
|
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
|
|
else:
|
|
flat_args[i] = kwargs[i]
|
|
|
|
return flat_args
|