mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
19
README.md
19
README.md
@@ -65,26 +65,26 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
Aside from the ZKStats operations, you can also use PyTorch functions like (`torch.abs`, `torch.max`, ...etc).
|
||||
|
||||
**Caveats**: Not all PyTorch functions are supported. For example, filtering data from a list by `X[X > 0]` is not supported because the zk circuit needs to be of a predetermined size, hence we cannot arbitrarily reshape our X into a new shape based on the filter condition inside the circuit.
|
||||
|
||||
TODO: We should have a list for all supported PyTorch functions.
|
||||
|
||||
**Caveats**: Not all PyTorch functions are supported. For example, filtering data from a list by `X[X > 0]` is not supported because the zk circuit needs to be of a predetermined size, hence we cannot arbitrarily reshape our X into a new shape based on the filter condition inside the circuit. To filter data based on condition, we can use s.where as follows.
|
||||
|
||||
#### Data Filtering
|
||||
|
||||
Since we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we need to filter data while still preserving the shape. We use condition + `torch.where` instead.
|
||||
Although we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we implemented State.where operation that allows users to filter data by their own choice of condition as follows.
|
||||
|
||||
```python
|
||||
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
|
||||
# Compute the mean of the absolute values
|
||||
x = data[0]
|
||||
condition = x > 20
|
||||
# Filter out data that is greater than 20. For the data that is greater than 20, we will use 0.0
|
||||
fil_X = torch.where(condition=condition, input=x, other=0.0)
|
||||
return s.mean(fil_X)
|
||||
# Here condition can be chained as shown below, and can have many variables if we have more than just x: e.g. filter = torch.logical_and(x>20, y<2) in case of regression for example.
|
||||
filter = torch.logical_and(x > 20, x<50)
|
||||
# call our where function
|
||||
filtered_x = s.where(filter, x)
|
||||
# Then, can use the stats operation as usual
|
||||
return s.mean(filtered_x)
|
||||
```
|
||||
|
||||
**Caveats**: Currently, this 'where' operation still doesn't work correctly, since we cannot just plug fil_X into our current s.mean() due to incompatible shape of fil_X and X in reality, we will update the compatible implementation of how to do data filtering soon. Keep posted!
|
||||
|
||||
### Proof Generation and Verification
|
||||
|
||||
The flow between data providers and users is as follows:
|
||||
@@ -204,6 +204,7 @@ See our jupyter notebook for [examples](./examples/).
|
||||
## Benchmarks
|
||||
|
||||
See our jupyter notebook for [benchmarks](./benchmark/).
|
||||
TODO: clean benchmark
|
||||
|
||||
## Note
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"x1": [
|
||||
7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 10.0, 6.0, 6.0, 9.1, 1.7, 9.2,
|
||||
10.0, 7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 6.0, 6.0, 9.1, 1.7, 9.2,
|
||||
0.2, 7.8, 3.7, 7.0, 2.5, 2.8, 5.9, 7.3, 2.9, 2.9, 3.5, 1.0, 9.7, 4.8, 0.9,
|
||||
7.1, 3.6, 8.2, 3.0, 7.6, 4.2, 5.2, 8.1, 6.3, 9.3, 8.8, 8.2, 6.7, 4.9, 5.4,
|
||||
9.8, 5.9, 7.1, 3.9, 9.3
|
||||
@@ -12,7 +12,7 @@
|
||||
1.5, 2.1, 0.4, 4.3, 0.2
|
||||
],
|
||||
"y": [
|
||||
18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 20.8, 12.5, 21.5, 32.5,
|
||||
20.8, 18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 12.5, 21.5, 32.5,
|
||||
18.6, 23.9, 7.0, 16.9, 22.9, 31.0, 15.0, 8.5, 8.7, 28.9, 19.7, 12.5, 17.4,
|
||||
7.2, 25.5, 21.4, 15.7, 15.5, 8.2, 28.2, 19.5, 25.5, 12.5, 20.3, 21.7, 22.1,
|
||||
19.6, 32.2, 22.4, 20.6, 19.7, 20.8, 21.1, 21.8, 17.7, 21.1, 19.4
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,308 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: ezkl==7.0.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 1)) (7.0.0)\n",
|
||||
"Requirement already satisfied: torch in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 2)) (2.2.0)\n",
|
||||
"Requirement already satisfied: requests in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 3)) (2.31.0)\n",
|
||||
"Requirement already satisfied: scipy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 4)) (1.12.0)\n",
|
||||
"Requirement already satisfied: numpy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 5)) (1.26.3)\n",
|
||||
"Requirement already satisfied: matplotlib in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 6)) (3.8.2)\n",
|
||||
"Requirement already satisfied: statistics in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 7)) (1.0.3.5)\n",
|
||||
"Requirement already satisfied: onnx in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 8)) (1.15.0)\n",
|
||||
"Requirement already satisfied: filelock in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.13.1)\n",
|
||||
"Requirement already satisfied: typing-extensions>=4.8.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (4.9.0)\n",
|
||||
"Requirement already satisfied: sympy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (1.12)\n",
|
||||
"Requirement already satisfied: networkx in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.2.1)\n",
|
||||
"Requirement already satisfied: jinja2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.1.3)\n",
|
||||
"Requirement already satisfied: fsspec in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (2023.12.2)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (3.3.2)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (3.6)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (2.2.0)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (2024.2.2)\n",
|
||||
"Requirement already satisfied: contourpy>=1.0.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (1.2.0)\n",
|
||||
"Requirement already satisfied: cycler>=0.10 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (0.12.1)\n",
|
||||
"Requirement already satisfied: fonttools>=4.22.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (4.47.2)\n",
|
||||
"Requirement already satisfied: kiwisolver>=1.3.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (1.4.5)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (23.2)\n",
|
||||
"Requirement already satisfied: pillow>=8 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (10.2.0)\n",
|
||||
"Requirement already satisfied: pyparsing>=2.3.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (3.1.1)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.7 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (2.8.2)\n",
|
||||
"Requirement already satisfied: docutils>=0.3 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from statistics->-r ../../requirements.txt (line 7)) (0.20.1)\n",
|
||||
"Requirement already satisfied: protobuf>=3.20.2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from onnx->-r ../../requirements.txt (line 8)) (4.25.2)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib->-r ../../requirements.txt (line 6)) (1.16.0)\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from jinja2->torch->-r ../../requirements.txt (line 2)) (2.1.4)\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from sympy->torch->-r ../../requirements.txt (line 2)) (1.3.0)\n",
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pip install -r ../../requirements.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import scipy\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import statistics\n",
|
||||
"import math"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# init path\n",
|
||||
"os.makedirs(os.path.dirname('shared/'), exist_ok=True)\n",
|
||||
"os.makedirs(os.path.dirname('prover/'), exist_ok=True)\n",
|
||||
"verifier_model_path = os.path.join('shared/verifier.onnx')\n",
|
||||
"prover_model_path = os.path.join('prover/prover.onnx')\n",
|
||||
"verifier_compiled_model_path = os.path.join('shared/verifier.compiled')\n",
|
||||
"prover_compiled_model_path = os.path.join('prover/prover.compiled')\n",
|
||||
"pk_path = os.path.join('shared/test.pk')\n",
|
||||
"vk_path = os.path.join('shared/test.vk')\n",
|
||||
"proof_path = os.path.join('shared/test.pf')\n",
|
||||
"settings_path = os.path.join('shared/settings.json')\n",
|
||||
"srs_path = os.path.join('shared/kzg.srs')\n",
|
||||
"witness_path = os.path.join('prover/witness.json')\n",
|
||||
"# this is private to prover since it contains actual data\n",
|
||||
"sel_data_path = os.path.join('prover/sel_data.json')\n",
|
||||
"# this is just dummy random value\n",
|
||||
"sel_dummy_data_path = os.path.join('shared/sel_dummy_data.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"======================= ZK-STATS FLOW ======================="
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_path = os.path.join('data.json')\n",
|
||||
"dummy_data_path = os.path.join('shared/dummy_data.json')\n",
|
||||
"\n",
|
||||
"create_dummy(data_path, dummy_data_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"scales = [2]\n",
|
||||
"selected_columns = ['col_name']\n",
|
||||
"commitment_maps = get_data_commitment_maps(data_path, scales)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/mhchia/projects/work/pse/zk-stats-lib/zkstats/computation.py:166: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
|
||||
" is_precise_aggregated = torch.tensor(1.0)\n",
|
||||
"/Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages/torch/onnx/symbolic_opset9.py:2174: FutureWarning: 'torch.onnx.symbolic_opset9._cast_Bool' is deprecated in version 2.0 and will be removed in the future. Please Avoid using this function and create a Cast node instead.\n",
|
||||
" return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Verifier/ data consumer side: send desired calculation\n",
|
||||
"from zkstats.computation import computation_to_model, State\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:\n",
|
||||
" x = data[0]\n",
|
||||
" # FIXME: should be replaced by `s.where` when it's available. Now the result may be incorrect\n",
|
||||
" filter = (x < 50)\n",
|
||||
" min_x = torch.min(x)\n",
|
||||
" filtered_x = torch.where(filter, x, min_x - 1)\n",
|
||||
" return s.median(filtered_x)\n",
|
||||
"\n",
|
||||
"error = 0.01\n",
|
||||
"_, verifier_model = computation_to_model(computation, error)\n",
|
||||
"\n",
|
||||
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Theory_output: tensor(40., dtype=torch.float64)\n",
|
||||
"==== Generate & Calibrate Setting ====\n",
|
||||
"scale: [2]\n",
|
||||
"setting: {\"run_args\":{\"tolerance\":{\"val\":0.0,\"scale\":1.0},\"input_scale\":2,\"param_scale\":2,\"scale_rebase_multiplier\":10,\"lookup_range\":[-582,1208],\"logrows\":14,\"num_inner_cols\":2,\"variables\":[[\"batch_size\",1]],\"input_visibility\":{\"Hashed\":{\"hash_is_public\":true,\"outlets\":[]}},\"output_visibility\":\"Public\",\"param_visibility\":\"Private\"},\"num_rows\":14432,\"total_assignments\":15928,\"total_const_size\":2126,\"model_instance_shapes\":[[1],[1]],\"model_output_scales\":[0,2],\"model_input_scales\":[2],\"module_sizes\":{\"kzg\":[],\"poseidon\":[14432,[1]],\"elgamal\":[0,[0]]},\"required_lookups\":[\"Abs\",{\"Div\":{\"denom\":2.0}},\"ReLU\",{\"Floor\":{\"scale\":4.0}},{\"GreaterThan\":{\"a\":0.0}},\"KroneckerDelta\"],\"check_mode\":\"UNSAFE\",\"version\":\"7.0.0\",\"num_blinding_factors\":null}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Prover/ data owner side\n",
|
||||
"_, prover_model = computation_to_model(computation, error)\n",
|
||||
"\n",
|
||||
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"spawning module 0\n",
|
||||
"spawning module 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"==== setting up ezkl ====\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"spawning module 0\n",
|
||||
"spawning module 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Time setup: 4.408194065093994 seconds\n",
|
||||
"=======================================\n",
|
||||
"Theory output: tensor(40., dtype=torch.float64)\n",
|
||||
"==== Generating Witness ====\n",
|
||||
"witness boolean: 1.0\n",
|
||||
"witness result 1 : 40.0\n",
|
||||
"==== Generating Proof ====\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"spawning module 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Here verifier & prover can concurrently call setup since all params are public to get pk.\n",
|
||||
"# Here write as verifier function to emphasize that verifier must calculate its own vk to be sure\n",
|
||||
"setup(verifier_model_path, verifier_compiled_model_path, settings_path,vk_path, pk_path )\n",
|
||||
"\n",
|
||||
"print(\"=======================================\")\n",
|
||||
"# Prover generates proof\n",
|
||||
"print(\"Theory output: \", theory_output)\n",
|
||||
"prover_gen_proof(prover_model_path, sel_data_path, witness_path, prover_compiled_model_path, settings_path, proof_path, pk_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"num_inputs: 1\n",
|
||||
"prf instances: [[[1780239215148830498, 13236513277824664467, 10913529727158264423, 131860697733488968], [12436184717236109307, 3962172157175319849, 7381016538464732718, 1011752739694698287], [12341676197686541490, 2627393525778350065, 16625494184434727973, 1478518078215075360]]]\n",
|
||||
"proof boolean: 1.0\n",
|
||||
"proof result 1 : 40.0\n",
|
||||
"verified\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Verifier verifies\n",
|
||||
"res = verifier_verify(proof_path, settings_path, vk_path, selected_columns, commitment_maps)\n",
|
||||
"print(\"Verifier gets result:\", res)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.1"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
13
examples/where/where+correlation/data.json
Normal file
13
examples/where/where+correlation/data.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"x": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
|
||||
40, 41, 42, 43, 44, 45, 46, 47, 48, 49
|
||||
],
|
||||
"y": [
|
||||
2.0, 5.2, 47.4, 23.6, 24.8, 27.0, 47.2, 50.4, 58.6, 57.8, 60.0, 27.2, 40.4,
|
||||
63.6, 28.8, 19.0, 65.2, 50.4, 63.6, 24.8, 35.0, 41.2, 54.4, 61.6, 57.8,
|
||||
78.0, 63.2, 55.4, 78.6, 72.8, 51.0, 62.2, 42.4, 47.6, 83.8, 62.0, 47.26,
|
||||
90.4, 80.6, 87.8, 82.0, 50.2, 80.4, 86.6, 80.8, 66.0, 95.2, 58.4, 74.6, 95.8
|
||||
]
|
||||
}
|
||||
256
examples/where/where+correlation/where+correlation.ipynb
Normal file
256
examples/where/where+correlation/where+correlation.ipynb
Normal file
File diff suppressed because one or more lines are too long
23
examples/where/where+covariance/data.json
Normal file
23
examples/where/where+covariance/data.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"x": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
|
||||
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
|
||||
59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
|
||||
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
|
||||
97, 98, 99
|
||||
],
|
||||
"y": [
|
||||
124.09, -37.12, 80.21, 62.13, 0.2, 45.79, 49.2, 105.68, 159.44, 135.97,
|
||||
141.56, 64.51, 9.56, 116.33, 8.94, 70.32, 79.54, 33.52, 123.45, 79.17,
|
||||
185.38, 113.5, 32.74, 204.23, 130.27, 76.63, 106.42, 191.19, 202.93, 212.26,
|
||||
146.64, 161.91, 203.84, 160.09, 124.31, 200.52, 220.17, 117.14, 270.71,
|
||||
232.62, 152.37, 137.63, 145.5, 290.38, 210.03, 305.25, 161.26, 213.92,
|
||||
126.23, 166.79, 174.66, 252.53, 229.94, 307.21, 190.17, 250.22, 215.2,
|
||||
196.0, 195.61, 250.23, 318.48, 235.65, 178.28, 200.54, 293.78, 243.95,
|
||||
319.72, 255.92, 313.52, 376.36, 304.54, 327.65, 317.21, 413.62, 400.09,
|
||||
347.71, 333.18, 302.64, 308.08, 430.22, 268.85, 367.58, 402.2, 274.01,
|
||||
460.27, 442.84, 280.65, 448.93, 345.64, 384.94, 438.06, 457.26, 337.6,
|
||||
303.74, 345.57, 494.38, 450.31, 383.49, 353.11, 392.61
|
||||
]
|
||||
}
|
||||
240
examples/where/where+covariance/where+covariance.ipynb
Normal file
240
examples/where/where+covariance/where+covariance.ipynb
Normal file
File diff suppressed because one or more lines are too long
253
examples/where/where+geomean/where+geomean.ipynb
Normal file
253
examples/where/where+geomean/where+geomean.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"col_name": [
|
||||
33.0, 15.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
|
||||
33.0, 75.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
|
||||
41.0, 47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0,
|
||||
51.0, 48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0,
|
||||
51.0, 56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0,
|
||||
241
examples/where/where+harmomean/harmomean.ipynb
Normal file
241
examples/where/where+harmomean/harmomean.ipynb
Normal file
File diff suppressed because one or more lines are too long
8
examples/where/where+mean/data.json
Normal file
8
examples/where/where+mean/data.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"col_name": [
|
||||
46.2, 40.4, 44.8, 48.1, 51.2, 91.9, 38.2, 36.3, 22.2, 11.5, 17.9, 20.2,
|
||||
99.9, 75.2, 29.8, 19.4, 46.1, 94.8, 6.6, 94.5, 99.7, 1.6, 4.0, 86.7, 28.7,
|
||||
63.0, 66.7, 2.5, 41.4, 35.6, 45.0, 13.7, 9.6, 16.6, 9.8, 20.3, 25.9, 71.9,
|
||||
27.5, 30.9, 62.9, 18.6, 45.7, 2.4, 91.4, 16.2, 61.5, 41.4, 77.1, 53.2
|
||||
]
|
||||
}
|
||||
245
examples/where/where+mean/where+mean.ipynb
Normal file
245
examples/where/where+mean/where+mean.ipynb
Normal file
File diff suppressed because one or more lines are too long
245
examples/where/where+median/where+median.ipynb
Normal file
245
examples/where/where+median/where+median.ipynb
Normal file
File diff suppressed because one or more lines are too long
276
examples/where/where+mode/where+mode.ipynb
Normal file
276
examples/where/where+mode/where+mode.ipynb
Normal file
File diff suppressed because one or more lines are too long
29
examples/where/where+pstdev/data.json
Normal file
29
examples/where/where+pstdev/data.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"col_name": [
|
||||
33.0, 75.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
|
||||
41.0, 47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0,
|
||||
51.0, 48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0,
|
||||
51.0, 56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0,
|
||||
62.0, 39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0,
|
||||
52.0, 45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0,
|
||||
45.0, 45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0,
|
||||
34.0, 66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0,
|
||||
30.0, 33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0,
|
||||
62.0, 38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0,
|
||||
52.0, 41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0,
|
||||
52.0, 58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0,
|
||||
64.0, 46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0,
|
||||
69.0, 27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0,
|
||||
55.0, 58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0,
|
||||
45.0, 40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0,
|
||||
39.0, 58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0,
|
||||
34.0, 50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0,
|
||||
22.0, 50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0,
|
||||
44.0, 55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0,
|
||||
43.0, 50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0,
|
||||
41.0, 41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0,
|
||||
64.0, 76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0,
|
||||
50.0, 31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0,
|
||||
40.0, 25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
|
||||
]
|
||||
}
|
||||
244
examples/where/where+pstdev/where+pstdev.ipynb
Normal file
244
examples/where/where+pstdev/where+pstdev.ipynb
Normal file
File diff suppressed because one or more lines are too long
29
examples/where/where+pvariance/data.json
Normal file
29
examples/where/where+pvariance/data.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"col_name": [
|
||||
33.0, 75.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
|
||||
41.0, 47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0,
|
||||
51.0, 48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0,
|
||||
51.0, 56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0,
|
||||
62.0, 39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0,
|
||||
52.0, 45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0,
|
||||
45.0, 45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0,
|
||||
34.0, 66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0,
|
||||
30.0, 33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0,
|
||||
62.0, 38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0,
|
||||
52.0, 41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0,
|
||||
52.0, 58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0,
|
||||
64.0, 46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0,
|
||||
69.0, 27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0,
|
||||
55.0, 58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0,
|
||||
45.0, 40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0,
|
||||
39.0, 58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0,
|
||||
34.0, 50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0,
|
||||
22.0, 50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0,
|
||||
44.0, 55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0,
|
||||
43.0, 50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0,
|
||||
41.0, 41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0,
|
||||
64.0, 76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0,
|
||||
50.0, 31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0,
|
||||
40.0, 25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
|
||||
]
|
||||
}
|
||||
238
examples/where/where+pvariance/where+pvariance.ipynb
Normal file
238
examples/where/where+pvariance/where+pvariance.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"x1": [
|
||||
7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 10.0, 6.0, 6.0, 9.1, 1.7, 9.2,
|
||||
10.0, 7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 6.0, 6.0, 9.1, 1.7, 9.2,
|
||||
0.2, 7.8, 3.7, 7.0, 2.5, 2.8, 5.9, 7.3, 2.9, 2.9, 3.5, 1.0, 9.7, 4.8, 0.9,
|
||||
7.1, 3.6, 8.2, 3.0, 7.6, 4.2, 5.2, 8.1, 6.3, 9.3, 8.8, 8.2, 6.7, 4.9, 5.4,
|
||||
9.8, 5.9, 7.1, 3.9, 9.3
|
||||
@@ -12,7 +12,7 @@
|
||||
1.5, 2.1, 0.4, 4.3, 0.2
|
||||
],
|
||||
"y": [
|
||||
18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 20.8, 12.5, 21.5, 32.5,
|
||||
20.8, 18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 12.5, 21.5, 32.5,
|
||||
18.6, 23.9, 7.0, 16.9, 22.9, 31.0, 15.0, 8.5, 8.7, 28.9, 19.7, 12.5, 17.4,
|
||||
7.2, 25.5, 21.4, 15.7, 15.5, 8.2, 28.2, 19.5, 25.5, 12.5, 20.3, 21.7, 22.1,
|
||||
19.6, 32.2, 22.4, 20.6, 19.7, 20.8, 21.1, 21.8, 17.7, 21.1, 19.4
|
||||
291
examples/where/where+regression/where+regression.ipynb
Normal file
291
examples/where/where+regression/where+regression.ipynb
Normal file
File diff suppressed because one or more lines are too long
239
examples/where/where+stdev/where+stdev.ipynb
Normal file
239
examples/where/where+stdev/where+stdev.ipynb
Normal file
File diff suppressed because one or more lines are too long
29
examples/where/where+variance/data.json
Normal file
29
examples/where/where+variance/data.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"col_name": [
|
||||
33.0, 75.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
|
||||
41.0, 47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0,
|
||||
51.0, 48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0,
|
||||
51.0, 56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0,
|
||||
62.0, 39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0,
|
||||
52.0, 45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0,
|
||||
45.0, 45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0,
|
||||
34.0, 66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0,
|
||||
30.0, 33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0,
|
||||
62.0, 38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0,
|
||||
52.0, 41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0,
|
||||
52.0, 58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0,
|
||||
64.0, 46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0,
|
||||
69.0, 27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0,
|
||||
55.0, 58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0,
|
||||
45.0, 40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0,
|
||||
39.0, 58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0,
|
||||
34.0, 50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0,
|
||||
22.0, 50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0,
|
||||
44.0, 55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0,
|
||||
43.0, 50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0,
|
||||
41.0, 41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0,
|
||||
64.0, 76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0,
|
||||
50.0, 31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0,
|
||||
40.0, 25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
|
||||
]
|
||||
}
|
||||
233
examples/where/where+variance/where+variance.ipynb
Normal file
233
examples/where/where+variance/where+variance.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -24,4 +24,4 @@ def column_2():
|
||||
|
||||
@pytest.fixture
|
||||
def scales():
|
||||
return [3]
|
||||
return [6]
|
||||
|
||||
@@ -10,6 +10,11 @@ from zkstats.computation import IModel
|
||||
|
||||
DEFAULT_POSSIBLE_SCALES = list(range(20))
|
||||
|
||||
# Error tolerance between circuit and python implementation
|
||||
ERROR_CIRCUIT_DEFAULT = 0.01
|
||||
ERROR_CIRCUIT_STRICT = 0.0001
|
||||
ERROR_CIRCUIT_RELAXED = 0.1
|
||||
|
||||
|
||||
def data_to_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, list]:
|
||||
column_names = [f"columns_{i}" for i in range(len(data))]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Type, Callable
|
||||
import statistics
|
||||
import torch
|
||||
|
||||
@@ -17,9 +18,10 @@ from zkstats.ops import (
|
||||
Covariance,
|
||||
Correlation,
|
||||
Regression,
|
||||
Operation
|
||||
)
|
||||
|
||||
from .helpers import assert_result, compute
|
||||
from .helpers import assert_result, compute, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
|
||||
|
||||
|
||||
def nested_computation(state: State, args: list[torch.Tensor]):
|
||||
@@ -129,3 +131,67 @@ def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Te
|
||||
assert isinstance(op_11, Mean)
|
||||
out_11 = statistics.mean([out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8, out_9, out_10.slope, out_10.intercept])
|
||||
assert_result(torch.tensor(out_11), op_11.result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op_type, expected_func, error",
|
||||
[
|
||||
(State.mean, statistics.mean, ERROR_CIRCUIT_DEFAULT),
|
||||
(State.median, statistics.median, ERROR_CIRCUIT_DEFAULT),
|
||||
(State.geometric_mean, statistics.geometric_mean, ERROR_CIRCUIT_DEFAULT),
|
||||
# Be more tolerant for HarmonicMean
|
||||
(State.harmonic_mean, statistics.harmonic_mean, ERROR_CIRCUIT_RELAXED),
|
||||
# Be less tolerant for Mode
|
||||
(State.mode, statistics.mode, ERROR_CIRCUIT_STRICT),
|
||||
(State.pstdev, statistics.pstdev, ERROR_CIRCUIT_DEFAULT),
|
||||
(State.pvariance, statistics.pvariance, ERROR_CIRCUIT_DEFAULT),
|
||||
(State.stdev, statistics.stdev, ERROR_CIRCUIT_DEFAULT),
|
||||
(State.variance, statistics.variance, ERROR_CIRCUIT_DEFAULT),
|
||||
]
|
||||
)
|
||||
def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[[Operation, torch.Tensor], torch.Tensor], expected_func: Callable[[list[float]], float], scales):
|
||||
column = column_0
|
||||
def condition(_x: torch.Tensor):
|
||||
return _x < 4
|
||||
|
||||
def where_and_op(state: State, args: list[torch.Tensor]):
|
||||
x = args[0]
|
||||
return op_type(state, state.where(condition(x), x))
|
||||
|
||||
state, model = computation_to_model(where_and_op, error)
|
||||
compute(tmp_path, [column], model, scales)
|
||||
|
||||
res_op = state.ops[-1]
|
||||
filtered = column[condition(column)]
|
||||
expected_res = expected_func(filtered.tolist())
|
||||
assert_result(res_op.result.data, expected_res)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op_type, expected_func, error",
|
||||
[
|
||||
(State.covariance, statistics.covariance, ERROR_CIRCUIT_RELAXED),
|
||||
(State.correlation, statistics.correlation, ERROR_CIRCUIT_RELAXED),
|
||||
]
|
||||
)
|
||||
def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type: Callable[[Operation, torch.Tensor], torch.Tensor], expected_func: Callable[[list[float]], float], scales):
|
||||
def condition_0(_x: torch.Tensor):
|
||||
return _x > 4
|
||||
|
||||
def where_and_op(state: State, args: list[torch.Tensor]):
|
||||
x = args[0]
|
||||
y = args[1]
|
||||
condition_x = condition_0(x)
|
||||
filtered_x = state.where(condition_x, x)
|
||||
filtered_y = state.where(condition_x, y)
|
||||
return op_type(state, filtered_x, filtered_y)
|
||||
|
||||
state, model = computation_to_model(where_and_op, error)
|
||||
compute(tmp_path, [column_0, column_1], model, scales)
|
||||
|
||||
res_op = state.ops[-1]
|
||||
condition_x = condition_0(column_0)
|
||||
filtered_x = column_0[condition_x]
|
||||
filtered_y = column_1[condition_x]
|
||||
expected_res = expected_func(filtered_x.tolist(), filtered_y.tolist())
|
||||
assert_result(res_op.result.data, expected_res)
|
||||
|
||||
@@ -7,13 +7,7 @@ import torch
|
||||
from zkstats.ops import Mean, Median, GeometricMean, HarmonicMean, Mode, PStdev, PVariance, Stdev, Variance, Covariance, Correlation, Operation, Regression
|
||||
from zkstats.computation import IModel, IsResultPrecise
|
||||
|
||||
from .helpers import compute, assert_result
|
||||
|
||||
|
||||
# Error tolerance between circuit and python implementation
|
||||
ERROR_CIRCUIT_DEFAULT = 0.01
|
||||
ERROR_CIRCUIT_STRICT = 0.0001
|
||||
ERROR_CIRCUIT_RELAXED = 0.1
|
||||
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -18,6 +18,7 @@ from .ops import (
|
||||
Covariance,
|
||||
Correlation,
|
||||
Regression,
|
||||
Where,
|
||||
IsResultPrecise,
|
||||
)
|
||||
|
||||
@@ -128,8 +129,20 @@ class State:
|
||||
Calculate the linear regression of x and y. The behavior should conform to
|
||||
[statistics.linear_regression](https://docs.python.org/3/library/statistics.html#statistics.linear_regression) in Python standard library.
|
||||
"""
|
||||
# hence support only one x for now
|
||||
return self._call_op([x, y], Regression)
|
||||
|
||||
# WHERE operation
|
||||
def where(self, filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the where operation of x. The behavior should conform to `torch.where` in PyTorch.
|
||||
|
||||
:param filter: A boolean tensor serves as a filter
|
||||
:param x: A tensor to be filtered
|
||||
:return: filtered tensor
|
||||
"""
|
||||
return self._call_op([filter, x], Where)
|
||||
|
||||
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
|
||||
if self.current_op_index is None:
|
||||
op = op_type.create(x, self.error)
|
||||
@@ -212,3 +225,4 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
|
||||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
|
||||
return computation(state, x)
|
||||
return state, Model
|
||||
|
||||
|
||||
@@ -45,7 +45,9 @@ def create_dummy(data_path: str, dummy_data_path: str) -> None:
|
||||
dummy_data ={}
|
||||
for col in data:
|
||||
# not use same value for every column to prevent something weird, like singular matrix
|
||||
dummy_data[col] = np.round(np.random.uniform(1,30,len(data[col])),1).tolist()
|
||||
min_col = min(data[col])
|
||||
max_col = max(data[col])
|
||||
dummy_data[col] = np.round(np.random.uniform(min_col,max_col,len(data[col])),1).tolist()
|
||||
|
||||
json.dump(dummy_data, open(dummy_data_path, 'w'))
|
||||
|
||||
@@ -250,12 +252,10 @@ def verifier_verify(proof_path: str, settings_path: str, vk_path: str, selected_
|
||||
# - is a tuple (is_in_error, result)
|
||||
# - is_valid is True
|
||||
# Sanity check
|
||||
# is_in_error = ezkl.vecu64_to_float(outputs[0], output_scales[0])
|
||||
is_in_error = ezkl.felt_to_float(outputs[0], output_scales[0])
|
||||
assert is_in_error == 1.0, f"result is not within error"
|
||||
result_arr = []
|
||||
for index in range(len(outputs)-1):
|
||||
# result_arr.append(ezkl.vecu64_to_float(outputs[index+1], output_scales[1]))
|
||||
result_arr.append(ezkl.felt_to_float(outputs[index+1], output_scales[1]))
|
||||
return result_arr
|
||||
|
||||
|
||||
193
zkstats/ops.py
193
zkstats/ops.py
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
# boolean: either 1.0 or 0.0
|
||||
IsResultPrecise = torch.Tensor
|
||||
MagicNumber = 9999999
|
||||
|
||||
|
||||
class Operation(ABC):
|
||||
@@ -22,15 +23,28 @@ class Operation(ABC):
|
||||
...
|
||||
|
||||
|
||||
class Where(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Where':
|
||||
# here error is trivial, but here to conform to other functions
|
||||
return cls(torch.where(x[0],x[1], MagicNumber ),error)
|
||||
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
|
||||
bool_array = torch.logical_or(x[1]==self.result, torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
|
||||
# print('sellll: ', self.result)
|
||||
return torch.sum(bool_array.float())==x[1].size()[1]
|
||||
|
||||
|
||||
class Mean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Mean':
|
||||
return cls(torch.mean(x[0]), error)
|
||||
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
|
||||
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = x.size()
|
||||
return torch.abs(torch.sum(x)-size[1]*self.result)<=torch.abs(self.error*size[1]*self.result)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x = torch.where(x==MagicNumber, 0.0, x)
|
||||
return torch.abs(torch.sum(x)-size*self.result)<=torch.abs(self.error*self.result*size)
|
||||
|
||||
|
||||
def to_1d(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -51,6 +65,7 @@ class Median(Operation):
|
||||
# we want in our context. However, we tend to have x as a `[1, len(x), 1]`. In this case,
|
||||
# we need to flatten `x` to 1d array to get the correct `lower` and `upper`.
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
super().__init__(torch.tensor(np.median(x_1d)), error)
|
||||
sorted_x = np.sort(x_1d)
|
||||
len_x = len(x_1d)
|
||||
@@ -63,21 +78,25 @@ class Median(Operation):
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
old_size = x.size()[1]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
min_x = torch.min(x)
|
||||
x = torch.where(x==MagicNumber,min_x-1, x)
|
||||
|
||||
# since within 1%, we regard as same value
|
||||
count_less = torch.sum((x < self.result).float())
|
||||
count_less = torch.sum((x < self.result).float())-(old_size-size)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
len = x.size()[1]
|
||||
half_len = torch.floor(torch.div(len, 2))
|
||||
|
||||
less_cons = count_less<half_len+len%2
|
||||
more_cons = count_less+count_equal>half_len
|
||||
half_size = torch.floor(torch.div(size, 2))
|
||||
|
||||
less_cons = count_less<half_size+size%2
|
||||
more_cons = count_less+count_equal>half_size
|
||||
|
||||
# For count_equal == 0
|
||||
lower_exist = torch.sum((x==self.lower).float())>0
|
||||
lower_cons = torch.sum((x>self.lower).float())==half_len
|
||||
lower_cons = torch.sum((x>self.lower).float())==half_size
|
||||
upper_exist = torch.sum((x==self.upper).float())>0
|
||||
upper_cons = torch.sum((x<self.upper).float())==half_len
|
||||
bound = count_less== half_len
|
||||
upper_cons = torch.sum((x<self.upper).float())==half_size
|
||||
bound = count_less== half_size
|
||||
# 0.02 since 2*0.01
|
||||
bound_avg = (torch.abs(self.lower+self.upper-2*self.result)<=torch.abs(2*self.error*self.result))
|
||||
|
||||
@@ -85,17 +104,20 @@ class Median(Operation):
|
||||
median_out_cons = torch.logical_and(torch.logical_and(bound, bound_avg), torch.logical_and(torch.logical_and(lower_cons, upper_cons), torch.logical_and(lower_exist, upper_exist)))
|
||||
return torch.where(count_equal==0, median_out_cons, median_in_cons)
|
||||
|
||||
|
||||
class GeometricMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'GeometricMean':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.exp(torch.mean(torch.log(x_1d)))
|
||||
return cls(result, error)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x = torch.where(x==MagicNumber, 1.0, x)
|
||||
return torch.abs((torch.log(self.result)*size)-torch.sum(torch.log(x)))<=size*torch.log(torch.tensor(1+self.error))
|
||||
|
||||
|
||||
@@ -103,13 +125,16 @@ class HarmonicMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'HarmonicMean':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
|
||||
return cls(result, error)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
# just make it really big so that 1/x goes to zero for element that gets filtered out
|
||||
x = torch.where(x==MagicNumber, x*x, x)
|
||||
return torch.abs((self.result*torch.sum(torch.div(1.0, x))) - size)<=torch.abs(self.error*size)
|
||||
|
||||
|
||||
@@ -131,15 +156,16 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
|
||||
max_sum_freq = sum_freq
|
||||
return mode
|
||||
|
||||
# TODO: Add mode_within, different from traditional mode
|
||||
|
||||
# TODO: Add class Mode_within , different from traditional mode
|
||||
# class Mode_(Operation):
|
||||
# @classmethod
|
||||
# def create(cls, x: list[torch.Tensor], error: float) -> 'Mode':
|
||||
# x_1d = to_1d(x[0])
|
||||
# # Mode has no result_error, just num_error which is the
|
||||
# # deviation that two numbers are considered the same. (Make sense because
|
||||
# # Mode has no result_error, just num_error which is the
|
||||
# # deviation that two numbers are considered the same. (Make sense because
|
||||
# # if some dataset has all different data, mode will be trivial if this is not the case)
|
||||
# # This value doesn't depend on any scale, but on the dataset itself, and the intention
|
||||
# # This value doesn't depend on any scale, but on the dataset itself, and the intention
|
||||
# # the evaluator. For example 0.01 means that data is counted as the same within 1% value range.
|
||||
|
||||
# # If wanting the strict definition of Mode, can just put this error to be 0
|
||||
@@ -156,13 +182,13 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
|
||||
# for ele in x[0]
|
||||
# ], dtype = torch.float32)
|
||||
# return torch.sum(_result) == size
|
||||
|
||||
|
||||
|
||||
class Mode(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Mode':
|
||||
x_1d = to_1d(x[0])
|
||||
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
# Here is traditional definition of Mode, can just put this num_error to be 0
|
||||
result = torch.tensor(mode_within(x_1d, 0))
|
||||
return cls(result, error)
|
||||
@@ -170,18 +196,18 @@ class Mode(Operation):
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
count_equal = torch.sum((torch.abs(x-self.result)<=torch.abs(self.error*self.result)).float())
|
||||
_result = torch.tensor([
|
||||
torch.sum((torch.abs(x-ele[0])<=torch.abs(self.error*ele[0])).float())<= count_equal
|
||||
for ele in x[0]
|
||||
], dtype = torch.float32)
|
||||
return torch.sum(_result) == size
|
||||
min_x = torch.min(x)
|
||||
old_size = x.size()[1]
|
||||
x = torch.where(x==MagicNumber, min_x-1, x)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
result = torch.tensor([torch.logical_or(torch.sum((x==ele[0]).float())<=count_equal, min_x-1 ==ele[0]) for ele in x[0]])
|
||||
return torch.sum(result) == old_size
|
||||
|
||||
|
||||
class PStdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 0))
|
||||
super().__init__(result, error)
|
||||
@@ -192,16 +218,19 @@ class PStdev(Operation):
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
class PVariance(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 0)
|
||||
super().__init__(result, error)
|
||||
@@ -212,16 +241,20 @@ class PVariance(Operation):
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Stdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 1))
|
||||
super().__init__(result, error)
|
||||
@@ -232,16 +265,19 @@ class Stdev(Operation):
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
class Variance(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.var(x_1d, correction = 1)
|
||||
super().__init__(result, error)
|
||||
@@ -252,17 +288,23 @@ class Variance(Operation):
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
|
||||
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class Covariance(Operation):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
|
||||
@@ -278,34 +320,37 @@ class Covariance(Operation):
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x, y = args[0], args[1]
|
||||
size_x = x.size()[1]
|
||||
size_y = y.size()[1]
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size_x*(self.x_mean))<=torch.abs(self.error*size_x*self.x_mean)
|
||||
y_mean_cons = torch.abs(torch.sum(y)-size_y*(self.y_mean))<=torch.abs(self.error*size_y*self.y_mean)
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
|
||||
size_x = torch.sum((x!=MagicNumber).float())
|
||||
size_y = torch.sum((y!=MagicNumber).float())
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
|
||||
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
|
||||
# only x_fil_mean is enough, no need for y_fil_mean since it will multiply 0 anyway
|
||||
return torch.logical_and(
|
||||
torch.logical_and(x_mean_cons,y_mean_cons),
|
||||
torch.abs(torch.sum((x-self.x_mean)*(y-self.y_mean))-(size_x-1)*self.result)<self.error*(size_x-1)*self.result
|
||||
torch.logical_and(size_x==size_y,torch.logical_and(x_mean_cons,y_mean_cons)),
|
||||
torch.abs(torch.sum((x_fil_mean-self.x_mean)*(y-self.y_mean))-(size_x-1)*self.result)<self.error*self.result*(size_x-1)
|
||||
)
|
||||
|
||||
|
||||
def stdev(x: torch.Tensor, x_std: torch.Tensor, x_mean: torch.Tensor, error: float) -> torch.Tensor:
|
||||
size_x = x.size()[1]
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size_x*(x_mean))<=torch.abs(error*size_x*x_mean)
|
||||
return (torch.logical_and(torch.abs(torch.sum((x-x_mean)*(x-x_mean))-x_std*x_std*(size_x-1))<=torch.abs(2*error*x_std*x_std*(size_x-1)),x_mean_cons),x_std)
|
||||
|
||||
|
||||
def covariance(x: torch.Tensor, y: torch.Tensor, cov: torch.Tensor, x_mean: torch.Tensor, y_mean: torch.Tensor, error: float) -> torch.Tensor:
|
||||
size_x = x.size()[1]
|
||||
size_y = y.size()[1]
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size_x*(x_mean))<=torch.abs(error*size_x*(x_mean))
|
||||
y_mean_cons = torch.abs(torch.sum(y)-size_y*(y_mean))<=torch.abs(error*size_y*(y_mean))
|
||||
return (torch.logical_and(torch.logical_and(x_mean_cons,y_mean_cons), torch.abs(torch.sum((x-x_mean)*(y-y_mean))-(size_x-1)*(cov))<error*(size_x-1)*(cov)), cov)
|
||||
# refer other constraints to correlation function, not put here since will be repetitive
|
||||
def stdev_for_corr(x_fil_mean:torch.Tensor,size_x:torch.Tensor, x_std: torch.Tensor, x_mean: torch.Tensor, error: float) -> torch.Tensor:
|
||||
return (
|
||||
torch.abs(torch.sum((x_fil_mean-x_mean)*(x_fil_mean-x_mean))-x_std*x_std*(size_x - 1))<=torch.abs(2*error*x_std*x_std*(size_x - 1))
|
||||
, x_std)
|
||||
# refer other constraints to correlation function, not put here since will be repetitive
|
||||
def covariance_for_corr(x_fil_mean: torch.Tensor,y_fil_mean: torch.Tensor,size_x:torch.Tensor, size_y:torch.Tensor, cov: torch.Tensor, x_mean: torch.Tensor, y_mean: torch.Tensor, error: float) -> torch.Tensor:
|
||||
return (
|
||||
torch.abs(torch.sum((x_fil_mean-x_mean)*(y_fil_mean-y_mean))-(size_x-1)*cov)<error*cov*(size_x-1)
|
||||
, cov)
|
||||
|
||||
|
||||
class Correlation(Operation):
|
||||
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
y_1d = to_1d(y)
|
||||
y_1d = y_1d[y_1d!=MagicNumber]
|
||||
x_1d_list = x_1d.tolist()
|
||||
y_1d_list = y_1d.tolist()
|
||||
self.x_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
@@ -323,11 +368,21 @@ class Correlation(Operation):
|
||||
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x, y = args[0], args[1]
|
||||
bool1, cov = covariance(x, y, self.cov, self.x_mean, self.y_mean, self.error)
|
||||
bool2, x_std = stdev(x, self.x_std, self.x_mean, self.error)
|
||||
bool3, y_std = stdev(y, self.y_std, self.y_mean, self.error)
|
||||
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
|
||||
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
|
||||
size_x = torch.sum((x!=MagicNumber).float())
|
||||
size_y = torch.sum((y!=MagicNumber).float())
|
||||
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
|
||||
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
|
||||
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
|
||||
y_fil_mean = torch.where(y==MagicNumber, self.y_mean, y)
|
||||
|
||||
miscel_cons = torch.logical_and(size_x==size_y, torch.logical_and(x_mean_cons, y_mean_cons))
|
||||
bool1, cov = covariance_for_corr(x_fil_mean,y_fil_mean,size_x, size_y, self.cov, self.x_mean, self.y_mean, self.error)
|
||||
bool2, x_std = stdev_for_corr( x_fil_mean, size_x, self.x_std, self.x_mean, self.error)
|
||||
bool3, y_std = stdev_for_corr( y_fil_mean, size_y, self.y_std, self.y_mean, self.error)
|
||||
bool4 = torch.abs(cov - self.result*x_std*y_std)<=self.error*cov
|
||||
return torch.logical_and(torch.logical_and(bool1, bool2),torch.logical_and(bool3, bool4))
|
||||
return torch.logical_and(torch.logical_and(torch.logical_and(bool1, bool2),torch.logical_and(bool3, bool4)), miscel_cons)
|
||||
|
||||
|
||||
def stacked_x(args: list[float]):
|
||||
@@ -336,12 +391,19 @@ def stacked_x(args: list[float]):
|
||||
|
||||
class Regression(Operation):
|
||||
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float):
|
||||
x_1ds = [to_1d(i).tolist() for i in xs]
|
||||
y_1d = to_1d(y).tolist()
|
||||
x_1ds = [to_1d(i) for i in xs]
|
||||
fil_x_1ds=[]
|
||||
for x_1 in x_1ds:
|
||||
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
|
||||
x_1ds = fil_x_1ds
|
||||
|
||||
y_1d = to_1d(y)
|
||||
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
|
||||
|
||||
x_one = stacked_x(x_1ds)
|
||||
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
|
||||
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
|
||||
print('result: ', result)
|
||||
super().__init__(result, error)
|
||||
|
||||
@classmethod
|
||||
@@ -353,6 +415,9 @@ class Regression(Operation):
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# infer y from the last parameter
|
||||
y = args[-1]
|
||||
y = torch.where(y==MagicNumber, torch.tensor(0.0), y)
|
||||
x_one = torch.cat((*args[:-1], torch.ones_like(args[0])), dim=2)
|
||||
x_one = torch.where((x_one[:,:,0] ==MagicNumber).unsqueeze(-1), torch.tensor([0.0]*x_one.size()[2]), x_one)
|
||||
x_t = torch.transpose(x_one, 1, 2)
|
||||
return torch.sum(torch.abs(x_t @ x_one @ self.result - x_t @ y)) <= self.error * torch.sum(torch.abs(x_t @ y))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user