find gsutil on linux (#557)

* find gsutil on linux

* cleaned up downloader and ditched gsutil

Co-authored-by: dan <dan@nod-labs.com>
This commit is contained in:
Daniel Garvey
2022-12-05 21:03:48 -06:00
committed by GitHub
parent b0dc19a910
commit bba8646669
20 changed files with 99 additions and 250 deletions

View File

@@ -1,7 +1,6 @@
import numpy as np
import tensorflow as tf
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_tf_model
def load_and_preprocess_image(fname: str):

View File

@@ -6,7 +6,7 @@ pyinstaller
tqdm
# SHARK Downloader
gsutil
google-cloud-storage
# Testing
pytest

View File

@@ -1,7 +1,9 @@
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_torch_model("bloom")
mlir_model, func_name, inputs, golden_out = download_model(
"bloom", frontend="torch"
)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"

View File

@@ -1,9 +1,10 @@
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
"microsoft/MiniLM-L12-H384-uncased"
mlir_model, func_name, inputs, golden_out = download_model(
"microsoft/MiniLM-L12-H384-uncased",
frontend="torch",
)

View File

@@ -5,7 +5,7 @@ import torchvision.models as models
from torchvision import transforms
import sys
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
################################## Preprocessing inputs and model ############
@@ -66,7 +66,9 @@ labels = load_labels()
## Can pass any img or input to the forward module.
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
mlir_model, func_name, inputs, golden_out = download_model(
"resnet50", frontend="torch"
)
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
shark_module.compile()

View File

@@ -37,10 +37,12 @@ args = p.parse_args()
def fp16_unet():
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
"stable_diff_f16_18_OCT", tank_url="gs://shark_tank/prashant_nod"
mlir_model, func_name, inputs, golden_out = download_model(
"stable_diff_f16_18_OCT",
tank_url="gs://shark_tank/prashant_nod",
frontend="torch",
)
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"

View File

@@ -17,7 +17,7 @@ from keras_cv.models.generative.stable_diffusion.text_encoder import (
)
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_tf_model
from shark.shark_downloader import download_model
from PIL import Image
# pip install "git+https://github.com/keras-team/keras-cv.git"
@@ -75,8 +75,8 @@ class SharkStableDiffusion:
# Create models
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
mlir_model, func_name, inputs, golden_out = download_tf_model(
"stable_diff", tank_url="gs://shark_tank/quinn"
mlir_model, func_name, inputs, golden_out = download_model(
"stable_diff", tank_url="gs://shark_tank/quinn", frontend="tf"
)
shark_module = SharkInference(
mlir_model, func_name, device=device, mlir_dialect="mhlo"

View File

@@ -39,10 +39,12 @@ def _compile_module(shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
model_name, tank_url=tank_url
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"

View File

@@ -1,8 +1,10 @@
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
mlir_model, func_name, inputs, golden_out = download_model(
"v_diffusion", frontend="torch"
)
shark_module = SharkInference(
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"

View File

@@ -17,17 +17,22 @@ import os
import sys
from pathlib import Path
from shark.parser import shark_args
from google.cloud import storage
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def download_public_file(full_gs_url, destination_file_name):
"""Downloads a public blob from the bucket."""
# bucket_name = "gs://your-bucket-name/path/to/file"
# destination_file_name = "local/path/to/file"
storage_client = storage.Client.create_anonymous_client()
bucket_name = full_gs_url.split("/")[2]
source_blob_name = "/".join(full_gs_url.split("/")[3:])
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(source_blob_name)
blob.download_to_filename(destination_file_name)
GSUTIL_PATH = resource_path("gsutil")
GSUTIL_FLAGS = ' -o "GSUtil:parallel_process_count=1" -m cp -r '
@@ -98,103 +103,23 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
# Downloads the torch model from gs://shark_tank dir.
def download_torch_model(
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
def download_model(
model_name,
dynamic=False,
tank_url="gs://shark_tank/latest",
frontend=None,
tuned=None,
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_torch"
def gs_download_model():
gs_command = (
GSUTIL_PATH
+ GSUTIL_FLAGS
+ tank_url
+ "/"
+ model_dir_name
+ ' "'
+ WORKDIR
+ '"'
)
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
if not check_dir_exists(model_dir_name, frontend="torch", dynamic=dyn_str):
gs_download_model()
else:
if not _internet_connected():
print(
"No internet connection. Using the model already present in the tank."
)
else:
model_dir = os.path.join(WORKDIR, model_dir_name)
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
gs_hash = (
GSUTIL_PATH
+ GSUTIL_FLAGS
+ tank_url
+ "/"
+ model_dir_name
+ "/hash.npy"
+ " "
+ os.path.join(model_dir, "upstream_hash.npy")
)
if os.system(gs_hash) != 0:
raise Exception("hash of the model not present in the tank.")
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
if shark_args.update_tank == True:
gs_download_model()
else:
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
os.path.join(model_dir, model_name + dyn_str + "_torch.mlir"),
mode="rb",
) as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
# Downloads the tflite model from gs://shark_tank dir.
def download_tflite_model(
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
):
dyn_str = "_dynamic" if dynamic else ""
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_tflite"
def gs_download_model():
gs_command = (
GSUTIL_PATH
+ GSUTIL_FLAGS
+ tank_url
+ "/"
+ model_dir_name
+ ' "'
+ WORKDIR
+ '"'
)
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
model_dir_name = model_name + "_" + frontend
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
if not check_dir_exists(
model_dir_name, frontend="tflite", dynamic=dyn_str
model_dir_name, frontend=frontend, dynamic=dyn_str
):
gs_download_model()
download_public_file(full_gs_url, WORKDIR)
else:
if not _internet_connected():
print(
@@ -203,104 +128,34 @@ def download_tflite_model(
else:
model_dir = os.path.join(WORKDIR, model_dir_name)
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
gs_hash = (
GSUTIL_PATH
+ GSUTIL_FLAGS
+ tank_url
+ "/"
+ model_dir_name
+ "/hash.npy"
+ " "
+ os.path.join(model_dir, "upstream_hash.npy")
gs_hash_url = (
tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
)
download_public_file(
gs_hash_url, os.path.join(model_dir, "upstream_hash.npy")
)
if os.system(gs_hash) != 0:
raise Exception("hash of the model not present in the tank.")
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
if shark_args.update_tank == True:
gs_download_model()
download_public_file(full_gs_url, WORKDIR)
else:
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir"),
mode="rb",
) as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
def download_tf_model(
model_name, tuned=None, tank_url="gs://shark_tank/latest"
):
model_name = model_name.replace("/", "_")
os.makedirs(WORKDIR, exist_ok=True)
model_dir_name = model_name + "_tf"
def gs_download_model():
gs_command = (
GSUTIL_PATH
+ GSUTIL_FLAGS
+ tank_url
+ "/"
+ model_dir_name
+ ' "'
+ WORKDIR
+ '"'
)
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
if not check_dir_exists(model_dir_name, frontend="tf"):
gs_download_model()
else:
if not _internet_connected():
print(
"No internet connection. Using the model already present in the tank."
)
else:
model_dir = os.path.join(WORKDIR, model_dir_name)
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
gs_hash = (
GSUTIL_PATH
+ GSUTIL_FLAGS
+ tank_url
+ "/"
+ model_dir_name
+ "/hash.npy"
+ " "
+ os.path.join(model_dir, "upstream_hash.npy")
)
if os.system(gs_hash) != 0:
raise Exception("hash of the model not present in the tank.")
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
if shark_args.update_tank == True:
gs_download_model()
else:
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)
model_dir = os.path.join(WORKDIR, model_dir_name)
suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
suffix = (
"_" + frontend + ".mlir"
if tuned is None
else "_" + frontend + "_" + tuned + ".mlir"
)
filename = os.path.join(model_dir, model_name + suffix)
if not os.path.isfile(filename):
filename = os.path.join(model_dir, model_name + "_tf.mlir")
filename = os.path.join(
model_dir, model_name + "_" + frontend + ".mlir"
)
with open(filename, mode="rb") as f:
mlir_file = f.read()

View File

@@ -1,8 +1,9 @@
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
"bert-base-uncased_tosa"
mlir_model, func_name, inputs, golden_out = download_model(
"bert-base-uncased_tosa",
frontend="torch",
)
shark_module = SharkInference(

View File

@@ -1,6 +1,6 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_tf_model
from shark.shark_downloader import download_model
from shark.parser import shark_args
from tank.test_utils import get_valid_test_params, shark_test_name_func
from parameterized import parameterized
@@ -21,8 +21,8 @@ class DebertaBaseModuleTester:
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device):
model, func_name, inputs, golden_out = download_tf_model(
"microsoft/deberta-base"
model, func_name, inputs, golden_out = download_model(
"microsoft/deberta-base", frontend="tf"
)
shark_module = SharkInference(

View File

@@ -1,5 +1,5 @@
import numpy as np
from shark.shark_downloader import download_tflite_model
from shark.shark_downloader import download_model
from shark.shark_inference import SharkInference
import pytest
import unittest
@@ -58,8 +58,8 @@ class GptTfliteModuleTester:
shark_args.save_vmfb = self.save_vmfb
# Preprocess to get SharkImporter input args
mlir_model, func_name, inputs, tflite_results = download_tflite_model(
model_name="gpt2-64"
mlir_model, func_name, inputs, tflite_results = download_model(
model_name="gpt2-64", backend="tflite"
)
shark_module = SharkInference(
mlir_module=mlir_model,

View File

@@ -20,10 +20,6 @@ class OPTModuleTester:
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device, model_name):
# model_mlir, func_name, input, act_out = download_torch_model(
# "opt", dynamic
# )
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# config = OPTConfig()
# opt_model = OPTModel(config)

View File

@@ -1,6 +1,6 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_tf_model
from shark.shark_downloader import download_model
from tank.test_utils import get_valid_test_params, shark_test_name_func
from parameterized import parameterized
@@ -18,8 +18,8 @@ class RemBertModuleTester:
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device):
model, func_name, inputs, golden_out = download_tf_model(
"google/rembert"
model, func_name, inputs, golden_out = download_model(
"google/rembert", frontend="tf"
)
shark_module = SharkInference(

View File

@@ -1,6 +1,6 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_tf_model
from shark.shark_downloader import download_model
import iree.compiler as ireec
import unittest
@@ -16,8 +16,9 @@ class TapasBaseModuleTester:
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device):
model, func_name, inputs, golden_out = download_tf_model(
"google/tapas-base"
model, func_name, inputs, golden_out = download_model(
"google/tapas-base",
frontend="tf",
)
shark_module = SharkInference(

View File

@@ -15,7 +15,7 @@ from torchvision.transforms import functional as TF
from tqdm import trange
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
import numpy as np
import sys
@@ -191,7 +191,9 @@ x_in = x[0:min_batch_size, :, :, :]
ts = x_in.new_ones([x_in.shape[0]])
t_in = t[0] * ts
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
mlir_model, func_name, inputs, golden_out = download_model(
"v_diffusion", frontend="torch"
)
shark_module = SharkInference(
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"

View File

@@ -5,11 +5,7 @@ from shark.iree_utils._common import (
)
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
from parameterized import parameterized
from shark.shark_downloader import (
download_tf_model,
download_torch_model,
download_tflite_model,
)
from shark.shark_downloader import download_model
from shark.shark_inference import SharkInference
from shark.parser import shark_args
import iree.compiler as ireec
@@ -133,23 +129,11 @@ class SharkModuleTester:
def create_and_check_module(self, dynamic, device):
shark_args.local_tank_cache = self.local_tank_cache
shark_args.update_tank = self.update_tank
if self.config["framework"] == "tf":
model, func_name, inputs, golden_out = download_tf_model(
self.config["model_name"],
tank_url=self.tank_url,
)
elif self.config["framework"] == "torch":
model, func_name, inputs, golden_out = download_torch_model(
self.config["model_name"],
tank_url=self.tank_url,
)
elif self.config["framework"] == "tflite":
model, func_name, inputs, golden_out = download_tflite_model(
model_name=self.config["model_name"],
tank_url=self.tank_url,
)
else:
model, func_name, inputs, golden_out = None, None, None, None
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
tank_url=self.tank_url,
frontend=self.config["framework"],
)
shark_module = SharkInference(
model,

View File

@@ -3,7 +3,7 @@ import requests
import torch
from torchvision import transforms
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
################################## Preprocessing inputs and helper functions ########
@@ -69,8 +69,8 @@ def resnet_inf(numpy_img, device):
if device not in compiled_module.keys():
if DEBUG:
log_write.write("Compiling the Resnet50 module.\n")
mlir_model, func_name, inputs, golden_out = download_torch_model(
"resnet50"
mlir_model, func_name, inputs, golden_out = download_model(
"resnet50", frontend="torch"
)
shark_module = SharkInference(
mlir_model, func_name, device=device, mlir_dialect="linalg"

View File

@@ -36,10 +36,10 @@ def _compile_module(args, shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(args, tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_torch_model
from shark.shark_downloader import download_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
model_name, tank_url=tank_url
mlir_model, func_name, inputs, golden_out = download_model(
model_name, tank_url=tank_url, frontend="torch"
)
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"