Add Resnet50 fp16 variant to pytests. (#760)

This commit is contained in:
Ean Garvey
2023-01-10 18:31:11 -06:00
committed by GitHub
parent 9570045cc3
commit 72f29b67d5
9 changed files with 93 additions and 11 deletions

View File

@@ -1,5 +1,5 @@
#!/bin/bash
IMPORTER=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py --upload=False --ci_tank_dir=True

View File

@@ -41,9 +41,12 @@ def create_hash(file_name):
def save_torch_model(torch_model_list):
from tank.model_utils import get_hf_model
from tank.model_utils import get_vision_model
from tank.model_utils import get_hf_img_cls_model
from tank.model_utils import (
get_hf_model,
get_vision_model,
get_hf_img_cls_model,
get_fp16_model,
)
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
@@ -65,7 +68,8 @@ def save_torch_model(torch_model_list):
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"

View File

@@ -123,8 +123,12 @@ fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
TORCH_VERSION=${T_VER:9:17}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu117
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
else

View File

@@ -339,7 +339,10 @@ for currently supported models. Exiting benchmark ONNX."
else:
bench_result["shape_type"] = "static"
bench_result["device"] = device_str
bench_result["data_type"] = inputs[0].dtype
if "fp16" in modelname:
bench_result["data_type"] = "float16"
else:
bench_result["data_type"] = inputs[0].dtype
for e in engines:
(
bench_result["param_count"],

View File

@@ -169,9 +169,12 @@ def download_model(
os.path.join(model_dir, "upstream_hash.npy"),
single_file=True,
)
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
try:
upstream_hash = str(
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
except FileNotFoundError:
upstream_hash = None
if local_hash != upstream_hash:
print(
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."

View File

@@ -17,6 +17,7 @@ albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with at
alexnet,linalg,torch,1e-2,1e-3,default,None,False,False,True,"Assertion Error: Zeros Output"
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile."
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311"
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390"
@@ -28,6 +29,7 @@ nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,""
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc,True,False,True,""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"https://github.com/nod-ai/SHARK/issues/388"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,True,"https://github.com/nod-ai/SHARK/issues/575"
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error: mostly conv
17 alexnet linalg torch 1e-2 1e-3 default None False False True Assertion Error: Zeros Output
18 bert-base-cased linalg torch 1e-2 1e-3 default None False False False
19 bert-base-uncased linalg torch 1e-2 1e-3 default None False False False
20 bert-base-uncased_fp16 linalg torch 1e-1 1e-1 default None True False True
21 facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False Fails during iree-compile.
22 google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/311
23 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/390
29 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error (mostly conv)
30 resnet18 linalg torch 1e-2 1e-3 default None True True True
31 resnet50 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error (mostly conv)
32 resnet50_fp16 linalg torch 1e-2 1e-2 default nhcw-nhwc True False True
33 squeezenet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True https://github.com/nod-ai/SHARK/issues/388
34 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error (mostly conv)
35 efficientnet-v2-s mhlo tf 1e-02 1e-3 default nhcw-nhwc False False True https://github.com/nod-ai/SHARK/issues/575

View File

@@ -2,12 +2,14 @@ model_name, use_tracing, dynamic, param_count, tags, notes
microsoft/MiniLM-L12-H384-uncased,True,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
albert-base-v2,True,True,11M,"nlp;bert-variant;transformer-encoder","12 layers; 128 embedding dim; 768 hidden dim; 12 attention heads; Smaller than BERTbase (11M params vs 109M params); Uses weight sharing to reduce # params but computational cost is similar to BERT."
bert-base-uncased,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-uncased_fp16,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
distilbert-base-uncased,True,True,66M,"nlp;bert-variant;transformer-encoder","Smaller and faster than BERT with 97percent retained accuracy."
google/mobilebert-uncased,True,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet50_fp16,False,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
squeezenet1_0,False,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
wide_resnet50_2,False,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
1 model_name use_tracing dynamic param_count tags notes
2 microsoft/MiniLM-L12-H384-uncased True True 66M nlp;bert-variant;transformer-encoder Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)
3 albert-base-v2 True True 11M nlp;bert-variant;transformer-encoder 12 layers; 128 embedding dim; 768 hidden dim; 12 attention heads; Smaller than BERTbase (11M params vs 109M params); Uses weight sharing to reduce # params but computational cost is similar to BERT.
4 bert-base-uncased True True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
5 bert-base-uncased_fp16 True True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
7 distilbert-base-uncased True True 66M nlp;bert-variant;transformer-encoder Smaller and faster than BERT with 97percent retained accuracy.
8 google/mobilebert-uncased True True 25M nlp,bert-variant,transformer-encoder,mobile 24 layers, 512 hidden size, 128 embedding
9 alexnet False True 61M cnn,parallel-layers The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod.
10 resnet18 False True 11M cnn,image-classification,residuals,resnet-variant 1 7x7 conv2d and the rest are 3x3 conv2d
11 resnet50 False True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
12 resnet50_fp16 False True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
13 resnet101 False True 29M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
14 squeezenet1_0 False True 1.25M cnn,image-classification,mobile,parallel-layers Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)
15 wide_resnet50_2 False True 69M cnn,image-classification,residuals,resnet-variant Resnet variant where model depth is decreased and width is increased.

View File

@@ -12,6 +12,7 @@ vision_models = [
"resnet101",
"resnet18",
"resnet50",
"resnet50_fp16",
"squeezenet1_0",
"wide_resnet50_2",
"mobilenet_v3_small",
@@ -31,6 +32,8 @@ def get_torch_model(modelname):
return get_vision_model(modelname)
elif modelname in hf_img_cls_models:
return get_hf_img_cls_model(modelname)
elif "fp16" in modelname:
return get_fp16_model(modelname)
else:
return get_hf_model(modelname)
@@ -114,7 +117,6 @@ class HuggingFaceLanguage(torch.nn.Module):
def get_hf_model(name):
from transformers import (
BertTokenizer,
TFBertModel,
)
model = HuggingFaceLanguage(name)
@@ -146,6 +148,7 @@ def get_vision_model(torch_model):
"alexnet": models.alexnet(weights="DEFAULT"),
"resnet18": models.resnet18(weights="DEFAULT"),
"resnet50": models.resnet50(weights="DEFAULT"),
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
"resnet101": models.resnet101(weights="DEFAULT"),
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
@@ -153,10 +156,26 @@ def get_vision_model(torch_model):
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
}
if isinstance(torch_model, str):
fp16_model = None
if "fp16" in torch_model:
fp16_model = True
torch_model = vision_models_dict[torch_model]
model = VisionModule(torch_model)
test_input = torch.randn(1, 3, 224, 224)
actual_out = model(test_input)
if fp16_model is not None:
test_input_fp16 = test_input.to(
device=torch.device("cuda"), dtype=torch.half
)
model_fp16 = model.half()
model_fp16.eval()
model_fp16.to("cuda")
actual_out_fp16 = model_fp16(test_input_fp16)
model, test_input, actual_out = (
model_fp16,
test_input_fp16,
actual_out_fp16,
)
return model, test_input, actual_out
@@ -164,6 +183,49 @@ def get_vision_model(torch_model):
####################### Other PyTorch HF Models ###############################
class BertHalfPrecisionModel(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
from transformers import AutoModelForMaskedLM
self.model = AutoModelForMaskedLM.from_pretrained(
hf_model_name, # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
torch_dtype=torch.float16,
).to("cuda")
def forward(self, tokens):
return self.model.forward(tokens)[0]
def get_fp16_model(torch_model):
from transformers import AutoTokenizer
modelname = torch_model.replace("_fp16", "")
model = BertHalfPrecisionModel(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
text = "Replace me by any text you like."
test_input_fp16 = tokenizer(
text,
truncation=True,
max_length=128,
return_tensors="pt",
).input_ids.to("cuda")
# test_input = torch.randint(2, (1, 128))
# test_input_fp16 = test_input.to(
# device=torch.device("cuda")
# )
model_fp16 = model.half()
model_fp16.eval()
with torch.no_grad():
actual_out_fp16 = model_fp16(test_input_fp16)
return model_fp16, test_input_fp16, actual_out_fp16
# Utility function for comparing two tensors (torch).
def compare_tensors(torch_tensor, numpy_tensor, rtol=1e-02, atol=1e-03):
# torch_to_numpy = torch_tensor.detach().numpy()

View File

@@ -16,3 +16,5 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-class
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
1 model_name use_tracing model_type dynamic param_count tags notes
16 microsoft/beit-base-patch16-224-pt22k-ft22k True hf_img_cls False 86M image-classification,transformer-encoder,bert-variant,vision-transformer N/A
17 nvidia/mit-b0 True hf_img_cls False 3.7M image-classification,transformer-encoder SegFormer
18 mnasnet1_0 False vision True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
19 resnet50_fp16 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
20 bert-base-uncased_fp16 True fp16 False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads