Compare commits

..

9 Commits

Author SHA1 Message Date
Boian Petkantchin
20114deea0 In MiniLM JAX example verify MLIR result against JAX 2023-05-16 09:54:07 -07:00
Boian Petkantchin
9acf519078 Add option to skip venv creation in setup script 2023-05-16 09:54:07 -07:00
Boian Petkantchin
bdf37b5311 If device/backend is unknown pass it to IREE verbatim 2023-05-16 09:54:07 -07:00
powderluv
8ee2ac89f8 Rename sharded_vicuna_fp32_web.py to vicuna_web.py 2023-05-16 09:41:35 -07:00
powderluv
60cb48be2e Rename sharded_vicuna_fp32.py to vicuna.py 2023-05-16 09:40:51 -07:00
powderluv
86a215b063 Delete sharded_vicunia.py 2023-05-16 09:37:39 -07:00
powderluv
d6e3a9a236 Delete standalone_vicuna.py 2023-05-16 09:37:26 -07:00
Chi_Liu
a0097a1ead Add mlir_type for torch_model_list.csv (#1428)
- Enable stablehlo/tosa mlir output for torch model
- Add BERT stablehlo support
2023-05-15 10:23:54 -07:00
Ean Garvey
a9bae00606 Fix vulkan device selection at compile time and adapt to IREE python changes. (#1407)
* Add support for vulkan device selection at compile time.

* Don't convert device ID to int and fix .exe imports
2023-05-12 23:31:50 -07:00
15 changed files with 98 additions and 1057 deletions

View File

@@ -1,303 +0,0 @@
import torch
import argparse
import torch_mlir
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import transform_fx as transform_fx_
import re
from shark.shark_inference import SharkInference
from tqdm import tqdm
parser = argparse.ArgumentParser(
prog="ProgramName",
description="What the program does",
epilog="Text at the bottom of help",
)
parser.add_argument("--precision", "-p", default="fp32", help="fp32, fp16")
parser.add_argument(
"--device", "-d", default="vulkan", help="vulkan, cpu, cuda"
)
class VicunaLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states, attention_mask, position_ids):
outputs = self.model(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
next_hidden_states = outputs[0]
return next_hidden_states
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=False,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
print(output)
output = torch.tensor(output)
return (output,)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers):
super().__init__()
self.model = model
assert len(layers) == len(model.model.layers)
self.model.model.layers = torch.nn.modules.container.ModuleList(layers)
self.model.model.config.use_cache = False
self.model.model.config.output_attentions = False
def forward(self, input_ids, attention_mask=None):
return self.model.forward(input_ids, attention_mask=attention_mask)
def compile_vicuna_layer(
vicuna_layer, hidden_states, attention_mask, position_ids
):
fx_g = make_fx(
vicuna_layer,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(hidden_states, attention_mask, position_ids)
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
if args.precision == "fp16":
fx_g = fx_g.half()
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
print("FX_G recompile")
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
return ts_g
path = "TheBloke/vicuna-7B-1.1-HF"
kwargs = {"torch_dtype": torch.float32}
vicuna_model = AutoModelForCausalLM.from_pretrained(
path, low_cpu_mem_usage=True, **kwargs
)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
print(type(vicuna_model.model.layers))
def compile_to_vmfb(inputs, layers):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
mlir_path = Path(f"{idx}.mlir")
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
print(f"Compiling layer {idx} mlir")
ts_g = compile_vicuna_layer(layer, inputs[0], inputs[1], inputs[2])
module = torch_mlir.compile(
ts_g,
inputs,
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
device = args.device if idx < 25 else "cpu"
vmfb_path = Path(f"{idx}.vmfb")
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module("", f"{idx}")
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
if __name__ == "__main__":
args = parser.parse_args()
# prompt = input("Enter Prompt: ")
dtype = torch.float32 if args.precision == "fp32" else torch.float16
placeholder_input = (
torch.zeros([1, 256, 4096], dtype=dtype),
torch.zeros([1, 1, 256, 256], dtype=dtype),
torch.zeros([1, 256], dtype=torch.int64),
)
_, modules = compile_to_vmfb(placeholder_input, vicuna_model.model.layers)
shark_layers = [CompiledVicunaLayer(m) for m in modules]
sharded_model = ShardedVicunaModel(vicuna_model, shark_layers)
prompt = "It was a dark and stormy"
prompt = prompt.strip()
input_ids = tokenizer(prompt).input_ids
original_input_ids = input_ids
input_id_len = len(input_ids)
pad_len = 256 - input_id_len
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
input_ids = torch.nn.functional.pad(
torch.tensor(input_ids), (0, pad_len), mode="constant", value=259
)
input_ids = input_ids.reshape([1, 256])
attention_mask = torch.nn.functional.pad(
torch.tensor(attention_mask),
(0, pad_len),
mode="constant",
value=0,
)
# print(input_ids)
if args.precision == "fp16":
input_ids = input_ids.to(torch.float16)
print(attention_mask)
logits = sharded_model.forward(input_ids, attention_mask=attention_mask)[
"logits"
]
print(logits)

View File

@@ -1,695 +0,0 @@
import torch
import torch_mlir
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import transform_fx as transform_fx_
import re
def get_tank_vicuna_mlir(num):
# name can be 1 or 2 for first and second vicuna model
mname = {1: "FirstVicuna", 2: "SecondVicuna"}
tank_url = "gs://shark_tank/FastChat/"
download_public_file(tank_url, mname[num])
print(f"Downloaded model : {mname[num]} from tank")
def get_torch_mlir_module_bytecode(model, model_inputs):
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(*model_inputs)
print("Got FX_G")
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def transform_fx(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.empty,
]:
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
fx_g.graph.lint()
transform_fx(fx_g)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
print("FX_G recompile")
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
print("Got TS_G")
return ts_g
def compile_vicuna(model, model_inputs, model_name, model_vmfb_name):
# ADD Device Arg
from shark.shark_inference import SharkInference
vmfb_path = Path(model_vmfb_name + ".vmfb")
if vmfb_path.exists():
shark_module = SharkInference(
None, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.load_module(vmfb_path)
return shark_module
mlir_path = Path(model_name + ".mlir")
print(
f"[DEBUG] mlir path { mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
# model_inputs = list(model_inputs)
# model_inputs[0] = torch_mlir.TensorPlaceholder.like(model_inputs[0], dynamic_axes=[1])
# model_inputs = tuple(model_inputs)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
def remove_constant_dim(line):
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
return line
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
)
# shark_module.compile()
import os
path = shark_module.save_module(os.getcwd(), model_vmfb_name, [])
print("Saved vmfb at ", str(path))
return shark_module
kwargs = {"torch_dtype": torch.float32} # 16
model_path = "TheBloke/vicuna-7B-1.1-HF"
# Requires input_ids as tensor(1x40)
class FirstVicuna(torch.nn.Module):
def __init__(self, model_path):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
) # .cuda().half()
def forward(self, input_ids, attention_mask):
# input_len = input_id_len
# input_ids = input_ids[:,:input_len].reshape([1,input_len])
op = self.model(
input_ids=input_ids, use_cache=True, attention_mask=attention_mask
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
# Uncomment this after verifying that SecondVicuna compiles as well.
# Might have to cast to_numpy.
# Requires input_ids as tensor(1x1),
# past_key_values = 32 length tuple containing tuple of tensor pairs, which is same as output
# of firstVicuna[1:]
class SecondVicuna_(torch.nn.Module):
def __init__(self, model_path):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
def forward(self, input_tuple):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
past_key_values = [
(
input_tuple[i],
input_tuple[i + 1],
)
for i in range(0, len(input_tuple) - 1, 2)
]
# for e1, e2 in zip(input_tuple, input_tuple[1:]):
# past_key_values.append(tuple(e1, e2))
past_key_values = tuple(past_key_values)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondVicuna(torch.nn.Module):
def __init__(self, model_path):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
) # .cuda().half()
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
)
# for e1, e2 in zip(input_tuple, input_tuple[1:]):
# past_key_values.append(tuple(e1, e2))
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
pkv = [
torch.rand([1, 32, 40, 128], dtype=torch.float32)
for _ in range(64)
]
return self.model(input_ids, past_key_values=pkv)
if __name__ == "__main__":
import sys
vicuna_number = 1
# input_tuple = (torch.ones([1,1], dtype=torch.int),) + tuple(torch.rand([1, 32, 40, 128], dtype=torch.float32) for _ in range(64))
# input_tuple = torch.rand([1,2])
# secondVicuna = SecondVicuna(model_path)
# shark_second_vicuna = compile_vicuna(secondVicuna, (input_tuple,), "second_vicuna.mlir", "second_vicuna")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# prompt = "INPUT: The SQL command to extract all the users whose name starts with A is:"
prompt = "".join(["0" for _ in range(254)])
input_ids = tokenizer(prompt).input_ids
# print("Got input_ids from the tokenizer")
if vicuna_number == 1:
prompt = input("Enter Prompt: ")
prompt = prompt.strip()
input_ids = tokenizer(prompt).input_ids
original_input_ids = input_ids
input_id_len = len(input_ids)
pad_len = 256 - input_id_len
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
input_ids = torch.nn.functional.pad(
torch.tensor(input_ids), (0, pad_len), mode="constant", value=259
)
input_ids = input_ids.reshape([1, 256])
attention_mask = torch.nn.functional.pad(
torch.tensor(attention_mask),
(0, pad_len),
mode="constant",
value=0,
)
firstVicuna = FirstVicuna(model_path)
prompt2 = "".join(["0" for _ in range(254)])
input_ids2 = tokenizer(prompt2).input_ids
input_ids2 = torch.tensor(input_ids2).reshape([1, 256])
# firstVicunaInput = tuple([torch.as_tensor([input_ids])])#.cuda()
# firstVicunaCompileInput = (input_ids2, torch.tensor([input_id_len]))
firstVicunaCompileInput = (input_ids2, attention_mask)
len_ = int(torch.tensor([input_id_len]))
# firstVicunaInput = (input_ids,int(torch.tensor([input_id_len])), )
firstVicunaInput = (
input_ids,
attention_mask,
)
shark_first_vicuna = compile_vicuna(
firstVicuna,
firstVicunaCompileInput,
"first_vicuna",
"first_vicuna",
)
# input_ids = torch.tensor(input_ids)
# output_first_vicuna = shark_first_vicuna("forward", (input_ids.reshape([1, input_ids.shape[0]]),))
output_first_vicuna = shark_first_vicuna("forward", firstVicunaInput)
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
torch.save(output_first_vicuna_tensor, "outpt_first_vicuna_tensor.pt")
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
torch.save(logits_first_vicuna, "logits_first_vicuna_tensor.pt")
# output_non_shark_first_vicuna = firstVicuna.forward(firstVicunaInput[0])
for i in range(40):
original_input_ids.append(
torch.argmax(logits_first_vicuna[:, len_ + i - 1, :], dim=1)
)
print(
torch.argmax(logits_first_vicuna[:, len_ + i - 1, :], dim=1),
tokenizer.decode(
torch.argmax(
logits_first_vicuna[:, len_ + i - 1, :], dim=1
)
),
)
input_id_len = len(original_input_ids)
pad_len = 256 - input_id_len
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
input_ids = torch.nn.functional.pad(
torch.tensor(original_input_ids),
(0, pad_len),
mode="constant",
value=259,
)
input_ids = input_ids.reshape([1, 256])
attention_mask = torch.nn.functional.pad(
torch.tensor(attention_mask),
(0, pad_len),
mode="constant",
value=0,
)
firstVicunaInput = (
input_ids,
attention_mask,
)
output_first_vicuna = shark_first_vicuna(
"forward", firstVicunaInput
)
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
print(
tokenizer.decode(
torch.argmax(logits_first_vicuna[:, len_ - 1, :], dim=1)
)
)
if vicuna_number == 2:
# last_token_logits = output_first_vicuna[0][0][-1]
# print("SHARK firstVicuna = ", str(last_token_logits))
# print("NonSHARK firstVicuna = ", str(output_non_shark_first_vicuna[0][0][-1]))
# temperature = 0.7
# probs = torch.softmax(torch.tensor(last_token_logits / temperature, dim=-1))
# token = torch.tensor(int(torch.multinomial(probs, num_samples=1))).reshape([1,1])
# token = torch.ones([1,1], dtype=torch.int64)#.cuda()
# pkvt = []
# for i in range(64):
# pkvt.append(torch.randn(1, 32, 40, 128, dtype=torch.float32))
# pkvt = tuple(pkvt)
# token = torch.ones([1,1], dtype=torch.int64)#.cuda()
output_first_vicuna = torch.load("outpt_first_vicuna_tensor.pt")
logits_first_vicuna = torch.load("logits_first_vicuna_tensor.pt")
print(logits_first_vicuna.shape)
for i in range(logits_first_vicuna.shape[1]):
token = torch.argmax(
torch.tensor(logits_first_vicuna)[:, i, :], dim=1
).reshape([1, 1])
print(token, tokenizer.decode(token[0][0]))
token = torch.argmax(
torch.tensor(logits_first_vicuna)[:, 8, :], dim=1
).reshape([1, 1])
print(logits_first_vicuna)
print(torch.tensor(logits_first_vicuna)[:, -1, :])
print(token, tokenizer.decode(token[0][0]))
result = [tokenizer.decode(token[0][0])]
pkvt = tuple(torch.tensor(x) for x in output_first_vicuna)
# pkv = torch.stack(pkvt, dim=0)
secondVicuna = SecondVicuna(model_path)
# del shark_first_vicuna
# del output_first_vicuna
# torch.cuda.empty_cache()
shark_second_vicuna = compile_vicuna(
secondVicuna, (token,) + pkvt, "second_vicuna", "second_vicuna"
)
print(len(pkvt))
output_second_vicuna = shark_second_vicuna("forward", (token,) + pkvt)
import time
f_ = open("all-logit-outputs.txt", "w+")
print(output_second_vicuna[0].shape)
for _ in range(10):
f_.write(
f"{_}:------------------------------------------------------------------------\n"
)
t1 = time.time()
start_point = output_second_vicuna[1].shape[2] - 256
for j in range(output_second_vicuna[0].shape[1]):
token_test = torch.argmax(
torch.tensor(output_second_vicuna[0])[:, j, :], dim=1
).reshape([1, 1])
sym = token_test, tokenizer.decode(token_test[0][0])
f_.write(f"{i}: {token_test} | {sym}")
token = torch.argmax(
torch.tensor(output_second_vicuna[0])[:, -1, :], dim=1
).reshape([1, 1])
# print(token, tokenizer.decode(token[0][0]))
result.append(tokenizer.decode(token[0][0]))
truncated_outputs = tuple(
x[:, :, :256, :] for x in output_second_vicuna[1:]
)
output_second_vicuna = shark_second_vicuna(
"forward", (token,) + truncated_outputs
)
# print(f"Token Generated in {time.time() - t1} seconds")
f_.write("\n")
f_.close()
print(result)

View File

@@ -2,9 +2,10 @@
# Sets up a venv suitable for running samples.
# e.g:
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
# Environment Variables by the script.
# Environment variables used by the script.
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
# VENV_DIR=myshark.venv #create a venv called myshark.venv
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
# IMPORTER=1 #Install importer deps
# BENCHMARK=1 #Install benchmark deps
@@ -26,15 +27,17 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
echo "Python: $PYTHON"
echo "Python version: $PYTHON_VERSION_X_Y"
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
PYTHON="$(which python3)"
else
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
fi
fi
Red=`tput setaf 1`
@@ -147,8 +150,7 @@ if [[ ! -z "${ONNX}" ]]; then
fi
fi
if [[ -z "${CONDA_PREFIX}" ]]; then
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
echo "${Green}Before running examples activate venv with:"
echo " ${Green}source $VENV_DIR/bin/activate"
fi

View File

@@ -1,7 +1,7 @@
from transformers import AutoTokenizer, FlaxAutoModel
import torch
import jax
from typing import Union, Dict, List
from typing import Union, Dict, List, Any
import numpy as np
from shark.shark_inference import SharkInference
import io
@@ -36,18 +36,38 @@ def get_sample_input():
)
def export_to_mlir(sample_input: NumpyTree):
model = FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
model_mlir = jax.jit(model).lower(**sample_input).compiler_ir()
return str(model_mlir).encode()
def get_jax_model():
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
byte_stream = io.BytesIO()
model_mlir.operation.write_bytecode(file=byte_stream)
return byte_stream.getvalue()
def assert_array_list_allclose(x, y, *args, **kwargs):
assert len(x) == len(y)
for a, b in zip(x, y):
np.testing.assert_allclose(
np.asarray(a), np.asarray(b), *args, **kwargs
)
sample_input = get_sample_input()
mlir = export_to_mlir(sample_input)
jax_model = get_jax_model()
mlir = export_jax_to_mlir(jax_model, sample_input)
# Compile and load module.
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
shark_inference.compile()
# Run main function.
print(shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0]))
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
# Run JAX model.
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
# Verify result.
assert_array_list_allclose(result, reference_result, atol=1e-5)

