mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Diffusers hijack imminent
This commit is contained in:
File diff suppressed because it is too large
Load Diff
218
apps/shark_studio/modules/meta_model.py
Normal file
218
apps/shark_studio/modules/meta_model.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from msvcrt import kbhit
|
||||
from shark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
load_vmfb_using_mmap,
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
)
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
get_checkpoints_path,
|
||||
get_resource_path,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import (
|
||||
cmd_opts,
|
||||
)
|
||||
from iree import runtime as ireert
|
||||
from pathlib import Path
|
||||
import gc
|
||||
import os
|
||||
|
||||
|
||||
class SharkMetaModelBase:
|
||||
# This class is a lightweight base for managing an
|
||||
# inference API class. It should provide methods for:
|
||||
# - compiling a set (model map) of torch IR modules
|
||||
# - preparing weights for an inference job
|
||||
# - loading weights for an inference job
|
||||
# - utilites like benchmarks, tests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_map: dict,
|
||||
device: str,
|
||||
dtype: str = "f16",
|
||||
import_mlir: bool = True,
|
||||
):
|
||||
self.model_map = model_map
|
||||
self.pipe_map = {}
|
||||
self.triple = get_iree_target_triple(device)
|
||||
self._device, self.device_id = clean_device_info(device)
|
||||
self.import_mlir = import_mlir
|
||||
self.iree_module_dict = {}
|
||||
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.mkdir(self.tmp_dir)
|
||||
self.tempfiles = {}
|
||||
self.pipe_vmfb_path = ""
|
||||
self._dtype = dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@device.setter
|
||||
def device(self, device_str):
|
||||
self._device = device_str
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, dtype_val):
|
||||
self._dtype = dtype_val
|
||||
|
||||
def get_compiled_map(self, pipe_id, static_kwargs, submodel="None", init_kwargs={}) -> None:
|
||||
# First checks whether we have .vmfbs precompiled, then populates the map
|
||||
# with the precompiled executables and fetches executables for the rest of the map.
|
||||
# The weights aren't static here anymore so this function should be a part of pipeline
|
||||
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
|
||||
# and your model map is populated with any IR - unique model IDs and their static params,
|
||||
# call this method to get the artifacts associated with your map.
|
||||
self.pipe_id = self.safe_name(pipe_id)
|
||||
self.pipe_vmfb_path = Path(
|
||||
os.path.join(get_checkpoints_path(".."), self.pipe_id)
|
||||
)
|
||||
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
|
||||
if submodel == "None":
|
||||
print("\n[LOG] Gathering any pre-compiled artifacts....")
|
||||
for key in self.model_map:
|
||||
self.get_compiled_map(pipe_id, static_kwargs, submodel=key)
|
||||
else:
|
||||
self.pipe_map[submodel] = {}
|
||||
self.get_precompiled(self.pipe_id, submodel)
|
||||
ireec_flags = []
|
||||
if submodel in self.iree_module_dict:
|
||||
return
|
||||
elif "vmfb_path" in self.pipe_map[submodel]:
|
||||
return
|
||||
elif submodel not in self.tempfiles:
|
||||
print(
|
||||
f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..."
|
||||
)
|
||||
if submodel in static_kwargs:
|
||||
init_kwargs = static_kwargs[submodel]
|
||||
for key in static_kwargs["pipe"]:
|
||||
if key not in init_kwargs:
|
||||
init_kwargs[key] = static_kwargs["pipe"][key]
|
||||
self.import_torch_ir(submodel, init_kwargs)
|
||||
self.get_compiled_map(pipe_id, submodel)
|
||||
else:
|
||||
ireec_flags = (
|
||||
self.model_map[submodel]["ireec_flags"]
|
||||
if "ireec_flags" in self.model_map[submodel]
|
||||
else []
|
||||
)
|
||||
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
self.tempfiles[submodel],
|
||||
device=self.device,
|
||||
frontend="torch",
|
||||
mmap=True,
|
||||
external_weight_file=self.get_io_params(submodel),
|
||||
extra_args=ireec_flags,
|
||||
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
|
||||
)
|
||||
return
|
||||
|
||||
def get_io_params(self, submodel):
|
||||
if "external_weight_file" in self.static_kwargs[submodel]:
|
||||
# we are using custom weights
|
||||
weights_path = self.static_kwargs[submodel]["external_weight_file"]
|
||||
elif "external_weight_path" in self.static_kwargs[submodel]:
|
||||
# we are using the default weights for the HF model
|
||||
weights_path = self.static_kwargs[submodel]["external_weight_path"]
|
||||
else:
|
||||
# assume the torch IR contains the weights.
|
||||
weights_path = None
|
||||
return weights_path
|
||||
|
||||
def get_precompiled(self, pipe_id, submodel="None"):
|
||||
if submodel == "None":
|
||||
for model in self.model_map:
|
||||
self.get_precompiled(pipe_id, model)
|
||||
vmfbs = []
|
||||
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
|
||||
vmfbs.extend(filenames)
|
||||
break
|
||||
for file in vmfbs:
|
||||
if submodel in file:
|
||||
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
|
||||
self.pipe_vmfb_path, file
|
||||
)
|
||||
return
|
||||
|
||||
def import_torch_ir(self, submodel, kwargs):
|
||||
torch_ir = self.model_map[submodel]["initializer"](
|
||||
**self.safe_dict(kwargs), compile_to="torch"
|
||||
)
|
||||
if submodel == "clip":
|
||||
# clip.export_clip_model returns (torch_ir, tokenizer)
|
||||
torch_ir = torch_ir[0]
|
||||
|
||||
self.tempfiles[submodel] = os.path.join(
|
||||
self.tmp_dir, f"{submodel}.torch.tempfile"
|
||||
)
|
||||
|
||||
with open(self.tempfiles[submodel], "w+") as f:
|
||||
f.write(torch_ir)
|
||||
del torch_ir
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def load_submodels(self, submodels: list):
|
||||
for submodel in submodels:
|
||||
if submodel in self.iree_module_dict:
|
||||
print(f"\n[LOG] {submodel} is ready for inference.")
|
||||
continue
|
||||
if "vmfb_path" in self.pipe_map[submodel]:
|
||||
weights_path = self.get_io_params(submodel)
|
||||
# print(
|
||||
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
|
||||
# )
|
||||
self.iree_module_dict[submodel] = {}
|
||||
(
|
||||
self.iree_module_dict[submodel]["vmfb"],
|
||||
self.iree_module_dict[submodel]["config"],
|
||||
self.iree_module_dict[submodel]["temp_file_to_unlink"],
|
||||
) = load_vmfb_using_mmap(
|
||||
self.pipe_map[submodel]["vmfb_path"],
|
||||
self.device,
|
||||
device_idx=0,
|
||||
rt_flags=[],
|
||||
external_weight_file=weights_path,
|
||||
)
|
||||
else:
|
||||
self.get_compiled_map(self.pipe_id, submodel)
|
||||
return
|
||||
|
||||
def unload_submodels(self, submodels: list):
|
||||
for submodel in submodels:
|
||||
if submodel in self.iree_module_dict:
|
||||
del self.iree_module_dict[submodel]
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def run(self, submodel, inputs):
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
inp = [
|
||||
ireert.asdevicearray(
|
||||
self.iree_module_dict[submodel]["config"].device, input
|
||||
)
|
||||
for input in inputs
|
||||
]
|
||||
return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
|
||||
|
||||
def safe_name(self, name):
|
||||
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
|
||||
|
||||
def safe_dict(self, kwargs: dict):
|
||||
flat_args = {}
|
||||
for i in kwargs:
|
||||
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
|
||||
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
|
||||
else:
|
||||
flat_args[i] = kwargs[i]
|
||||
|
||||
return flat_args
|
||||
47
apps/shark_studio/tests/diffusers_pipeline_test.py
Normal file
47
apps/shark_studio/tests/diffusers_pipeline_test.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2023 Nod Labs, Inc
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import unittest
|
||||
import PIL
|
||||
from typing import List
|
||||
from apps.shark_studio.api.sd import SharkDiffusionPipeline
|
||||
#from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class SDBaseAPITest(unittest.TestCase):
|
||||
def testPipeSimple(self):
|
||||
pipe = SharkDiffusionPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path="hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
device="vulkan",
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
pipe.setup_shark(
|
||||
base_model_id="hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
height=512,
|
||||
width=512,
|
||||
batch_size=1,
|
||||
precision="f32",
|
||||
device="vulkan",
|
||||
)
|
||||
|
||||
pipe.prepare_pipe(
|
||||
custom_weights="",
|
||||
adapters=[],
|
||||
embeddings=[],
|
||||
is_img2img=False,
|
||||
)
|
||||
|
||||
prompt = ["An astronaut riding a fearsome shark"]
|
||||
negative_prompt = [""]
|
||||
image = pipe(prompt, negative_prompt).images[0]
|
||||
assert isinstance(image, List(PIL.Image.Image))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user