feat: support prove

This commit is contained in:
mhchia
2023-12-12 17:19:04 +08:00
parent 2ddb71b645
commit 10c400c4fb
16 changed files with 460 additions and 174 deletions

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -52,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -71,16 +71,16 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -182,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -224,7 +224,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -273,7 +273,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Here verifier & prover can concurrently call setup since all params are public to get pk. \n",
"# Here verifier & prover can concurrently call setup since all params are public to get pk.\n",
"# Here write as verifier function to emphasize that verifier must calculate its own vk to be sure\n",
"verifier_setup(verifier_model_path, verifier_compiled_model_path, settings_path, srs_path,vk_path, pk_path )\n",
"\n",
@@ -285,7 +285,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"metadata": {},
"outputs": [
{

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

File diff suppressed because one or more lines are too long

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

File diff suppressed because one or more lines are too long

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

View File

@@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../core.py"
"%run -i ../../zkstats/core.py"
]
},
{

74
zkstats/cli.py Normal file
View File

@@ -0,0 +1,74 @@
import os
import click
from .core import prover_gen_proof, prover_setup, load_model
os.makedirs(os.path.dirname('shared/'), exist_ok=True)
os.makedirs(os.path.dirname('prover/'), exist_ok=True)
verifier_model_path = os.path.join('shared/verifier.onnx')
prover_model_path = os.path.join('prover/prover.onnx')
verifier_compiled_model_path = os.path.join('shared/verifier.compiled')
prover_compiled_model_path = os.path.join('prover/prover.compiled')
pk_path = os.path.join('shared/test.pk')
vk_path = os.path.join('shared/test.vk')
proof_path = os.path.join('shared/test.pf')
settings_path = os.path.join('shared/settings.json')
srs_path = os.path.join('shared/kzg.srs')
witness_path = os.path.join('prover/witness.json')
# this is private to prover since it contains actual data
data_path = os.path.join('data.json')
comb_data_path = os.path.join('prover/comb_data.json')
@click.group()
def cli():
pass
@click.command()
@click.argument('model_path')
def prove(model_path: str):
click.echo(f"Hello, {model_path}!")
prover_model = load_model(model_path)
print("!@# prover_model=", prover_model)
prover_setup(
[data_path],
comb_data_path,
prover_model,
prover_model_path,
prover_compiled_model_path,
"default",
"resources",
settings_path,
srs_path,
vk_path,
pk_path,
)
prover_gen_proof(
prover_model_path,
comb_data_path,
witness_path,
prover_compiled_model_path,
settings_path,
proof_path,
pk_path,
srs_path,
)
@click.command()
def verify():
click.echo(f"Hello, verify!")
def main():
cli()
# Register commands
cli.add_command(prove)
cli.add_command(verify)
if __name__ == "__main__":
main()

View File

@@ -1,12 +1,43 @@
import sys
import importlib.util
from typing import Type
import torch
from torch import Tensor
import ezkl
import os
import numpy as np
import json
import time
def load_model(module_path: str) -> Type[torch.nn.Module]:
"""
Load a model from a Python module.
"""
# FIXME: This is unsafe since malicious code can be executed
model_name = "Model"
module_name = os.path.splitext(os.path.basename(module_path))[0]
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
try:
cls = getattr(module, model_name)
except AttributeError:
raise ImportError(f"class {model_name} does not exist in {module_name}")
return cls
# Export model
def export_onnx(model, data_tensor_array, model_loc):
def export_onnx(model: torch.nn.Module, data_tensor_array, model_loc):
circuit = model()
# Try running `prepare()` if it exists
try:
circuit.prepare(data_tensor_array)
except AttributeError:
pass
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -16,6 +47,7 @@ def export_onnx(model, data_tensor_array, model_loc):
# Flips the neural net into inference mode
circuit.eval()
print("!@# circuit.eval=", circuit)
input_names = []
dynamic_axes = {}
@@ -37,9 +69,9 @@ def export_onnx(model, data_tensor_array, model_loc):
input_names = input_names, # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes=dynamic_axes)
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
# mode is either "accuracy" or "resources"
def gen_settings(comb_data_path, onnx_filename, scale, mode, settings_filename):
@@ -67,8 +99,8 @@ def gen_settings(comb_data_path, onnx_filename, scale, mode, settings_filename):
print("scale: ", scale)
print("setting: ", f_setting.read())
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
def verifier_define_calculation(verifier_model, verifier_model_path, dummy_data_path_array):
# load data from dummy_data_path_array into dummy_data_tensor_array
@@ -77,31 +109,42 @@ def verifier_define_calculation(verifier_model, verifier_model_path, dummy_data_
dummy_data = np.array(json.loads(open(path, "r").read())["input_data"][0])
dummy_data_tensor_array.append(torch.reshape(torch.tensor(dummy_data), (1, len(dummy_data),1 )))
# export onnx file
export_onnx(verifier_model,dummy_data_tensor_array, verifier_model_path)
export_onnx(verifier_model, dummy_data_tensor_array, verifier_model_path)
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
# we decide to not have comb_data_path as parameter since a bit redundant parameter.
def prover_gen_settings(data_path_array, comb_data_path, prover_model,prover_model_path, scale, mode, settings_path):
def process_data(data_path_array, comb_data_path) -> list[Tensor]:
# Load data from data_path_array into data_tensor_array
data_tensor_array=[]
comb_data = []
for path in data_path_array:
data = np.array(json.loads(open(path, "r").read())["input_data"][0])
data_tensor_array.append(torch.reshape(torch.tensor(data), (1, len(data),1 )))
data = np.array(
json.loads(open(path, "r").read())["input_data"][0]
)
print("!@# data=", data)
data_tensor = torch.tensor(data)
t = (1, len(data), 1)
data_tensor_array.append(torch.reshape(data_tensor, t))
comb_data.append(data.tolist())
# Serialize data into file:
# comb_data comes from `data`
json.dump(dict(input_data = comb_data), open(comb_data_path, 'w' ))
return data_tensor_array
# we decide to not have comb_data_path as parameter since a bit redundant parameter.
def prover_gen_settings(data_path_array, comb_data_path, prover_model,prover_model_path, scale, mode, settings_path):
data_tensor_array = process_data(data_path_array, comb_data_path)
# export onnx file
export_onnx(prover_model, data_tensor_array, prover_model_path)
# gen + calibrate setting
gen_settings(comb_data_path, prover_model_path, scale, mode, settings_path)
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
# Here prover can concurrently call this since all params are public to get pk.
# Here prover can concurrently call this since all params are public to get pk.
# Here write as verifier function to emphasize that verifier must calculate its own vk to be sure
def verifier_setup(verifier_model_path, verifier_compiled_model_path, settings_path, srs_path,vk_path, pk_path ):
# compile circuit
@@ -111,7 +154,7 @@ def verifier_setup(verifier_model_path, verifier_compiled_model_path, settings_p
# srs path
res = ezkl.get_srs(srs_path, settings_path)
# setupt vk, pk param for use..... prover can use same pk or can init their own!
# setup vk, pk param for use..... prover can use same pk or can init their own!
print("==== setting up ezkl ====")
start_time = time.time()
res = ezkl.setup(
@@ -128,43 +171,77 @@ def verifier_setup(verifier_model_path, verifier_compiled_model_path, settings_p
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
# ===================================================================================================
def prover_gen_proof(prover_model_path, comb_data_path, witness_path, prover_compiled_model_path, settings_path, proof_path, pk_path, srs_path):
res = ezkl.compile_circuit(prover_model_path, prover_compiled_model_path, settings_path)
assert res == True
# now generate the witness file
print('==== Generating Witness ====')
witness = ezkl.gen_witness(comb_data_path, prover_compiled_model_path, witness_path)
assert os.path.isfile(witness_path)
# print(witness["outputs"])
settings = json.load(open(settings_path))
output_scale = settings['model_output_scales']
print("witness boolean: ", ezkl.vecu64_to_float(witness['outputs'][0][0], output_scale[0]))
for i in range(len(witness['outputs'][1])):
print("witness result", i+1,":", ezkl.vecu64_to_float(witness['outputs'][1][i], output_scale[1]))
# we decide to not have comb_data_path as parameter since a bit redundant parameter.
def prover_setup(
data_path_array,
comb_data_path,
prover_model,
prover_model_path,
prover_compiled_model_path,
scale,
mode,
settings_path,
srs_path,
vk_path,
pk_path,
):
data_tensor_array = process_data(data_path_array, comb_data_path)
# GENERATE A PROOF
print("==== Generating Proof ====")
start_time = time.time()
res = ezkl.prove(
witness_path,
prover_compiled_model_path,
pk_path,
proof_path,
srs_path,
"single",
)
# export onnx file
export_onnx(prover_model, data_tensor_array, prover_model_path)
# gen + calibrate setting
gen_settings(comb_data_path, prover_model_path, scale, mode, settings_path)
verifier_setup(prover_model_path, prover_compiled_model_path, settings_path, srs_path, vk_path, pk_path)
print("proof: " ,res)
end_time = time.time()
time_gen_prf = end_time -start_time
print(f"Time gen prf: {time_gen_prf} seconds")
assert os.path.isfile(proof_path)
# ===================================================================================================
# ===================================================================================================
def prover_gen_proof(
prover_model_path,
comb_data_path,
witness_path,
prover_compiled_model_path,
settings_path,
proof_path,
pk_path,
srs_path,
):
print("!@# compiled_model exists?", os.path.isfile(prover_compiled_model_path))
res = ezkl.compile_circuit(prover_model_path, prover_compiled_model_path, settings_path)
print("!@# compiled_model exists?", os.path.isfile(prover_compiled_model_path))
assert res == True
# now generate the witness file
print('==== Generating Witness ====')
witness = ezkl.gen_witness(comb_data_path, prover_compiled_model_path, witness_path)
assert os.path.isfile(witness_path)
# print(witness["outputs"])
settings = json.load(open(settings_path))
output_scale = settings['model_output_scales']
print("witness boolean: ", ezkl.vecu64_to_float(witness['outputs'][0][0], output_scale[0]))
for i in range(len(witness['outputs'][1])):
print("witness result", i+1,":", ezkl.vecu64_to_float(witness['outputs'][1][i], output_scale[1]))
# GENERATE A PROOF
print("==== Generating Proof ====")
start_time = time.time()
res = ezkl.prove(
witness_path,
prover_compiled_model_path,
pk_path,
proof_path,
srs_path,
"single",
)
print("proof: " ,res)
end_time = time.time()
time_gen_prf = end_time -start_time
print(f"Time gen prf: {time_gen_prf} seconds")
assert os.path.isfile(proof_path)
# ===================================================================================================
# ===================================================================================================
def verifier_verify(proof_path, settings_path, vk_path, srs_path):
# enforce boolean statement to be true

46
zkstats/models.py Normal file
View File

@@ -0,0 +1,46 @@
from typing import Any
from abc import ABC, abstractmethod
from torch import nn
import torch
class BaseZKStatsModel(ABC, nn.Module):
def __init__(self):
super().__init__()
@abstractmethod
def forward(self, X: Any) -> Any:
"""
:param X: a tensor of shape (1, n, 1)
:return: a tuple of (bool, float)
"""
class NoDivisionModel(BaseZKStatsModel):
def __init__(self):
super().__init__()
# w represents mean in this case
@abstractmethod
def prepare(expected_output: Any):
...
@abstractmethod
def forward(self, X: Any) -> tuple[float, float]:
# some expression of tolerance to error in the inference
# must have w first!
...
class MeanModel(NoDivisionModel):
def __init__(self):
super().__init__()
def prepare(self, X: Any):
expected_output = torch.mean(X[0])
# w represents mean in this case
self.w = nn.Parameter(data = expected_output, requires_grad = False)
def forward(self, X: Any) -> tuple[float, float]:
# some expression of tolerance to error in the inference
# must have w first!
return (torch.abs(torch.sum(X)-X.size()[1]*(self.w))<0.01*X.size()[1]*(self.w), self.w)