View File

@@ -1,5 +1,6 @@
flax
jax[cpu]
nodai-SHARK
orbax
transformers
torch

View File

@@ -45,10 +45,15 @@ def run_cmd(cmd, debug=False):
def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return _IREE_DEVICE_MAP[uri_parts[0]]
return iree_driver
else:
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
return f"{iree_driver}://{uri_parts[1]}"
def get_supported_device_list():
@@ -68,7 +73,7 @@ _IREE_DEVICE_MAP = {
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
_IREE_TARGET_MAP = {
@@ -110,10 +115,8 @@ def check_device_drivers(device):
subprocess.check_output("rocminfo")
except Exception:
return True
# Unknown device.
else:
return True
# Unknown device. We assume drivers are installed.
return False

View File

@@ -23,6 +23,7 @@ import re
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
@@ -30,6 +31,9 @@ def get_iree_device_args(device, extra_args=[]):
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
device_num = device_uri[1]
else:
device_num = 0
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
@@ -42,7 +46,9 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(extra_args=extra_args)
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args

View File

@@ -21,7 +21,7 @@ from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_vulkan_device_name():
def get_vulkan_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
@@ -31,8 +31,8 @@ def get_vulkan_device_name():
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing first one: {vulkaninfo_list[0]}")
return vulkaninfo_list[0]
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]
def get_os_name():
@@ -119,14 +119,14 @@ def get_vulkan_target_triple(device_name):
return triple
def get_vulkan_triple_flag(device_name="", extra_args=[]):
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
vulkan_device = get_vulkan_device_name()
vulkan_device = get_vulkan_device_name(device_num=device_num)
else:
vulkan_device = device_name
triple = get_vulkan_target_triple(vulkan_device)
@@ -144,7 +144,7 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
return None
def get_iree_vulkan_args(extra_args=[]):
def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
@@ -156,7 +156,9 @@ def get_iree_vulkan_args(extra_args=[]):
break
if vulkan_triple_flag is None:
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
vulkan_triple_flag = get_vulkan_triple_flag(
device_num=device_num, extra_args=extra_args
)
if vulkan_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)

