mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
215 lines
7.2 KiB
Python
215 lines
7.2 KiB
Python
# Lint as: python3
|
|
"""AMD-SHARK Tank"""
|
|
# python generate_amdsharktank.py, you have to give a csv tile with [model_name, model_download_url]
|
|
# will generate local amdshark tank folder like this:
|
|
# /AMDSHARK
|
|
# /gen_amdshark_tank
|
|
# /albert_lite_base
|
|
# /...model_name...
|
|
#
|
|
|
|
import os
|
|
import csv
|
|
import argparse
|
|
from amdshark.amdshark_importer import AMDSharkImporter
|
|
import subprocess as sp
|
|
import hashlib
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
|
|
def create_hash(file_name):
|
|
with open(file_name, "rb") as f:
|
|
file_hash = hashlib.blake2b(digest_size=64)
|
|
while chunk := f.read(2**10):
|
|
file_hash.update(chunk)
|
|
|
|
return file_hash.hexdigest()
|
|
|
|
|
|
def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
|
from tank.model_utils import (
|
|
get_hf_model,
|
|
get_hf_seq2seq_model,
|
|
get_hf_causallm_model,
|
|
get_vision_model,
|
|
get_hf_img_cls_model,
|
|
get_fp16_model,
|
|
)
|
|
from amdshark.amdshark_importer import import_with_fx, save_mlir
|
|
|
|
with open(torch_model_list) as csvfile:
|
|
torch_reader = csv.reader(csvfile, delimiter=",")
|
|
fields = next(torch_reader)
|
|
for row in torch_reader:
|
|
torch_model_name = row[0]
|
|
tracing_required = row[1]
|
|
model_type = row[2]
|
|
is_dynamic = row[3]
|
|
mlir_type = row[4]
|
|
is_decompose = row[5]
|
|
|
|
tracing_required = False if tracing_required == "False" else True
|
|
is_dynamic = False
|
|
print("generating artifacts for: " + torch_model_name)
|
|
model = None
|
|
input = None
|
|
if model_type == "vision":
|
|
model, input, _ = get_vision_model(
|
|
torch_model_name, import_args
|
|
)
|
|
elif model_type == "hf":
|
|
model, input, _ = get_hf_model(torch_model_name, import_args)
|
|
elif model_type == "hf_seq2seq":
|
|
model, input, _ = get_hf_seq2seq_model(
|
|
torch_model_name, import_args
|
|
)
|
|
elif model_type == "hf_causallm":
|
|
model, input, _ = get_hf_causallm_model(
|
|
torch_model_name, import_args
|
|
)
|
|
elif model_type == "hf_img_cls":
|
|
model, input, _ = get_hf_img_cls_model(
|
|
torch_model_name, import_args
|
|
)
|
|
torch_model_name = torch_model_name.replace("/", "_")
|
|
if import_args["batch_size"] > 1:
|
|
print(
|
|
f"Batch size for this model set to {import_args['batch_size']}"
|
|
)
|
|
torch_model_dir = os.path.join(
|
|
local_tank_cache,
|
|
str(torch_model_name)
|
|
+ "_torch"
|
|
+ f"_BS{str(import_args['batch_size'])}",
|
|
)
|
|
else:
|
|
torch_model_dir = os.path.join(
|
|
local_tank_cache, str(torch_model_name) + "_torch"
|
|
)
|
|
os.makedirs(torch_model_dir, exist_ok=True)
|
|
|
|
if is_decompose:
|
|
# Add decomposition to some torch ops
|
|
# TODO add op whitelist/blacklist
|
|
import_with_fx(
|
|
model,
|
|
(input,),
|
|
is_f16=False,
|
|
f16_input_mask=None,
|
|
debug=True,
|
|
training=False,
|
|
return_str=False,
|
|
save_dir=torch_model_dir,
|
|
model_name=torch_model_name,
|
|
mlir_type=mlir_type,
|
|
is_dynamic=False,
|
|
tracing_required=True,
|
|
)
|
|
else:
|
|
mlir_importer = AMDSharkImporter(
|
|
model,
|
|
(input,),
|
|
frontend="torch",
|
|
)
|
|
mlir_importer.import_debug(
|
|
is_dynamic=False,
|
|
tracing_required=True,
|
|
dir=torch_model_dir,
|
|
model_name=torch_model_name,
|
|
mlir_type=mlir_type,
|
|
)
|
|
# Generate torch dynamic models.
|
|
if is_dynamic:
|
|
mlir_importer.import_debug(
|
|
is_dynamic=True,
|
|
tracing_required=True,
|
|
dir=torch_model_dir,
|
|
model_name=torch_model_name + "_dynamic",
|
|
mlir_type=mlir_type,
|
|
)
|
|
|
|
|
|
def check_requirements(frontend):
|
|
import importlib
|
|
|
|
has_pkgs = False
|
|
if frontend == "torch":
|
|
tv_spec = importlib.util.find_spec("torchvision")
|
|
has_pkgs = tv_spec is not None
|
|
|
|
return has_pkgs
|
|
|
|
|
|
class NoImportException(Exception):
|
|
"Raised when requirements are not met for OTF model artifact generation."
|
|
pass
|
|
|
|
|
|
def gen_amdshark_files(modelname, frontend, tank_dir, importer_args):
|
|
# If a model's artifacts are requested by amdshark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
|
|
# TODO: Add TFlite support.
|
|
import tempfile
|
|
|
|
import_args = importer_args
|
|
if check_requirements(frontend):
|
|
torch_model_csv = os.path.join(
|
|
os.path.dirname(__file__), "torch_model_list.csv"
|
|
)
|
|
custom_model_csv = tempfile.NamedTemporaryFile(
|
|
dir=os.path.dirname(__file__),
|
|
delete=True,
|
|
)
|
|
if frontend == "torch":
|
|
with open(torch_model_csv, mode="r") as src:
|
|
reader = csv.reader(src)
|
|
for row in reader:
|
|
if row[0] == modelname:
|
|
target = row
|
|
with open(custom_model_csv.name, mode="w") as trg:
|
|
writer = csv.writer(trg)
|
|
writer.writerow(["modelname", "src"])
|
|
writer.writerow(target)
|
|
save_torch_model(custom_model_csv.name, tank_dir, import_args)
|
|
else:
|
|
raise NoImportException
|
|
|
|
|
|
# Validates whether the file is present or not.
|
|
def is_valid_file(arg):
|
|
if not os.path.exists(arg):
|
|
return None
|
|
else:
|
|
return arg
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
|
|
# parser = argparse.ArgumentParser()
|
|
# parser.add_argument(
|
|
# "--torch_model_csv",
|
|
# type=lambda x: is_valid_file(x),
|
|
# default="./tank/torch_model_list.csv",
|
|
# help="""Contains the file with torch_model name and args.
|
|
# Please see: https://github.com/nod-ai/AMD-SHARK-Studio/blob/main/tank/torch_model_list.csv""",
|
|
# )
|
|
# parser.add_argument(
|
|
# "--ci_tank_dir",
|
|
# type=bool,
|
|
# default=False,
|
|
# )
|
|
# parser.add_argument("--upload", type=bool, default=False)
|
|
|
|
# old_import_args = parser.parse_import_args()
|
|
import_args = {
|
|
"batch_size": 1,
|
|
}
|
|
print(import_args)
|
|
home = str(Path.home())
|
|
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_amdshark_tank")
|
|
torch_model_csv = os.path.join(
|
|
os.path.dirname(__file__), "torch_model_list.csv"
|
|
)
|
|
|
|
save_torch_model(torch_model_csv, WORKDIR, import_args)
|