Files
SHARK-Studio/amdshark/examples/amdshark_inference/t5_tf.py
pdhirajkumarprasad fe03539901 Migration to AMDShark (#2182)
Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
2025-11-20 12:52:07 +05:30

36 lines
1.1 KiB
Python

from PIL import Image
import requests
from transformers import T5Tokenizer, TFT5Model
import tensorflow as tf
from amdshark.amdshark_inference import AMDSharkInference
# Create a set of inputs
t5_inputs = [
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
]
class T5Module(tf.Module):
def __init__(self):
super(T5Module, self).__init__()
self.m = TFT5Model.from_pretrained("t5-small")
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
@tf.function(input_signature=t5_inputs, jit_compile=True)
def forward(self, input_ids, decoder_input_ids):
return self.m.predict(input_ids, decoder_input_ids)
if __name__ == "__main__":
# Prepping Data
tokenizer = T5Tokenizer.from_pretrained("t5-small")
text = "I love the distilled version of models."
inputs = tokenizer(text, return_tensors="tf").input_ids
amdshark_module = AMDSharkInference(T5Module(), (inputs, inputs))
amdshark_module.set_frontend("tensorflow")
amdshark_module.compile()
print(amdshark_module.forward((inputs, inputs)))