mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
feat: support prove
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -52,7 +52,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -71,16 +71,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -119,7 +119,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -182,7 +182,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -224,7 +224,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -273,7 +273,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Here verifier & prover can concurrently call setup since all params are public to get pk. \n",
|
||||
"# Here verifier & prover can concurrently call setup since all params are public to get pk.\n",
|
||||
"# Here write as verifier function to emphasize that verifier must calculate its own vk to be sure\n",
|
||||
"verifier_setup(verifier_model_path, verifier_compiled_model_path, settings_path, srs_path,vk_path, pk_path )\n",
|
||||
"\n",
|
||||
@@ -285,7 +285,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../core.py"
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
74
zkstats/cli.py
Normal file
74
zkstats/cli.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
import click
|
||||
|
||||
from .core import prover_gen_proof, prover_setup, load_model
|
||||
|
||||
os.makedirs(os.path.dirname('shared/'), exist_ok=True)
|
||||
os.makedirs(os.path.dirname('prover/'), exist_ok=True)
|
||||
verifier_model_path = os.path.join('shared/verifier.onnx')
|
||||
prover_model_path = os.path.join('prover/prover.onnx')
|
||||
verifier_compiled_model_path = os.path.join('shared/verifier.compiled')
|
||||
prover_compiled_model_path = os.path.join('prover/prover.compiled')
|
||||
pk_path = os.path.join('shared/test.pk')
|
||||
vk_path = os.path.join('shared/test.vk')
|
||||
proof_path = os.path.join('shared/test.pf')
|
||||
settings_path = os.path.join('shared/settings.json')
|
||||
srs_path = os.path.join('shared/kzg.srs')
|
||||
witness_path = os.path.join('prover/witness.json')
|
||||
# this is private to prover since it contains actual data
|
||||
data_path = os.path.join('data.json')
|
||||
comb_data_path = os.path.join('prover/comb_data.json')
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.argument('model_path')
|
||||
def prove(model_path: str):
|
||||
click.echo(f"Hello, {model_path}!")
|
||||
|
||||
prover_model = load_model(model_path)
|
||||
print("!@# prover_model=", prover_model)
|
||||
prover_setup(
|
||||
[data_path],
|
||||
comb_data_path,
|
||||
prover_model,
|
||||
prover_model_path,
|
||||
prover_compiled_model_path,
|
||||
"default",
|
||||
"resources",
|
||||
settings_path,
|
||||
srs_path,
|
||||
vk_path,
|
||||
pk_path,
|
||||
)
|
||||
prover_gen_proof(
|
||||
prover_model_path,
|
||||
comb_data_path,
|
||||
witness_path,
|
||||
prover_compiled_model_path,
|
||||
settings_path,
|
||||
proof_path,
|
||||
pk_path,
|
||||
srs_path,
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
def verify():
|
||||
click.echo(f"Hello, verify!")
|
||||
|
||||
|
||||
def main():
|
||||
cli()
|
||||
|
||||
|
||||
# Register commands
|
||||
cli.add_command(prove)
|
||||
cli.add_command(verify)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,12 +1,43 @@
|
||||
import sys
|
||||
import importlib.util
|
||||
from typing import Type
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import ezkl
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
def load_model(module_path: str) -> Type[torch.nn.Module]:
|
||||
"""
|
||||
Load a model from a Python module.
|
||||
"""
|
||||
# FIXME: This is unsafe since malicious code can be executed
|
||||
|
||||
model_name = "Model"
|
||||
module_name = os.path.splitext(os.path.basename(module_path))[0]
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
try:
|
||||
cls = getattr(module, model_name)
|
||||
except AttributeError:
|
||||
raise ImportError(f"class {model_name} does not exist in {module_name}")
|
||||
return cls
|
||||
|
||||
|
||||
# Export model
|
||||
def export_onnx(model, data_tensor_array, model_loc):
|
||||
def export_onnx(model: torch.nn.Module, data_tensor_array, model_loc):
|
||||
circuit = model()
|
||||
# Try running `prepare()` if it exists
|
||||
try:
|
||||
circuit.prepare(data_tensor_array)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -16,6 +47,7 @@ def export_onnx(model, data_tensor_array, model_loc):
|
||||
|
||||
# Flips the neural net into inference mode
|
||||
circuit.eval()
|
||||
print("!@# circuit.eval=", circuit)
|
||||
input_names = []
|
||||
dynamic_axes = {}
|
||||
|
||||
@@ -37,9 +69,9 @@ def export_onnx(model, data_tensor_array, model_loc):
|
||||
input_names = input_names, # the model's input names
|
||||
output_names = ['output'], # the model's output names
|
||||
dynamic_axes=dynamic_axes)
|
||||
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
# mode is either "accuracy" or "resources"
|
||||
def gen_settings(comb_data_path, onnx_filename, scale, mode, settings_filename):
|
||||
@@ -67,8 +99,8 @@ def gen_settings(comb_data_path, onnx_filename, scale, mode, settings_filename):
|
||||
print("scale: ", scale)
|
||||
print("setting: ", f_setting.read())
|
||||
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
def verifier_define_calculation(verifier_model, verifier_model_path, dummy_data_path_array):
|
||||
# load data from dummy_data_path_array into dummy_data_tensor_array
|
||||
@@ -77,31 +109,42 @@ def verifier_define_calculation(verifier_model, verifier_model_path, dummy_data_
|
||||
dummy_data = np.array(json.loads(open(path, "r").read())["input_data"][0])
|
||||
dummy_data_tensor_array.append(torch.reshape(torch.tensor(dummy_data), (1, len(dummy_data),1 )))
|
||||
# export onnx file
|
||||
export_onnx(verifier_model,dummy_data_tensor_array, verifier_model_path)
|
||||
export_onnx(verifier_model, dummy_data_tensor_array, verifier_model_path)
|
||||
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
# we decide to not have comb_data_path as parameter since a bit redundant parameter.
|
||||
def prover_gen_settings(data_path_array, comb_data_path, prover_model,prover_model_path, scale, mode, settings_path):
|
||||
def process_data(data_path_array, comb_data_path) -> list[Tensor]:
|
||||
# Load data from data_path_array into data_tensor_array
|
||||
data_tensor_array=[]
|
||||
comb_data = []
|
||||
for path in data_path_array:
|
||||
data = np.array(json.loads(open(path, "r").read())["input_data"][0])
|
||||
data_tensor_array.append(torch.reshape(torch.tensor(data), (1, len(data),1 )))
|
||||
data = np.array(
|
||||
json.loads(open(path, "r").read())["input_data"][0]
|
||||
)
|
||||
print("!@# data=", data)
|
||||
data_tensor = torch.tensor(data)
|
||||
t = (1, len(data), 1)
|
||||
data_tensor_array.append(torch.reshape(data_tensor, t))
|
||||
comb_data.append(data.tolist())
|
||||
# Serialize data into file:
|
||||
# comb_data comes from `data`
|
||||
json.dump(dict(input_data = comb_data), open(comb_data_path, 'w' ))
|
||||
return data_tensor_array
|
||||
|
||||
# we decide to not have comb_data_path as parameter since a bit redundant parameter.
|
||||
def prover_gen_settings(data_path_array, comb_data_path, prover_model,prover_model_path, scale, mode, settings_path):
|
||||
data_tensor_array = process_data(data_path_array, comb_data_path)
|
||||
|
||||
# export onnx file
|
||||
export_onnx(prover_model, data_tensor_array, prover_model_path)
|
||||
# gen + calibrate setting
|
||||
gen_settings(comb_data_path, prover_model_path, scale, mode, settings_path)
|
||||
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
# Here prover can concurrently call this since all params are public to get pk.
|
||||
# Here prover can concurrently call this since all params are public to get pk.
|
||||
# Here write as verifier function to emphasize that verifier must calculate its own vk to be sure
|
||||
def verifier_setup(verifier_model_path, verifier_compiled_model_path, settings_path, srs_path,vk_path, pk_path ):
|
||||
# compile circuit
|
||||
@@ -111,7 +154,7 @@ def verifier_setup(verifier_model_path, verifier_compiled_model_path, settings_p
|
||||
# srs path
|
||||
res = ezkl.get_srs(srs_path, settings_path)
|
||||
|
||||
# setupt vk, pk param for use..... prover can use same pk or can init their own!
|
||||
# setup vk, pk param for use..... prover can use same pk or can init their own!
|
||||
print("==== setting up ezkl ====")
|
||||
start_time = time.time()
|
||||
res = ezkl.setup(
|
||||
@@ -128,43 +171,77 @@ def verifier_setup(verifier_model_path, verifier_compiled_model_path, settings_p
|
||||
assert os.path.isfile(pk_path)
|
||||
assert os.path.isfile(settings_path)
|
||||
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
def prover_gen_proof(prover_model_path, comb_data_path, witness_path, prover_compiled_model_path, settings_path, proof_path, pk_path, srs_path):
|
||||
res = ezkl.compile_circuit(prover_model_path, prover_compiled_model_path, settings_path)
|
||||
assert res == True
|
||||
# now generate the witness file
|
||||
print('==== Generating Witness ====')
|
||||
witness = ezkl.gen_witness(comb_data_path, prover_compiled_model_path, witness_path)
|
||||
assert os.path.isfile(witness_path)
|
||||
# print(witness["outputs"])
|
||||
settings = json.load(open(settings_path))
|
||||
output_scale = settings['model_output_scales']
|
||||
print("witness boolean: ", ezkl.vecu64_to_float(witness['outputs'][0][0], output_scale[0]))
|
||||
for i in range(len(witness['outputs'][1])):
|
||||
print("witness result", i+1,":", ezkl.vecu64_to_float(witness['outputs'][1][i], output_scale[1]))
|
||||
# we decide to not have comb_data_path as parameter since a bit redundant parameter.
|
||||
def prover_setup(
|
||||
data_path_array,
|
||||
comb_data_path,
|
||||
prover_model,
|
||||
prover_model_path,
|
||||
prover_compiled_model_path,
|
||||
scale,
|
||||
mode,
|
||||
settings_path,
|
||||
srs_path,
|
||||
vk_path,
|
||||
pk_path,
|
||||
):
|
||||
data_tensor_array = process_data(data_path_array, comb_data_path)
|
||||
|
||||
# GENERATE A PROOF
|
||||
print("==== Generating Proof ====")
|
||||
start_time = time.time()
|
||||
res = ezkl.prove(
|
||||
witness_path,
|
||||
prover_compiled_model_path,
|
||||
pk_path,
|
||||
proof_path,
|
||||
srs_path,
|
||||
"single",
|
||||
)
|
||||
# export onnx file
|
||||
export_onnx(prover_model, data_tensor_array, prover_model_path)
|
||||
# gen + calibrate setting
|
||||
gen_settings(comb_data_path, prover_model_path, scale, mode, settings_path)
|
||||
verifier_setup(prover_model_path, prover_compiled_model_path, settings_path, srs_path, vk_path, pk_path)
|
||||
|
||||
print("proof: " ,res)
|
||||
end_time = time.time()
|
||||
time_gen_prf = end_time -start_time
|
||||
print(f"Time gen prf: {time_gen_prf} seconds")
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
def prover_gen_proof(
|
||||
prover_model_path,
|
||||
comb_data_path,
|
||||
witness_path,
|
||||
prover_compiled_model_path,
|
||||
settings_path,
|
||||
proof_path,
|
||||
pk_path,
|
||||
srs_path,
|
||||
):
|
||||
print("!@# compiled_model exists?", os.path.isfile(prover_compiled_model_path))
|
||||
res = ezkl.compile_circuit(prover_model_path, prover_compiled_model_path, settings_path)
|
||||
print("!@# compiled_model exists?", os.path.isfile(prover_compiled_model_path))
|
||||
assert res == True
|
||||
# now generate the witness file
|
||||
print('==== Generating Witness ====')
|
||||
witness = ezkl.gen_witness(comb_data_path, prover_compiled_model_path, witness_path)
|
||||
assert os.path.isfile(witness_path)
|
||||
# print(witness["outputs"])
|
||||
settings = json.load(open(settings_path))
|
||||
output_scale = settings['model_output_scales']
|
||||
print("witness boolean: ", ezkl.vecu64_to_float(witness['outputs'][0][0], output_scale[0]))
|
||||
for i in range(len(witness['outputs'][1])):
|
||||
print("witness result", i+1,":", ezkl.vecu64_to_float(witness['outputs'][1][i], output_scale[1]))
|
||||
|
||||
# GENERATE A PROOF
|
||||
print("==== Generating Proof ====")
|
||||
start_time = time.time()
|
||||
res = ezkl.prove(
|
||||
witness_path,
|
||||
prover_compiled_model_path,
|
||||
pk_path,
|
||||
proof_path,
|
||||
srs_path,
|
||||
"single",
|
||||
)
|
||||
|
||||
print("proof: " ,res)
|
||||
end_time = time.time()
|
||||
time_gen_prf = end_time -start_time
|
||||
print(f"Time gen prf: {time_gen_prf} seconds")
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
# ===================================================================================================
|
||||
# ===================================================================================================
|
||||
|
||||
def verifier_verify(proof_path, settings_path, vk_path, srs_path):
|
||||
# enforce boolean statement to be true
|
||||
46
zkstats/models.py
Normal file
46
zkstats/models.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
class BaseZKStatsModel(ABC, nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, X: Any) -> Any:
|
||||
"""
|
||||
:param X: a tensor of shape (1, n, 1)
|
||||
:return: a tuple of (bool, float)
|
||||
"""
|
||||
|
||||
|
||||
class NoDivisionModel(BaseZKStatsModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# w represents mean in this case
|
||||
|
||||
@abstractmethod
|
||||
def prepare(expected_output: Any):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, X: Any) -> tuple[float, float]:
|
||||
# some expression of tolerance to error in the inference
|
||||
# must have w first!
|
||||
...
|
||||
|
||||
class MeanModel(NoDivisionModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def prepare(self, X: Any):
|
||||
expected_output = torch.mean(X[0])
|
||||
# w represents mean in this case
|
||||
self.w = nn.Parameter(data = expected_output, requires_grad = False)
|
||||
|
||||
def forward(self, X: Any) -> tuple[float, float]:
|
||||
# some expression of tolerance to error in the inference
|
||||
# must have w first!
|
||||
return (torch.abs(torch.sum(X)-X.size()[1]*(self.w))<0.01*X.size()[1]*(self.w), self.w)
|
||||
Reference in New Issue
Block a user