Initialize TF models locally (#610)

This commit is contained in:
mariecwhite
2022-12-11 16:35:34 -08:00
committed by GitHub
parent 616ee9b824
commit eb8114ece8

View File

@@ -169,23 +169,15 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
tf_resnet_model = tf.keras.applications.resnet50.ResNet50(
weights="imagenet",
include_top=True,
input_shape=tuple(RESNET_INPUT_SHAPE[1:]),
)
tf_efficientnet_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
)
class ResNetModule(tf.Module):
def __init__(self):
super(ResNetModule, self).__init__()
self.m = tf_resnet_model
self.m = tf.keras.applications.resnet50.ResNet50(
weights="imagenet",
include_top=True,
input_shape=tuple(RESNET_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(
@@ -205,7 +197,11 @@ class ResNetModule(tf.Module):
class EfficientNetModule(tf.Module):
def __init__(self):
super(EfficientNetModule, self).__init__()
self.m = tf_efficientnet_model
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
weights="imagenet",
include_top=True,
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
)
self.m.predict = lambda x: self.m.call(x, training=False)
@tf.function(