From eade3647d04fa3db2697c21fd1a8991a5dc4cdbd Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 19 Sep 2023 19:51:54 +0100 Subject: [PATCH] feat: merge compiled-model and circuit to reduce pipeline errors (#495) --- examples/notebooks/data_attest.ipynb | 6 +- examples/notebooks/decision_tree.ipynb | 6 +- examples/notebooks/encrypted_vis.ipynb | 11 +- examples/notebooks/ezkl_demo.ipynb | 17 +- .../notebooks/gradient_boosted_trees.ipynb | 6 +- examples/notebooks/hashed_vis.ipynb | 8 +- examples/notebooks/keras_simple_demo.ipynb | 6 +- examples/notebooks/lightgbm.ipynb | 6 +- examples/notebooks/little_transformer.ipynb | 21 +- examples/notebooks/lstm.ipynb | 8 +- examples/notebooks/mean_postgres.ipynb | 5 +- examples/notebooks/mnist_gan.ipynb | 8 +- examples/notebooks/mnist_vae.ipynb | 17 +- .../nbeats_timeseries_forecasting.ipynb | 6 +- examples/notebooks/random_forest.ipynb | 6 +- examples/notebooks/set_membership.ipynb | 6 +- examples/notebooks/simple_demo.ipynb | 6 +- .../simple_demo_aggregated_proofs.ipynb | 6 +- examples/notebooks/svm.ipynb | 6 +- examples/notebooks/variance.ipynb | 6 +- examples/notebooks/voice_judge.ipynb | 8 +- examples/notebooks/xgboost.ipynb | 6 +- src/commands.rs | 35 +--- src/execute.rs | 165 +++++----------- src/graph/mod.rs | 179 ++++++++++-------- src/graph/model.rs | 5 +- src/python.rs | 37 +--- src/wasm.rs | 41 +--- tests/integration_tests.rs | 130 +++---------- tests/output_comparison.py | 2 +- tests/python/binding_tests.py | 33 +--- tests/wasm.rs | 1 - tests/wasm/settings.json | 6 +- tests/wasm/testWasm.test.ts | 2 +- tests/wasm/test_network.compiled | Bin 1201 -> 2179 bytes 35 files changed, 299 insertions(+), 518 deletions(-) diff --git a/examples/notebooks/data_attest.ipynb b/examples/notebooks/data_attest.ipynb index 620db479..9c903360 100644 --- a/examples/notebooks/data_attest.ipynb +++ b/examples/notebooks/data_attest.ipynb @@ -404,7 +404,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -445,7 +445,7 @@ "\n", "witness_path = \"witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)" + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)" ] }, { @@ -471,7 +471,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -506,7 +505,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/decision_tree.ipynb b/examples/notebooks/decision_tree.ipynb index d3c4ca8c..2cb55197 100644 --- a/examples/notebooks/decision_tree.ipynb +++ b/examples/notebooks/decision_tree.ipynb @@ -161,7 +161,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -185,7 +185,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -209,7 +209,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -238,7 +237,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/encrypted_vis.ipynb b/examples/notebooks/encrypted_vis.ipynb index 25113537..340d447d 100644 --- a/examples/notebooks/encrypted_vis.ipynb +++ b/examples/notebooks/encrypted_vis.ipynb @@ -290,7 +290,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -331,7 +331,7 @@ "\n", "witness_path = \"witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)" + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)" ] }, { @@ -350,7 +350,7 @@ "source": [ "\n", "\n", - "res = ezkl.mock(witness_path, compiled_model_path, settings_path)" + "res = ezkl.mock(witness_path, compiled_model_path)" ] }, { @@ -376,9 +376,9 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", + "\n", "assert res == True\n", "assert os.path.isfile(vk_path)\n", "assert os.path.isfile(pk_path)\n", @@ -411,9 +411,9 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", + "\n", "print(res)\n", "assert os.path.isfile(proof_path)" ] @@ -556,7 +556,6 @@ " srs_path,\n", " \"poseidon\",\n", " \"accum\",\n", - " settings_path,\n", " )" ] }, diff --git a/examples/notebooks/ezkl_demo.ipynb b/examples/notebooks/ezkl_demo.ipynb index 4d5593dd..0e72d990 100644 --- a/examples/notebooks/ezkl_demo.ipynb +++ b/examples/notebooks/ezkl_demo.ipynb @@ -451,7 +451,7 @@ }, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -497,8 +497,8 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", + " \n", "\n", "assert res == True\n", "assert os.path.isfile(vk_path)\n", @@ -528,7 +528,7 @@ "# now generate the witness file\n", "witness_path = os.path.join('witness.json')\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -556,7 +556,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(proof)\n", @@ -654,10 +653,10 @@ }, "outputs": [], "source": [ - " sol_code_path = os.path.join('Verifier.sol')\n", - " abi_path = os.path.join('Verifier.abi')\n", + "sol_code_path = os.path.join('Verifier.sol')\n", + "abi_path = os.path.join('Verifier.abi')\n", "\n", - " res = ezkl.create_evm_verifier(\n", + "res = ezkl.create_evm_verifier(\n", " vk_path,\n", " srs_path,\n", " settings_path,\n", @@ -665,8 +664,8 @@ " abi_path\n", " )\n", "\n", - " assert res == True\n", - " assert os.path.isfile(sol_code_path)" + "assert res == True\n", + "assert os.path.isfile(sol_code_path)" ] }, { diff --git a/examples/notebooks/gradient_boosted_trees.ipynb b/examples/notebooks/gradient_boosted_trees.ipynb index 83268014..260713c1 100644 --- a/examples/notebooks/gradient_boosted_trees.ipynb +++ b/examples/notebooks/gradient_boosted_trees.ipynb @@ -191,7 +191,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -215,7 +215,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -239,7 +239,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -268,7 +267,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/hashed_vis.ipynb b/examples/notebooks/hashed_vis.ipynb index 9d9dd1de..9c59948d 100644 --- a/examples/notebooks/hashed_vis.ipynb +++ b/examples/notebooks/hashed_vis.ipynb @@ -273,7 +273,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -322,7 +322,7 @@ "\n", "witness_path = \"witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)" + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)" ] }, { @@ -355,7 +355,7 @@ "source": [ "\n", "\n", - "res = ezkl.mock(witness_path, compiled_model_path, settings_path)" + "res = ezkl.mock(witness_path, compiled_model_path)" ] }, { @@ -403,7 +403,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -464,7 +463,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/keras_simple_demo.ipynb b/examples/notebooks/keras_simple_demo.ipynb index fa62b595..40f613f9 100644 --- a/examples/notebooks/keras_simple_demo.ipynb +++ b/examples/notebooks/keras_simple_demo.ipynb @@ -144,7 +144,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -169,7 +169,7 @@ "# now generate the witness file \n", "witness_path = \"witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -191,7 +191,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -220,7 +219,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/lightgbm.ipynb b/examples/notebooks/lightgbm.ipynb index d19bf002..534fb1d6 100644 --- a/examples/notebooks/lightgbm.ipynb +++ b/examples/notebooks/lightgbm.ipynb @@ -187,7 +187,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -211,7 +211,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -235,7 +235,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -264,7 +263,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/little_transformer.ipynb b/examples/notebooks/little_transformer.ipynb index bec4d18e..58ac5033 100644 --- a/examples/notebooks/little_transformer.ipynb +++ b/examples/notebooks/little_transformer.ipynb @@ -10,6 +10,16 @@ "## Model Architecture and training" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "c22afe46", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install pytorch_lightning\n" + ] + }, { "cell_type": "code", "execution_count": null, @@ -17,9 +27,6 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install pytorch_lightning\n", - "\n", - "\n", "import random\n", "import math\n", "import numpy as np\n", @@ -324,7 +331,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -349,7 +356,7 @@ "# now generate the witness file \n", "witness_path = \"gan_witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -360,7 +367,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.mock(witness_path, compiled_model_path, settings_path)\n", + "res = ezkl.mock(witness_path, compiled_model_path)\n", "assert res == True" ] }, @@ -382,7 +389,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -411,7 +417,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/lstm.ipynb b/examples/notebooks/lstm.ipynb index 02a2e9ce..82574a7c 100644 --- a/examples/notebooks/lstm.ipynb +++ b/examples/notebooks/lstm.ipynb @@ -148,7 +148,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -173,7 +173,7 @@ "# now generate the witness file \n", "witness_path = \"lstmwitness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -184,7 +184,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.mock(witness_path, compiled_model_path, settings_path)\n", + "res = ezkl.mock(witness_path, compiled_model_path)\n", "assert res == True" ] }, @@ -206,7 +206,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -235,7 +234,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/mean_postgres.ipynb b/examples/notebooks/mean_postgres.ipynb index f37916d5..5ce3ee41 100644 --- a/examples/notebooks/mean_postgres.ipynb +++ b/examples/notebooks/mean_postgres.ipynb @@ -266,7 +266,7 @@ "outputs": [], "source": [ "\n", - "ezkl.compile_model(onnx_filename, compiled_filename, settings_filename)" + "ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)" ] }, { @@ -334,7 +334,7 @@ "\n", "witness_path = \"witness.json\"\n", "\n", - "res = ezkl.gen_witness(input_filename, compiled_filename, witness_path, settings_path = settings_filename)\n", + "res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -363,7 +363,6 @@ " params_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_filename,\n", " )\n", "\n", "\n", diff --git a/examples/notebooks/mnist_gan.ipynb b/examples/notebooks/mnist_gan.ipynb index 4014119b..0b94c1ad 100644 --- a/examples/notebooks/mnist_gan.ipynb +++ b/examples/notebooks/mnist_gan.ipynb @@ -277,7 +277,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -300,7 +300,7 @@ "# now generate the witness file \n", "witness_path = \"gan_witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -311,7 +311,7 @@ "outputs": [], "source": [ "# uncomment to mock prove\n", - "# res = ezkl.mock(witness_path, compiled_model_path, settings_path)\n", + "# res = ezkl.mock(witness_path, compiled_model_path)\n", "# assert res == True" ] }, @@ -332,7 +332,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -360,7 +359,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/mnist_vae.ipynb b/examples/notebooks/mnist_vae.ipynb index 98b246cd..c4e03ad8 100644 --- a/examples/notebooks/mnist_vae.ipynb +++ b/examples/notebooks/mnist_vae.ipynb @@ -208,7 +208,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -231,7 +231,7 @@ "# now generate the witness file\n", "witness_path = \"ae_witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -241,7 +241,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.mock(witness_path, compiled_model_path, settings_path)\n", + "res = ezkl.mock(witness_path, compiled_model_path)\n", "assert res == True" ] }, @@ -262,7 +262,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -290,7 +289,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", @@ -447,7 +445,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -470,7 +468,7 @@ "# now generate the witness file \n", "witness_path = \"vae_witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -481,7 +479,7 @@ "outputs": [], "source": [ "# uncomment to mock prove\n", - "# res = ezkl.mock(witness_path, compiled_model_path, settings_path)\n", + "# res = ezkl.mock(witness_path, compiled_model_path)\n", "# assert res == True" ] }, @@ -502,9 +500,9 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", + "\n", "assert res == True\n", "assert os.path.isfile(vk_path)\n", "assert os.path.isfile(pk_path)\n", @@ -530,7 +528,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/nbeats_timeseries_forecasting.ipynb b/examples/notebooks/nbeats_timeseries_forecasting.ipynb index 10e3a742..c1fc853d 100644 --- a/examples/notebooks/nbeats_timeseries_forecasting.ipynb +++ b/examples/notebooks/nbeats_timeseries_forecasting.ipynb @@ -851,7 +851,7 @@ }, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -875,7 +875,7 @@ }, "outputs": [], "source": [ - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -892,7 +892,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -919,7 +918,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/random_forest.ipynb b/examples/notebooks/random_forest.ipynb index 24701bbc..b941daa3 100644 --- a/examples/notebooks/random_forest.ipynb +++ b/examples/notebooks/random_forest.ipynb @@ -187,7 +187,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -211,7 +211,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -235,7 +235,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -264,7 +263,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/set_membership.ipynb b/examples/notebooks/set_membership.ipynb index 66f4e923..37152ef7 100644 --- a/examples/notebooks/set_membership.ipynb +++ b/examples/notebooks/set_membership.ipynb @@ -181,7 +181,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -205,7 +205,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -238,7 +238,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -276,7 +275,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/simple_demo.ipynb b/examples/notebooks/simple_demo.ipynb index fcd08783..5ad6b983 100644 --- a/examples/notebooks/simple_demo.ipynb +++ b/examples/notebooks/simple_demo.ipynb @@ -158,7 +158,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -182,7 +182,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -206,7 +206,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -235,7 +234,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/simple_demo_aggregated_proofs.ipynb b/examples/notebooks/simple_demo_aggregated_proofs.ipynb index 44c93501..a7a1a303 100644 --- a/examples/notebooks/simple_demo_aggregated_proofs.ipynb +++ b/examples/notebooks/simple_demo_aggregated_proofs.ipynb @@ -164,7 +164,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -188,7 +188,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -212,7 +212,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -241,7 +240,6 @@ " srs_path,\n", " \"poseidon\", # IMPORTANT NOTE: To produce an aggregated EVM proof you will want to use poseidon for the smaller proofs\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/svm.ipynb b/examples/notebooks/svm.ipynb index c1c26901..7f04c152 100644 --- a/examples/notebooks/svm.ipynb +++ b/examples/notebooks/svm.ipynb @@ -174,7 +174,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -198,7 +198,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -222,7 +222,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -251,7 +250,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/variance.ipynb b/examples/notebooks/variance.ipynb index 8fe72096..695d3fb9 100644 --- a/examples/notebooks/variance.ipynb +++ b/examples/notebooks/variance.ipynb @@ -250,7 +250,7 @@ "ezkl.gen_settings(onnx_filename, settings_filename)\n", "await ezkl.calibrate_settings(\n", " input_filename, onnx_filename, settings_filename, \"resources\")\n", - "ezkl.compile_model(onnx_filename, compiled_filename, settings_filename)\n", + "ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)\n", "\n", "# show the settings.json\n", "with open(\"settings.json\") as f:\n", @@ -283,7 +283,6 @@ " vk_path,\n", " pk_path,\n", " params_path,\n", - " settings_filename,\n", " )\n", "\n", "assert res == True\n", @@ -302,7 +301,7 @@ "\n", "witness_path = \"witness.json\"\n", "\n", - "res = ezkl.gen_witness(input_filename, compiled_filename, witness_path, settings_path = settings_filename)\n", + "res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -331,7 +330,6 @@ " params_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_filename,\n", " )\n", "\n", "\n", diff --git a/examples/notebooks/voice_judge.ipynb b/examples/notebooks/voice_judge.ipynb index 9baa8283..ee764c45 100644 --- a/examples/notebooks/voice_judge.ipynb +++ b/examples/notebooks/voice_judge.ipynb @@ -648,7 +648,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -690,7 +690,7 @@ "\n", "witness_path = \"witness.json\"\n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -710,7 +710,7 @@ "source": [ "\n", "\n", - "res = ezkl.mock(witness_path, compiled_model_path, settings_path)" + "res = ezkl.mock(witness_path, compiled_model_path)" ] }, { @@ -737,7 +737,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -772,7 +771,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/examples/notebooks/xgboost.ipynb b/examples/notebooks/xgboost.ipynb index d6024866..1b4a805f 100644 --- a/examples/notebooks/xgboost.ipynb +++ b/examples/notebooks/xgboost.ipynb @@ -188,7 +188,7 @@ "metadata": {}, "outputs": [], "source": [ - "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", "assert res == True" ] }, @@ -212,7 +212,7 @@ "source": [ "# now generate the witness file \n", "\n", - "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", "assert os.path.isfile(witness_path)" ] }, @@ -236,7 +236,6 @@ " vk_path,\n", " pk_path,\n", " srs_path,\n", - " settings_path,\n", " )\n", "\n", "assert res == True\n", @@ -265,7 +264,6 @@ " srs_path,\n", " \"evm\",\n", " \"single\",\n", - " settings_path,\n", " )\n", "\n", "print(res)\n", diff --git a/src/commands.rs b/src/commands.rs index 2ab44f9d..33e56481 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -233,13 +233,10 @@ pub enum Commands { data: PathBuf, /// The path to the compiled model file #[arg(short = 'M', long)] - compiled_model: PathBuf, + compiled_circuit: PathBuf, /// Path to the witness (public and private inputs) .json file #[arg(short = 'O', long, default_value = "witness.json")] output: PathBuf, - /// Path to circuit_settings .json file to read in - #[arg(short = 'S', long)] - settings_path: PathBuf, }, /// Produces the proving hyperparameters, from run-args @@ -314,9 +311,6 @@ pub enum Commands { /// The path to the .onnx model file #[arg(short = 'M', long)] model: PathBuf, - /// circuit params path - #[arg(short = 'S', long)] - settings_path: PathBuf, }, /// Mock aggregate proofs @@ -381,13 +375,13 @@ pub enum Commands { }, /// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements #[command(arg_required_else_help = true)] - CompileModel { + CompileCircuit { /// The path to the .onnx model file #[arg(short = 'M', long)] model: PathBuf, /// The path to output the processed model #[arg(long)] - compiled_model: PathBuf, + compiled_circuit: PathBuf, /// The path to load circuit params from #[arg(short = 'S', long)] settings_path: PathBuf, @@ -397,7 +391,7 @@ pub enum Commands { Setup { /// The path to the compiled model file #[arg(short = 'M', long)] - compiled_model: PathBuf, + compiled_circuit: PathBuf, /// The srs path #[arg(long)] srs_path: PathBuf, @@ -407,9 +401,6 @@ pub enum Commands { /// The path to output the proving key file #[arg(long, default_value = "pk.key")] pk_path: PathBuf, - /// The path to load circuit params from - #[arg(short = 'S', long)] - settings_path: PathBuf, }, #[cfg(not(target_arch = "wasm32"))] @@ -421,7 +412,7 @@ pub enum Commands { witness: PathBuf, /// The path to the processed model file #[arg(short = 'M', long)] - compiled_model: PathBuf, + compiled_circuit: PathBuf, #[arg( long, require_equals = true, @@ -430,15 +421,9 @@ pub enum Commands { value_enum )] transcript: TranscriptType, - /// proving arguments - #[clap(flatten)] - args: RunArgs, /// number of fuzz iterations #[arg(long, default_value = "10")] num_runs: usize, - /// optional circuit params path (overrides any run args set) - #[arg(short = 'S', long)] - settings_path: Option, }, #[cfg(not(target_arch = "wasm32"))] SetupTestEVMData { @@ -447,10 +432,7 @@ pub enum Commands { data: PathBuf, /// The path to the compiled model file #[arg(short = 'M', long)] - compiled_model: PathBuf, - /// The path to load circuit params from - #[arg(long)] - settings_path: PathBuf, + compiled_circuit: PathBuf, /// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information /// derived from the file information in the data .json file. /// Should include both the network input (possibly private) and the network output (public input to the proof) @@ -476,7 +458,7 @@ pub enum Commands { witness: PathBuf, /// The path to the compiled model file #[arg(short = 'M', long)] - compiled_model: PathBuf, + compiled_circuit: PathBuf, /// The path to load the desired proving key file #[arg(long)] pk_path: PathBuf, @@ -503,9 +485,6 @@ pub enum Commands { value_enum )] strategy: StrategyType, - /// The path to load circuit params from - #[arg(short = 'S', long)] - settings_path: PathBuf, /// run sanity checks during calculations (safe or unsafe) #[arg(long, default_value = "safe")] check_mode: CheckMode, diff --git a/src/execute.rs b/src/execute.rs index 0663d031..f0b5b4da 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -116,23 +116,10 @@ pub async fn run(cli: Cli) -> Result<(), Box> { #[cfg(not(target_arch = "wasm32"))] Commands::Fuzz { witness, - compiled_model, + compiled_circuit, transcript, - args, num_runs, - settings_path, - } => { - fuzz( - compiled_model, - args.logrows, - witness, - transcript, - num_runs, - args, - settings_path, - ) - .await - } + } => fuzz(compiled_circuit, witness, transcript, num_runs).await, Commands::GenSrs { srs_path, logrows } => gen_srs_cmd(srs_path, logrows as u32), #[cfg(not(target_arch = "wasm32"))] @@ -164,17 +151,12 @@ pub async fn run(cli: Cli) -> Result<(), Box> { } => calibrate(model, data, settings_path, target, scales).await, Commands::GenWitness { data, - compiled_model, + compiled_circuit, output, - settings_path, - } => gen_witness(compiled_model, data, Some(output), settings_path) + } => gen_witness(compiled_circuit, data, Some(output)) .await .map(|_| ()), - Commands::Mock { - model, - witness, - settings_path, - } => mock(model, witness, settings_path).await, + Commands::Mock { model, witness } => mock(model, witness).await, #[cfg(not(target_arch = "wasm32"))] Commands::CreateEVMVerifier { vk_path, @@ -213,23 +195,21 @@ pub async fn run(cli: Cli) -> Result<(), Box> { abi_path, aggregation_settings, ), - Commands::CompileModel { + Commands::CompileCircuit { model, - compiled_model, + compiled_circuit, settings_path, - } => compile_model(model, compiled_model, settings_path), + } => compile_circuit(model, compiled_circuit, settings_path), Commands::Setup { - compiled_model, + compiled_circuit, srs_path, - settings_path, vk_path, pk_path, - } => setup(compiled_model, srs_path, settings_path, vk_path, pk_path), + } => setup(compiled_circuit, srs_path, vk_path, pk_path), #[cfg(not(target_arch = "wasm32"))] Commands::SetupTestEVMData { data, - compiled_model, - settings_path, + compiled_circuit, test_data, rpc_url, input_source, @@ -237,8 +217,7 @@ pub async fn run(cli: Cli) -> Result<(), Box> { } => { setup_test_evm_witness( data, - compiled_model, - settings_path, + compiled_circuit, test_data, rpc_url, input_source, @@ -249,23 +228,21 @@ pub async fn run(cli: Cli) -> Result<(), Box> { #[cfg(not(target_arch = "wasm32"))] Commands::Prove { witness, - compiled_model, + compiled_circuit, pk_path, proof_path, srs_path, transcript, strategy, - settings_path, check_mode, } => prove( witness, - compiled_model, + compiled_circuit, pk_path, Some(proof_path), srs_path, transcript, strategy, - settings_path, check_mode, ) .await @@ -426,20 +403,13 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<(), Box, - settings_path: PathBuf, ) -> Result> { // these aren't real values so the sanity checks are mostly meaningless - let circuit_settings = GraphSettings::load(&settings_path)?; - - let mut circuit = GraphCircuit::preprocessed_from_settings( - &circuit_settings, - &compiled_model_path, - CheckMode::UNSAFE, - )?; + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; let data = GraphData::from_path(data)?; #[cfg(not(target_arch = "wasm32"))] @@ -454,7 +424,8 @@ pub(crate) async fn gen_witness( // print each variable tuple (symbol, value) as symbol=value trace!( "witness generation {:?} took {:?}", - circuit_settings + circuit + .settings() .run_args .variables .iter() @@ -476,7 +447,7 @@ pub(crate) fn gen_circuit_settings( run_args: RunArgs, ) -> Result<(), Box> { let circuit = GraphCircuit::from_run_args(&run_args, &model_path)?; - let params = circuit.settings; + let params = circuit.settings(); params.save(¶ms_output).map_err(Box::::from) } @@ -628,21 +599,23 @@ pub(crate) async fn calibrate( .calibrate(&data) .map_err(|e| format!("failed to calibrate: {}", e))?; + let settings = circuit.settings().clone(); + let found_run_args = RunArgs { - input_scale: circuit.settings.run_args.input_scale, - param_scale: circuit.settings.run_args.param_scale, - bits: circuit.settings.run_args.bits, - logrows: circuit.settings.run_args.logrows, - scale_rebase_multiplier: circuit.settings.run_args.scale_rebase_multiplier, + input_scale: settings.run_args.input_scale, + param_scale: settings.run_args.param_scale, + bits: settings.run_args.bits, + logrows: settings.run_args.logrows, + scale_rebase_multiplier: settings.run_args.scale_rebase_multiplier, ..run_args.clone() }; let found_settings = GraphSettings { run_args: found_run_args, - required_lookups: circuit.settings.required_lookups, - model_output_scales: circuit.settings.model_output_scales, - num_constraints: circuit.settings.num_constraints, - total_const_size: circuit.settings.total_const_size, + required_lookups: settings.required_lookups, + model_output_scales: settings.model_output_scales, + num_constraints: settings.num_constraints, + total_const_size: settings.total_const_size, ..original_settings.clone() }; @@ -775,17 +748,11 @@ pub(crate) async fn calibrate( } pub(crate) async fn mock( - compiled_model_path: PathBuf, + compiled_circuit_path: PathBuf, data_path: PathBuf, - settings_path: PathBuf, ) -> Result<(), Box> { // mock should catch any issues by default so we set it to safe - let circuit_settings = GraphSettings::load(&settings_path)?; - let mut circuit = GraphCircuit::preprocessed_from_settings( - &circuit_settings, - &compiled_model_path, - CheckMode::SAFE, - )?; + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; let data = GraphWitness::from_path(data_path)?; @@ -796,7 +763,7 @@ pub(crate) async fn mock( info!("Mock proof"); let prover = halo2_proofs::dev::MockProver::run( - circuit.settings.run_args.logrows, + circuit.settings().run_args.logrows, &circuit, public_inputs, ) @@ -1082,32 +1049,27 @@ pub(crate) fn create_evm_aggregate_verifier( Ok(()) } -pub(crate) fn compile_model( +pub(crate) fn compile_circuit( model_path: PathBuf, - compiled_model: PathBuf, + compiled_circuit: PathBuf, settings_path: PathBuf, ) -> Result<(), Box> { let settings = GraphSettings::load(&settings_path)?; - let model = Model::from_run_args(&settings.run_args, &model_path)?; - model.save(compiled_model)?; + let circuit = GraphCircuit::from_settings(&settings, &model_path, CheckMode::UNSAFE)?; + circuit.save(compiled_circuit)?; Ok(()) } pub(crate) fn setup( - compiled_model: PathBuf, + compiled_circuit: PathBuf, srs_path: PathBuf, - settings_path: PathBuf, vk_path: PathBuf, pk_path: PathBuf, ) -> Result<(), Box> { // these aren't real values so the sanity checks are mostly meaningless - let circuit_settings = GraphSettings::load(&settings_path)?; - let circuit = GraphCircuit::preprocessed_from_settings( - &circuit_settings, - &compiled_model, - CheckMode::UNSAFE, - )?; - let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?; + let circuit = GraphCircuit::load(compiled_circuit)?; + + let params = load_params_cmd(srs_path, circuit.settings().run_args.logrows)?; let pk = create_keys::, Fr, GraphCircuit>(&circuit, ¶ms) .map_err(Box::::from)?; @@ -1120,8 +1082,7 @@ pub(crate) fn setup( #[cfg(not(target_arch = "wasm32"))] pub(crate) async fn setup_test_evm_witness( data_path: PathBuf, - compiled_model_path: PathBuf, - settings_path: PathBuf, + compiled_circuit_path: PathBuf, test_data: PathBuf, rpc_url: Option, input_source: TestDataSource, @@ -1131,12 +1092,7 @@ pub(crate) async fn setup_test_evm_witness( info!("run this command in background to keep the instance running for testing"); let mut data = GraphData::from_path(data_path)?; - let circuit_settings = GraphSettings::load(&settings_path)?; - let mut circuit = GraphCircuit::preprocessed_from_settings( - &circuit_settings, - &compiled_model_path, - CheckMode::SAFE, - )?; + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; // if both input and output are from files fail if matches!(input_source, TestDataSource::File) && matches!(output_source, TestDataSource::File) @@ -1164,27 +1120,21 @@ pub(crate) async fn setup_test_evm_witness( #[allow(clippy::too_many_arguments)] pub(crate) async fn prove( data_path: PathBuf, - compiled_model_path: PathBuf, + compiled_circuit_path: PathBuf, pk_path: PathBuf, proof_path: Option, srs_path: PathBuf, transcript: TranscriptType, strategy: StrategyType, - settings_path: PathBuf, check_mode: CheckMode, ) -> Result, Box> { let data = GraphWitness::from_path(data_path)?; - let circuit_settings = GraphSettings::load(&settings_path)?; - let mut circuit = GraphCircuit::preprocessed_from_settings( - &circuit_settings, - &compiled_model_path, - check_mode, - )?; + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; circuit.load_graph_witness(&data)?; let public_inputs = circuit.prepare_public_inputs(&data)?; - let circuit_settings = circuit.settings.clone(); + let circuit_settings = circuit.settings().clone(); let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?; @@ -1230,35 +1180,24 @@ pub(crate) async fn prove( #[cfg(not(target_arch = "wasm32"))] pub(crate) async fn fuzz( - compiled_model_path: PathBuf, - logrows: u32, + compiled_circuit_path: PathBuf, data_path: PathBuf, transcript: TranscriptType, num_runs: usize, - run_args: RunArgs, - settings_path: Option, ) -> Result<(), Box> { check_solc_requirement(); let passed = AtomicBool::new(true); + // these aren't real values so the sanity checks are mostly meaningless + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; + let logrows = circuit.settings().run_args.logrows; + info!("setting up tests"); let _r = Gag::stdout().unwrap(); let params = gen_srs::>(logrows); let data = GraphWitness::from_path(data_path)?; - // these aren't real values so the sanity checks are mostly meaningless - let mut circuit = match settings_path { - Some(path) => { - let circuit_settings = GraphSettings::load(&path)?; - GraphCircuit::preprocessed_from_settings( - &circuit_settings, - &compiled_model_path, - CheckMode::UNSAFE, - )? - } - None => GraphCircuit::preprocessed_from_run_args(&run_args, &compiled_model_path)?, - }; let pk = create_keys::, Fr, GraphCircuit>(&circuit, ¶ms) .map_err(Box::::from)?; @@ -1418,7 +1357,7 @@ pub(crate) async fn fuzz( run_fuzz_fn(num_runs, fuzz_proof_instances, &passed); if matches!(transcript, TranscriptType::EVM) { - let num_instance = circuit.settings.total_instances(); + let num_instance = circuit.settings().total_instances(); let yul_code = gen_evm_verifier(¶ms, pk.get_vk(), num_instance)?; let deployment_code = gen_deployment_code(yul_code).unwrap(); diff --git a/src/graph/mod.rs b/src/graph/mod.rs index f128b346..3c56873e 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -391,18 +391,59 @@ pub struct GraphConfig { } /// Defines the circuit for a computational graph / model loaded from a `.onnx` file. -#[derive(Clone, Debug, Default, Serialize)] -pub struct GraphCircuit { +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct CoreCircuit { /// The model / graph of computations. pub model: Model, - /// Vector of input tensors to the model / graph of computations. - pub graph_witness: GraphWitness, - /// The settings of the model / graph of computations. + /// The settings of the model. pub settings: GraphSettings, +} + +/// Defines the circuit for a computational graph / model loaded from a `.onnx` file. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct GraphCircuit { + /// Core circuit + pub core: CoreCircuit, + /// The witness data for the model. + pub graph_witness: GraphWitness, /// The settings of the model's modules. pub module_settings: ModuleSettings, } +impl GraphCircuit { + /// Settings for the graph + pub fn settings(&self) -> &GraphSettings { + &self.core.settings + } + /// Settings for the graph (mutable) + pub fn settings_mut(&mut self) -> &mut GraphSettings { + &mut self.core.settings + } + /// The model + pub fn model(&self) -> &Model { + &self.core.model + } + /// + pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box> { + let f = std::fs::File::create(path)?; + let writer = std::io::BufWriter::new(f); + bincode::serialize_into(writer, &self)?; + Ok(()) + } + + /// + pub fn load(path: std::path::PathBuf) -> Result> { + // read bytes from file + let mut f = std::fs::File::open(&path) + .unwrap_or_else(|_| panic!("failed to load model at {}", path.display())); + let metadata = std::fs::metadata(&path).expect("unable to read metadata"); + let mut buffer = vec![0; metadata.len() as usize]; + f.read_exact(&mut buffer).expect("buffer overflow"); + let result = bincode::deserialize(&buffer)?; + Ok(result) + } +} + #[derive(Clone, Debug, Default, Deserialize, Serialize)] /// The data source for a test pub enum TestDataSource { @@ -481,10 +522,14 @@ impl GraphCircuit { // as they occupy independent rows settings.num_constraints = std::cmp::max(settings.num_constraints, sizes.max_constraints()); - Ok(GraphCircuit { + let core = CoreCircuit { model, + settings: settings.clone(), + }; + + Ok(GraphCircuit { + core, graph_witness: GraphWitness::new(inputs, vec![]), - settings, module_settings, }) } @@ -507,10 +552,14 @@ impl GraphCircuit { settings.check_mode = check_mode; - Ok(GraphCircuit { + let core = CoreCircuit { model, + settings: settings.clone(), + }; + + Ok(GraphCircuit { + core, graph_witness: GraphWitness::new(inputs, vec![]), - settings, module_settings, }) } @@ -536,10 +585,10 @@ impl GraphCircuit { // the ordering here is important, we want the inputs to come before the outputs // as they are configured in that order as Column let mut public_inputs = vec![]; - if self.settings.run_args.input_visibility.is_public() { + if self.settings().run_args.input_visibility.is_public() { public_inputs = self.graph_witness.inputs.clone(); } - if self.settings.run_args.output_visibility.is_public() { + if self.settings().run_args.output_visibility.is_public() { public_inputs.extend(self.graph_witness.outputs.clone()); } info!( @@ -564,7 +613,7 @@ impl GraphCircuit { .collect::>>(); let module_instances = - GraphModules::public_inputs(data, VarVisibility::from_args(&self.settings.run_args)?); + GraphModules::public_inputs(data, VarVisibility::from_args(&self.settings().run_args)?); if !module_instances.is_empty() { pi_inner.extend(module_instances); @@ -579,8 +628,8 @@ impl GraphCircuit { &mut self, data: &GraphData, ) -> Result>, Box> { - let shapes = self.model.graph.input_shapes(); - let scales = self.model.graph.get_input_scales(); + let shapes = self.model().graph.input_shapes(); + let scales = self.model().graph.get_input_scales(); self.process_data_source(&data.input_data, shapes, scales) } @@ -590,8 +639,8 @@ impl GraphCircuit { &mut self, data: &GraphData, ) -> Result>, Box> { - let shapes = self.model.graph.input_shapes(); - let scales = self.model.graph.get_input_scales(); + let shapes = self.model().graph.input_shapes(); + let scales = self.model().graph.get_input_scales(); info!("input scales: {:?}", scales); self.process_data_source(&data.input_data, shapes, scales) .await @@ -715,18 +764,18 @@ impl GraphCircuit { .ceil() as usize + 1; - let min_rows_from_constraints = (self.settings.num_constraints as f32 + let min_rows_from_constraints = (self.settings().num_constraints as f32 + reserved_blinding_rows) .log2() .ceil() as usize; let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints); // if public input then public inputs col will have public inputs len - if self.settings.run_args.input_visibility.is_public() - || self.settings.run_args.output_visibility.is_public() + if self.settings().run_args.input_visibility.is_public() + || self.settings().run_args.output_visibility.is_public() { let max_instance_len = self - .model + .model() .instance_shapes() .iter() .fold(0, |acc, x| std::cmp::max(acc, x.iter().product::())) @@ -740,28 +789,32 @@ impl GraphCircuit { // ensure logrows is at least 4 logrows = std::cmp::max(logrows, MIN_LOGROWS as usize); logrows = std::cmp::min(logrows, MAX_PUBLIC_SRS as usize); + let model = self.model().clone(); + let settings_mut = self.settings_mut(); + settings_mut.run_args.bits = min_bits; + settings_mut.run_args.logrows = logrows as u32; - self.settings.run_args.bits = min_bits; - self.settings.run_args.logrows = logrows as u32; - - self.settings = GraphCircuit::new(self.model.clone(), &self.settings.run_args)?.settings; + *settings_mut = GraphCircuit::new(model, &settings_mut.run_args)? + .settings() + .clone(); // recalculate the total const size give nthe new logrows - let total_const_len = self.settings.total_const_size; + let total_const_len = settings_mut.total_const_size; let const_len_logrows = (total_const_len as f64).log2().ceil() as u32; - self.settings.run_args.logrows = - std::cmp::max(self.settings.run_args.logrows, const_len_logrows); + settings_mut.run_args.logrows = + std::cmp::max(settings_mut.run_args.logrows, const_len_logrows); // recalculate the total number of constraints given the new logrows - let min_rows_from_constraints = (self.settings.num_constraints as f32 + let min_rows_from_constraints = (settings_mut.num_constraints as f32 + reserved_blinding_rows) .log2() .ceil() as u32; - self.settings.run_args.logrows = - std::cmp::max(self.settings.run_args.logrows, min_rows_from_constraints); + settings_mut.run_args.logrows = + std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints); info!( "setting bits to: {}, setting logrows to: {}", - self.settings.run_args.bits, self.settings.run_args.logrows + self.settings().run_args.bits, + self.settings().run_args.logrows ); Ok(()) @@ -772,7 +825,8 @@ impl GraphCircuit { let res = self.forward(&mut input.to_vec())?; let blinding_offset = (ASSUMED_BLINDING_FACTORS as f64 / 2.0).ceil() + 1.0; - let max_range = 2i128.pow(self.settings.run_args.bits as u32 - 1) - blinding_offset as i128; + let max_range = + 2i128.pow(self.settings().run_args.bits as u32 - 1) - blinding_offset as i128; if res.max_lookup_inputs > max_range { let recommended_bits = (res.max_lookup_inputs as f64 + blinding_offset) @@ -800,7 +854,7 @@ impl GraphCircuit { ) -> Result> { let original_inputs = inputs.to_vec(); - let visibility = VarVisibility::from_args(&self.settings.run_args)?; + let visibility = VarVisibility::from_args(&self.settings().run_args)?; let mut processed_inputs = None; let mut processed_params = None; let mut processed_outputs = None; @@ -814,7 +868,7 @@ impl GraphCircuit { } if visibility.params.requires_processing() { - let params = self.model.get_all_params(); + let params = self.model().get_all_params(); if !params.is_empty() { let flattened_params = Tensor::new(Some(¶ms), &[params.len()])?.combine()?; processed_params = Some(GraphModules::forward( @@ -824,7 +878,7 @@ impl GraphCircuit { } } - let model_results = self.model.forward(inputs)?; + let model_results = self.model().forward(inputs)?; if visibility.output.requires_processing() { processed_outputs = Some(GraphModules::forward( @@ -868,18 +922,6 @@ impl GraphCircuit { Self::new(model, run_args) } - /// - pub fn preprocessed_from_run_args( - run_args: &RunArgs, - model_path: &std::path::Path, - ) -> Result> { - let model = Model::load(model_path.to_path_buf()).map_err(|e| { - error!("failed to deserialize compiled model. have you called compile-model ?"); - e - })?; - Self::new(model, run_args) - } - /// Create a new circuit from a set of input data and [GraphSettings]. #[cfg(not(target_arch = "wasm32"))] pub fn from_settings( @@ -891,19 +933,6 @@ impl GraphCircuit { Self::new_from_settings(model, params.clone(), check_mode) } - /// Create a new circuit from a set of input data and [GraphSettings]. - pub fn preprocessed_from_settings( - params: &GraphSettings, - model_path: &std::path::Path, - check_mode: CheckMode, - ) -> Result> { - let model = Model::load(model_path.to_path_buf()).map_err(|e| { - error!("failed to deserialize compiled model. have you called compile-model ?"); - e - })?; - Self::new_from_settings(model, params.clone(), check_mode) - } - /// #[cfg(not(target_arch = "wasm32"))] pub async fn populate_on_chain_test_data( @@ -918,7 +947,7 @@ impl GraphCircuit { TestDataSource::OnChain ) { // if not public then fail - if !self.settings.run_args.input_visibility.is_public() { + if !self.settings().run_args.input_visibility.is_public() { return Err("Cannot use on-chain data source as private data".into()); } @@ -931,11 +960,11 @@ impl GraphCircuit { }; // Get the flatten length of input_data let length = input_data.iter().map(|x| x.len()).sum(); - let scales = vec![self.settings.run_args.input_scale; length]; + let scales = vec![self.settings().run_args.input_scale; length]; let datam: (Vec>, OnChainSource) = OnChainSource::test_from_file_data( input_data, scales, - self.model.graph.input_shapes(), + self.model().graph.input_shapes(), test_on_chain_data.rpc.as_deref(), ) .await?; @@ -946,7 +975,7 @@ impl GraphCircuit { TestDataSource::OnChain ) { // if not public then fail - if !self.settings.run_args.output_visibility.is_public() { + if !self.settings().run_args.output_visibility.is_public() { return Err("Cannot use on-chain data source as private data".into()); } @@ -960,8 +989,8 @@ impl GraphCircuit { }; let datum: (Vec>, OnChainSource) = OnChainSource::test_from_file_data( output_data, - self.model.graph.get_output_scales(), - self.model.graph.output_shapes(), + self.model().graph.get_output_scales(), + self.model().graph.output_shapes(), test_on_chain_data.rpc.as_deref(), ) .await?; @@ -1018,7 +1047,7 @@ impl Circuit for GraphCircuit { fn params(&self) -> Self::Params { // safe to clone because the model is Arc'd - self.settings.clone() + self.settings().clone() } fn configure_with_params(cs: &mut ConstraintSystem, params: Self::Params) -> Self::Config { @@ -1098,18 +1127,18 @@ impl Circuit for GraphCircuit { &mut layouter, &config.module_configs, &mut inputs, - self.settings.run_args.input_visibility, + self.settings().run_args.input_visibility, &mut instance_offset, &self.module_settings.input, )?; // now we need to assign the flattened params to the model - let mut model = self.model.clone(); - let param_visibility = self.settings.run_args.param_visibility; + let mut model = self.model().clone(); + let param_visibility = self.settings().run_args.param_visibility; trace!("running params module layout"); - if !self.model.get_all_params().is_empty() && param_visibility.requires_processing() { + if !self.model().get_all_params().is_empty() && param_visibility.requires_processing() { // now we need to flatten the params - let consts = self.model.get_all_params(); + let consts = self.model().get_all_params(); let mut flattened_params = { let mut t = Tensor::new(Some(&consts), &[consts.len()]) @@ -1136,7 +1165,7 @@ impl Circuit for GraphCircuit { &self.module_settings.params, )?; - let shapes = self.model.const_shapes(); + let shapes = self.model().const_shapes(); trace!("replacing processed consts"); let split_params = split_valtensor(&flattened_params[0], shapes).map_err(|_| { log::error!("failed to split params"); @@ -1154,7 +1183,7 @@ impl Circuit for GraphCircuit { .layout( config.model_config.clone(), &mut layouter, - &self.settings.run_args, + &self.settings().run_args, &inputs, &config.model_config.vars, ) @@ -1169,7 +1198,7 @@ impl Circuit for GraphCircuit { &mut layouter, &config.module_configs, &mut outputs, - self.settings.run_args.output_visibility, + self.settings().run_args.output_visibility, &mut instance_offset, &self.module_settings.output, )?; diff --git a/src/graph/model.rs b/src/graph/model.rs index c511d353..dab2e1d9 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -523,7 +523,10 @@ impl Model { "------------ output node int {}: {} \n ------------ float: {}", idx, res.output.map(crate::fieldutils::felt_to_i32).show(), - res.output.map(|x| crate::fieldutils::felt_to_f64(x) / scale_to_multiplier(n.out_scale)).show() + res.output + .map(|x| crate::fieldutils::felt_to_f64(x) + / scale_to_multiplier(n.out_scale)) + .show() ); results.insert(idx, vec![res.output]); } diff --git a/src/python.rs b/src/python.rs index 816034fa..7ad99081 100644 --- a/src/python.rs +++ b/src/python.rs @@ -652,22 +652,11 @@ fn calibrate_settings( data, model, output, - settings_path, ))] -fn gen_witness( - data: PathBuf, - model: PathBuf, - output: Option, - settings_path: PathBuf, -) -> PyResult { +fn gen_witness(data: PathBuf, model: PathBuf, output: Option) -> PyResult { let output = Runtime::new() .unwrap() - .block_on(crate::execute::gen_witness( - model, - data, - output, - settings_path, - )) + .block_on(crate::execute::gen_witness(model, data, output)) .map_err(|e| { let err_str = format!("Failed to run generate witness: {}", e); PyRuntimeError::new_err(err_str) @@ -679,12 +668,11 @@ fn gen_witness( #[pyfunction(signature = ( witness, model, - settings_path, ))] -fn mock(witness: PathBuf, model: PathBuf, settings_path: PathBuf) -> PyResult { +fn mock(witness: PathBuf, model: PathBuf) -> PyResult { Runtime::new() .unwrap() - .block_on(crate::execute::mock(model, witness, settings_path)) + .block_on(crate::execute::mock(model, witness)) .map_err(|e| { let err_str = format!("Failed to run mock: {}", e); PyRuntimeError::new_err(err_str) @@ -713,16 +701,14 @@ fn mock_aggregate(aggregation_snarks: Vec, logrows: u32) -> PyResult Result { - crate::execute::setup(model, srs_path, settings_path, vk_path, pk_path).map_err(|e| { + crate::execute::setup(model, srs_path, vk_path, pk_path).map_err(|e| { let err_str = format!("Failed to run setup: {}", e); PyRuntimeError::new_err(err_str) })?; @@ -739,7 +725,6 @@ fn setup( srs_path, transcript, strategy, - settings_path, ))] fn prove( witness: PathBuf, @@ -749,7 +734,6 @@ fn prove( srs_path: PathBuf, transcript: TranscriptType, strategy: StrategyType, - settings_path: PathBuf, ) -> PyResult { let snark = Runtime::new() .unwrap() @@ -761,7 +745,6 @@ fn prove( srs_path, transcript, strategy, - settings_path, CheckMode::UNSAFE, )) .map_err(|e| { @@ -819,15 +802,15 @@ fn setup_aggregate( #[pyfunction(signature = ( model, - compiled_model, + compiled_circuit, settings_path, ))] -fn compile_model( +fn compile_circuit( model: PathBuf, - compiled_model: PathBuf, + compiled_circuit: PathBuf, settings_path: PathBuf, ) -> Result { - crate::execute::compile_model(model, compiled_model, settings_path).map_err(|e| { + crate::execute::compile_circuit(model, compiled_circuit, settings_path).map_err(|e| { let err_str = format!("Failed to setup aggregate: {}", e); PyRuntimeError::new_err(err_str) })?; @@ -1124,7 +1107,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(aggregate, m)?)?; m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?; m.add_function(wrap_pyfunction!(setup_aggregate, m)?)?; - m.add_function(wrap_pyfunction!(compile_model, m)?)?; + m.add_function(wrap_pyfunction!(compile_circuit, m)?)?; m.add_function(wrap_pyfunction!(verify_aggr, m)?)?; m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?; m.add_function(wrap_pyfunction!(deploy_evm, m)?)?; diff --git a/src/wasm.rs b/src/wasm.rs index 3e480e63..9e2e43af 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -214,24 +214,13 @@ pub fn elgamalDecrypt( #[wasm_bindgen] #[allow(non_snake_case)] pub fn genWitness( - compiled_model: wasm_bindgen::Clamped>, + compiled_circuit: wasm_bindgen::Clamped>, input: wasm_bindgen::Clamped>, - settings: wasm_bindgen::Clamped>, ) -> Result, JsError> { - let compiled_model: crate::graph::Model = - bincode::deserialize(&compiled_model[..]).map_err(|e| JsError::new(&format!("{}", e)))?; + let mut circuit: crate::graph::GraphCircuit = + bincode::deserialize(&compiled_circuit[..]).map_err(|e| JsError::new(&format!("{}", e)))?; let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..]).map_err(|e| JsError::new(&format!("{}", e)))?; - let circuit_settings: crate::graph::GraphSettings = - serde_json::from_slice(&settings[..]).map_err(|e| JsError::new(&format!("{}", e)))?; - - // read in circuit - let mut circuit = GraphCircuit::new_from_settings( - compiled_model, - circuit_settings, - crate::circuit::CheckMode::UNSAFE, - ) - .map_err(|e| JsError::new(&format!("{}", e)))?; let mut input = circuit .load_graph_input(&input) @@ -286,8 +275,7 @@ pub fn verify( pub fn prove( witness: wasm_bindgen::Clamped>, pk: wasm_bindgen::Clamped>, - compiled_model: wasm_bindgen::Clamped>, - settings: wasm_bindgen::Clamped>, + compiled_circuit: wasm_bindgen::Clamped>, srs: wasm_bindgen::Clamped>, ) -> Result, JsError> { #[cfg(feature = "det-prove")] @@ -300,31 +288,20 @@ pub fn prove( halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) .map_err(|e| JsError::new(&format!("{}", e)))?; + // read in circuit + let mut circuit: crate::graph::GraphCircuit = + bincode::deserialize(&compiled_circuit[..]).map_err(|e| JsError::new(&format!("{}", e)))?; + // read in model input let data: crate::graph::GraphWitness = serde_json::from_slice(&witness[..]).map_err(|e| JsError::new(&format!("{}", e)))?; - // read in circuit params - let circuit_settings: GraphSettings = - serde_json::from_slice(&settings[..]).map_err(|e| JsError::new(&format!("{}", e)))?; - // read in proving key let mut reader = std::io::BufReader::new(&pk[..]); let pk = ProvingKey::::read::<_, GraphCircuit>( &mut reader, halo2_proofs::SerdeFormat::RawBytes, - circuit_settings.clone(), - ) - .map_err(|e| JsError::new(&format!("{}", e)))?; - - // read in circuit - let compiled_model: crate::graph::Model = - bincode::deserialize(&compiled_model[..]).map_err(|e| JsError::new(&format!("{}", e)))?; - - let mut circuit = GraphCircuit::new_from_settings( - compiled_model, - circuit_settings, - crate::circuit::CheckMode::UNSAFE, + circuit.settings().clone(), ) .map_err(|e| JsError::new(&format!("{}", e)))?; diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 3445d6d0..3d0e7753 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -237,7 +237,7 @@ mod native_tests { "hummingbird_decision_tree", "oh_decision_tree", "linear_svc", - "gather_elements", + "gather_elements", "less", "xgboost_reg", ]; @@ -1059,10 +1059,10 @@ mod native_tests { "--bin", "ezkl", "--", - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings.json", @@ -1121,10 +1121,6 @@ mod native_tests { &format!("{}/{}/network.onnx", test_dir, example_name), "-O", &format!("{}/{}/witness.json", test_dir, example_name), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1162,10 +1158,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings.json", @@ -1185,10 +1181,6 @@ mod native_tests { &format!("{}/{}/network.onnx", test_dir, example_name), "-O", &format!("{}/{}/witness.json", test_dir, example_name), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1201,10 +1193,6 @@ mod native_tests { format!("{}/{}/witness.json", test_dir, counter_example).as_str(), "-M", format!("{}/{}/network.compiled", test_dir, example_name).as_str(), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1281,10 +1269,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings.json", @@ -1304,10 +1292,6 @@ mod native_tests { &format!("{}/{}/network.onnx", test_dir, example_name), "-O", &format!("{}/{}/witness_mock.json", test_dir, example_name), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1320,10 +1304,6 @@ mod native_tests { format!("{}/{}/witness_mock.json", test_dir, example_name).as_str(), "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1379,10 +1359,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.compiled", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings.json", @@ -1403,10 +1383,6 @@ mod native_tests { &format!("{}/{}/network.compiled", test_dir, example_name), "-O", &format!("{}/{}/witness.json", test_dir, example_name), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .stdout(std::process::Stdio::null()) .status() @@ -1469,10 +1445,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/tutorial/network.onnx", test_dir).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/tutorial/network.onnx", test_dir).as_str(), &format!("--settings-path={}/tutorial/settings.json", test_dir), ]) @@ -1489,7 +1465,6 @@ mod native_tests { &format!("{}/tutorial/network.onnx", test_dir), "-O", &format!("{}/tutorial/witness_tutorial.json", test_dir), - &format!("--settings-path={}/tutorial/settings.json", test_dir), ]) .status() .expect("failed to execute process"); @@ -1502,7 +1477,6 @@ mod native_tests { format!("{}/tutorial/witness_tutorial.json", test_dir).as_str(), "-M", format!("{}/tutorial/network.onnx", test_dir).as_str(), - &format!("--settings-path={}/tutorial/settings.json", test_dir), ]) .status() .expect("failed to execute process"); @@ -1545,10 +1519,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings.json", @@ -1566,10 +1540,6 @@ mod native_tests { format!("{}/{}/input.json", test_dir, example_name).as_str(), "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), "-O", format!("{}/{}/input.json", test_dir, example_name).as_str(), ]) @@ -1590,10 +1560,6 @@ mod native_tests { "--vk-path", &format!("{}/{}/key.vk", test_dir, example_name), &srs_path, - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1612,10 +1578,6 @@ mod native_tests { &srs_path, "--transcript=poseidon", "--strategy=accum", - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1668,10 +1630,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings.json", @@ -1689,10 +1651,6 @@ mod native_tests { format!("{}/{}/input.json", test_dir, example_name).as_str(), "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), "-O", format!("{}/{}/input.json", test_dir, example_name).as_str(), ]) @@ -1713,10 +1671,6 @@ mod native_tests { "--vk-path", &format!("{}/{}/key.vk", test_dir, example_name), &srs_path, - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1735,10 +1689,6 @@ mod native_tests { &srs_path, "--transcript=poseidon", "--strategy=accum", - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1835,10 +1785,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings.json", @@ -1856,10 +1806,6 @@ mod native_tests { format!("{}/{}/input.json", test_dir, example_name).as_str(), "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), "-O", format!("{}/{}/input.json", test_dir, example_name).as_str(), ]) @@ -1880,10 +1826,6 @@ mod native_tests { "--pk-path", &format!("{}/{}/evm.pk", test_dir, example_name), &srs_path, - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), ]) .status() .expect("failed to execute process"); @@ -1901,10 +1843,6 @@ mod native_tests { "--pk-path", &format!("{}/{}/evm.pk", test_dir, example_name), &srs_path, - &format!( - "--settings-path={}/{}/settings.json", - test_dir, example_name - ), "--transcript=poseidon", "--strategy=accum", ]) @@ -2088,10 +2026,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings.json", @@ -2112,7 +2050,6 @@ mod native_tests { format!("{}/{}/input.json", test_dir, example_name).as_str(), "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - format!("--settings-path={}", settings_path).as_str(), "-O", format!("{}/{}/input.json", test_dir, example_name).as_str(), ]) @@ -2130,7 +2067,6 @@ mod native_tests { "--vk-path", &format!("{}/{}/key.vk", test_dir, example_name), &srs_path, - format!("--settings-path={}", settings_path).as_str(), ]) .status() .expect("failed to execute process"); @@ -2149,7 +2085,6 @@ mod native_tests { &srs_path, "--transcript=evm", "--strategy=single", - format!("--settings-path={}", settings_path).as_str(), &format!("--check-mode={}", checkmode), ]) .status() @@ -2203,6 +2138,8 @@ mod native_tests { format!("{}/{}/settings_fuzz.json", test_dir, example_name).as_str(), &format!("--input-scale={}", scale), &format!("--param-scale={}", scale), + &format!("--bits={}", bits), + &format!("--logrows={}", logrows), ]) .status() .expect("failed to execute process"); @@ -2210,10 +2147,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), &format!( "--settings-path={}/{}/settings_fuzz.json", @@ -2231,8 +2168,6 @@ mod native_tests { format!("{}/{}/input.json", test_dir, example_name).as_str(), "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--settings-path", - format!("{}/{}/settings_fuzz.json", test_dir, example_name).as_str(), "-O", format!("{}/{}/witness_fuzz.json", test_dir, example_name).as_str(), ]) @@ -2247,10 +2182,6 @@ mod native_tests { format!("{}/{}/witness_fuzz.json", test_dir, example_name).as_str(), "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - &format!("--bits={}", bits), - &format!("--logrows={}", logrows), - &format!("--input-scale={}", scale), - &format!("--param-scale={}", scale), &format!("--num-runs={}", 5), &format!("--transcript={}", transcript), ]) @@ -2300,10 +2231,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - "--compiled-model", + "--compiled-circuit", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), (format!("--settings-path={}", settings_path).as_str()), ]) @@ -2321,7 +2252,6 @@ mod native_tests { format!("{}/{}/input.json", test_dir, example_name).as_str(), "-M", format!("{}/{}/network.onnx", test_dir, example_name).as_str(), - format!("--settings-path={}", settings_path).as_str(), "-O", format!("{}/{}/input.json", test_dir, example_name).as_str(), ]) @@ -2339,7 +2269,6 @@ mod native_tests { "--vk-path", &format!("{}/{}/key.vk", test_dir, example_name), &srs_path, - format!("--settings-path={}", settings_path).as_str(), ]) .status() .expect("failed to execute process"); @@ -2359,7 +2288,6 @@ mod native_tests { &srs_path, "--transcript=evm", "--strategy=single", - format!("--settings-path={}", settings_path).as_str(), ]) .status() .expect("failed to execute process"); @@ -2478,10 +2406,10 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "compile-model", + "compile-circuit", "-M", &model_path, - "--compiled-model", + "--compiled-circuit", &model_path, &format!( "--settings-path={}/{}/settings.json", @@ -2499,7 +2427,6 @@ mod native_tests { data_path.as_str(), "-M", &model_path, - format!("--settings-path={}", settings_path).as_str(), "--test-data", test_on_chain_data_path.as_str(), rpc_arg.as_str(), @@ -2517,7 +2444,6 @@ mod native_tests { test_on_chain_data_path.as_str(), "-M", &model_path, - format!("--settings-path={}", settings_path).as_str(), "-O", &witness_path, ]) @@ -2535,7 +2461,6 @@ mod native_tests { "--vk-path", &format!("{}/{}/key.vk", test_dir, example_name), &srs_path, - format!("--settings-path={}", settings_path).as_str(), ]) .status() .expect("failed to execute process"); @@ -2555,7 +2480,6 @@ mod native_tests { &srs_path, "--transcript=evm", "--strategy=single", - format!("--settings-path={}", settings_path).as_str(), ]) .status() .expect("failed to execute process"); diff --git a/tests/output_comparison.py b/tests/output_comparison.py index 76499bbf..9b540d55 100644 --- a/tests/output_comparison.py +++ b/tests/output_comparison.py @@ -63,7 +63,7 @@ def compare_outputs(zk_output, onnx_output): res = [] contains_sublist = any(isinstance(sub, list) for sub in zk_output) - + print("zk ", zk_output) if contains_sublist: try: if len(onnx_output) == 1: diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index 46bc9055..dbbac895 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -254,7 +254,7 @@ def test_model_compile(): folder_path, 'settings.json' ) - res = ezkl.compile_model(model_path, compiled_model_path, settings_path) + res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path) assert res == True @@ -276,13 +276,8 @@ def test_forward(): folder_path, 'witness.json' ) - settings_path = os.path.join( - folder_path, - 'settings.json' - ) - res = ezkl.gen_witness(data_path, model_path, - output_path, settings_path=settings_path) + res = ezkl.gen_witness(data_path, model_path, output_path) with open(output_path, "r") as f: data = json.load(f) @@ -329,8 +324,7 @@ def test_mock(): settings_path = os.path.join(folder_path, 'settings.json') - res = ezkl.mock(data_path, model_path, - settings_path) + res = ezkl.mock(data_path, model_path) assert res == True @@ -358,7 +352,6 @@ def test_setup(): vk_path, pk_path, srs_path, - settings_path, ) assert res == True assert os.path.isfile(vk_path) @@ -382,19 +375,16 @@ def test_setup_evm(): pk_path = os.path.join(folder_path, 'test_evm.pk') vk_path = os.path.join(folder_path, 'test_evm.vk') - settings_path = os.path.join(folder_path, 'settings.json') res = ezkl.setup( model_path, vk_path, pk_path, srs_path, - settings_path, ) assert res == True assert os.path.isfile(vk_path) assert os.path.isfile(pk_path) - assert os.path.isfile(settings_path) def test_prove_and_verify(): @@ -414,7 +404,6 @@ def test_prove_and_verify(): pk_path = os.path.join(folder_path, 'test.pk') proof_path = os.path.join(folder_path, 'test.pf') - settings_path = os.path.join(folder_path, 'settings.json') res = ezkl.prove( data_path, @@ -424,11 +413,11 @@ def test_prove_and_verify(): srs_path, "poseidon", "single", - settings_path, ) assert res['transcript_type'] == 'Poseidon' assert os.path.isfile(proof_path) + settings_path = os.path.join(folder_path, 'settings.json') vk_path = os.path.join(folder_path, 'test.vk') res = ezkl.verify(proof_path, settings_path, vk_path, srs_path) @@ -453,7 +442,6 @@ def test_prove_evm(): pk_path = os.path.join(folder_path, 'test_evm.pk') proof_path = os.path.join(folder_path, 'test_evm.pf') - settings_path = os.path.join(folder_path, 'settings.json') res = ezkl.prove( data_path, model_path, @@ -462,7 +450,6 @@ def test_prove_evm(): srs_path, "evm", "single", - settings_path, ) assert res['transcript_type'] == 'EVM' assert os.path.isfile(proof_path) @@ -575,7 +562,7 @@ async def aggregate_and_verify_aggr(): assert res == True assert os.path.isfile(settings_path) - res = ezkl.compile_model(model_path, compiled_model_path, settings_path) + res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path) assert res == True ezkl.setup( @@ -583,7 +570,6 @@ async def aggregate_and_verify_aggr(): vk_path, pk_path, srs_path, - settings_path, ) proof_path = os.path.join(folder_path, '1l_relu.pf') @@ -594,7 +580,7 @@ async def aggregate_and_verify_aggr(): ) res = ezkl.gen_witness(data_path, compiled_model_path, - output_path, settings_path=settings_path) + output_path) ezkl.prove( output_path, @@ -604,7 +590,6 @@ async def aggregate_and_verify_aggr(): srs_path, "poseidon", "accum", - settings_path, ) # mock aggregate @@ -694,7 +679,7 @@ async def evm_aggregate_and_verify_aggr(): 'compiled_relu.onnx' ) - res = ezkl.compile_model(model_path, compiled_model_path, settings_path) + res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path) assert res == True ezkl.setup( @@ -702,7 +687,6 @@ async def evm_aggregate_and_verify_aggr(): vk_path, pk_path, srs_path, - settings_path, ) proof_path = os.path.join(folder_path, '1l_relu.pf') @@ -713,7 +697,7 @@ async def evm_aggregate_and_verify_aggr(): ) res = ezkl.gen_witness(data_path, compiled_model_path, - output_path, settings_path=settings_path) + output_path) ezkl.prove( output_path, @@ -723,7 +707,6 @@ async def evm_aggregate_and_verify_aggr(): srs_path, "poseidon", "accum", - settings_path, ) aggregate_proof_path = os.path.join(folder_path, 'aggr_evm_1l_relu.pf') diff --git a/tests/wasm.rs b/tests/wasm.rs index 3c2e0f54..ade9ba0c 100644 --- a/tests/wasm.rs +++ b/tests/wasm.rs @@ -182,7 +182,6 @@ mod wasm32 { let witness = genWitness( wasm_bindgen::Clamped(NETWORK.to_vec()), wasm_bindgen::Clamped(INPUT.to_vec()), - wasm_bindgen::Clamped(CIRCUIT_PARAMS.to_vec()), ) .map_err(|_| "failed") .unwrap(); diff --git a/tests/wasm/settings.json b/tests/wasm/settings.json index b2d815ec..aab22c42 100644 --- a/tests/wasm/settings.json +++ b/tests/wasm/settings.json @@ -4,9 +4,9 @@ "val": 0.0, "scale": 1.0 }, - "input_scale": 0, - "param_scale": 0, - "scale_rebase_multiplier": 2, + "input_scale": 20, + "param_scale": 20, + "scale_rebase_multiplier": 10, "bits": 5, "logrows": 7, "variables": [ diff --git a/tests/wasm/testWasm.test.ts b/tests/wasm/testWasm.test.ts index ba336859..461e38b3 100644 --- a/tests/wasm/testWasm.test.ts +++ b/tests/wasm/testWasm.test.ts @@ -35,7 +35,7 @@ describe('Generate witness, prove and verify', () => { circuit_settings_ser = await readEzklArtifactsFile(path, example, 'settings.json'); params_ser = await readEzklSrsFile(path, example); const startTimeProve = Date.now(); - result = wasmFunctions.prove(witness, pk, circuit_ser, circuit_settings_ser, params_ser); + result = wasmFunctions.prove(witness, pk, circuit_ser, params_ser); const endTimeProve = Date.now(); proof_ser = new Uint8ClampedArray(result.buffer); proveTime = endTimeProve - startTimeProve; diff --git a/tests/wasm/test_network.compiled b/tests/wasm/test_network.compiled index b2fe3fc3b00c32877a1a87276015a3eccc2bef60..1b2d1e96ad8e4ff01c46949631504ca491f40926 100644 GIT binary patch delta 202 zcmdnU*(|tWBg@19Hq!=s1}IZ$%0Jc zlb13{aKaR_PjnQYn8P#qFSEeJ0)fd2OdQG}DG=s_nqi;^gu3aOC2-T3CkrwLPgY>m U2dWnV(k?)($;bi}btgsy04b>;T>t<8 delta 7 OcmZn`+{n3MBMSfv5dyRT