View File

@@ -30,8 +30,8 @@ import os
import sys
from typing import Dict, List
import iree.compiler._mlir_libs
from iree.compiler import ir
from iree.compiler.transforms import ireec as ireec_trans
def model_annotation(
@@ -409,7 +409,6 @@ def shape_list_to_string(input):
def create_context() -> ir.Context:
context = ir.Context()
ireec_trans.register_all_dialects(context)
context.allow_unregistered_dialects = True
return context

View File

@@ -191,6 +191,7 @@ class SharkImporter:
dir=tempfile.gettempdir(),
model_name="model",
golden_values=None,
mlir_type="linalg",
):
if self.inputs == None:
print(
@@ -204,6 +205,7 @@ class SharkImporter:
tracing_required,
func_name,
save_dir=artifact_path,
mlir_type=mlir_type,
)
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
# TODO: Check for multiple outputs.

View File

@@ -21,7 +21,7 @@ import io
mlir_type_mapping_dict = {
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
"mhlo": torch_mlir.OutputType.STABLEHLO,
"stablehlo": torch_mlir.OutputType.STABLEHLO,
"tosa": torch_mlir.OutputType.TOSA,
}

View File

@@ -50,6 +50,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required = row[1]
model_type = row[2]
is_dynamic = row[3]
mlir_type = row[4]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
@@ -121,6 +122,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
)
# Generate torch dynamic models.
if is_dynamic:
@@ -129,6 +131,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
mlir_type=mlir_type,
)

