fix!: make calibrate-settings sync in python (#616)

BREAKING CHANGE: calibrate settings is no longer async
This commit is contained in:
dante
2023-11-18 01:08:10 +03:00
committed by GitHub
parent 679b59794a
commit 13dae3392f
40 changed files with 7089 additions and 7094 deletions

View File

@@ -2,7 +2,6 @@ import ezkl
import os
import pytest
import json
import asyncio
import subprocess
import time
@@ -156,7 +155,8 @@ def test_gen_srs():
assert os.path.isfile(params_k20_path)
async def calibrate_over_user_range():
def test_calibrate_over_user_range():
data_path = os.path.join(
examples_path,
'onnx',
@@ -183,20 +183,14 @@ async def calibrate_over_user_range():
model_path, output_path, py_run_args=run_args)
assert res == True
res = await ezkl.calibrate_settings(
res = ezkl.calibrate_settings(
data_path, model_path, output_path, "resources", [0, 1, 2])
assert res == True
assert os.path.isfile(output_path)
def test_calibrate_calibrate_over_user_range():
"""
Test for calibrate
"""
asyncio.run(calibrate_over_user_range())
async def calibrate():
def test_calibrate():
data_path = os.path.join(
examples_path,
'onnx',
@@ -223,19 +217,12 @@ async def calibrate():
model_path, output_path, py_run_args=run_args)
assert res == True
res = await ezkl.calibrate_settings(
res = ezkl.calibrate_settings(
data_path, model_path, output_path, "resources")
assert res == True
assert os.path.isfile(output_path)
def test_calibrate():
"""
Test for calibrate
"""
asyncio.run(calibrate())
def test_model_compile():
"""
Test for model compilation/serialization
@@ -559,7 +546,7 @@ def test_verify_evm():
assert res == True
async def aggregate_and_verify_aggr():
def test_aggregate_and_verify_aggr():
data_path = os.path.join(
examples_path,
'onnx',
@@ -588,7 +575,7 @@ async def aggregate_and_verify_aggr():
res = ezkl.gen_settings(model_path, settings_path)
assert res == True
res = await ezkl.calibrate_settings(
res = ezkl.calibrate_settings(
data_path, model_path, settings_path, "resources")
assert res == True
assert os.path.isfile(settings_path)
@@ -665,14 +652,7 @@ async def aggregate_and_verify_aggr():
assert res == True
def test_aggregate_and_verify_aggr():
"""
Tests for aggregated proof and verifying aggregate proof
"""
asyncio.run(aggregate_and_verify_aggr())
async def evm_aggregate_and_verify_aggr():
def test_evm_aggregate_and_verify_aggr():
data_path = os.path.join(
examples_path,
'onnx',
@@ -697,7 +677,7 @@ async def evm_aggregate_and_verify_aggr():
settings_path,
)
await ezkl.calibrate_settings(
ezkl.calibrate_settings(
data_path,
model_path,
settings_path,
@@ -807,25 +787,21 @@ async def evm_aggregate_and_verify_aggr():
# assert res == True
def test_evm_aggregate_and_verify_aggr():
"""
Tests for aggregated proof and verifying aggregate proof
"""
asyncio.run(evm_aggregate_and_verify_aggr())
def get_examples():
EXAMPLES_OMIT = [
# these are too large
'mobilenet_large',
'mobilenet',
'doodles',
'nanoGPT',
# these fails for some reason
"self_attention",
'multihead_attention',
'large_op_graph',
'1l_instance_norm',
'variable_cnn',
'accuracy',
'linear_regression'
'linear_regression',
"mnist_gan",
]
examples = []
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
@@ -851,9 +827,15 @@ def test_all_examples(model_file, input_file):
witness_path = os.path.join(folder_path, 'witness.json')
proof_path = os.path.join(folder_path, 'proof.json')
print("Testing example: ", model_file)
res = ezkl.gen_settings(model_file, settings_path)
assert res
res = ezkl.calibrate_settings(
input_file, model_file, settings_path, "resources")
assert res
print("Compiling example: ", model_file)
res = ezkl.compile_circuit(model_file, compiled_model_path, settings_path)
assert res
@@ -865,8 +847,10 @@ def test_all_examples(model_file, input_file):
# generate the srs file if the path does not exist
if not os.path.exists(srs_path):
print("Generating srs file: ", srs_path)
ezkl.gen_srs(os.path.join(folder_path, srs_path), logrows)
print("Setting up example: ", model_file)
res = ezkl.setup(
compiled_model_path,
vk_path,
@@ -877,9 +861,11 @@ def test_all_examples(model_file, input_file):
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
print("Generating witness for example: ", model_file)
res = ezkl.gen_witness(input_file, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)
print("Proving example: ", model_file)
ezkl.prove(
witness_path,
compiled_model_path,
@@ -890,6 +876,8 @@ def test_all_examples(model_file, input_file):
)
assert os.path.isfile(proof_path)
print("Verifying example: ", model_file)
res = ezkl.verify(
proof_path,
settings_path,