mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Initialize TF models locally (#610)
This commit is contained in:
@@ -169,23 +169,15 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
|
||||
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
|
||||
|
||||
tf_resnet_model = tf.keras.applications.resnet50.ResNet50(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(RESNET_INPUT_SHAPE[1:]),
|
||||
)
|
||||
|
||||
tf_efficientnet_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
|
||||
)
|
||||
|
||||
|
||||
class ResNetModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(ResNetModule, self).__init__()
|
||||
self.m = tf_resnet_model
|
||||
self.m = tf.keras.applications.resnet50.ResNet50(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(RESNET_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
@@ -205,7 +197,11 @@ class ResNetModule(tf.Module):
|
||||
class EfficientNetModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetModule, self).__init__()
|
||||
self.m = tf_efficientnet_model
|
||||
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
|
||||
Reference in New Issue
Block a user