mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Enable XLA compiler for TF models (#484)
This commit is contained in:
@@ -42,7 +42,7 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input)
|
||||
@tf.function(input_signature=tf_bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class CLIPModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, pixel_values=z
|
||||
)
|
||||
|
||||
@tf.function(input_signature=clip_vit_inputs)
|
||||
@tf.function(input_signature=clip_vit_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, pixel_values):
|
||||
return self.m.predict(
|
||||
input_ids, attention_mask, pixel_values
|
||||
|
||||
@@ -28,7 +28,7 @@ class AlbertModule(tf.Module):
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs)
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class GPT2Module(tf.Module):
|
||||
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=gpt2_inputs)
|
||||
@tf.function(input_signature=gpt2_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class T5Module(tf.Module):
|
||||
self.m = TFT5Model.from_pretrained("t5-small")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs)
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.m.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
@@ -52,7 +52,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True
|
||||
)
|
||||
def forward(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
@@ -32,7 +32,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def predict(self, input_word_ids, input_mask, segment_ids):
|
||||
return self.m.predict(input_word_ids, input_mask, segment_ids)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class BertModule(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def predict(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class SeqClassification(tf.Module):
|
||||
)
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||
|
||||
@tf.function(input_signature=inputs_signature)
|
||||
@tf.function(input_signature=inputs_signature, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return tf.math.softmax(
|
||||
self.m.predict(input_ids, attention_mask), axis=-1
|
||||
|
||||
@@ -72,7 +72,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
@@ -60,7 +60,8 @@ class BertModule(tf.Module):
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input2: segment_ids
|
||||
tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels
|
||||
]
|
||||
],
|
||||
jit_compile=True
|
||||
)
|
||||
def learn(self, input_word_ids, input_mask, segment_ids, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -75,7 +76,7 @@ class BertModule(tf.Module):
|
||||
self.optimizer.apply_gradients(zip(gradients, variables))
|
||||
return loss
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def predict(self, input_word_ids, input_mask, segment_ids):
|
||||
inputs = [input_word_ids, input_mask, segment_ids]
|
||||
return self.m.predict(inputs)
|
||||
|
||||
@@ -57,7 +57,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
@@ -50,7 +50,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
@@ -57,7 +57,8 @@ class BertModule(tf.Module):
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input2: segment_ids
|
||||
tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels
|
||||
]
|
||||
],
|
||||
jit_compile=True
|
||||
)
|
||||
def learn(self, input_word_ids, input_mask, segment_ids, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -72,7 +73,7 @@ class BertModule(tf.Module):
|
||||
self.optimizer.apply_gradients(zip(gradients, variables))
|
||||
return loss
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def predict(self, input_word_ids, input_mask, segment_ids):
|
||||
inputs = [input_word_ids, input_mask, segment_ids]
|
||||
return self.m.predict(inputs)
|
||||
|
||||
@@ -53,7 +53,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
@@ -46,7 +46,8 @@ class BertModule(tf.Module):
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
],
|
||||
jit_compile=True
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
@@ -52,7 +52,7 @@ class SeqClassification(tf.Module):
|
||||
)
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||
|
||||
@tf.function(input_signature=inputs_signature)
|
||||
@tf.function(input_signature=inputs_signature, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return tf.math.softmax(
|
||||
self.m.predict(input_ids, attention_mask), axis=-1
|
||||
|
||||
@@ -87,7 +87,7 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input)
|
||||
@tf.function(input_signature=tf_bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
@@ -162,7 +162,7 @@ class MaskedLM(tf.Module):
|
||||
)
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||
|
||||
@tf.function(input_signature=input_signature_maskedlm)
|
||||
@tf.function(input_signature=input_signature_maskedlm, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
@@ -191,7 +191,7 @@ class ResNetModule(tf.Module):
|
||||
self.m = tf_model
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])
|
||||
@tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)], jit_compile=True)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
@@ -240,7 +240,7 @@ class AutoModelImageClassfication(tf.Module):
|
||||
)
|
||||
self.m.predict = lambda x: self.m(x)
|
||||
|
||||
@tf.function(input_signature=input_signature_img_cls)
|
||||
@tf.function(input_signature=input_signature_img_cls, jit_compile=True)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user