From 45677c1e23540bfbb46f72161a616bbbda0e8389 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Mon, 14 Nov 2022 02:31:01 -0600 Subject: [PATCH] Install torch version required by torch-mlir when setting up importer venv. (#486) --- setup_venv.sh | 5 +++-- shark/examples/shark_training/bert_training_tf.py | 2 +- tank/examples/bert_fine_tuning/bert_fine_tune_tf.py | 2 +- tank/examples/bert_tf/bert_large_gen.py | 2 +- tank/examples/bert_tf/bert_large_run.py | 2 +- tank/examples/bert_tf/bert_large_tf.py | 2 +- tank/examples/bert_tf/bert_small_gen.py | 2 +- tank/examples/bert_tf/bert_small_run.py | 2 +- tank/examples/bert_tf/bert_small_tf_run.py | 2 +- tank/model_utils_tf.py | 5 ++++- 10 files changed, 15 insertions(+), 11 deletions(-) diff --git a/setup_venv.sh b/setup_venv.sh index 6e778286..d8b6c21f 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -105,8 +105,6 @@ else echo "Not installing a backend, please make sure to add your backend to PYTHONPATH" fi -$PYTHON -m pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/ - if [[ ! -z "${IMPORTER}" ]]; then echo "${Yellow}Installing importer tools.." if [[ $(uname -s) = 'Linux' ]]; then @@ -119,6 +117,9 @@ if [[ ! -z "${IMPORTER}" ]]; then $PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu fi fi + +$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/ + if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then $PYTHON -m pip uninstall -y torch torchvision $PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu116 diff --git a/shark/examples/shark_training/bert_training_tf.py b/shark/examples/shark_training/bert_training_tf.py index 43bae320..8db49c61 100644 --- a/shark/examples/shark_training/bert_training_tf.py +++ b/shark/examples/shark_training/bert_training_tf.py @@ -53,7 +53,7 @@ class BertModule(tf.Module): bert_input, # inputs tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels ], - jit_compile=True + jit_compile=True, ) def forward(self, inputs, labels): with tf.GradientTape() as tape: diff --git a/tank/examples/bert_fine_tuning/bert_fine_tune_tf.py b/tank/examples/bert_fine_tuning/bert_fine_tune_tf.py index 76a3b514..3965c6c5 100644 --- a/tank/examples/bert_fine_tuning/bert_fine_tune_tf.py +++ b/tank/examples/bert_fine_tuning/bert_fine_tune_tf.py @@ -73,7 +73,7 @@ class BertModule(tf.Module): bert_input, # inputs tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels ], - jit_compile=True + jit_compile=True, ) def learn(self, inputs, labels): with tf.GradientTape() as tape: diff --git a/tank/examples/bert_tf/bert_large_gen.py b/tank/examples/bert_tf/bert_large_gen.py index 70e92036..380b7174 100644 --- a/tank/examples/bert_tf/bert_large_gen.py +++ b/tank/examples/bert_tf/bert_large_gen.py @@ -61,7 +61,7 @@ class BertModule(tf.Module): ), # input2: segment_ids tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels ], - jit_compile=True + jit_compile=True, ) def learn(self, input_word_ids, input_mask, segment_ids, labels): with tf.GradientTape() as tape: diff --git a/tank/examples/bert_tf/bert_large_run.py b/tank/examples/bert_tf/bert_large_run.py index 5d76424b..c0fcee3c 100644 --- a/tank/examples/bert_tf/bert_large_run.py +++ b/tank/examples/bert_tf/bert_large_run.py @@ -58,7 +58,7 @@ class BertModule(tf.Module): bert_input, # inputs tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels ], - jit_compile=True + jit_compile=True, ) def learn(self, inputs, labels): with tf.GradientTape() as tape: diff --git a/tank/examples/bert_tf/bert_large_tf.py b/tank/examples/bert_tf/bert_large_tf.py index e531d7a1..33d1ed03 100644 --- a/tank/examples/bert_tf/bert_large_tf.py +++ b/tank/examples/bert_tf/bert_large_tf.py @@ -51,7 +51,7 @@ class BertModule(tf.Module): bert_input, # inputs tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels ], - jit_compile=True + jit_compile=True, ) def learn(self, inputs, labels): with tf.GradientTape() as tape: diff --git a/tank/examples/bert_tf/bert_small_gen.py b/tank/examples/bert_tf/bert_small_gen.py index d459b062..ca48b5a2 100644 --- a/tank/examples/bert_tf/bert_small_gen.py +++ b/tank/examples/bert_tf/bert_small_gen.py @@ -58,7 +58,7 @@ class BertModule(tf.Module): ), # input2: segment_ids tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels ], - jit_compile=True + jit_compile=True, ) def learn(self, input_word_ids, input_mask, segment_ids, labels): with tf.GradientTape() as tape: diff --git a/tank/examples/bert_tf/bert_small_run.py b/tank/examples/bert_tf/bert_small_run.py index b7d3ad34..b610985c 100644 --- a/tank/examples/bert_tf/bert_small_run.py +++ b/tank/examples/bert_tf/bert_small_run.py @@ -54,7 +54,7 @@ class BertModule(tf.Module): bert_input, # inputs tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels ], - jit_compile=True + jit_compile=True, ) def learn(self, inputs, labels): with tf.GradientTape() as tape: diff --git a/tank/examples/bert_tf/bert_small_tf_run.py b/tank/examples/bert_tf/bert_small_tf_run.py index bc67866f..277e0c20 100644 --- a/tank/examples/bert_tf/bert_small_tf_run.py +++ b/tank/examples/bert_tf/bert_small_tf_run.py @@ -47,7 +47,7 @@ class BertModule(tf.Module): bert_input, # inputs tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels ], - jit_compile=True + jit_compile=True, ) def learn(self, inputs, labels): with tf.GradientTape() as tape: diff --git a/tank/model_utils_tf.py b/tank/model_utils_tf.py index 86197e84..824b74dd 100644 --- a/tank/model_utils_tf.py +++ b/tank/model_utils_tf.py @@ -191,7 +191,10 @@ class ResNetModule(tf.Module): self.m = tf_model self.m.predict = lambda x: self.m.call(x, training=False) - @tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)], jit_compile=True) + @tf.function( + input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)], + jit_compile=True, + ) def forward(self, inputs): return self.m.predict(inputs)