Files
SHARK-Studio/tank/generate_amdsharktank.py
pdhirajkumarprasad f6dd02fa67 Fix for migration (#2183)
Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
2025-11-24 15:32:02 +05:30

215 lines
7.2 KiB
Python

# Lint as: python3
"""AMD-SHARK Tank"""
# python generate_amdsharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local amdshark tank folder like this:
# /AMDSHARK
# /gen_amdshark_tank
# /albert_lite_base
# /...model_name...
#
import os
import csv
import argparse
from amdshark.amdshark_importer import AMDSharkImporter
import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b(digest_size=64)
while chunk := f.read(2**10):
file_hash.update(chunk)
return file_hash.hexdigest()
def save_torch_model(torch_model_list, local_tank_cache, import_args):
from tank.model_utils import (
get_hf_model,
get_hf_seq2seq_model,
get_hf_causallm_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
)
from amdshark.amdshark_importer import import_with_fx, save_mlir
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
fields = next(torch_reader)
for row in torch_reader:
torch_model_name = row[0]
tracing_required = row[1]
model_type = row[2]
is_dynamic = row[3]
mlir_type = row[4]
is_decompose = row[5]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False
print("generating artifacts for: " + torch_model_name)
model = None
input = None
if model_type == "vision":
model, input, _ = get_vision_model(
torch_model_name, import_args
)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name, import_args)
elif model_type == "hf_seq2seq":
model, input, _ = get_hf_seq2seq_model(
torch_model_name, import_args
)
elif model_type == "hf_causallm":
model, input, _ = get_hf_causallm_model(
torch_model_name, import_args
)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(
torch_model_name, import_args
)
torch_model_name = torch_model_name.replace("/", "_")
if import_args["batch_size"] > 1:
print(
f"Batch size for this model set to {import_args['batch_size']}"
)
torch_model_dir = os.path.join(
local_tank_cache,
str(torch_model_name)
+ "_torch"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
torch_model_dir = os.path.join(
local_tank_cache, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
if is_decompose:
# Add decomposition to some torch ops
# TODO add op whitelist/blacklist
import_with_fx(
model,
(input,),
is_f16=False,
f16_input_mask=None,
debug=True,
training=False,
return_str=False,
save_dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
is_dynamic=False,
tracing_required=True,
)
else:
mlir_importer = AMDSharkImporter(
model,
(input,),
frontend="torch",
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
)
# Generate torch dynamic models.
if is_dynamic:
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
mlir_type=mlir_type,
)
def check_requirements(frontend):
import importlib
has_pkgs = False
if frontend == "torch":
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
return has_pkgs
class NoImportException(Exception):
"Raised when requirements are not met for OTF model artifact generation."
pass
def gen_amdshark_files(modelname, frontend, tank_dir, importer_args):
# If a model's artifacts are requested by amdshark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
# TODO: Add TFlite support.
import tempfile
import_args = importer_args
if check_requirements(frontend):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
if frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir, import_args)
else:
raise NoImportException
# Validates whether the file is present or not.
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
if __name__ == "__main__":
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--torch_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/torch_model_list.csv",
# help="""Contains the file with torch_model name and args.
# Please see: https://github.com/nod-ai/AMD-SHARK-Studio/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
# )
# parser.add_argument("--upload", type=bool, default=False)
# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": 1,
}
print(import_args)
home = str(Path.home())
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_amdshark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
save_torch_model(torch_model_csv, WORKDIR, import_args)