mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-10 05:57:55 -05:00
fix: make cli work again with new computation ui
This commit is contained in:
@@ -147,7 +147,7 @@
|
||||
"source": [
|
||||
"## Step 2\n",
|
||||
"- User defines their computation in a function with signature `computation(state: State, x: list[torch.Tensor])`.\n",
|
||||
"- Prover calls `create_model(computation)` to derive the actual model.\n",
|
||||
"- Prover calls `computation_to_model(computation)` to derive the actual model.\n",
|
||||
"- Prover calls `prover_gen_settings`: export onnx file and compute the settings required by `ezkl.calibrate_settings`"
|
||||
]
|
||||
},
|
||||
@@ -181,7 +181,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from zkstats.computation import State, create_model\n",
|
||||
"from zkstats.computation import State, computation_to_model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def computation(state: State, x: list[torch.Tensor]):\n",
|
||||
@@ -190,7 +190,7 @@
|
||||
" out_1 = state.median(x_0)\n",
|
||||
" return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))\n",
|
||||
"\n",
|
||||
"_, prover_model = create_model(computation)\n",
|
||||
"_, prover_model = computation_to_model(computation)\n",
|
||||
"prover_gen_settings([data_path], comb_data_path, prover_model, prover_model_path, \"default\", \"resources\", settings_path)\n"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -2,7 +2,7 @@ import statistics
|
||||
import torch
|
||||
import torch
|
||||
|
||||
from zkstats.computation import State, create_model
|
||||
from zkstats.computation import State, computation_to_model
|
||||
from zkstats.ops import Mean, Median
|
||||
|
||||
from .helpers import compute
|
||||
@@ -15,7 +15,7 @@ def computation(state: State, x: list[torch.Tensor]):
|
||||
|
||||
|
||||
def test_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float):
|
||||
state, model = create_model(computation, error)
|
||||
state, model = computation_to_model(computation, error)
|
||||
compute(tmp_path, [column_0, column_1], model)
|
||||
assert state.current_op_index == 3
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import click
|
||||
import torch
|
||||
|
||||
from .core import prover_gen_proof, prover_gen_settings, verifier_setup, verifier_verify, gen_data_commitment
|
||||
from .computation import computation_to_model, State
|
||||
|
||||
cwd = os.getcwd()
|
||||
# TODO: Should make this configurable
|
||||
@@ -18,7 +19,6 @@ pk_path = f"{output_dir}/model.pk"
|
||||
vk_path = f"{output_dir}/model.vk"
|
||||
proof_path = f"{output_dir}/model.pf"
|
||||
settings_path = f"{output_dir}/settings.json"
|
||||
srs_path = f"{output_dir}/kzg.srs"
|
||||
witness_path = f"{output_dir}/witness.json"
|
||||
comb_data_path = f"{output_dir}/comb_data.json"
|
||||
|
||||
@@ -29,11 +29,11 @@ def cli():
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument('model_path')
|
||||
@click.argument('computation_path')
|
||||
@click.argument('data_path')
|
||||
def prove(model_path: str, data_path: str):
|
||||
model = load_model(model_path)
|
||||
print("Loaded model:", model)
|
||||
def prove(computation_path: str, data_path: str):
|
||||
computation = load_computation(computation_path)
|
||||
_, model = computation_to_model(computation)
|
||||
prover_gen_settings(
|
||||
[data_path],
|
||||
comb_data_path,
|
||||
@@ -44,7 +44,7 @@ def prove(model_path: str, data_path: str):
|
||||
settings_path,
|
||||
)
|
||||
verifier_setup(
|
||||
model_path,
|
||||
model_onnx_path,
|
||||
compiled_model_path,
|
||||
settings_path,
|
||||
vk_path,
|
||||
@@ -89,13 +89,13 @@ def main():
|
||||
cli()
|
||||
|
||||
|
||||
def load_model(module_path: str) -> Type[torch.nn.Module]:
|
||||
def load_computation(module_path: str) -> Type[torch.nn.Module]:
|
||||
"""
|
||||
Load a model from a Python module.
|
||||
"""
|
||||
# FIXME: This is unsafe since malicious code can be executed
|
||||
|
||||
model_name = "Model"
|
||||
model_name = "computation"
|
||||
module_name = os.path.splitext(os.path.basename(module_path))[0]
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
@@ -103,10 +103,9 @@ def load_model(module_path: str) -> Type[torch.nn.Module]:
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
try:
|
||||
cls = getattr(module, model_name)
|
||||
return getattr(module, model_name)
|
||||
except AttributeError:
|
||||
raise ImportError(f"class {model_name} does not exist in {module_name}")
|
||||
return cls
|
||||
raise ImportError(f"{model_name=} does not exist in {module_name=}")
|
||||
|
||||
|
||||
# Register commands
|
||||
|
||||
@@ -101,10 +101,10 @@ class IModel(nn.Module):
|
||||
# out_0 = state.median(x[0])
|
||||
# out_1 = state.median(x[1])
|
||||
# return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
|
||||
TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torch.Tensor]]
|
||||
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
|
||||
|
||||
|
||||
def create_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
|
||||
def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
|
||||
"""
|
||||
Create a torch model from a `computation` function defined by user
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user