View File

@@ -1,23 +1,24 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
google/mobilebert-uncased,True,hf,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,vision,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,vision,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,vision,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
squeezenet1_0,False,vision,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
wide_resnet50_2,False,vision,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
mobilenet_v3_small,False,vision,True,2.5M,"image-classification,cnn,mobile",N/A
google/vit-base-patch16-224,True,hf_img_cls,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
microsoft/resnet-50,True,hf_img_cls,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-classification,vision-transformer,cnn",N/A
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
model_name, use_tracing, model_type, dynamic, mlir_type, param_count, tags, notes
efficientnet_b0,True,vision,False,linalg,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,linalg,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
bert-base-uncased,True,hf,True,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
google/mobilebert-uncased,True,hf,True,linalg,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,vision,True,linalg,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,vision,True,linalg,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,vision,True,linalg,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,vision,True,linalg,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
squeezenet1_0,False,vision,True,linalg,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
wide_resnet50_2,False,vision,True,linalg,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
mobilenet_v3_small,False,vision,True,linalg,2.5M,"image-classification,cnn,mobile",N/A
google/vit-base-patch16-224,True,hf_img_cls,False,linalg,86M,"image-classification,vision-transformer,transformer-encoder",N/A
microsoft/resnet-50,True,hf_img_cls,False,linalg,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,22M,"image-classification,vision-transformer,cnn",N/A
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,linalg,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,linalg,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,linalg,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,linalg,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,linalg,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
bert-base-uncased,True,hf,False,stablehlo,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
1 model_name use_tracing model_type dynamic mlir_type param_count tags notes
2 efficientnet_b0 True vision False linalg 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
3 efficientnet_b7 True vision False linalg 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
4 microsoft/MiniLM-L12-H384-uncased True hf True linalg 66M nlp;bert-variant;transformer-encoder Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)
5 bert-base-uncased True hf True linalg 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True hf True linalg 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
7 google/mobilebert-uncased True hf True linalg 25M nlp,bert-variant,transformer-encoder,mobile 24 layers, 512 hidden size, 128 embedding
8 alexnet False vision True linalg 61M cnn,parallel-layers The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod.
9 resnet18 False vision True linalg 11M cnn,image-classification,residuals,resnet-variant 1 7x7 conv2d and the rest are 3x3 conv2d
10 resnet50 False vision True linalg 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
11 resnet101 False vision True linalg 29M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
12 squeezenet1_0 False vision True linalg 1.25M cnn,image-classification,mobile,parallel-layers Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)
13 wide_resnet50_2 False vision True linalg 69M cnn,image-classification,residuals,resnet-variant Resnet variant where model depth is decreased and width is increased.
14 mobilenet_v3_small False vision True linalg 2.5M image-classification,cnn,mobile N/A
15 google/vit-base-patch16-224 True hf_img_cls False linalg 86M image-classification,vision-transformer,transformer-encoder N/A
16 microsoft/resnet-50 True hf_img_cls False linalg 23M image-classification,cnn,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
17 facebook/deit-small-distilled-patch16-224 True hf_img_cls False linalg 22M image-classification,vision-transformer,cnn N/A
18 microsoft/beit-base-patch16-224-pt22k-ft22k True hf_img_cls False linalg 86M image-classification,transformer-encoder,bert-variant,vision-transformer N/A
19 nvidia/mit-b0 True hf_img_cls False linalg 3.7M image-classification,transformer-encoder SegFormer
20 mnasnet1_0 False vision True linalg - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
21 resnet50_fp16 False vision True linalg 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
22 bert-base-uncased_fp16 True fp16 False linalg 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
23 bert-large-uncased True hf True linalg 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
24 bert-base-uncased True hf False stablehlo 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads