Enable XLA compiler for TF models (#484)

This commit is contained in:
mariecwhite
2022-11-13 20:10:47 -08:00
committed by GitHub
parent 559928e93b
commit ec461a4456
20 changed files with 33 additions and 25 deletions

View File

@@ -42,7 +42,7 @@ class TFHuggingFaceLanguage(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=tf_bert_input)
@tf.function(input_signature=tf_bert_input, jit_compile=True)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)

View File

@@ -22,7 +22,7 @@ class CLIPModule(tf.Module):
input_ids=x, attention_mask=y, pixel_values=z
)
@tf.function(input_signature=clip_vit_inputs)
@tf.function(input_signature=clip_vit_inputs, jit_compile=True)
def forward(self, input_ids, attention_mask, pixel_values):
return self.m.predict(
input_ids, attention_mask, pixel_values

View File

@@ -28,7 +28,7 @@ class AlbertModule(tf.Module):
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
@tf.function(input_signature=t5_inputs)
@tf.function(input_signature=t5_inputs, jit_compile=True)
def forward(self, input_ids, attention_mask):
return self.m.predict(input_ids, attention_mask)

View File

@@ -19,7 +19,7 @@ class GPT2Module(tf.Module):
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
@tf.function(input_signature=gpt2_inputs)
@tf.function(input_signature=gpt2_inputs, jit_compile=True)
def forward(self, input_ids, attention_mask):
return self.m.predict(input_ids, attention_mask)

View File

@@ -26,7 +26,7 @@ class BertModule(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=bert_input)
@tf.function(input_signature=bert_input, jit_compile=True)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)

View File

@@ -26,7 +26,7 @@ class BertModule(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=bert_input)
@tf.function(input_signature=bert_input, jit_compile=True)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)

View File

@@ -18,7 +18,7 @@ class T5Module(tf.Module):
self.m = TFT5Model.from_pretrained("t5-small")
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
@tf.function(input_signature=t5_inputs)
@tf.function(input_signature=t5_inputs, jit_compile=True)
def forward(self, input_ids, decoder_input_ids):
return self.m.predict(input_ids, decoder_input_ids)

View File

@@ -52,7 +52,8 @@ class BertModule(tf.Module):
input_signature=[
bert_input, # inputs
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
]
],
jit_compile=True
)
def forward(self, inputs, labels):
with tf.GradientTape() as tape:

View File

@@ -32,7 +32,7 @@ class BertModule(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=bert_input)
@tf.function(input_signature=bert_input, jit_compile=True)
def predict(self, input_word_ids, input_mask, segment_ids):
return self.m.predict(input_word_ids, input_mask, segment_ids)

View File

@@ -33,7 +33,7 @@ class BertModule(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=bert_input)
@tf.function(input_signature=bert_input, jit_compile=True)
def predict(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)

View File

@@ -52,7 +52,7 @@ class SeqClassification(tf.Module):
)
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
@tf.function(input_signature=inputs_signature)
@tf.function(input_signature=inputs_signature, jit_compile=True)
def forward(self, input_ids, attention_mask):
return tf.math.softmax(
self.m.predict(input_ids, attention_mask), axis=-1

View File

@@ -72,7 +72,8 @@ class BertModule(tf.Module):
input_signature=[
bert_input, # inputs
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
]
],
jit_compile=True
)
def learn(self, inputs, labels):
with tf.GradientTape() as tape:

View File

@@ -60,7 +60,8 @@ class BertModule(tf.Module):
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
), # input2: segment_ids
tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels
]
],
jit_compile=True
)
def learn(self, input_word_ids, input_mask, segment_ids, labels):
with tf.GradientTape() as tape:
@@ -75,7 +76,7 @@ class BertModule(tf.Module):
self.optimizer.apply_gradients(zip(gradients, variables))
return loss
@tf.function(input_signature=bert_input)
@tf.function(input_signature=bert_input, jit_compile=True)
def predict(self, input_word_ids, input_mask, segment_ids):
inputs = [input_word_ids, input_mask, segment_ids]
return self.m.predict(inputs)

View File

@@ -57,7 +57,8 @@ class BertModule(tf.Module):
input_signature=[
bert_input, # inputs
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
]
],
jit_compile=True
)
def learn(self, inputs, labels):
with tf.GradientTape() as tape:

View File

@@ -50,7 +50,8 @@ class BertModule(tf.Module):
input_signature=[
bert_input, # inputs
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
]
],
jit_compile=True
)
def learn(self, inputs, labels):
with tf.GradientTape() as tape:

View File

@@ -57,7 +57,8 @@ class BertModule(tf.Module):
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
), # input2: segment_ids
tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels
]
],
jit_compile=True
)
def learn(self, input_word_ids, input_mask, segment_ids, labels):
with tf.GradientTape() as tape:
@@ -72,7 +73,7 @@ class BertModule(tf.Module):
self.optimizer.apply_gradients(zip(gradients, variables))
return loss
@tf.function(input_signature=bert_input)
@tf.function(input_signature=bert_input, jit_compile=True)
def predict(self, input_word_ids, input_mask, segment_ids):
inputs = [input_word_ids, input_mask, segment_ids]
return self.m.predict(inputs)

View File

@@ -53,7 +53,8 @@ class BertModule(tf.Module):
input_signature=[
bert_input, # inputs
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
]
],
jit_compile=True
)
def learn(self, inputs, labels):
with tf.GradientTape() as tape:

View File

@@ -46,7 +46,8 @@ class BertModule(tf.Module):
input_signature=[
bert_input, # inputs
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
]
],
jit_compile=True
)
def learn(self, inputs, labels):
with tf.GradientTape() as tape:

View File

@@ -52,7 +52,7 @@ class SeqClassification(tf.Module):
)
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
@tf.function(input_signature=inputs_signature)
@tf.function(input_signature=inputs_signature, jit_compile=True)
def forward(self, input_ids, attention_mask):
return tf.math.softmax(
self.m.predict(input_ids, attention_mask), axis=-1

View File

@@ -87,7 +87,7 @@ class TFHuggingFaceLanguage(tf.Module):
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=tf_bert_input)
@tf.function(input_signature=tf_bert_input, jit_compile=True)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)
@@ -162,7 +162,7 @@ class MaskedLM(tf.Module):
)
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
@tf.function(input_signature=input_signature_maskedlm)
@tf.function(input_signature=input_signature_maskedlm, jit_compile=True)
def forward(self, input_ids, attention_mask):
return self.m.predict(input_ids, attention_mask)
@@ -191,7 +191,7 @@ class ResNetModule(tf.Module):
self.m = tf_model
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])
@tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)], jit_compile=True)
def forward(self, inputs):
return self.m.predict(inputs)
@@ -240,7 +240,7 @@ class AutoModelImageClassfication(tf.Module):
)
self.m.predict = lambda x: self.m(x)
@tf.function(input_signature=input_signature_img_cls)
@tf.function(input_signature=input_signature_img_cls, jit_compile=True)
def forward(self, inputs):
return self.m.predict(inputs)