mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support for training TF models + fix TF BERT training example (#80)
This commit is contained in:
@@ -11,7 +11,7 @@ from official.nlp.modeling import layers
|
||||
from official.nlp.modeling import networks
|
||||
from official.nlp.modeling.models import bert_classifier
|
||||
|
||||
from shark.shark_runner_tf import SharkTrainerTF
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
|
||||
vocab_size = 100
|
||||
NUM_CLASSES = 5
|
||||
@@ -65,17 +65,19 @@ class BertModule(tf.Module):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict_sample_input = np.asarray([
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
|
||||
])
|
||||
sample_input_tensors = tf.convert_to_tensor(predict_sample_input)
|
||||
]
|
||||
sample_input_tensors = [tf.convert_to_tensor(val, dtype=tf.int32) for val in predict_sample_input]
|
||||
num_iter = 10
|
||||
shark_module = SharkTrainerTF(
|
||||
shark_module = SharkTrainer(
|
||||
BertModule(),
|
||||
(sample_input_tensors,
|
||||
tf.convert_to_tensor(np.random.randint(5, size=(BATCH_SIZE)))))
|
||||
tf.convert_to_tensor(np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32)))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
start = time.time()
|
||||
print(shark_module.train(num_iter))
|
||||
end = time.time()
|
||||
|
||||
@@ -135,12 +135,18 @@ def get_iree_frontend_args(frontend):
|
||||
def get_iree_module(module, device, input_type, args, func_name):
|
||||
flatbuffer_blob = None
|
||||
# Compile according to the input type, else just try compiling.
|
||||
if input_type in ["mhlo", "tosa"]:
|
||||
if input_type in ["tosa"]:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module),
|
||||
target_backends=[IREE_DEVICE_MAP[device]],
|
||||
extra_args=args,
|
||||
input_type=input_type)
|
||||
elif input_type in ["mhlo"]:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[IREE_DEVICE_MAP[device]],
|
||||
extra_args=args,
|
||||
input_type=input_type)
|
||||
else:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module),
|
||||
|
||||
@@ -48,10 +48,16 @@ class SharkRunner:
|
||||
if self.frontend in ["pytorch", "torch"]:
|
||||
self.model = get_torch_mlir_module(self.model, input, dynamic,
|
||||
jit_trace, from_aot)
|
||||
(
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
) = get_iree_compiled_module(self.model, device)
|
||||
) = get_iree_compiled_module(self.model, device)
|
||||
|
||||
if self.frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
) = get_iree_compiled_module(self.model, device, self.frontend)
|
||||
|
||||
# Debugging Options:
|
||||
if shark_args.save_mlir:
|
||||
|
||||
@@ -110,11 +110,13 @@ class SharkTrainer:
|
||||
input_list = []
|
||||
for x in self.input:
|
||||
if (isinstance(x, list)):
|
||||
nested_list = []
|
||||
for val in x:
|
||||
if (isinstance(val, np.ndarray)):
|
||||
input_list.append([val for val in x])
|
||||
nested_list.append(val)
|
||||
else:
|
||||
input_list.append([val.numpy() for val in x])
|
||||
nested_list.append(val.numpy())
|
||||
input_list.append(nested_list)
|
||||
elif (isinstance(x, np.ndarray)):
|
||||
input_list.append(x)
|
||||
else:
|
||||
@@ -122,7 +124,7 @@ class SharkTrainer:
|
||||
|
||||
print(f"Training started for {num_iters} iterations:")
|
||||
for i in tqdm(range(num_iters)):
|
||||
outputs = self.shark_runner.forward(input_list)
|
||||
outputs = self.shark_runner.forward(input_list, self.frontend)
|
||||
|
||||
return self.model.trainable_variables
|
||||
|
||||
|
||||
Reference in New Issue
Block a user