mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
from PIL import Image
|
|
import requests
|
|
|
|
from transformers import T5Tokenizer, TFT5Model
|
|
import tensorflow as tf
|
|
from amdshark.amdshark_inference import AMDSharkInference
|
|
|
|
# Create a set of inputs
|
|
t5_inputs = [
|
|
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
|
|
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
|
|
]
|
|
|
|
|
|
class T5Module(tf.Module):
|
|
def __init__(self):
|
|
super(T5Module, self).__init__()
|
|
self.m = TFT5Model.from_pretrained("t5-small")
|
|
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
|
|
|
|
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
|
def forward(self, input_ids, decoder_input_ids):
|
|
return self.m.predict(input_ids, decoder_input_ids)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Prepping Data
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
text = "I love the distilled version of models."
|
|
inputs = tokenizer(text, return_tensors="tf").input_ids
|
|
|
|
amdshark_module = AMDSharkInference(T5Module(), (inputs, inputs))
|
|
amdshark_module.set_frontend("tensorflow")
|
|
amdshark_module.compile()
|
|
print(amdshark_module.forward((inputs, inputs)))
|