bert training wip

This commit is contained in:
dan
2022-09-20 17:15:33 +00:00
committed by dan
parent a63755bc24
commit 6eea111f5e
9 changed files with 219 additions and 115 deletions

View File

@@ -13,6 +13,7 @@ import os
import csv
import argparse
from shark.shark_importer import SharkImporter
from shark.examples.shark_training.bert_training import get_model_and_test_values
from shark.parser import shark_args
import tensorflow as tf
import subprocess as sp
@@ -59,30 +60,59 @@ def save_torch_model(torch_model_list):
model = None
input = None
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
if model_type == "Training":
asm, np_inputs, train_func, func_name = None, None, None, None
#TODO {Dan}: replace this with a generic AutoModelForMaskedLM generator
if torch_model_name == "bert-large-uncased_training":
(
asm,
np_inputs,
train_func,
func_name,
) = get_model_and_test_values()
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
print("saving to:")
print(torch_model_dir)
model_path = os.path.join(
torch_model_dir, torch_model_name + "_torch" + ".mlir"
)
with open(model_path, "w+") as f:
f.write(asm)
with open(os.path.join(torch_model_dir, "inputs.npz"), "wb") as f:
[np.save(f, x.numpy()) for x in np_inputs]
np.save(os.path.join(torch_model_dir, "function_name"), np.array(func_name))
mlir_importer = SharkImporter(
model,
(input,),
frontend="torch",
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name,
)
else:
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
model, input, _ = get_hf_model(torch_model_name)
elif model_type == "hf_img_cls":
model, input, _ = get_hf_img_cls_model(torch_model_name)
torch_model_name = torch_model_name.replace("/", "_")
torch_model_dir = os.path.join(
WORKDIR, str(torch_model_name) + "_torch"
)
os.makedirs(torch_model_dir, exist_ok=True)
print(torch_model_name)
print(model_type)
mlir_importer = SharkImporter(
model,
(input,),
frontend="torch",
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=tracing_required,
dir=torch_model_dir,
model_name=torch_model_name,
)
mlir_hash = create_hash(
os.path.join(
torch_model_dir, torch_model_name + "_torch" + ".mlir"
@@ -239,11 +269,11 @@ if __name__ == "__main__":
if args.torch_model_csv:
save_torch_model(args.torch_model_csv)
if args.tf_model_csv:
save_tf_model(args.tf_model_csv)
# if args.tf_model_csv:
# save_tf_model(args.tf_model_csv)
if args.tflite_model_csv:
save_tflite_model(args.tflite_model_csv)
# if args.tflite_model_csv:
# save_tflite_model(args.tflite_model_csv)
if args.upload:
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"

View File

@@ -15,6 +15,7 @@
import torch
from torch._decomp import get_decompositions
from torch.fx.experimental.proxy_tensor import make_fx
from functorch._src.compile_utils import strip_overloads
from torch.nn.utils import _stateless
from torch import fx
@@ -69,6 +70,7 @@ class MakeFxModule:
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
temp = tempfile.NamedTemporaryFile(
suffix="_shark_ts", prefix="temp_ts_"

View File

@@ -1,14 +1,27 @@
import torch
import time
import numpy as np
from torch.nn.utils import _stateless
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_runner import SharkTrainer
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
BertModel,
AutoModelForMaskedLM,
)
from shark.shark_trainer import SharkTrainer
def get_torch_params(model):
params = {v: i for v, i in model.named_parameters()}
buffers = {v: i for v, i in model.named_buffers()}
return params, buffers
class MiniLMSequenceClassification(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
self.model = AutoModelForMaskedLM.from_pretrained(
"bert-large-uncased", # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
@@ -19,29 +32,49 @@ class MiniLMSequenceClassification(torch.nn.Module):
return self.model.forward(tokens)[0]
mod = MiniLMSequenceClassification()
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
print(dict(mod.named_buffers()))
def get_model_and_test_values():
mod = MiniLMSequenceClassification() # .to("cuda")
inp = torch.randint(2, (32, 128)) # .to("cuda")
inp = (torch.randint(2, (1, 128)),)
training_inputs = [i.detach() for i in mod.parameters()]
for i in mod.buffers():
training_inputs.append(i.detach())
training_inputs.append(inp.detach())
# np.savez("/home/dan/inputs.npz", *[x.numpy() for x in training_inputs])
def forward(params, buffers, args):
params_and_buffers = {**params, **buffers}
_stateless.functional_call(
mod, params_and_buffers, args, {}
).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
# optim.load_state_dict(optim_state)
optim.step()
return params, buffers
shark_module = SharkTrainer(mod, (inp,), from_aot=True)
shark_module.compile(forward)
rr = shark_module.shark_runner
asm = rr.mlir_module.operation.get_asm()
return asm, training_inputs, forward, "forward"
def forward(params, buffers, args):
params_and_buffers = {**params, **buffers}
_stateless.functional_call(
mod, params_and_buffers, args, {}
).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
# optim.load_state_dict(optim_state)
optim.step()
return params, buffers
shark_module = SharkTrainer(mod, inp, custom_inference_fn=forward)
print(shark_module.forward())
def custom_benchmark_func(mod, shark_module):
p, b = get_torch_params(mod)
shark_params_and_buffers = shark_module.shark_runner.run(training_inputs)
iterations = 1
start = time.time()
for i in range(iterations):
p, b = forward(p, b, inp)
end = time.time()
total_time = end - start
print("total_time(ms)/iter: " + str(1000 * total_time / iterations))
golden = [v for v in p.values()][0].shape
test = shark_params_and_buffers[0].shape
return np.allclose(golden, test)

View File

@@ -0,0 +1,84 @@
import torch
import time
import numpy as np
from torch.nn.utils import _stateless
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
BertModel,
AutoModelForMaskedLM,
)
from shark.shark_trainer import SharkTrainer
def get_torch_params(model):
params = {v: i for v, i in model.named_parameters()}
buffers = {v: i for v, i in model.named_buffers()}
return params, buffers
class MiniLMSequenceClassification(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForMaskedLM.from_pretrained(
"bert-large-uncased", # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
)
def forward(self, tokens):
return self.model.forward(tokens)[0]
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
#def get_model_and_test_values():
mod = MiniLMSequenceClassification().to("cuda")
inp = torch.randint(2, (1, 128)).to("cuda")
training_inputs = [i.detach() for i in mod.parameters()]
for i in mod.buffers():
training_inputs.append(i.detach())
training_inputs.append(inp.detach())
# np.savez("/home/dan/inputs.npz", *[x.numpy() for x in training_inputs])
def forward(params, buffers, args):
params_and_buffers = {**params, **buffers}
params_and_buffers = {k:v.to("cuda") for k, v in params_and_buffers.items()}
_stateless.functional_call(
mod, params_and_buffers, args, {}
).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
# optim.load_state_dict(optim_state)
optim.step()
return params, buffers
shark_module = SharkTrainer(mod, (inp,), from_aot=True)
#shark_module.compile(forward)
#rr = shark_module.shark_runner
#asm = rr.mlir_module.operation.get_asm()
#return asm, training_inputs, forward, "forward"
#def custom_benchmark_func(mod, shark_module):
p, b = get_torch_params(mod)
#shark_params_and_buffers = shark_module.shark_runner.run(training_inputs)
iterations = 200
for i in range(iterations):
if i==1:
start = time.time()
p, b = forward(p, b, inp)
b = {k:v.to("cpu") for k, v in b.items()}
#p = {k:v.to("cpu") for k, v in p.items()}
end = time.time()
total_time = end - start
print("total_time(ms)/iter: " + str(1000 * total_time / (iterations-1)))
golden = [v for v in p.values()][0].shape
#test = shark_params_and_buffers[0].shape
#return np.allclose(golden, test)

View File

@@ -25,7 +25,7 @@ import sys
# supported dialects by the shark-runtime.
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite"}
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite", "tm_tensor"}
class SharkRunner:

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from shark.shark_importer import SharkImporter
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
from shark.backward_makefx import MakeFxModule
@@ -76,14 +77,15 @@ class SharkTrainer:
# Returns the backward graph.
training_graph = aot_module.training_graph
weights = self.get_torch_params()
mlir_importer = SharkImporter(
training_graph, weights + self.input, "torch"
)
self.imported_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=self.jit_trace
)
self.shark_runner = SharkRunner(
training_graph,
weights + self.input,
self.dynamic,
self.device,
self.jit_trace,
self.from_aot,
self.frontend,
self.imported_mlir, func_name, self.device, "tm_tensor"
)
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
self.shark_runner = SharkRunner(

View File

@@ -1,34 +1 @@
resnet50,mhlo,tf,1e-02,1e-3,default
albert-base-v2,mhlo,tf,1e-02,1e-3,default
roberta-base,mhlo,tf,1e-02,1e-3,default
bert-base-uncased,mhlo,tf,1e-2,1e-3,default
camembert-base,mhlo,tf,1e-2,1e-3,default
dbmdz/convbert-base-turkish-cased,mhlo,tf,1e-2,1e-3,default
distilbert-base-uncased,mhlo,tf,1e-2,1e-3,default
facebook/convnext-tiny-224,mhlo,tf,1e-2,1e-3,tf_vit
funnel-transformer/small,mhlo,tf,1e-2,1e-3,default
google/electra-small-discriminator,mhlo,tf,1e-2,1e-3,default
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default
google/vit-base-patch16-224,mhlo,tf,1e-2,1e-3,tf_vit
hf-internal-testing/tiny-random-flaubert,mhlo,tf,1e-2,1e-3,default
microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default
microsoft/mpnet-base,mhlo,tf,1e-2,1e-3,default
albert-base-v2,linalg,torch,1e-2,1e-3,default
alexnet,linalg,torch,1e-2,1e-3,default
bert-base-cased,linalg,torch,1e-2,1e-3,default
bert-base-uncased,linalg,torch,1e-2,1e-3,default
distilbert-base-uncased,linalg,torch,1e-2,1e-3,default
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default
microsoft/resnet-50,linalg,torch,1e-2,1e-3,default
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default
mobilenet_v3_small,linalg,torch,1e-2,1e-3,default
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default
resnet101,linalg,torch,1e-2,1e-3,default
resnet18,linalg,torch,1e-2,1e-3,default
resnet50,linalg,torch,1e-2,1e-3,default
squeezenet1_0,linalg,torch,1e-2,1e-3,default
wide_resnet50_2,linalg,torch,1e-2,1e-3,default
bert-large-uncased_training,tm_tensor,torch,1e-2,1e-3,default
1 resnet50 bert-large-uncased_training mhlo tm_tensor tf torch 1e-02 1e-2 1e-3 default
albert-base-v2 mhlo tf 1e-02 1e-3 default
roberta-base mhlo tf 1e-02 1e-3 default
bert-base-uncased mhlo tf 1e-2 1e-3 default
camembert-base mhlo tf 1e-2 1e-3 default
dbmdz/convbert-base-turkish-cased mhlo tf 1e-2 1e-3 default
distilbert-base-uncased mhlo tf 1e-2 1e-3 default
facebook/convnext-tiny-224 mhlo tf 1e-2 1e-3 tf_vit
funnel-transformer/small mhlo tf 1e-2 1e-3 default
google/electra-small-discriminator mhlo tf 1e-2 1e-3 default
google/mobilebert-uncased mhlo tf 1e-2 1e-3 default
google/vit-base-patch16-224 mhlo tf 1e-2 1e-3 tf_vit
hf-internal-testing/tiny-random-flaubert mhlo tf 1e-2 1e-3 default
microsoft/MiniLM-L12-H384-uncased mhlo tf 1e-2 1e-3 tf_hf
microsoft/layoutlm-base-uncased mhlo tf 1e-2 1e-3 default
microsoft/mpnet-base mhlo tf 1e-2 1e-3 default
albert-base-v2 linalg torch 1e-2 1e-3 default
alexnet linalg torch 1e-2 1e-3 default
bert-base-cased linalg torch 1e-2 1e-3 default
bert-base-uncased linalg torch 1e-2 1e-3 default
distilbert-base-uncased linalg torch 1e-2 1e-3 default
facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default
google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default
microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default
microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default
microsoft/resnet-50 linalg torch 1e-2 1e-3 default
google/mobilebert-uncased linalg torch 1e-2 1e-3 default
mobilenet_v3_small linalg torch 1e-2 1e-3 default
nvidia/mit-b0 linalg torch 1e-2 1e-3 default
resnet101 linalg torch 1e-2 1e-3 default
resnet18 linalg torch 1e-2 1e-3 default
resnet50 linalg torch 1e-2 1e-3 default
squeezenet1_0 linalg torch 1e-2 1e-3 default
wide_resnet50_2 linalg torch 1e-2 1e-3 default

View File

@@ -1,7 +1,6 @@
import numpy as np
from iree import runtime as ireert
from iree.tf.support import module_utils
from iree.compiler import tf as tfc
from iree.compiler import compile_str
@@ -165,18 +164,21 @@ if __name__ == "__main__":
BertCompiled = ctx.modules.module
# compare output losses:
iterations = 10
start = time.time()
iterations = 100
for i in range(iterations):
example_inputs, example_labels = next(iter(glue_train))
example_labels = tf.cast(example_labels, tf.int32)
example_inputs = [value for key, value in example_inputs.items()]
# iree version
iree_loss = BertCompiled.learn(
example_inputs, example_labels
).to_host()
# iree_loss = BertCompiled.learn(
# example_inputs, example_labels
# ).to_host()
# base tensorflow
tf_loss = np.array(bert_model.learn(example_inputs, example_labels))
print(np.allclose(iree_loss, tf_loss))
# print(np.allclose(iree_loss, tf_loss))
end = time.time()
total = (end - start) * 1000
print("total time/iter (ms): " + str(total / iterations))

View File

@@ -1,18 +1,2 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
albert-base-v2,True,hf,True,11M,"nlp;bert-variant;transformer-encoder","12 layers; 128 embedding dim; 768 hidden dim; 12 attention heads; Smaller than BERTbase (11M params vs 109M params); Uses weight sharing to reduce # params but computational cost is similar to BERT."
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
google/mobilebert-uncased,True,hf,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,vision,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,vision,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,vision,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
squeezenet1_0,False,vision,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
wide_resnet50_2,False,vision,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
mobilenet_v3_small,False,vision,True,2.5M,"image-classification,cnn,mobile",N/A
google/vit-base-patch16-224,True,hf_img_cls,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
microsoft/resnet-50,True,hf_img_cls,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-classification,vision-transformer,cnn",N/A
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
bert-large-uncased_training,True,Training,False,336M,"nlp;bert-bariant;transformer-encoder","24 layers, 1024 hidden; 16 attention heads"
1 model_name use_tracing model_type dynamic param_count tags notes
2 microsoft/MiniLM-L12-H384-uncased bert-large-uncased_training True hf Training True False 66M 336M nlp;bert-variant;transformer-encoder nlp;bert-bariant;transformer-encoder Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params) 24 layers, 1024 hidden; 16 attention heads
albert-base-v2 True hf True 11M nlp;bert-variant;transformer-encoder 12 layers; 128 embedding dim; 768 hidden dim; 12 attention heads; Smaller than BERTbase (11M params vs 109M params); Uses weight sharing to reduce # params but computational cost is similar to BERT.
bert-base-uncased True hf True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
bert-base-cased True hf True 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
google/mobilebert-uncased True hf True 25M nlp,bert-variant,transformer-encoder,mobile 24 layers, 512 hidden size, 128 embedding
alexnet False vision True 61M cnn,parallel-layers The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod.
resnet18 False vision True 11M cnn,image-classification,residuals,resnet-variant 1 7x7 conv2d and the rest are 3x3 conv2d
resnet50 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
resnet101 False vision True 29M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
squeezenet1_0 False vision True 1.25M cnn,image-classification,mobile,parallel-layers Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)
wide_resnet50_2 False vision True 69M cnn,image-classification,residuals,resnet-variant Resnet variant where model depth is decreased and width is increased.
mobilenet_v3_small False vision True 2.5M image-classification,cnn,mobile N/A
google/vit-base-patch16-224 True hf_img_cls False 86M image-classification,vision-transformer,transformer-encoder N/A
microsoft/resnet-50 True hf_img_cls False 23M image-classification,cnn,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
facebook/deit-small-distilled-patch16-224 True hf_img_cls False 22M image-classification,vision-transformer,cnn N/A
microsoft/beit-base-patch16-224-pt22k-ft22k True hf_img_cls False 86M image-classification,transformer-encoder,bert-variant,vision-transformer N/A
nvidia/mit-b0 True hf_img_cls False 3.7M image-classification,transformer-encoder SegFormer