mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
9 Commits
20230319.6
...
20230320.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c72a8365b1 | ||
|
|
46d0842459 | ||
|
|
90c958bca2 | ||
|
|
f99903e023 | ||
|
|
c6f44ef1b3 | ||
|
|
8dcd4d5aeb | ||
|
|
d319f4684e | ||
|
|
54d7b6d83e | ||
|
|
4a622532e5 |
@@ -12,6 +12,7 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
@@ -79,6 +80,9 @@ def img2img_inf(
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
global schedulers
|
||||
|
||||
@@ -164,6 +168,7 @@ def img2img_inf(
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -232,6 +237,7 @@ def img2img_inf(
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
extra_info = {"STRENGTH": strength}
|
||||
text_output = ""
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
@@ -252,23 +258,20 @@ def img2img_inf(
|
||||
cpu_scheduling,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,6 +10,7 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
@@ -50,6 +51,9 @@ def inpaint_inf(
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
global schedulers
|
||||
|
||||
@@ -114,6 +118,7 @@ def inpaint_inf(
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -159,6 +164,7 @@ def inpaint_inf(
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
@@ -180,23 +186,20 @@ def inpaint_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,6 +10,7 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
@@ -53,6 +54,9 @@ def outpaint_inf(
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
global schedulers
|
||||
|
||||
@@ -116,6 +120,7 @@ def outpaint_inf(
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -163,6 +168,7 @@ def outpaint_inf(
|
||||
top = True if "up" in directions else False
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
@@ -189,23 +195,20 @@ def outpaint_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -9,7 +9,7 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
schedulers = None
|
||||
|
||||
@@ -46,6 +46,9 @@ def txt2img_inf(
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
global schedulers
|
||||
|
||||
@@ -108,6 +111,7 @@ def txt2img_inf(
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -152,6 +156,7 @@ def txt2img_inf(
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
@@ -169,25 +174,20 @@ def txt2img_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
)
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
# text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -31,6 +31,9 @@ from apps.stable_diffusion.src.utils import (
|
||||
end_profiling,
|
||||
)
|
||||
|
||||
SD_STATE_IDLE = "idle"
|
||||
SD_STATE_CANCEL = "cancel"
|
||||
|
||||
|
||||
class StableDiffusionPipeline:
|
||||
def __init__(
|
||||
@@ -58,6 +61,7 @@ class StableDiffusionPipeline:
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
self.log = ""
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||
# Tokenize text and get embeddings
|
||||
@@ -226,6 +230,7 @@ class StableDiffusionPipeline:
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
self.status = SD_STATE_IDLE
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
@@ -275,6 +280,9 @@ class StableDiffusionPipeline:
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
|
||||
@@ -32,4 +32,5 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_extended_name,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
)
|
||||
|
||||
@@ -629,3 +629,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
def get_generation_text_info(seeds, device):
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
|
||||
|
||||
return text_output
|
||||
|
||||
@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -255,5 +256,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,5 +258,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -277,5 +278,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list_txt2img,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
@@ -249,7 +250,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.web.utils.png_metadata import (
|
||||
|
||||
@@ -5,6 +5,10 @@ import glob
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src import args
|
||||
from dataclasses import dataclass
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -89,5 +93,13 @@ def get_custom_model_files():
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
# Try catch it, as gc can delete global_obj.sd_obj while switching model
|
||||
try:
|
||||
global_obj.set_sd_status(SD_STATE_CANCEL)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
@@ -38,6 +38,16 @@ def get_cfg_obj():
|
||||
return config_obj
|
||||
|
||||
|
||||
def set_sd_status(value):
|
||||
global sd_obj
|
||||
sd_obj.status = value
|
||||
|
||||
|
||||
def get_sd_status():
|
||||
global sd_obj
|
||||
return sd_obj.status
|
||||
|
||||
|
||||
def clear_cache():
|
||||
global sd_obj
|
||||
global config_obj
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python generate_sharktank.py
|
||||
python tank/generate_sharktank.py
|
||||
|
||||
@@ -33,6 +33,7 @@ lit
|
||||
pyyaml
|
||||
python-dateutil
|
||||
sacremoses
|
||||
sentencepiece
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
|
||||
@@ -35,8 +35,9 @@ def run_cmd(cmd, debug=False):
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
)
|
||||
result_str = result.stdout.decode()
|
||||
return result_str
|
||||
stdout = result.stdout.decode()
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
@@ -90,6 +90,7 @@ def build_benchmark_args(
|
||||
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
|
||||
# if time_extractor:
|
||||
# benchmark_cl.append(time_extractor)
|
||||
benchmark_cl.append(f"--print_statistics=true")
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
@@ -129,7 +130,8 @@ def build_benchmark_args_non_tensor_input(
|
||||
|
||||
def run_benchmark_module(benchmark_cl):
|
||||
"""
|
||||
Run benchmark command, extract result and return iteration/seconds.
|
||||
Run benchmark command, extract result and return iteration/seconds, host
|
||||
peak memory, and device peak memory.
|
||||
|
||||
# TODO: Add an example of the benchmark command.
|
||||
Input: benchmark command.
|
||||
@@ -138,15 +140,22 @@ def run_benchmark_module(benchmark_cl):
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_result = run_cmd(" ".join(benchmark_cl))
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
except AttributeError:
|
||||
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
return 1.0 / (time * 0.001)
|
||||
iter_per_second = 1.0 / (time_ms * 0.001)
|
||||
|
||||
# Extract peak memory.
|
||||
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
|
||||
host_peak_b = int(host_regex.search(bench_stderr).group(1))
|
||||
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
|
||||
device_peak_b = int(device_regex.search(bench_stderr).group(1))
|
||||
return iter_per_second, host_peak_b, device_peak_b
|
||||
|
||||
@@ -188,21 +188,23 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
benchmark_bash.write(" ".join(benchmark_cl))
|
||||
benchmark_bash.close()
|
||||
|
||||
benchmark_data = run_benchmark_module(benchmark_cl)
|
||||
iter_per_second, _, _ = run_benchmark_module(
|
||||
benchmark_cl
|
||||
)
|
||||
|
||||
benchmark_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
|
||||
)
|
||||
benchmark_file.write(f"DISPATCH: {d_}\n")
|
||||
benchmark_file.write(str(benchmark_data) + "\n")
|
||||
benchmark_file.write(str(iter_per_second) + "\n")
|
||||
benchmark_file.write(
|
||||
"SHARK BENCHMARK RESULT: "
|
||||
+ str(1 / (benchmark_data * 0.001))
|
||||
+ str(1 / (iter_per_second * 0.001))
|
||||
+ "\n"
|
||||
)
|
||||
benchmark_file.close()
|
||||
|
||||
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
|
||||
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
|
||||
|
||||
elif ".mlir" in f_ and "benchmark" not in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
|
||||
@@ -22,7 +22,8 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
|
||||
@@ -21,9 +21,17 @@ from shark.iree_utils.benchmark_utils import (
|
||||
from shark.parser import shark_args
|
||||
from datetime import datetime
|
||||
import time
|
||||
from typing import Optional
|
||||
import csv
|
||||
import os
|
||||
|
||||
TF_CPU_DEVICE = "/CPU:0"
|
||||
TF_GPU_DEVICE = "/GPU:0"
|
||||
|
||||
|
||||
def _bytes_to_mb_str(bytes_: Optional[int]) -> str:
|
||||
return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}"
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
def __init__(self):
|
||||
@@ -126,18 +134,26 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
end = time.time()
|
||||
if self.device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
else:
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by PyTorch.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_tf(self, modelname):
|
||||
@@ -155,8 +171,8 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
|
||||
from tank.model_utils_tf import get_tf_model
|
||||
|
||||
# tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
|
||||
tf_device = "/CPU:0"
|
||||
# tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE
|
||||
tf_device = TF_CPU_DEVICE
|
||||
with tf.device(tf_device):
|
||||
(
|
||||
model,
|
||||
@@ -169,24 +185,41 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*input)
|
||||
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
tf.config.experimental.reset_memory_stats(tf_device)
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(*input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
end = time.time()
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
memory_info = tf.config.experimental.get_memory_info(tf_device)
|
||||
device_peak_b = memory_info["peak"]
|
||||
else:
|
||||
# tf.config.experimental does not currently support measuring
|
||||
# CPU memory usage.
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_c(self):
|
||||
result = run_benchmark_module(self.benchmark_cl)
|
||||
print(f"Shark-IREE-C benchmark:{result} iter/second")
|
||||
return [f"{result}", f"{1000/result}"]
|
||||
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
|
||||
self.benchmark_cl
|
||||
)
|
||||
print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second")
|
||||
return [
|
||||
f"{iter_per_second}",
|
||||
f"{1000/iter_per_second}",
|
||||
_bytes_to_mb_str(host_peak_b),
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_python(self, inputs):
|
||||
input_list = [x for x in inputs]
|
||||
@@ -196,8 +229,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = self.run("forward", input_list)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
end = time.time()
|
||||
print(
|
||||
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
@@ -324,7 +356,12 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
"tags",
|
||||
"notes",
|
||||
"datetime",
|
||||
"host_memory_mb",
|
||||
"device_memory_mb",
|
||||
"measured_host_memory_mb",
|
||||
"measured_device_memory_mb",
|
||||
]
|
||||
# "frontend" must be the first element.
|
||||
engines = ["frontend", "shark_python", "shark_iree_c"]
|
||||
if shark_args.onnx_bench == True:
|
||||
engines.append("onnxruntime")
|
||||
@@ -336,75 +373,76 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
|
||||
with open("bench_results.csv", mode="a", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=field_names)
|
||||
bench_result = {}
|
||||
bench_result["model"] = modelname
|
||||
bench_info = {}
|
||||
bench_info["model"] = modelname
|
||||
bench_info["dialect"] = self.mlir_dialect
|
||||
bench_info["iterations"] = shark_args.num_iterations
|
||||
if dynamic == True:
|
||||
bench_result["shape_type"] = "dynamic"
|
||||
bench_info["shape_type"] = "dynamic"
|
||||
else:
|
||||
bench_result["shape_type"] = "static"
|
||||
bench_result["device"] = device_str
|
||||
bench_info["shape_type"] = "static"
|
||||
bench_info["device"] = device_str
|
||||
if "fp16" in modelname:
|
||||
bench_result["data_type"] = "float16"
|
||||
bench_info["data_type"] = "float16"
|
||||
else:
|
||||
bench_result["data_type"] = inputs[0].dtype
|
||||
bench_info["data_type"] = inputs[0].dtype
|
||||
|
||||
for e in engines:
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
) = ["", "", ""]
|
||||
engine_result = {}
|
||||
if e == "frontend":
|
||||
bench_result["engine"] = frontend
|
||||
engine_result["engine"] = frontend
|
||||
if check_requirements(frontend):
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
engine_result["host_memory_mb"],
|
||||
engine_result["device_memory_mb"],
|
||||
) = self.benchmark_frontend(modelname)
|
||||
self.frontend_result = bench_result["ms/iter"]
|
||||
bench_result["vs. PyTorch/TF"] = "baseline"
|
||||
self.frontend_result = engine_result["ms/iter"]
|
||||
engine_result["vs. PyTorch/TF"] = "baseline"
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
engine_result["param_count"],
|
||||
engine_result["tags"],
|
||||
engine_result["notes"],
|
||||
) = self.get_metadata(modelname)
|
||||
else:
|
||||
self.frontend_result = None
|
||||
continue
|
||||
|
||||
elif e == "shark_python":
|
||||
bench_result["engine"] = "shark_python"
|
||||
engine_result["engine"] = "shark_python"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
) = self.benchmark_python(inputs)
|
||||
|
||||
bench_result[
|
||||
engine_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, bench_result["ms/iter"]
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "shark_iree_c":
|
||||
bench_result["engine"] = "shark_iree_c"
|
||||
engine_result["engine"] = "shark_iree_c"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
engine_result["host_memory_mb"],
|
||||
engine_result["device_memory_mb"],
|
||||
) = self.benchmark_c()
|
||||
|
||||
bench_result[
|
||||
engine_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, bench_result["ms/iter"]
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "onnxruntime":
|
||||
bench_result["engine"] = "onnxruntime"
|
||||
engine_result["engine"] = "onnxruntime"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
) = self.benchmark_onnx(modelname, inputs)
|
||||
|
||||
bench_result["dialect"] = self.mlir_dialect
|
||||
bench_result["iterations"] = shark_args.num_iterations
|
||||
bench_result["datetime"] = str(datetime.now())
|
||||
writer.writerow(bench_result)
|
||||
engine_result["datetime"] = str(datetime.now())
|
||||
writer.writerow(bench_info | engine_result)
|
||||
|
||||
@@ -194,8 +194,14 @@ def download_model(
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
try:
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
except FileNotFoundError:
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
tank_dir = WORKDIR
|
||||
gen_shark_files(model_name, frontend, tank_dir)
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
|
||||
@@ -35,3 +35,14 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
|
||||
|
@@ -33,9 +33,10 @@ def create_hash(file_name):
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
def save_torch_model(torch_model_list):
|
||||
def save_torch_model(torch_model_list, local_tank_cache):
|
||||
from tank.model_utils import (
|
||||
get_hf_model,
|
||||
get_hf_seq2seq_model,
|
||||
get_vision_model,
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
@@ -59,7 +60,7 @@ def save_torch_model(torch_model_list):
|
||||
args.use_tuned = False
|
||||
args.import_mlir = True
|
||||
args.use_tuned = False
|
||||
args.local_tank_cache = WORKDIR
|
||||
args.local_tank_cache = local_tank_cache
|
||||
|
||||
precision_values = ["fp16"]
|
||||
seq_lengths = [64, 77]
|
||||
@@ -75,7 +76,7 @@ def save_torch_model(torch_model_list):
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
debug=True,
|
||||
sharktank_dir=WORKDIR,
|
||||
sharktank_dir=local_tank_cache,
|
||||
generate_vmfb=False,
|
||||
)
|
||||
model()
|
||||
@@ -84,13 +85,15 @@ def save_torch_model(torch_model_list):
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
elif model_type == "hf_seq2seq":
|
||||
model, input, _ = get_hf_seq2seq_model(torch_model_name)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
local_tank_cache, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
|
||||
@@ -115,12 +118,14 @@ def save_torch_model(torch_model_list):
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list):
|
||||
def save_tf_model(tf_model_list, local_tank_cache):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_masked_lm_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
get_tfhf_seq2seq_model,
|
||||
)
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -145,16 +150,22 @@ def save_tf_model(tf_model_list):
|
||||
input = None
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
if model_type == "img":
|
||||
model, input, _ = get_masked_lm_model(tf_model_name)
|
||||
elif model_type == "img":
|
||||
model, input, _ = get_causal_image_model(tf_model_name)
|
||||
if model_type == "keras":
|
||||
elif model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name)
|
||||
if model_type == "TFhf":
|
||||
elif model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name)
|
||||
elif model_type == "tfhf_seq2seq":
|
||||
model, input, _ = get_tfhf_seq2seq_model(tf_model_name)
|
||||
elif model_type == "hf_causallm":
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache, str(tf_model_name) + "_tf"
|
||||
)
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
@@ -172,7 +183,7 @@ def save_tf_model(tf_model_list):
|
||||
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list):
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
@@ -184,7 +195,7 @@ def save_tflite_model(tflite_model_list):
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(
|
||||
WORKDIR, str(tflite_model_name) + "_tflite"
|
||||
local_tank_cache, str(tflite_model_name) + "_tflite"
|
||||
)
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
@@ -219,6 +230,45 @@ def save_tflite_model(tflite_model_list):
|
||||
)
|
||||
|
||||
|
||||
def gen_shark_files(modelname, frontend, tank_dir):
|
||||
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
|
||||
# TODO: Add TFlite support.
|
||||
import tempfile
|
||||
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir)
|
||||
|
||||
if frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_torch_model(custom_model_csv.name, tank_dir)
|
||||
|
||||
|
||||
# Validates whether the file is present or not.
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
@@ -259,20 +309,19 @@ if __name__ == "__main__":
|
||||
# old_args = parser.parse_args()
|
||||
|
||||
home = str(Path.home())
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tf_model_list.csv"
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
tflite_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(
|
||||
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
|
||||
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv"),
|
||||
WORKDIR,
|
||||
)
|
||||
save_torch_model(torch_model_csv)
|
||||
save_tf_model(tf_model_csv)
|
||||
save_tflite_model(tflite_model_csv)
|
||||
save_torch_model(torch_model_csv, WORKDIR)
|
||||
save_tf_model(tf_model_csv, WORKDIR)
|
||||
save_tflite_model(tflite_model_csv, WORKDIR)
|
||||
@@ -31,4 +31,10 @@ xlm-roberta-base,False,False,-,-,-
|
||||
facebook/convnext-tiny-224,False,False,-,-,-
|
||||
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
|
||||
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
bert-large-uncased,True,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
efficientnet_b0,True,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
gpt2,True,False,110M,"nlp;transformer-decoder;auto-regressive","12 layers, 768 hidden units, 12 attention heads"
|
||||
|
||||
|
@@ -7,6 +7,8 @@ import sys
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
|
||||
vision_models = [
|
||||
"alexnet",
|
||||
"resnet101",
|
||||
@@ -17,6 +19,8 @@ vision_models = [
|
||||
"wide_resnet50_2",
|
||||
"mobilenet_v3_small",
|
||||
"mnasnet1_0",
|
||||
"efficientnet_b0",
|
||||
"efficientnet_b7",
|
||||
]
|
||||
hf_img_cls_models = [
|
||||
"google/vit-base-patch16-224",
|
||||
@@ -25,6 +29,10 @@ hf_img_cls_models = [
|
||||
"microsoft/beit-base-patch16-224-pt22k-ft22k",
|
||||
"nvidia/mit-b0",
|
||||
]
|
||||
hf_seq2seq_models = [
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
]
|
||||
|
||||
|
||||
def get_torch_model(modelname):
|
||||
@@ -32,6 +40,8 @@ def get_torch_model(modelname):
|
||||
return get_vision_model(modelname)
|
||||
elif modelname in hf_img_cls_models:
|
||||
return get_hf_img_cls_model(modelname)
|
||||
elif modelname in hf_seq2seq_models:
|
||||
return get_hf_seq2seq_model(modelname)
|
||||
elif "fp16" in modelname:
|
||||
return get_fp16_model(modelname)
|
||||
else:
|
||||
@@ -85,6 +95,7 @@ def get_hf_img_cls_model(name):
|
||||
# test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1)
|
||||
# print("test_input.shape: ", test_input.shape)
|
||||
# test_input.shape: torch.Size([1, 3, 224, 224])
|
||||
test_input = test_input.repeat(BATCH_SIZE, 1, 1, 1)
|
||||
actual_out = model(test_input)
|
||||
# print("actual_out.shape: ", actual_out.shape)
|
||||
# actual_out.shape: torch.Size([1, 1000])
|
||||
@@ -121,11 +132,52 @@ def get_hf_model(name):
|
||||
|
||||
model = HuggingFaceLanguage(name)
|
||||
# TODO: Currently the test input is set to (1,128)
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
test_input = torch.randint(2, (BATCH_SIZE, 128))
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### Hugging Face Seq2SeqLM Models ###################################
|
||||
|
||||
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
|
||||
T5_MAX_SEQUENCE_LENGTH = 512
|
||||
|
||||
|
||||
class HFSeq2SeqLanguageModel(torch.nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
from transformers import AutoTokenizer, T5Model
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
self.model = T5Model.from_pretrained(model_name, return_dict=True)
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.model.forward(
|
||||
input_ids, decoder_input_ids=decoder_input_ids
|
||||
)[0]
|
||||
|
||||
|
||||
def get_hf_seq2seq_model(name):
|
||||
m = HFSeq2SeqLanguageModel(name)
|
||||
encoded_input_ids = m.preprocess_input(
|
||||
"Studies have been shown that owning a dog is good for you"
|
||||
).input_ids
|
||||
decoder_input_ids = m.preprocess_input("Studies show that").input_ids
|
||||
decoder_input_ids = m.model._shift_right(decoder_input_ids)
|
||||
|
||||
test_input = (encoded_input_ids, decoder_input_ids)
|
||||
actual_out = m.forward(*test_input)
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
##################### Torch Vision Models ###################################
|
||||
@@ -144,24 +196,50 @@ class VisionModule(torch.nn.Module):
|
||||
def get_vision_model(torch_model):
|
||||
import torchvision.models as models
|
||||
|
||||
default_image_size = (224, 224)
|
||||
|
||||
vision_models_dict = {
|
||||
"alexnet": models.alexnet(weights="DEFAULT"),
|
||||
"resnet18": models.resnet18(weights="DEFAULT"),
|
||||
"resnet50": models.resnet50(weights="DEFAULT"),
|
||||
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
|
||||
"resnet101": models.resnet101(weights="DEFAULT"),
|
||||
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
|
||||
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
|
||||
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
|
||||
"alexnet": (models.alexnet(weights="DEFAULT"), default_image_size),
|
||||
"resnet18": (models.resnet18(weights="DEFAULT"), default_image_size),
|
||||
"resnet50": (models.resnet50(weights="DEFAULT"), default_image_size),
|
||||
"resnet50_fp16": (
|
||||
models.resnet50(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"resnet101": (models.resnet101(weights="DEFAULT"), default_image_size),
|
||||
"squeezenet1_0": (
|
||||
models.squeezenet1_0(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"wide_resnet50_2": (
|
||||
models.wide_resnet50_2(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"mobilenet_v3_small": (
|
||||
models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"mnasnet1_0": (
|
||||
models.mnasnet1_0(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
# EfficientNet input image size varies on the size of the model.
|
||||
"efficientnet_b0": (
|
||||
models.efficientnet_b0(weights="DEFAULT"),
|
||||
(224, 224),
|
||||
),
|
||||
"efficientnet_b7": (
|
||||
models.efficientnet_b7(weights="DEFAULT"),
|
||||
(600, 600),
|
||||
),
|
||||
}
|
||||
if isinstance(torch_model, str):
|
||||
fp16_model = None
|
||||
if "fp16" in torch_model:
|
||||
fp16_model = True
|
||||
torch_model = vision_models_dict[torch_model]
|
||||
torch_model, input_image_size = vision_models_dict[torch_model]
|
||||
model = VisionModule(torch_model)
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
test_input = torch.randn(BATCH_SIZE, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
if fp16_model is not None:
|
||||
test_input_fp16 = test_input.to(
|
||||
@@ -209,6 +287,7 @@ def get_fp16_model(torch_model):
|
||||
model = BertHalfPrecisionModel(modelname)
|
||||
tokenizer = AutoTokenizer.from_pretrained(modelname)
|
||||
text = "Replace me by any text you like."
|
||||
text = [text] * BATCH_SIZE
|
||||
test_input_fp16 = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
|
||||
@@ -7,11 +7,15 @@ from transformers import (
|
||||
)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
################################## MHLO/TF models #########################################
|
||||
# TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv
|
||||
keras_models = ["resnet50", "efficientnet-v2-s"]
|
||||
keras_models = [
|
||||
"resnet50",
|
||||
"efficientnet_b0",
|
||||
"efficientnet_b7",
|
||||
"efficientnet-v2-s",
|
||||
]
|
||||
maskedlm_models = [
|
||||
"albert-base-v2",
|
||||
"bert-base-uncased",
|
||||
@@ -32,9 +36,16 @@ maskedlm_models = [
|
||||
"hf-internal-testing/tiny-random-flaubert",
|
||||
"xlm-roberta",
|
||||
]
|
||||
causallm_models = [
|
||||
"gpt2",
|
||||
]
|
||||
tfhf_models = [
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
]
|
||||
tfhf_seq2seq_models = [
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
]
|
||||
img_models = [
|
||||
"google/vit-base-patch16-224",
|
||||
"facebook/convnext-tiny-224",
|
||||
@@ -45,23 +56,35 @@ def get_tf_model(name):
|
||||
if name in keras_models:
|
||||
return get_keras_model(name)
|
||||
elif name in maskedlm_models:
|
||||
return get_masked_lm_model(name)
|
||||
elif name in causallm_models:
|
||||
return get_causal_lm_model(name)
|
||||
elif name in tfhf_models:
|
||||
return get_TFhf_model(name)
|
||||
elif name in img_models:
|
||||
return get_causal_image_model(name)
|
||||
elif name in tfhf_seq2seq_models:
|
||||
return get_tfhf_seq2seq_model(name)
|
||||
else:
|
||||
raise Exception(
|
||||
"TF model not found! Please check that the modelname has been input correctly."
|
||||
)
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face LM Models ###################################
|
||||
##################### Tensorflow Hugging Face Bert Models ###################################
|
||||
BERT_MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
tf_bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -87,21 +110,31 @@ def get_TFhf_model(name):
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
text = [text] * BATCH_SIZE
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
max_length=BERT_MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
test_input = [
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(encoded_input["input_ids"], dtype=tf.int32),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(
|
||||
encoded_input["attention_mask"], dtype=tf.int32
|
||||
),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(
|
||||
encoded_input["token_type_ids"], dtype=tf.int32
|
||||
),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
]
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
@@ -115,34 +148,41 @@ def compare_tensors_tf(tf_tensor, numpy_tensor):
|
||||
return np.allclose(tf_to_numpy, numpy_tensor, rtol, atol)
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Masked LM Models ###################################
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_maskedlm = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCasualLM
|
||||
|
||||
|
||||
# Tokenizer for language models
|
||||
def preprocess_input(
|
||||
model_name, text="This is just used to compile the model"
|
||||
model_name, max_length, text="This is just used to compile the model"
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
text = [text] * BATCH_SIZE
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
return_tensors="tf",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=max_length,
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Masked LM Models ###################################
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
|
||||
MASKED_LM_MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_maskedlm = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForMaskedLM
|
||||
class MaskedLM(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(MaskedLM, self).__init__()
|
||||
@@ -156,19 +196,139 @@ class MaskedLM(tf.Module):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
def get_masked_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
model = MaskedLM(hf_name)
|
||||
encoded_input = preprocess_input(hf_name, text)
|
||||
encoded_input = preprocess_input(
|
||||
hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text
|
||||
)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Causal LM Models ###################################
|
||||
|
||||
from transformers import AutoConfig, TFAutoModelForCausalLM, TFGPT2Model
|
||||
|
||||
CAUSAL_LM_MAX_SEQUENCE_LENGTH = 1024
|
||||
|
||||
input_signature_causallm = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCausalLM
|
||||
# For more background, see:
|
||||
# https://huggingface.co/blog/tf-xla-generate
|
||||
class CausalLM(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(CausalLM, self).__init__()
|
||||
# Decoder-only models need left padding.
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, padding_side="left", pad_token="</s>"
|
||||
)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": CAUSAL_LM_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
self.model = TFGPT2Model.from_pretrained(model_name, return_dict=True)
|
||||
self.model.predict = lambda x, y: self.model(
|
||||
input_ids=x, attention_mask=y
|
||||
)[0]
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
@tf.function(input_signature=input_signature_causallm, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
model = CausalLM(hf_name)
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
encoded_input = model.preprocess_input(batched_text)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### TensorflowHugging Face Seq2SeqLM Models ###################################
|
||||
|
||||
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
|
||||
T5_MAX_SEQUENCE_LENGTH = 512
|
||||
|
||||
input_signature_t5 = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
|
||||
dtype=tf.int32,
|
||||
name="input_ids",
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
|
||||
dtype=tf.int32,
|
||||
name="attention_mask",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TFHFSeq2SeqLanguageModel(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(TFHFSeq2SeqLanguageModel, self).__init__()
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFT5Model,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
self.model = TFT5Model.from_pretrained(model_name, return_dict=True)
|
||||
self.model.predict = lambda x, y: self.model(x, decoder_input_ids=y)[0]
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
@tf.function(input_signature=input_signature_t5, jit_compile=True)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.model.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
def get_tfhf_seq2seq_model(name):
|
||||
m = TFHFSeq2SeqLanguageModel(name)
|
||||
text = "Studies have been shown that owning a dog is good for you"
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
encoded_input_ids = m.preprocess_input(batched_text).input_ids
|
||||
|
||||
text = "Studies show that"
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
decoder_input_ids = m.preprocess_input(batched_text).input_ids
|
||||
decoder_input_ids = m.model._shift_right(decoder_input_ids)
|
||||
|
||||
test_input = (encoded_input_ids, decoder_input_ids)
|
||||
actual_out = m.forward(*test_input)
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
##################### TensorFlow Keras Resnet Models #########################################################
|
||||
# Static shape, including batch size (1).
|
||||
# Can be dynamic once dynamic shape support is ready.
|
||||
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
|
||||
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
|
||||
RESNET_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
|
||||
EFFICIENTNET_V2_S_INPUT_SHAPE = [BATCH_SIZE, 384, 384, 3]
|
||||
EFFICIENTNET_B0_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
|
||||
EFFICIENTNET_B7_INPUT_SHAPE = [BATCH_SIZE, 600, 600, 3]
|
||||
|
||||
|
||||
class ResNetModule(tf.Module):
|
||||
@@ -195,25 +355,79 @@ class ResNetModule(tf.Module):
|
||||
return tf.keras.applications.resnet50.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetModule(tf.Module):
|
||||
class EfficientNetB0Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetModule, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
super(EfficientNetB0Module, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet.EfficientNetB0(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
|
||||
input_shape=tuple(EFFICIENTNET_B0_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)],
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_B0_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_INPUT_SHAPE
|
||||
return EFFICIENTNET_B0_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetB7Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetB7Module, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet.EfficientNetB7(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_B7_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_B7_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_B7_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetV2SModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetV2SModule, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_V2_S_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_V2_S_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_V2_S_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet_v2.preprocess_input(image)
|
||||
@@ -224,12 +438,17 @@ def load_image(path_to_image, width, height, channels):
|
||||
image = tf.image.decode_image(image, channels=channels)
|
||||
image = tf.image.resize(image, (width, height))
|
||||
image = image[tf.newaxis, :]
|
||||
image = tf.tile(image, [BATCH_SIZE, 1, 1, 1])
|
||||
return image
|
||||
|
||||
|
||||
def get_keras_model(modelname):
|
||||
if modelname == "efficientnet-v2-s":
|
||||
model = EfficientNetModule()
|
||||
model = EfficientNetV2SModule()
|
||||
elif modelname == "efficientnet_b0":
|
||||
model = EfficientNetB0Module()
|
||||
elif modelname == "efficientnet_b7":
|
||||
model = EfficientNetB7Module()
|
||||
else:
|
||||
model = ResNetModule()
|
||||
|
||||
@@ -256,7 +475,7 @@ import requests
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_img_cls = [
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, 3, 224, 224], dtype=tf.float32),
|
||||
]
|
||||
|
||||
|
||||
@@ -304,6 +523,9 @@ def preprocess_input_image(model_name):
|
||||
)
|
||||
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
inputs["pixel_values"] = tf.tile(
|
||||
inputs["pixel_values"], [BATCH_SIZE, 1, 1, 1]
|
||||
)
|
||||
|
||||
return [inputs[str(*inputs)]]
|
||||
|
||||
|
||||
@@ -19,3 +19,8 @@ facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
efficientnet-v2-s,keras
|
||||
bert-large-uncased,hf
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
efficientnet_b0,keras
|
||||
efficientnet_b7,keras
|
||||
gpt2,hf_causallm
|
||||
|
||||
|
@@ -18,4 +18,8 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
|
||||
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
|
||||
|
Reference in New Issue
Block a user