mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-04-25 03:01:17 -04:00
fix!: make calibrate-settings sync in python (#616)
BREAKING CHANGE: calibrate settings is no longer async
This commit is contained in:
@@ -2,7 +2,6 @@ import ezkl
|
||||
import os
|
||||
import pytest
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
@@ -156,7 +155,8 @@ def test_gen_srs():
|
||||
assert os.path.isfile(params_k20_path)
|
||||
|
||||
|
||||
async def calibrate_over_user_range():
|
||||
|
||||
def test_calibrate_over_user_range():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
@@ -183,20 +183,14 @@ async def calibrate_over_user_range():
|
||||
model_path, output_path, py_run_args=run_args)
|
||||
assert res == True
|
||||
|
||||
res = await ezkl.calibrate_settings(
|
||||
res = ezkl.calibrate_settings(
|
||||
data_path, model_path, output_path, "resources", [0, 1, 2])
|
||||
assert res == True
|
||||
assert os.path.isfile(output_path)
|
||||
|
||||
|
||||
def test_calibrate_calibrate_over_user_range():
|
||||
"""
|
||||
Test for calibrate
|
||||
"""
|
||||
asyncio.run(calibrate_over_user_range())
|
||||
|
||||
|
||||
async def calibrate():
|
||||
def test_calibrate():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
@@ -223,19 +217,12 @@ async def calibrate():
|
||||
model_path, output_path, py_run_args=run_args)
|
||||
assert res == True
|
||||
|
||||
res = await ezkl.calibrate_settings(
|
||||
res = ezkl.calibrate_settings(
|
||||
data_path, model_path, output_path, "resources")
|
||||
assert res == True
|
||||
assert os.path.isfile(output_path)
|
||||
|
||||
|
||||
def test_calibrate():
|
||||
"""
|
||||
Test for calibrate
|
||||
"""
|
||||
asyncio.run(calibrate())
|
||||
|
||||
|
||||
def test_model_compile():
|
||||
"""
|
||||
Test for model compilation/serialization
|
||||
@@ -559,7 +546,7 @@ def test_verify_evm():
|
||||
assert res == True
|
||||
|
||||
|
||||
async def aggregate_and_verify_aggr():
|
||||
def test_aggregate_and_verify_aggr():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
@@ -588,7 +575,7 @@ async def aggregate_and_verify_aggr():
|
||||
res = ezkl.gen_settings(model_path, settings_path)
|
||||
assert res == True
|
||||
|
||||
res = await ezkl.calibrate_settings(
|
||||
res = ezkl.calibrate_settings(
|
||||
data_path, model_path, settings_path, "resources")
|
||||
assert res == True
|
||||
assert os.path.isfile(settings_path)
|
||||
@@ -665,14 +652,7 @@ async def aggregate_and_verify_aggr():
|
||||
assert res == True
|
||||
|
||||
|
||||
def test_aggregate_and_verify_aggr():
|
||||
"""
|
||||
Tests for aggregated proof and verifying aggregate proof
|
||||
"""
|
||||
asyncio.run(aggregate_and_verify_aggr())
|
||||
|
||||
|
||||
async def evm_aggregate_and_verify_aggr():
|
||||
def test_evm_aggregate_and_verify_aggr():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
@@ -697,7 +677,7 @@ async def evm_aggregate_and_verify_aggr():
|
||||
settings_path,
|
||||
)
|
||||
|
||||
await ezkl.calibrate_settings(
|
||||
ezkl.calibrate_settings(
|
||||
data_path,
|
||||
model_path,
|
||||
settings_path,
|
||||
@@ -807,25 +787,21 @@ async def evm_aggregate_and_verify_aggr():
|
||||
# assert res == True
|
||||
|
||||
|
||||
def test_evm_aggregate_and_verify_aggr():
|
||||
"""
|
||||
Tests for aggregated proof and verifying aggregate proof
|
||||
"""
|
||||
asyncio.run(evm_aggregate_and_verify_aggr())
|
||||
|
||||
def get_examples():
|
||||
EXAMPLES_OMIT = [
|
||||
# these are too large
|
||||
'mobilenet_large',
|
||||
'mobilenet',
|
||||
'doodles',
|
||||
'nanoGPT',
|
||||
# these fails for some reason
|
||||
"self_attention",
|
||||
'multihead_attention',
|
||||
'large_op_graph',
|
||||
'1l_instance_norm',
|
||||
'variable_cnn',
|
||||
'accuracy',
|
||||
'linear_regression'
|
||||
'linear_regression',
|
||||
"mnist_gan",
|
||||
]
|
||||
examples = []
|
||||
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
|
||||
@@ -851,9 +827,15 @@ def test_all_examples(model_file, input_file):
|
||||
witness_path = os.path.join(folder_path, 'witness.json')
|
||||
proof_path = os.path.join(folder_path, 'proof.json')
|
||||
|
||||
print("Testing example: ", model_file)
|
||||
res = ezkl.gen_settings(model_file, settings_path)
|
||||
assert res
|
||||
|
||||
res = ezkl.calibrate_settings(
|
||||
input_file, model_file, settings_path, "resources")
|
||||
assert res
|
||||
|
||||
print("Compiling example: ", model_file)
|
||||
res = ezkl.compile_circuit(model_file, compiled_model_path, settings_path)
|
||||
assert res
|
||||
|
||||
@@ -865,8 +847,10 @@ def test_all_examples(model_file, input_file):
|
||||
|
||||
# generate the srs file if the path does not exist
|
||||
if not os.path.exists(srs_path):
|
||||
print("Generating srs file: ", srs_path)
|
||||
ezkl.gen_srs(os.path.join(folder_path, srs_path), logrows)
|
||||
|
||||
print("Setting up example: ", model_file)
|
||||
res = ezkl.setup(
|
||||
compiled_model_path,
|
||||
vk_path,
|
||||
@@ -877,9 +861,11 @@ def test_all_examples(model_file, input_file):
|
||||
assert os.path.isfile(vk_path)
|
||||
assert os.path.isfile(pk_path)
|
||||
|
||||
print("Generating witness for example: ", model_file)
|
||||
res = ezkl.gen_witness(input_file, compiled_model_path, witness_path)
|
||||
assert os.path.isfile(witness_path)
|
||||
|
||||
print("Proving example: ", model_file)
|
||||
ezkl.prove(
|
||||
witness_path,
|
||||
compiled_model_path,
|
||||
@@ -890,6 +876,8 @@ def test_all_examples(model_file, input_file):
|
||||
)
|
||||
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
print("Verifying example: ", model_file)
|
||||
res = ezkl.verify(
|
||||
proof_path,
|
||||
settings_path,
|
||||
|
||||
Reference in New Issue
Block a user