From d8952fc5756e4df0d3d5be458e7b3e2420f930de Mon Sep 17 00:00:00 2001 From: jaredeh Date: Wed, 13 Dec 2023 21:54:47 -0800 Subject: [PATCH] updating to work with new internal apis (#2755) --- examples/compile_tensorflow.py | 51 +++++++++++++++------------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/examples/compile_tensorflow.py b/examples/compile_tensorflow.py index 43e5685b29..4fee2a41c2 100644 --- a/examples/compile_tensorflow.py +++ b/examples/compile_tensorflow.py @@ -2,15 +2,14 @@ import os, sys os.environ["CLANG"] = '1' -os.environ["GPU"] = '1' import numpy as np import subprocess import tensorflow as tf import tf2onnx -from examples.compile_efficientnet import compile_net from extra.onnx import get_run_onnx from tinygrad.tensor import Tensor +from extra.export_model import export_model_clang, compile_net, jit_model def get_uncompiled_model2(dataset_size=32, output_size=4): inputs = tf.keras.Input(shape=(dataset_size,), name="inputs") @@ -21,44 +20,40 @@ def get_uncompiled_model2(dataset_size=32, output_size=4): model = tf.keras.Model(inputs=inputs, outputs=outputs) return model -def create_onnx_model(keras_model): - input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')] - onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13) - return onnx_model +class TinyOnnx: + def __init__(self, keras_model): + input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')] + onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13) + self.run_onnx = get_run_onnx(onnx_model) + + def forward(self, x): + return self.run_onnx({"x": x}, debug=False)['predictions'] def compile_onnx_model(onnx_model): - run_onnx = get_run_onnx(onnx_model) - - from tinygrad.jit import TinyJit - @TinyJit - def run(x): return run_onnx({"x": x}, debug=False)['predictions'].realize() - + tinyonnx = TinyOnnx(onnx_model) the_input = Tensor.randn(1,32) - the_output = run(the_input) - the_output = run(the_input) - special_names = {id(the_input.lazydata.realized.cl): "input", id(the_output.lazydata.realized.cl): "outputs"} - cprog, statements, bufs, bufs_to_save = compile_net(run, special_names) - cprog = ["#include ", "#include ", "#include "] + cprog + run, special_names = jit_model(tinyonnx, the_input) - # buffers (all except input) - cprog += [f"float {x[0]}[{x[1]}];" for x in bufs.values() if x[0] != "input"] + functions, statements, bufs, bufs_to_save = compile_net(run, special_names) + prg = export_model_clang(functions, statements, bufs, {}, ["input0"], ["output0"]) + + the_output = run(the_input) + cprog = ["#include ", "#include ", "#include "] + cprog.append(prg) # weights cprog.append("void initialize(float *weights) {") weights = bytes() for name,cl in bufs_to_save.items(): - cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl)});") - weights += bytes(memoryview(cl)[0:len(cl)//4]) + cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl._buf)*4});") + weights += bytes(cl._buf) cprog.append("}") # write the weights to disk with open("/tmp/tf_weights", "wb") as f: f.write(weights) - # the net - cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"] - # test program cprog.append(f"""int main(int argc, char *argv[]) {{ // read in the weights from disk @@ -72,8 +67,9 @@ def compile_onnx_model(onnx_model): // test run float input[32]; + float outputs[4]; for (int i = 0; i < 32; i++) scanf("%f", &input[i]); - float *outputs = infer(input); + net(input, outputs); printf("%f %f %f %f\\n", outputs[0], outputs[1], outputs[2], outputs[3]); }}""") @@ -84,7 +80,7 @@ def compile_onnx_model(onnx_model): # add test weights subprocess.check_output(['clang', '-O2', '-lm', '-fPIC', '-x', 'c', '-', '-o', "/tmp/tf_test"], input=prg.encode('utf-8')) - tinygrad_output = [x for x in the_output.numpy()[0]] + tinygrad_output = the_output[0].numpy()[0].tolist() print("tinygrad:", tinygrad_output, file=sys.stderr) c_input = ' '.join(["%f" % x for x in the_input[0].numpy()])+"\n" @@ -96,8 +92,7 @@ def compile_onnx_model(onnx_model): if __name__ == "__main__": keras_model = get_uncompiled_model2() - onnx_model = create_onnx_model(keras_model) - test_input, test_output = compile_onnx_model(onnx_model) + test_input, test_output = compile_onnx_model(keras_model) tf_output = keras_model(test_input).numpy()[0] print("keras: ", tf_output, file=sys.stderr) np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5)