mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
bert training wip
This commit is contained in:
@@ -13,6 +13,7 @@ import os
|
||||
import csv
|
||||
import argparse
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.examples.shark_training.bert_training import get_model_and_test_values
|
||||
from shark.parser import shark_args
|
||||
import tensorflow as tf
|
||||
import subprocess as sp
|
||||
@@ -59,30 +60,59 @@ def save_torch_model(torch_model_list):
|
||||
|
||||
model = None
|
||||
input = None
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
if model_type == "Training":
|
||||
asm, np_inputs, train_func, func_name = None, None, None, None
|
||||
#TODO {Dan}: replace this with a generic AutoModelForMaskedLM generator
|
||||
if torch_model_name == "bert-large-uncased_training":
|
||||
(
|
||||
asm,
|
||||
np_inputs,
|
||||
train_func,
|
||||
func_name,
|
||||
) = get_model_and_test_values()
|
||||
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
print("saving to:")
|
||||
print(torch_model_dir)
|
||||
model_path = os.path.join(
|
||||
torch_model_dir, torch_model_name + "_torch" + ".mlir"
|
||||
)
|
||||
with open(model_path, "w+") as f:
|
||||
f.write(asm)
|
||||
with open(os.path.join(torch_model_dir, "inputs.npz"), "wb") as f:
|
||||
[np.save(f, x.numpy()) for x in np_inputs]
|
||||
np.save(os.path.join(torch_model_dir, "function_name"), np.array(func_name))
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
)
|
||||
else:
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
print(torch_model_name)
|
||||
print(model_type)
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
torch_model_dir, torch_model_name + "_torch" + ".mlir"
|
||||
@@ -239,11 +269,11 @@ if __name__ == "__main__":
|
||||
if args.torch_model_csv:
|
||||
save_torch_model(args.torch_model_csv)
|
||||
|
||||
if args.tf_model_csv:
|
||||
save_tf_model(args.tf_model_csv)
|
||||
# if args.tf_model_csv:
|
||||
# save_tf_model(args.tf_model_csv)
|
||||
|
||||
if args.tflite_model_csv:
|
||||
save_tflite_model(args.tflite_model_csv)
|
||||
# if args.tflite_model_csv:
|
||||
# save_tflite_model(args.tflite_model_csv)
|
||||
|
||||
if args.upload:
|
||||
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import torch
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from functorch._src.compile_utils import strip_overloads
|
||||
from torch.nn.utils import _stateless
|
||||
|
||||
from torch import fx
|
||||
@@ -69,6 +70,7 @@ class MakeFxModule:
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
temp = tempfile.NamedTemporaryFile(
|
||||
suffix="_shark_ts", prefix="temp_ts_"
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.nn.utils import _stateless
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_runner import SharkTrainer
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
BertModel,
|
||||
AutoModelForMaskedLM,
|
||||
)
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
|
||||
|
||||
def get_torch_params(model):
|
||||
params = {v: i for v, i in model.named_parameters()}
|
||||
buffers = {v: i for v, i in model.named_buffers()}
|
||||
return params, buffers
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(
|
||||
"bert-large-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
@@ -19,29 +32,49 @@ class MiniLMSequenceClassification(torch.nn.Module):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
mod = MiniLMSequenceClassification()
|
||||
|
||||
|
||||
def get_sorted_params(named_params):
|
||||
return [i[1] for i in sorted(named_params.items())]
|
||||
|
||||
|
||||
print(dict(mod.named_buffers()))
|
||||
def get_model_and_test_values():
|
||||
mod = MiniLMSequenceClassification() # .to("cuda")
|
||||
inp = torch.randint(2, (32, 128)) # .to("cuda")
|
||||
|
||||
inp = (torch.randint(2, (1, 128)),)
|
||||
training_inputs = [i.detach() for i in mod.parameters()]
|
||||
for i in mod.buffers():
|
||||
training_inputs.append(i.detach())
|
||||
|
||||
training_inputs.append(inp.detach())
|
||||
# np.savez("/home/dan/inputs.npz", *[x.numpy() for x in training_inputs])
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
_stateless.functional_call(
|
||||
mod, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
# optim.load_state_dict(optim_state)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
shark_module = SharkTrainer(mod, (inp,), from_aot=True)
|
||||
shark_module.compile(forward)
|
||||
|
||||
rr = shark_module.shark_runner
|
||||
asm = rr.mlir_module.operation.get_asm()
|
||||
return asm, training_inputs, forward, "forward"
|
||||
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
_stateless.functional_call(
|
||||
mod, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
# optim.load_state_dict(optim_state)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
|
||||
shark_module = SharkTrainer(mod, inp, custom_inference_fn=forward)
|
||||
|
||||
print(shark_module.forward())
|
||||
def custom_benchmark_func(mod, shark_module):
|
||||
p, b = get_torch_params(mod)
|
||||
shark_params_and_buffers = shark_module.shark_runner.run(training_inputs)
|
||||
iterations = 1
|
||||
start = time.time()
|
||||
for i in range(iterations):
|
||||
p, b = forward(p, b, inp)
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("total_time(ms)/iter: " + str(1000 * total_time / iterations))
|
||||
golden = [v for v in p.values()][0].shape
|
||||
test = shark_params_and_buffers[0].shape
|
||||
return np.allclose(golden, test)
|
||||
|
||||
84
shark/examples/shark_training/bert_training_pt_bench.py
Normal file
84
shark/examples/shark_training/bert_training_pt_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.nn.utils import _stateless
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
BertModel,
|
||||
AutoModelForMaskedLM,
|
||||
)
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
|
||||
|
||||
def get_torch_params(model):
|
||||
params = {v: i for v, i in model.named_parameters()}
|
||||
buffers = {v: i for v, i in model.named_buffers()}
|
||||
return params, buffers
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(
|
||||
"bert-large-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_sorted_params(named_params):
|
||||
return [i[1] for i in sorted(named_params.items())]
|
||||
|
||||
|
||||
#def get_model_and_test_values():
|
||||
mod = MiniLMSequenceClassification().to("cuda")
|
||||
inp = torch.randint(2, (1, 128)).to("cuda")
|
||||
|
||||
training_inputs = [i.detach() for i in mod.parameters()]
|
||||
for i in mod.buffers():
|
||||
training_inputs.append(i.detach())
|
||||
|
||||
training_inputs.append(inp.detach())
|
||||
# np.savez("/home/dan/inputs.npz", *[x.numpy() for x in training_inputs])
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
params_and_buffers = {k:v.to("cuda") for k, v in params_and_buffers.items()}
|
||||
_stateless.functional_call(
|
||||
mod, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
# optim.load_state_dict(optim_state)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
shark_module = SharkTrainer(mod, (inp,), from_aot=True)
|
||||
#shark_module.compile(forward)
|
||||
|
||||
#rr = shark_module.shark_runner
|
||||
#asm = rr.mlir_module.operation.get_asm()
|
||||
#return asm, training_inputs, forward, "forward"
|
||||
|
||||
|
||||
#def custom_benchmark_func(mod, shark_module):
|
||||
p, b = get_torch_params(mod)
|
||||
#shark_params_and_buffers = shark_module.shark_runner.run(training_inputs)
|
||||
iterations = 200
|
||||
for i in range(iterations):
|
||||
if i==1:
|
||||
start = time.time()
|
||||
p, b = forward(p, b, inp)
|
||||
b = {k:v.to("cpu") for k, v in b.items()}
|
||||
#p = {k:v.to("cpu") for k, v in p.items()}
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("total_time(ms)/iter: " + str(1000 * total_time / (iterations-1)))
|
||||
golden = [v for v in p.values()][0].shape
|
||||
#test = shark_params_and_buffers[0].shape
|
||||
#return np.allclose(golden, test)
|
||||
@@ -25,7 +25,7 @@ import sys
|
||||
|
||||
|
||||
# supported dialects by the shark-runtime.
|
||||
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite"}
|
||||
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite", "tm_tensor"}
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.backward_makefx import MakeFxModule
|
||||
@@ -76,14 +77,15 @@ class SharkTrainer:
|
||||
# Returns the backward graph.
|
||||
training_graph = aot_module.training_graph
|
||||
weights = self.get_torch_params()
|
||||
mlir_importer = SharkImporter(
|
||||
training_graph, weights + self.input, "torch"
|
||||
)
|
||||
|
||||
self.imported_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=self.jit_trace
|
||||
)
|
||||
self.shark_runner = SharkRunner(
|
||||
training_graph,
|
||||
weights + self.input,
|
||||
self.dynamic,
|
||||
self.device,
|
||||
self.jit_trace,
|
||||
self.from_aot,
|
||||
self.frontend,
|
||||
self.imported_mlir, func_name, self.device, "tm_tensor"
|
||||
)
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
self.shark_runner = SharkRunner(
|
||||
|
||||
@@ -1,34 +1 @@
|
||||
resnet50,mhlo,tf,1e-02,1e-3,default
|
||||
albert-base-v2,mhlo,tf,1e-02,1e-3,default
|
||||
roberta-base,mhlo,tf,1e-02,1e-3,default
|
||||
bert-base-uncased,mhlo,tf,1e-2,1e-3,default
|
||||
camembert-base,mhlo,tf,1e-2,1e-3,default
|
||||
dbmdz/convbert-base-turkish-cased,mhlo,tf,1e-2,1e-3,default
|
||||
distilbert-base-uncased,mhlo,tf,1e-2,1e-3,default
|
||||
facebook/convnext-tiny-224,mhlo,tf,1e-2,1e-3,tf_vit
|
||||
funnel-transformer/small,mhlo,tf,1e-2,1e-3,default
|
||||
google/electra-small-discriminator,mhlo,tf,1e-2,1e-3,default
|
||||
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default
|
||||
google/vit-base-patch16-224,mhlo,tf,1e-2,1e-3,tf_vit
|
||||
hf-internal-testing/tiny-random-flaubert,mhlo,tf,1e-2,1e-3,default
|
||||
microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf
|
||||
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default
|
||||
microsoft/mpnet-base,mhlo,tf,1e-2,1e-3,default
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default
|
||||
alexnet,linalg,torch,1e-2,1e-3,default
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default
|
||||
distilbert-base-uncased,linalg,torch,1e-2,1e-3,default
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default
|
||||
microsoft/resnet-50,linalg,torch,1e-2,1e-3,default
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default
|
||||
mobilenet_v3_small,linalg,torch,1e-2,1e-3,default
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default
|
||||
resnet101,linalg,torch,1e-2,1e-3,default
|
||||
resnet18,linalg,torch,1e-2,1e-3,default
|
||||
resnet50,linalg,torch,1e-2,1e-3,default
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default
|
||||
bert-large-uncased_training,tm_tensor,torch,1e-2,1e-3,default
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from iree import runtime as ireert
|
||||
from iree.tf.support import module_utils
|
||||
from iree.compiler import tf as tfc
|
||||
from iree.compiler import compile_str
|
||||
|
||||
@@ -165,18 +164,21 @@ if __name__ == "__main__":
|
||||
BertCompiled = ctx.modules.module
|
||||
|
||||
# compare output losses:
|
||||
|
||||
iterations = 10
|
||||
start = time.time()
|
||||
iterations = 100
|
||||
for i in range(iterations):
|
||||
example_inputs, example_labels = next(iter(glue_train))
|
||||
example_labels = tf.cast(example_labels, tf.int32)
|
||||
example_inputs = [value for key, value in example_inputs.items()]
|
||||
|
||||
# iree version
|
||||
iree_loss = BertCompiled.learn(
|
||||
example_inputs, example_labels
|
||||
).to_host()
|
||||
# iree_loss = BertCompiled.learn(
|
||||
# example_inputs, example_labels
|
||||
# ).to_host()
|
||||
|
||||
# base tensorflow
|
||||
tf_loss = np.array(bert_model.learn(example_inputs, example_labels))
|
||||
print(np.allclose(iree_loss, tf_loss))
|
||||
# print(np.allclose(iree_loss, tf_loss))
|
||||
end = time.time()
|
||||
total = (end - start) * 1000
|
||||
print("total time/iter (ms): " + str(total / iterations))
|
||||
|
||||
@@ -1,18 +1,2 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
albert-base-v2,True,hf,True,11M,"nlp;bert-variant;transformer-encoder","12 layers; 128 embedding dim; 768 hidden dim; 12 attention heads; Smaller than BERTbase (11M params vs 109M params); Uses weight sharing to reduce # params but computational cost is similar to BERT."
|
||||
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
squeezenet1_0,False,vision,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
|
||||
wide_resnet50_2,False,vision,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
|
||||
mobilenet_v3_small,False,vision,True,2.5M,"image-classification,cnn,mobile",N/A
|
||||
google/vit-base-patch16-224,True,hf_img_cls,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
|
||||
microsoft/resnet-50,True,hf_img_cls,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-classification,vision-transformer,cnn",N/A
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
bert-large-uncased_training,True,Training,False,336M,"nlp;bert-bariant;transformer-encoder","24 layers, 1024 hidden; 16 attention heads"
|
||||
|
||||
|
Reference in New Issue
Block a user