fix: make cli work again with new computation ui

This commit is contained in:
mhchia
2024-01-26 22:07:15 +08:00
parent e3e10168b6
commit dd7973aa3d
4 changed files with 17 additions and 18 deletions

View File

@@ -147,7 +147,7 @@
"source": [
"## Step 2\n",
"- User defines their computation in a function with signature `computation(state: State, x: list[torch.Tensor])`.\n",
"- Prover calls `create_model(computation)` to derive the actual model.\n",
"- Prover calls `computation_to_model(computation)` to derive the actual model.\n",
"- Prover calls `prover_gen_settings`: export onnx file and compute the settings required by `ezkl.calibrate_settings`"
]
},
@@ -181,7 +181,7 @@
}
],
"source": [
"from zkstats.computation import State, create_model\n",
"from zkstats.computation import State, computation_to_model\n",
"\n",
"\n",
"def computation(state: State, x: list[torch.Tensor]):\n",
@@ -190,7 +190,7 @@
" out_1 = state.median(x_0)\n",
" return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))\n",
"\n",
"_, prover_model = create_model(computation)\n",
"_, prover_model = computation_to_model(computation)\n",
"prover_gen_settings([data_path], comb_data_path, prover_model, prover_model_path, \"default\", \"resources\", settings_path)\n"
]
},

View File

@@ -2,7 +2,7 @@ import statistics
import torch
import torch
from zkstats.computation import State, create_model
from zkstats.computation import State, computation_to_model
from zkstats.ops import Mean, Median
from .helpers import compute
@@ -15,7 +15,7 @@ def computation(state: State, x: list[torch.Tensor]):
def test_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float):
state, model = create_model(computation, error)
state, model = computation_to_model(computation, error)
compute(tmp_path, [column_0, column_1], model)
assert state.current_op_index == 3

View File

@@ -7,6 +7,7 @@ import click
import torch
from .core import prover_gen_proof, prover_gen_settings, verifier_setup, verifier_verify, gen_data_commitment
from .computation import computation_to_model, State
cwd = os.getcwd()
# TODO: Should make this configurable
@@ -18,7 +19,6 @@ pk_path = f"{output_dir}/model.pk"
vk_path = f"{output_dir}/model.vk"
proof_path = f"{output_dir}/model.pf"
settings_path = f"{output_dir}/settings.json"
srs_path = f"{output_dir}/kzg.srs"
witness_path = f"{output_dir}/witness.json"
comb_data_path = f"{output_dir}/comb_data.json"
@@ -29,11 +29,11 @@ def cli():
@click.command()
@click.argument('model_path')
@click.argument('computation_path')
@click.argument('data_path')
def prove(model_path: str, data_path: str):
model = load_model(model_path)
print("Loaded model:", model)
def prove(computation_path: str, data_path: str):
computation = load_computation(computation_path)
_, model = computation_to_model(computation)
prover_gen_settings(
[data_path],
comb_data_path,
@@ -44,7 +44,7 @@ def prove(model_path: str, data_path: str):
settings_path,
)
verifier_setup(
model_path,
model_onnx_path,
compiled_model_path,
settings_path,
vk_path,
@@ -89,13 +89,13 @@ def main():
cli()
def load_model(module_path: str) -> Type[torch.nn.Module]:
def load_computation(module_path: str) -> Type[torch.nn.Module]:
"""
Load a model from a Python module.
"""
# FIXME: This is unsafe since malicious code can be executed
model_name = "Model"
model_name = "computation"
module_name = os.path.splitext(os.path.basename(module_path))[0]
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
@@ -103,10 +103,9 @@ def load_model(module_path: str) -> Type[torch.nn.Module]:
spec.loader.exec_module(module)
try:
cls = getattr(module, model_name)
return getattr(module, model_name)
except AttributeError:
raise ImportError(f"class {model_name} does not exist in {module_name}")
return cls
raise ImportError(f"{model_name=} does not exist in {module_name=}")
# Register commands

View File

@@ -101,10 +101,10 @@ class IModel(nn.Module):
# out_0 = state.median(x[0])
# out_1 = state.median(x[1])
# return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torch.Tensor]]
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
def create_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
"""