Add support for training TF models + fix TF BERT training example (#80)

This commit is contained in:
Ean Garvey
2022-06-01 20:46:07 -05:00
committed by GitHub
parent 05dffd4d59
commit 8ad73d365c
4 changed files with 28 additions and 12 deletions

View File

@@ -11,7 +11,7 @@ from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_classifier
from shark.shark_runner_tf import SharkTrainerTF
from shark.shark_trainer import SharkTrainer
vocab_size = 100
NUM_CLASSES = 5
@@ -65,17 +65,19 @@ class BertModule(tf.Module):
if __name__ == "__main__":
predict_sample_input = np.asarray([
predict_sample_input = [
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
])
sample_input_tensors = tf.convert_to_tensor(predict_sample_input)
]
sample_input_tensors = [tf.convert_to_tensor(val, dtype=tf.int32) for val in predict_sample_input]
num_iter = 10
shark_module = SharkTrainerTF(
shark_module = SharkTrainer(
BertModule(),
(sample_input_tensors,
tf.convert_to_tensor(np.random.randint(5, size=(BATCH_SIZE)))))
tf.convert_to_tensor(np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32)))
shark_module.set_frontend("tensorflow")
shark_module.compile()
start = time.time()
print(shark_module.train(num_iter))
end = time.time()

View File

@@ -135,12 +135,18 @@ def get_iree_frontend_args(frontend):
def get_iree_module(module, device, input_type, args, func_name):
flatbuffer_blob = None
# Compile according to the input type, else just try compiling.
if input_type in ["mhlo", "tosa"]:
if input_type in ["tosa"]:
flatbuffer_blob = ireec.compile_str(
str(module),
target_backends=[IREE_DEVICE_MAP[device]],
extra_args=args,
input_type=input_type)
elif input_type in ["mhlo"]:
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[IREE_DEVICE_MAP[device]],
extra_args=args,
input_type=input_type)
else:
flatbuffer_blob = ireec.compile_str(
str(module),

View File

@@ -48,10 +48,16 @@ class SharkRunner:
if self.frontend in ["pytorch", "torch"]:
self.model = get_torch_mlir_module(self.model, input, dynamic,
jit_trace, from_aot)
(
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(self.model, device)
) = get_iree_compiled_module(self.model, device)
if self.frontend in ["tensorflow", "tf", "mhlo"]:
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(self.model, device, self.frontend)
# Debugging Options:
if shark_args.save_mlir:

View File

@@ -110,11 +110,13 @@ class SharkTrainer:
input_list = []
for x in self.input:
if (isinstance(x, list)):
nested_list = []
for val in x:
if (isinstance(val, np.ndarray)):
input_list.append([val for val in x])
nested_list.append(val)
else:
input_list.append([val.numpy() for val in x])
nested_list.append(val.numpy())
input_list.append(nested_list)
elif (isinstance(x, np.ndarray)):
input_list.append(x)
else:
@@ -122,7 +124,7 @@ class SharkTrainer:
print(f"Training started for {num_iters} iterations:")
for i in tqdm(range(num_iters)):
outputs = self.shark_runner.forward(input_list)
outputs = self.shark_runner.forward(input_list, self.frontend)
return self.model.trainable_variables