Merge pull request #55 from ZKStats/feat/access-data-with-key

use key to access column in computation
This commit is contained in:
Kevin Mai-Husan Chia
2024-07-04 23:43:40 +08:00
committed by GitHub
33 changed files with 138 additions and 106 deletions

View File

@@ -50,11 +50,11 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
For example, we have two columns of data and we want to compute the mean of the medians of the two columns:
```python
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
def user_computation(s: State, data: Args) -> torch.Tensor:
# Compute the median of the first column
median1 = s.median(data[0])
median1 = s.median(data['column1'])
# Compute the median of the second column
median2 = s.median(data[1])
median2 = s.median(data['column2'])
# Compute the mean of the medians
return s.mean(torch.cat((median1.unsqueeze(0), median2.unsqueeze(0))).reshape(1,-1,1))
```
@@ -74,9 +74,9 @@ TODO: We should have a list for all supported PyTorch functions.
Although we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we implemented State.where operation that allows users to filter data by their own choice of condition as follows.
```python
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
def user_computation(s: State, data: Args) -> torch.Tensor:
# Compute the mean of the absolute values
x = data[0]
x = data['x']
# Here condition can be chained as shown below, and can have many variables if we have more than just x: e.g. filter = torch.logical_and(x>20, y<2) in case of regression for example.
filter = torch.logical_and(x > 20, x<50)
# call our where function
@@ -116,9 +116,9 @@ Note here, that we can also just let prover generate model, and then send that m
```python
from zkstats.core import computation_to_model
# For prover: generate prover_model, and write to precal_witness file
_, prover_model = computation_to_model(user_computation, precal_witness_path, True, error)
_, prover_model = computation_to_model(user_computation, precal_witness_path, True, selected_columns, error)
# For verifier, generate verifier model (which is same as prover_model) by reading precal_witness file
_, verifier_model = computation_to_model(user_computation, precal_witness_path, False, error)
_, verifier_model = computation_to_model(user_computation, precal_witness_path, False, selected_columns, error)
```
#### Data Provider: generate settings

View File

@@ -181,7 +181,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -176,7 +176,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -194,7 +194,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -157,7 +157,7 @@
"\n",
"error = 0.01\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n"
]
},
@@ -167,7 +167,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -153,7 +153,7 @@
"\n",
"error = 0.01\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n"
]
},
@@ -163,7 +163,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -154,7 +154,7 @@
" return s.geometric_mean(x)\n",
"\n",
"error = 0.01\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)"
]
},
@@ -164,7 +164,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -150,7 +150,7 @@
"error = 0.1\n",
"\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -161,7 +161,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -166,7 +166,7 @@
"\n",
"\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -186,7 +186,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -154,7 +154,7 @@
"# error here in Median only matters in determining the median value in case it doesnt exist in dataset. (Avg of 2 middle values)\n",
"error = 0.01\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation,precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model,prover_model_path, scales, \"resources\", settings_path)"
]
@@ -165,7 +165,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path, verifier_model, verifier_model_path)"
]

View File

@@ -157,7 +157,7 @@
"error = 0\n",
"\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -168,7 +168,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -135,7 +135,7 @@
"error = 0.01\n",
"\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -146,7 +146,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -135,7 +135,7 @@
"error = 0.01\n",
"\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -147,7 +147,7 @@
"outputs": [],
"source": [
"# Prover/ data owner side\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -142,7 +142,7 @@
" return s.linear_regression(x, y)\n",
"\n",
"error = 0.05\n",
"_, prover_model = computation_to_model(computation,precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model,prover_model_path, scales, \"resources\", settings_path)"
]
@@ -153,7 +153,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path, verifier_model, verifier_model_path)"
]

View File

@@ -134,7 +134,7 @@
"\n",
"error = 0.01\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -145,7 +145,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -134,7 +134,7 @@
"\n",
"error = 0.01\n",
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -145,7 +145,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -167,7 +167,7 @@
" return s.correlation(filtered_x, filtered_y)\n",
"\n",
"error = 0.01\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n"
]
},
@@ -177,7 +177,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -159,7 +159,7 @@
"\n",
"error = 0.01\n",
"\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -170,7 +170,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -159,7 +159,7 @@
"\n",
"error = 0.01\n",
"\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n"
]
},
@@ -169,7 +169,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -153,7 +153,7 @@
"\n",
"error = 0.1\n",
"\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n"
]
},
@@ -163,7 +163,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -145,7 +145,7 @@
"\n",
"error = 0.01\n",
"\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -156,7 +156,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -154,7 +154,7 @@
" return s.median(filtered_x)\n",
"\n",
"error = 0.01\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -165,7 +165,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -183,7 +183,7 @@
" return s.mode(filtered_x)\n",
"\n",
"error = 0.01\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -194,7 +194,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -149,7 +149,7 @@
" return s.pstdev(filtered_x)\n",
"\n",
"error = 0.01\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n",
"\n"
]
@@ -160,7 +160,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns,error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -149,7 +149,7 @@
" return s.pvariance(filtered_x)\n",
"\n",
"error = 0.01\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n"
]
},
@@ -159,7 +159,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns,error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -150,7 +150,7 @@
" y = data[1]\n",
"\n",
" filter = (y < 20)\n",
" \n",
"\n",
" filtered_x = s.where(filter, x)\n",
" filtered_y = s.where(filter, y)\n",
" return s.linear_regression(filtered_x,filtered_y)\n",
@@ -158,7 +158,7 @@
"\n",
"\n",
"error = 0.05\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation, precal_witness_path, True, selected_columns, error)\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)\n"
]
},
@@ -168,7 +168,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False,error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns,error)\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},

View File

@@ -150,7 +150,7 @@
" return s.stdev(filtered_x)\n",
"\n",
"error = 0.05\n",
"_, prover_model = computation_to_model(computation,precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation,precal_witness_path, True, selected_columns, error)\n",
"\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model,prover_model_path, scales, \"resources\", settings_path)"
]
@@ -161,7 +161,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path, verifier_model, verifier_model_path)"
]

View File

@@ -149,7 +149,7 @@
" return s.variance(filtered_x)\n",
"\n",
"error = 0.01\n",
"_, prover_model = computation_to_model(computation,precal_witness_path, True, error)\n",
"_, prover_model = computation_to_model(computation,precal_witness_path, True, selected_columns, error)\n",
"\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model,prover_model_path, scales, \"resources\", settings_path)"
]
@@ -160,7 +160,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, error)\n",
"_, verifier_model = computation_to_model(computation, precal_witness_path, False, selected_columns, error)\n",
"\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path, verifier_model, verifier_model_path)"
]

View File

@@ -4,8 +4,8 @@ from pathlib import Path
import torch
from zkstats.core import create_dummy,prover_gen_settings, setup, prover_gen_proof, verifier_verify, generate_data_commitment, verifier_define_calculation
from zkstats.computation import IModel, State, computation_to_model
from zkstats.core import prover_gen_settings, setup, prover_gen_proof, verifier_verify, generate_data_commitment, verifier_define_calculation
from zkstats.computation import computation_to_model, TComputation, State, IModel
DEFAULT_POSSIBLE_SCALES = list(range(20))
@@ -27,16 +27,15 @@ def data_to_json_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, li
json.dump(column_to_data, f)
return column_to_data
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
def compute(
def compute_model(
basepath: Path,
data: list[torch.Tensor],
model: Type[IModel],
# computation: TComputation,
model: IModel,
scales_params: Optional[Sequence[int]] = None,
selected_columns_params: Optional[list[str]] = None,
# error:float = 1.0
) -> None:
):
sel_data_path = basepath / "comb_data.json"
model_path = basepath / "model.onnx"
settings_path = basepath / "settings.json"
@@ -65,12 +64,12 @@ def compute(
scales_for_commitments = scales_params
# create_dummy((data_path), (dummy_data_path))
generate_data_commitment((data_path), scales_for_commitments, (data_commitment_path))
# _, prover_model = computation_to_model(computation, (precal_witness_path), True, error)
# _, prover_model = computation_to_model(computation, (precal_witness_path), True, selected_columns, error)
prover_gen_settings((data_path), selected_columns, (sel_data_path), model, (model_path), scales, "resources", (settings_path))
# No need, since verifier & prover share the same onnx
# _, verifier_model = computation_to_model(computation, (precal_witness_path), False,error)
# _, verifier_model = computation_to_model(computation, (precal_witness_path), False, selected_columns, error)
# verifier_define_calculation((dummy_data_path), selected_columns, (sel_dummy_data_path),verifier_model, (verifier_model_path))
setup((model_path), (compiled_model_path), (settings_path),(vk_path), (pk_path ))
@@ -80,6 +79,29 @@ def compute(
verifier_verify((proof_path), (settings_path), (vk_path), selected_columns, (data_commitment_path))
def compute(
basepath: Path,
data: list[torch.Tensor],
computation: TComputation,
scales_params: Optional[Sequence[int]] = None,
selected_columns_params: Optional[list[str]] = None,
) -> State:
data_path = basepath / "data.json"
precal_witness_path = basepath / "precal_witness_path.json"
column_to_data = data_to_json_file(data_path, data)
# If selected_columns_params is None, select all columns
if selected_columns_params is None:
selected_columns = list(column_to_data.keys())
else:
selected_columns = selected_columns_params
state, model = computation_to_model(computation, precal_witness_path, True, selected_columns, selected_columns)
compute_model(basepath, data, model, scales_params, selected_columns_params)
return state
# Error tolerance between zkstats python implementation and python statistics module
ERROR_ZKSTATS_STATISTICS = 0.0001

View File

@@ -4,7 +4,7 @@ import torch
import pytest
from zkstats.computation import State, computation_to_model
from zkstats.computation import State, Args, computation_to_model
from zkstats.ops import (
Mean,
Median,
@@ -24,10 +24,10 @@ from zkstats.ops import (
from .helpers import assert_result, compute, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
def nested_computation(state: State, args: list[torch.Tensor]):
x = args[0]
y = args[1]
z = args[2]
def nested_computation(state: State, args: Args):
x = args['columns_0']
y = args['columns_1']
z = args['columns_2']
out_0 = state.median(x)
out_1 = state.geometric_mean(y)
out_2 = state.harmonic_mean(x)
@@ -63,12 +63,8 @@ def nested_computation(state: State, args: list[torch.Tensor]):
[ERROR_CIRCUIT_DEFAULT],
)
def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, column_2: torch.Tensor, error, scales):
precal_witness_path = tmp_path / "precal_witness_path.json"
state, model = computation_to_model(nested_computation, precal_witness_path,True, error)
x, y, z = column_0, column_1, column_2
compute(tmp_path, [x, y, z], model, scales)
# There are 11 ops in the computation
state = compute(tmp_path, [x, y, z], nested_computation, scales)
assert state.current_op_index == 12
ops = state.ops
@@ -156,12 +152,10 @@ def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[
def condition(_x: torch.Tensor):
return _x < 4
def where_and_op(state: State, args: list[torch.Tensor]):
x = args[0]
def where_and_op(state: State, args: Args):
x = args['columns_0']
return op_type(state, state.where(condition(x), x))
precal_witness_path = tmp_path / "precal_witness_path.json"
state, model = computation_to_model(where_and_op, precal_witness_path,True, error)
compute(tmp_path, [column], model, scales)
state = compute(tmp_path, [column], where_and_op, scales)
res_op = state.ops[-1]
filtered = column[condition(column)]
@@ -180,16 +174,14 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
def condition_0(_x: torch.Tensor):
return _x > 4
def where_and_op(state: State, args: list[torch.Tensor]):
x = args[0]
y = args[1]
def where_and_op(state: State, args: Args):
x = args['columns_0']
y = args['columns_1']
condition_x = condition_0(x)
filtered_x = state.where(condition_x, x)
filtered_y = state.where(condition_x, y)
return op_type(state, filtered_x, filtered_y)
precal_witness_path = tmp_path / "precal_witness_path.json"
state, model = computation_to_model(where_and_op, precal_witness_path, True ,error)
compute(tmp_path, [column_0, column_1], model, scales)
state = compute(tmp_path, [column_0, column_1], where_and_op, scales)
res_op = state.ops[-1]
condition_x = condition_0(column_0)

View File

@@ -69,12 +69,11 @@ def test_integration_select_partial_columns(tmp_path, column_0, column_1, error,
# Select only the first column from two columns
selected_columns = [columns[0]]
def simple_computation(state, x):
return state.mean(x[0])
precal_witness_path = tmp_path / "precal_witness_path.json"
_, model = computation_to_model(simple_computation,precal_witness_path, True, error)
def simple_computation(state, args):
x = args['columns_0']
return state.mean(x)
# gen settings, setup, prove, verify
compute(tmp_path, [column_0, column_1], model, scales, selected_columns)
compute(tmp_path, [column_0, column_1], simple_computation, scales, selected_columns)
def test_csv_data(tmp_path, column_0, column_1, error, scales):
@@ -85,8 +84,9 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
selected_columns = list(data_json.keys())
def simple_computation(state, x):
return state.mean(x[0])
def simple_computation(state, args):
x = args['columns_0']
return state.mean(x)
sel_data_path = tmp_path / "comb_data.json"
model_path = tmp_path / "model.onnx"
@@ -98,7 +98,7 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
generate_data_commitment(data_csv_path, scales, data_commitment_path)
# Test: `prover_gen_settings` works with csv
_, model_for_proving = computation_to_model(simple_computation, precal_witness_path, True,error)
_, model_for_proving = computation_to_model(simple_computation, precal_witness_path, True, selected_columns, error)
prover_gen_settings(
data_path=data_csv_path,
selected_columns=selected_columns,
@@ -112,7 +112,7 @@ def test_csv_data(tmp_path, column_0, column_1, error, scales):
# Test: `prover_gen_settings` works with csv
# Instantiate the model for verification since the state of `model_for_proving` is changed after `prover_gen_settings`
_, model_for_verification = computation_to_model(simple_computation, precal_witness_path, False,error)
_, model_for_verification = computation_to_model(simple_computation, precal_witness_path, False, selected_columns, error)
verifier_define_calculation(data_csv_path, selected_columns, str(sel_data_path), model_for_verification, str(model_path))
def json_file_to_csv(data_json_path, data_csv_path):

View File

@@ -7,7 +7,7 @@ import torch
from zkstats.ops import Mean, Median, GeometricMean, HarmonicMean, Mode, PStdev, PVariance, Stdev, Variance, Covariance, Correlation, Operation, Regression
from zkstats.computation import IModel, IsResultPrecise, State, computation_to_model
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
from .helpers import compute_model, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
@pytest.mark.parametrize(
@@ -58,7 +58,7 @@ def test_linear_regression(tmp_path, column_0: torch.Tensor, column_1: torch.Ten
class Model(IModel):
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return regression.ezkl(x), regression.result
compute(tmp_path, columns, Model, scales)
compute_model(tmp_path, columns, Model, scales)
@@ -70,4 +70,4 @@ def run_test_ops(tmp_path, op_type: Type[Operation], expected_func: Callable[[li
class Model(IModel):
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return op.ezkl(x), op.result
compute(tmp_path, columns, Model, scales)
compute_model(tmp_path, columns, Model, scales)

View File

@@ -270,8 +270,24 @@ class IModel(nn.Module):
# return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
TComputation = Callable[[State, list[torch.Tensor]], torch.Tensor]
class Args:
def __init__(
self,
columns: list[str],
data: list[torch.Tensor],
):
if len(columns) != len(data):
raise ValueError("columns and data must have the same length")
self.data_dict = {
column_name: d
for column_name, d in zip(columns, data)
}
def computation_to_model(computation: TComputation, precal_witness_path:str, isProver:bool ,error: float = DEFAULT_ERROR ) -> tuple[State, Type[IModel]]:
def __getitem__(self, key: str) -> torch.Tensor:
return self.data_dict[key]
def computation_to_model(computation: TComputation, precal_witness_path: str, isProver:bool, selected_columns: list[str], error: float = DEFAULT_ERROR ) -> tuple[State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
:param computation: A function that takes a State and a list of torch.Tensor, and returns a torch.Tensor
@@ -281,7 +297,7 @@ def computation_to_model(computation: TComputation, precal_witness_path:str, isP
"""
state = State(error)
state.precal_witness_path= precal_witness_path
state.precal_witness_path = precal_witness_path
state.isProver = isProver
class Model(IModel):
@@ -291,14 +307,16 @@ def computation_to_model(computation: TComputation, precal_witness_path:str, isP
"""
# In the preprocess step, the operations are calculated and the results are stored in the state.
# So we don't need to get the returned result
computation(state, x)
args = Args(selected_columns, x)
computation(state, args)
state.set_ready_for_exporting_onnx()
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
"""
Called by torch.onnx.export.
"""
result = computation(state, x)
args = Args(selected_columns, x)
result = computation(state, args)
is_computation_result_accurate = state.bools[0]()
for op_precise_check in state.bools[1:]:
is_op_result_accurate = op_precise_check()