Merge pull request #23 from ZKStats/feat/where-op

Feat/where op
This commit is contained in:
JernKunpittaya
2024-03-08 18:42:11 +07:00
committed by GitHub
53 changed files with 3742 additions and 3042 deletions

View File

@@ -65,26 +65,26 @@ def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
Aside from the ZKStats operations, you can also use PyTorch functions like (`torch.abs`, `torch.max`, ...etc).
**Caveats**: Not all PyTorch functions are supported. For example, filtering data from a list by `X[X > 0]` is not supported because the zk circuit needs to be of a predetermined size, hence we cannot arbitrarily reshape our X into a new shape based on the filter condition inside the circuit.
TODO: We should have a list for all supported PyTorch functions.
**Caveats**: Not all PyTorch functions are supported. For example, filtering data from a list by `X[X > 0]` is not supported because the zk circuit needs to be of a predetermined size, hence we cannot arbitrarily reshape our X into a new shape based on the filter condition inside the circuit. To filter data based on condition, we can use s.where as follows.
#### Data Filtering
Since we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we need to filter data while still preserving the shape. We use condition + `torch.where` instead.
Although we cannot filter data into any arbitrary shape using just condition + index (e.g. `X[X > 0]`), we implemented State.where operation that allows users to filter data by their own choice of condition as follows.
```python
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
# Compute the mean of the absolute values
x = data[0]
condition = x > 20
# Filter out data that is greater than 20. For the data that is greater than 20, we will use 0.0
fil_X = torch.where(condition=condition, input=x, other=0.0)
return s.mean(fil_X)
# Here condition can be chained as shown below, and can have many variables if we have more than just x: e.g. filter = torch.logical_and(x>20, y<2) in case of regression for example.
filter = torch.logical_and(x > 20, x<50)
# call our where function
filtered_x = s.where(filter, x)
# Then, can use the stats operation as usual
return s.mean(filtered_x)
```
**Caveats**: Currently, this 'where' operation still doesn't work correctly, since we cannot just plug fil_X into our current s.mean() due to incompatible shape of fil_X and X in reality, we will update the compatible implementation of how to do data filtering soon. Keep posted!
### Proof Generation and Verification
The flow between data providers and users is as follows:
@@ -204,6 +204,7 @@ See our jupyter notebook for [examples](./examples/).
## Benchmarks
See our jupyter notebook for [benchmarks](./benchmark/).
TODO: clean benchmark
## Note

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,6 @@
{
"x1": [
7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 10.0, 6.0, 6.0, 9.1, 1.7, 9.2,
10.0, 7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 6.0, 6.0, 9.1, 1.7, 9.2,
0.2, 7.8, 3.7, 7.0, 2.5, 2.8, 5.9, 7.3, 2.9, 2.9, 3.5, 1.0, 9.7, 4.8, 0.9,
7.1, 3.6, 8.2, 3.0, 7.6, 4.2, 5.2, 8.1, 6.3, 9.3, 8.8, 8.2, 6.7, 4.9, 5.4,
9.8, 5.9, 7.1, 3.9, 9.3
@@ -12,7 +12,7 @@
1.5, 2.1, 0.4, 4.3, 0.2
],
"y": [
18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 20.8, 12.5, 21.5, 32.5,
20.8, 18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 12.5, 21.5, 32.5,
18.6, 23.9, 7.0, 16.9, 22.9, 31.0, 15.0, 8.5, 8.7, 28.9, 19.7, 12.5, 17.4,
7.2, 25.5, 21.4, 15.7, 15.5, 8.2, 28.2, 19.5, 25.5, 12.5, 20.3, 21.7, 22.1,
19.6, 32.2, 22.4, 20.6, 19.7, 20.8, 21.1, 21.8, 17.7, 21.1, 19.4

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,308 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: ezkl==7.0.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 1)) (7.0.0)\n",
"Requirement already satisfied: torch in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 2)) (2.2.0)\n",
"Requirement already satisfied: requests in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 3)) (2.31.0)\n",
"Requirement already satisfied: scipy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 4)) (1.12.0)\n",
"Requirement already satisfied: numpy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 5)) (1.26.3)\n",
"Requirement already satisfied: matplotlib in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 6)) (3.8.2)\n",
"Requirement already satisfied: statistics in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 7)) (1.0.3.5)\n",
"Requirement already satisfied: onnx in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 8)) (1.15.0)\n",
"Requirement already satisfied: filelock in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.13.1)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (4.9.0)\n",
"Requirement already satisfied: sympy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (1.12)\n",
"Requirement already satisfied: networkx in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.1.3)\n",
"Requirement already satisfied: fsspec in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (2023.12.2)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (2.2.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (2024.2.2)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (1.2.0)\n",
"Requirement already satisfied: cycler>=0.10 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (4.47.2)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (1.4.5)\n",
"Requirement already satisfied: packaging>=20.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (23.2)\n",
"Requirement already satisfied: pillow>=8 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (10.2.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (3.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (2.8.2)\n",
"Requirement already satisfied: docutils>=0.3 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from statistics->-r ../../requirements.txt (line 7)) (0.20.1)\n",
"Requirement already satisfied: protobuf>=3.20.2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from onnx->-r ../../requirements.txt (line 8)) (4.25.2)\n",
"Requirement already satisfied: six>=1.5 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib->-r ../../requirements.txt (line 6)) (1.16.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from jinja2->torch->-r ../../requirements.txt (line 2)) (2.1.4)\n",
"Requirement already satisfied: mpmath>=0.19 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from sympy->torch->-r ../../requirements.txt (line 2)) (1.3.0)\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"pip install -r ../../requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"import torch\n",
"from torch import nn\n",
"import json\n",
"import os\n",
"import time\n",
"import scipy\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import statistics\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../zkstats/core.py"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# init path\n",
"os.makedirs(os.path.dirname('shared/'), exist_ok=True)\n",
"os.makedirs(os.path.dirname('prover/'), exist_ok=True)\n",
"verifier_model_path = os.path.join('shared/verifier.onnx')\n",
"prover_model_path = os.path.join('prover/prover.onnx')\n",
"verifier_compiled_model_path = os.path.join('shared/verifier.compiled')\n",
"prover_compiled_model_path = os.path.join('prover/prover.compiled')\n",
"pk_path = os.path.join('shared/test.pk')\n",
"vk_path = os.path.join('shared/test.vk')\n",
"proof_path = os.path.join('shared/test.pf')\n",
"settings_path = os.path.join('shared/settings.json')\n",
"srs_path = os.path.join('shared/kzg.srs')\n",
"witness_path = os.path.join('prover/witness.json')\n",
"# this is private to prover since it contains actual data\n",
"sel_data_path = os.path.join('prover/sel_data.json')\n",
"# this is just dummy random value\n",
"sel_dummy_data_path = os.path.join('shared/sel_dummy_data.json')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"======================= ZK-STATS FLOW ======================="
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"data_path = os.path.join('data.json')\n",
"dummy_data_path = os.path.join('shared/dummy_data.json')\n",
"\n",
"create_dummy(data_path, dummy_data_path)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"scales = [2]\n",
"selected_columns = ['col_name']\n",
"commitment_maps = get_data_commitment_maps(data_path, scales)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/mhchia/projects/work/pse/zk-stats-lib/zkstats/computation.py:166: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" is_precise_aggregated = torch.tensor(1.0)\n",
"/Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages/torch/onnx/symbolic_opset9.py:2174: FutureWarning: 'torch.onnx.symbolic_opset9._cast_Bool' is deprecated in version 2.0 and will be removed in the future. Please Avoid using this function and create a Cast node instead.\n",
" return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))\n"
]
}
],
"source": [
"# Verifier/ data consumer side: send desired calculation\n",
"from zkstats.computation import computation_to_model, State\n",
"\n",
"\n",
"def computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:\n",
" x = data[0]\n",
" # FIXME: should be replaced by `s.where` when it's available. Now the result may be incorrect\n",
" filter = (x < 50)\n",
" min_x = torch.min(x)\n",
" filtered_x = torch.where(filter, x, min_x - 1)\n",
" return s.median(filtered_x)\n",
"\n",
"error = 0.01\n",
"_, verifier_model = computation_to_model(computation, error)\n",
"\n",
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Theory_output: tensor(40., dtype=torch.float64)\n",
"==== Generate & Calibrate Setting ====\n",
"scale: [2]\n",
"setting: {\"run_args\":{\"tolerance\":{\"val\":0.0,\"scale\":1.0},\"input_scale\":2,\"param_scale\":2,\"scale_rebase_multiplier\":10,\"lookup_range\":[-582,1208],\"logrows\":14,\"num_inner_cols\":2,\"variables\":[[\"batch_size\",1]],\"input_visibility\":{\"Hashed\":{\"hash_is_public\":true,\"outlets\":[]}},\"output_visibility\":\"Public\",\"param_visibility\":\"Private\"},\"num_rows\":14432,\"total_assignments\":15928,\"total_const_size\":2126,\"model_instance_shapes\":[[1],[1]],\"model_output_scales\":[0,2],\"model_input_scales\":[2],\"module_sizes\":{\"kzg\":[],\"poseidon\":[14432,[1]],\"elgamal\":[0,[0]]},\"required_lookups\":[\"Abs\",{\"Div\":{\"denom\":2.0}},\"ReLU\",{\"Floor\":{\"scale\":4.0}},{\"GreaterThan\":{\"a\":0.0}},\"KroneckerDelta\"],\"check_mode\":\"UNSAFE\",\"version\":\"7.0.0\",\"num_blinding_factors\":null}\n"
]
}
],
"source": [
"# Prover/ data owner side\n",
"_, prover_model = computation_to_model(computation, error)\n",
"\n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"spawning module 0\n",
"spawning module 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"==== setting up ezkl ====\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"spawning module 0\n",
"spawning module 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time setup: 4.408194065093994 seconds\n",
"=======================================\n",
"Theory output: tensor(40., dtype=torch.float64)\n",
"==== Generating Witness ====\n",
"witness boolean: 1.0\n",
"witness result 1 : 40.0\n",
"==== Generating Proof ====\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"spawning module 0\n"
]
}
],
"source": [
"# Here verifier & prover can concurrently call setup since all params are public to get pk.\n",
"# Here write as verifier function to emphasize that verifier must calculate its own vk to be sure\n",
"setup(verifier_model_path, verifier_compiled_model_path, settings_path,vk_path, pk_path )\n",
"\n",
"print(\"=======================================\")\n",
"# Prover generates proof\n",
"print(\"Theory output: \", theory_output)\n",
"prover_gen_proof(prover_model_path, sel_data_path, witness_path, prover_compiled_model_path, settings_path, proof_path, pk_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"num_inputs: 1\n",
"prf instances: [[[1780239215148830498, 13236513277824664467, 10913529727158264423, 131860697733488968], [12436184717236109307, 3962172157175319849, 7381016538464732718, 1011752739694698287], [12341676197686541490, 2627393525778350065, 16625494184434727973, 1478518078215075360]]]\n",
"proof boolean: 1.0\n",
"proof result 1 : 40.0\n",
"verified\n"
]
}
],
"source": [
"# Verifier verifies\n",
"res = verifier_verify(proof_path, settings_path, vk_path, selected_columns, commitment_maps)\n",
"print(\"Verifier gets result:\", res)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"x": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49
],
"y": [
2.0, 5.2, 47.4, 23.6, 24.8, 27.0, 47.2, 50.4, 58.6, 57.8, 60.0, 27.2, 40.4,
63.6, 28.8, 19.0, 65.2, 50.4, 63.6, 24.8, 35.0, 41.2, 54.4, 61.6, 57.8,
78.0, 63.2, 55.4, 78.6, 72.8, 51.0, 62.2, 42.4, 47.6, 83.8, 62.0, 47.26,
90.4, 80.6, 87.8, 82.0, 50.2, 80.4, 86.6, 80.8, 66.0, 95.2, 58.4, 74.6, 95.8
]
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,23 @@
{
"x": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
97, 98, 99
],
"y": [
124.09, -37.12, 80.21, 62.13, 0.2, 45.79, 49.2, 105.68, 159.44, 135.97,
141.56, 64.51, 9.56, 116.33, 8.94, 70.32, 79.54, 33.52, 123.45, 79.17,
185.38, 113.5, 32.74, 204.23, 130.27, 76.63, 106.42, 191.19, 202.93, 212.26,
146.64, 161.91, 203.84, 160.09, 124.31, 200.52, 220.17, 117.14, 270.71,
232.62, 152.37, 137.63, 145.5, 290.38, 210.03, 305.25, 161.26, 213.92,
126.23, 166.79, 174.66, 252.53, 229.94, 307.21, 190.17, 250.22, 215.2,
196.0, 195.61, 250.23, 318.48, 235.65, 178.28, 200.54, 293.78, 243.95,
319.72, 255.92, 313.52, 376.36, 304.54, 327.65, 317.21, 413.62, 400.09,
347.71, 333.18, 302.64, 308.08, 430.22, 268.85, 367.58, 402.2, 274.01,
460.27, 442.84, 280.65, 448.93, 345.64, 384.94, 438.06, 457.26, 337.6,
303.74, 345.57, 494.38, 450.31, 383.49, 353.11, 392.61
]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,6 @@
{
"col_name": [
33.0, 15.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
33.0, 75.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
41.0, 47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0,
51.0, 48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0,
51.0, 56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0,

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,8 @@
{
"col_name": [
46.2, 40.4, 44.8, 48.1, 51.2, 91.9, 38.2, 36.3, 22.2, 11.5, 17.9, 20.2,
99.9, 75.2, 29.8, 19.4, 46.1, 94.8, 6.6, 94.5, 99.7, 1.6, 4.0, 86.7, 28.7,
63.0, 66.7, 2.5, 41.4, 35.6, 45.0, 13.7, 9.6, 16.6, 9.8, 20.3, 25.9, 71.9,
27.5, 30.9, 62.9, 18.6, 45.7, 2.4, 91.4, 16.2, 61.5, 41.4, 77.1, 53.2
]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,29 @@
{
"col_name": [
33.0, 75.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
41.0, 47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0,
51.0, 48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0,
51.0, 56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0,
62.0, 39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0,
52.0, 45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0,
45.0, 45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0,
34.0, 66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0,
30.0, 33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0,
62.0, 38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0,
52.0, 41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0,
52.0, 58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0,
64.0, 46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0,
69.0, 27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0,
55.0, 58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0,
45.0, 40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0,
39.0, 58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0,
34.0, 50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0,
22.0, 50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0,
44.0, 55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0,
43.0, 50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0,
41.0, 41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0,
64.0, 76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0,
50.0, 31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0,
40.0, 25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
]
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,29 @@
{
"col_name": [
33.0, 75.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
41.0, 47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0,
51.0, 48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0,
51.0, 56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0,
62.0, 39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0,
52.0, 45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0,
45.0, 45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0,
34.0, 66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0,
30.0, 33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0,
62.0, 38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0,
52.0, 41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0,
52.0, 58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0,
64.0, 46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0,
69.0, 27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0,
55.0, 58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0,
45.0, 40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0,
39.0, 58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0,
34.0, 50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0,
22.0, 50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0,
44.0, 55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0,
43.0, 50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0,
41.0, 41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0,
64.0, 76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0,
50.0, 31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0,
40.0, 25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
]
}

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,6 @@
{
"x1": [
7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 10.0, 6.0, 6.0, 9.1, 1.7, 9.2,
10.0, 7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 6.0, 6.0, 9.1, 1.7, 9.2,
0.2, 7.8, 3.7, 7.0, 2.5, 2.8, 5.9, 7.3, 2.9, 2.9, 3.5, 1.0, 9.7, 4.8, 0.9,
7.1, 3.6, 8.2, 3.0, 7.6, 4.2, 5.2, 8.1, 6.3, 9.3, 8.8, 8.2, 6.7, 4.9, 5.4,
9.8, 5.9, 7.1, 3.9, 9.3
@@ -12,7 +12,7 @@
1.5, 2.1, 0.4, 4.3, 0.2
],
"y": [
18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 20.8, 12.5, 21.5, 32.5,
20.8, 18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 12.5, 21.5, 32.5,
18.6, 23.9, 7.0, 16.9, 22.9, 31.0, 15.0, 8.5, 8.7, 28.9, 19.7, 12.5, 17.4,
7.2, 25.5, 21.4, 15.7, 15.5, 8.2, 28.2, 19.5, 25.5, 12.5, 20.3, 21.7, 22.1,
19.6, 32.2, 22.4, 20.6, 19.7, 20.8, 21.1, 21.8, 17.7, 21.1, 19.4

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,29 @@
{
"col_name": [
33.0, 75.0, 38.0, 38.0, 70.0, 44.0, 34.0, 67.0, 54.0, 78.0, 80.0, 21.0,
41.0, 47.0, 57.0, 50.0, 65.0, 43.0, 51.0, 54.0, 62.0, 68.0, 45.0, 39.0,
51.0, 48.0, 48.0, 42.0, 37.0, 75.0, 40.0, 48.0, 65.0, 26.0, 42.0, 53.0,
51.0, 56.0, 74.0, 54.0, 55.0, 15.0, 58.0, 46.0, 64.0, 59.0, 39.0, 36.0,
62.0, 39.0, 72.0, 32.0, 82.0, 76.0, 88.0, 51.0, 44.0, 35.0, 18.0, 53.0,
52.0, 45.0, 64.0, 31.0, 32.0, 61.0, 66.0, 59.0, 50.0, 69.0, 44.0, 22.0,
45.0, 45.0, 46.0, 42.0, 83.0, 53.0, 53.0, 69.0, 53.0, 33.0, 48.0, 49.0,
34.0, 66.0, 29.0, 66.0, 52.0, 45.0, 83.0, 54.0, 53.0, 31.0, 71.0, 60.0,
30.0, 33.0, 43.0, 26.0, 55.0, 56.0, 56.0, 54.0, 57.0, 68.0, 58.0, 61.0,
62.0, 38.0, 52.0, 74.0, 76.0, 37.0, 42.0, 54.0, 38.0, 38.0, 30.0, 31.0,
52.0, 41.0, 69.0, 40.0, 46.0, 69.0, 29.0, 28.0, 66.0, 41.0, 40.0, 36.0,
52.0, 58.0, 46.0, 42.0, 85.0, 45.0, 70.0, 49.0, 48.0, 34.0, 18.0, 39.0,
64.0, 46.0, 54.0, 42.0, 45.0, 64.0, 46.0, 68.0, 46.0, 54.0, 47.0, 41.0,
69.0, 27.0, 61.0, 37.0, 25.0, 66.0, 30.0, 59.0, 67.0, 34.0, 36.0, 40.0,
55.0, 58.0, 74.0, 55.0, 66.0, 55.0, 72.0, 40.0, 27.0, 38.0, 74.0, 52.0,
45.0, 40.0, 35.0, 46.0, 64.0, 41.0, 50.0, 45.0, 42.0, 22.0, 25.0, 55.0,
39.0, 58.0, 56.0, 62.0, 55.0, 65.0, 57.0, 34.0, 44.0, 47.0, 70.0, 60.0,
34.0, 50.0, 43.0, 60.0, 66.0, 46.0, 58.0, 76.0, 40.0, 49.0, 64.0, 45.0,
22.0, 50.0, 34.0, 44.0, 76.0, 63.0, 59.0, 36.0, 59.0, 47.0, 70.0, 64.0,
44.0, 55.0, 50.0, 48.0, 66.0, 40.0, 76.0, 48.0, 75.0, 73.0, 55.0, 41.0,
43.0, 50.0, 34.0, 57.0, 50.0, 53.0, 28.0, 35.0, 52.0, 52.0, 49.0, 67.0,
41.0, 41.0, 61.0, 24.0, 43.0, 51.0, 40.0, 52.0, 44.0, 25.0, 81.0, 54.0,
64.0, 76.0, 37.0, 45.0, 48.0, 46.0, 43.0, 67.0, 28.0, 35.0, 25.0, 71.0,
50.0, 31.0, 43.0, 54.0, 40.0, 51.0, 40.0, 49.0, 34.0, 26.0, 46.0, 62.0,
40.0, 25.0, 61.0, 58.0, 56.0, 39.0, 46.0, 53.0, 21.0, 57.0, 42.0, 80.0
]
}

File diff suppressed because one or more lines are too long

View File

@@ -24,4 +24,4 @@ def column_2():
@pytest.fixture
def scales():
return [3]
return [6]

View File

@@ -10,6 +10,11 @@ from zkstats.computation import IModel
DEFAULT_POSSIBLE_SCALES = list(range(20))
# Error tolerance between circuit and python implementation
ERROR_CIRCUIT_DEFAULT = 0.01
ERROR_CIRCUIT_STRICT = 0.0001
ERROR_CIRCUIT_RELAXED = 0.1
def data_to_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, list]:
column_names = [f"columns_{i}" for i in range(len(data))]

View File

@@ -1,3 +1,4 @@
from typing import Type, Callable
import statistics
import torch
@@ -17,9 +18,10 @@ from zkstats.ops import (
Covariance,
Correlation,
Regression,
Operation
)
from .helpers import assert_result, compute
from .helpers import assert_result, compute, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
def nested_computation(state: State, args: list[torch.Tensor]):
@@ -129,3 +131,67 @@ def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Te
assert isinstance(op_11, Mean)
out_11 = statistics.mean([out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8, out_9, out_10.slope, out_10.intercept])
assert_result(torch.tensor(out_11), op_11.result)
@pytest.mark.parametrize(
"op_type, expected_func, error",
[
(State.mean, statistics.mean, ERROR_CIRCUIT_DEFAULT),
(State.median, statistics.median, ERROR_CIRCUIT_DEFAULT),
(State.geometric_mean, statistics.geometric_mean, ERROR_CIRCUIT_DEFAULT),
# Be more tolerant for HarmonicMean
(State.harmonic_mean, statistics.harmonic_mean, ERROR_CIRCUIT_RELAXED),
# Be less tolerant for Mode
(State.mode, statistics.mode, ERROR_CIRCUIT_STRICT),
(State.pstdev, statistics.pstdev, ERROR_CIRCUIT_DEFAULT),
(State.pvariance, statistics.pvariance, ERROR_CIRCUIT_DEFAULT),
(State.stdev, statistics.stdev, ERROR_CIRCUIT_DEFAULT),
(State.variance, statistics.variance, ERROR_CIRCUIT_DEFAULT),
]
)
def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[[Operation, torch.Tensor], torch.Tensor], expected_func: Callable[[list[float]], float], scales):
column = column_0
def condition(_x: torch.Tensor):
return _x < 4
def where_and_op(state: State, args: list[torch.Tensor]):
x = args[0]
return op_type(state, state.where(condition(x), x))
state, model = computation_to_model(where_and_op, error)
compute(tmp_path, [column], model, scales)
res_op = state.ops[-1]
filtered = column[condition(column)]
expected_res = expected_func(filtered.tolist())
assert_result(res_op.result.data, expected_res)
@pytest.mark.parametrize(
"op_type, expected_func, error",
[
(State.covariance, statistics.covariance, ERROR_CIRCUIT_RELAXED),
(State.correlation, statistics.correlation, ERROR_CIRCUIT_RELAXED),
]
)
def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type: Callable[[Operation, torch.Tensor], torch.Tensor], expected_func: Callable[[list[float]], float], scales):
def condition_0(_x: torch.Tensor):
return _x > 4
def where_and_op(state: State, args: list[torch.Tensor]):
x = args[0]
y = args[1]
condition_x = condition_0(x)
filtered_x = state.where(condition_x, x)
filtered_y = state.where(condition_x, y)
return op_type(state, filtered_x, filtered_y)
state, model = computation_to_model(where_and_op, error)
compute(tmp_path, [column_0, column_1], model, scales)
res_op = state.ops[-1]
condition_x = condition_0(column_0)
filtered_x = column_0[condition_x]
filtered_y = column_1[condition_x]
expected_res = expected_func(filtered_x.tolist(), filtered_y.tolist())
assert_result(res_op.result.data, expected_res)

View File

@@ -7,13 +7,7 @@ import torch
from zkstats.ops import Mean, Median, GeometricMean, HarmonicMean, Mode, PStdev, PVariance, Stdev, Variance, Covariance, Correlation, Operation, Regression
from zkstats.computation import IModel, IsResultPrecise
from .helpers import compute, assert_result
# Error tolerance between circuit and python implementation
ERROR_CIRCUIT_DEFAULT = 0.01
ERROR_CIRCUIT_STRICT = 0.0001
ERROR_CIRCUIT_RELAXED = 0.1
from .helpers import compute, assert_result, ERROR_CIRCUIT_DEFAULT, ERROR_CIRCUIT_STRICT, ERROR_CIRCUIT_RELAXED
@pytest.mark.parametrize(

View File

@@ -18,6 +18,7 @@ from .ops import (
Covariance,
Correlation,
Regression,
Where,
IsResultPrecise,
)
@@ -128,8 +129,20 @@ class State:
Calculate the linear regression of x and y. The behavior should conform to
[statistics.linear_regression](https://docs.python.org/3/library/statistics.html#statistics.linear_regression) in Python standard library.
"""
# hence support only one x for now
return self._call_op([x, y], Regression)
# WHERE operation
def where(self, filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the where operation of x. The behavior should conform to `torch.where` in PyTorch.
:param filter: A boolean tensor serves as a filter
:param x: A tensor to be filtered
:return: filtered tensor
"""
return self._call_op([filter, x], Where)
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
if self.current_op_index is None:
op = op_type.create(x, self.error)
@@ -212,3 +225,4 @@ def computation_to_model(computation: TComputation, error: float = DEFAULT_ERROR
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return computation(state, x)
return state, Model

View File

@@ -45,7 +45,9 @@ def create_dummy(data_path: str, dummy_data_path: str) -> None:
dummy_data ={}
for col in data:
# not use same value for every column to prevent something weird, like singular matrix
dummy_data[col] = np.round(np.random.uniform(1,30,len(data[col])),1).tolist()
min_col = min(data[col])
max_col = max(data[col])
dummy_data[col] = np.round(np.random.uniform(min_col,max_col,len(data[col])),1).tolist()
json.dump(dummy_data, open(dummy_data_path, 'w'))
@@ -250,12 +252,10 @@ def verifier_verify(proof_path: str, settings_path: str, vk_path: str, selected_
# - is a tuple (is_in_error, result)
# - is_valid is True
# Sanity check
# is_in_error = ezkl.vecu64_to_float(outputs[0], output_scales[0])
is_in_error = ezkl.felt_to_float(outputs[0], output_scales[0])
assert is_in_error == 1.0, f"result is not within error"
result_arr = []
for index in range(len(outputs)-1):
# result_arr.append(ezkl.vecu64_to_float(outputs[index+1], output_scales[1]))
result_arr.append(ezkl.felt_to_float(outputs[index+1], output_scales[1]))
return result_arr

View File

@@ -6,6 +6,7 @@ import torch
# boolean: either 1.0 or 0.0
IsResultPrecise = torch.Tensor
MagicNumber = 9999999
class Operation(ABC):
@@ -22,15 +23,28 @@ class Operation(ABC):
...
class Where(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Where':
# here error is trivial, but here to conform to other functions
return cls(torch.where(x[0],x[1], MagicNumber ),error)
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
bool_array = torch.logical_or(x[1]==self.result, torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
# print('sellll: ', self.result)
return torch.sum(bool_array.float())==x[1].size()[1]
class Mean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Mean':
return cls(torch.mean(x[0]), error)
# support where statement, hopefully we can use 'nan' once onnx.isnan() is supported
return cls(torch.mean(x[0][x[0]!=MagicNumber]), error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
size = x.size()
return torch.abs(torch.sum(x)-size[1]*self.result)<=torch.abs(self.error*size[1]*self.result)
size = torch.sum((x!=MagicNumber).float())
x = torch.where(x==MagicNumber, 0.0, x)
return torch.abs(torch.sum(x)-size*self.result)<=torch.abs(self.error*self.result*size)
def to_1d(x: torch.Tensor) -> torch.Tensor:
@@ -51,6 +65,7 @@ class Median(Operation):
# we want in our context. However, we tend to have x as a `[1, len(x), 1]`. In this case,
# we need to flatten `x` to 1d array to get the correct `lower` and `upper`.
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
super().__init__(torch.tensor(np.median(x_1d)), error)
sorted_x = np.sort(x_1d)
len_x = len(x_1d)
@@ -63,21 +78,25 @@ class Median(Operation):
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
old_size = x.size()[1]
size = torch.sum((x!=MagicNumber).float())
min_x = torch.min(x)
x = torch.where(x==MagicNumber,min_x-1, x)
# since within 1%, we regard as same value
count_less = torch.sum((x < self.result).float())
count_less = torch.sum((x < self.result).float())-(old_size-size)
count_equal = torch.sum((x==self.result).float())
len = x.size()[1]
half_len = torch.floor(torch.div(len, 2))
less_cons = count_less<half_len+len%2
more_cons = count_less+count_equal>half_len
half_size = torch.floor(torch.div(size, 2))
less_cons = count_less<half_size+size%2
more_cons = count_less+count_equal>half_size
# For count_equal == 0
lower_exist = torch.sum((x==self.lower).float())>0
lower_cons = torch.sum((x>self.lower).float())==half_len
lower_cons = torch.sum((x>self.lower).float())==half_size
upper_exist = torch.sum((x==self.upper).float())>0
upper_cons = torch.sum((x<self.upper).float())==half_len
bound = count_less== half_len
upper_cons = torch.sum((x<self.upper).float())==half_size
bound = count_less== half_size
# 0.02 since 2*0.01
bound_avg = (torch.abs(self.lower+self.upper-2*self.result)<=torch.abs(2*self.error*self.result))
@@ -85,17 +104,20 @@ class Median(Operation):
median_out_cons = torch.logical_and(torch.logical_and(bound, bound_avg), torch.logical_and(torch.logical_and(lower_cons, upper_cons), torch.logical_and(lower_exist, upper_exist)))
return torch.where(count_equal==0, median_out_cons, median_in_cons)
class GeometricMean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'GeometricMean':
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
result = torch.exp(torch.mean(torch.log(x_1d)))
return cls(result, error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
# Assume x is [1, n, 1]
x = x[0]
size = x.size()[1]
size = torch.sum((x!=MagicNumber).float())
x = torch.where(x==MagicNumber, 1.0, x)
return torch.abs((torch.log(self.result)*size)-torch.sum(torch.log(x)))<=size*torch.log(torch.tensor(1+self.error))
@@ -103,13 +125,16 @@ class HarmonicMean(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'HarmonicMean':
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
result = torch.div(1.0,torch.mean(torch.div(1.0, x_1d)))
return cls(result, error)
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
# Assume x is [1, n, 1]
x = x[0]
size = x.size()[1]
size = torch.sum((x!=MagicNumber).float())
# just make it really big so that 1/x goes to zero for element that gets filtered out
x = torch.where(x==MagicNumber, x*x, x)
return torch.abs((self.result*torch.sum(torch.div(1.0, x))) - size)<=torch.abs(self.error*size)
@@ -131,15 +156,16 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
max_sum_freq = sum_freq
return mode
# TODO: Add mode_within, different from traditional mode
# TODO: Add class Mode_within , different from traditional mode
# class Mode_(Operation):
# @classmethod
# def create(cls, x: list[torch.Tensor], error: float) -> 'Mode':
# x_1d = to_1d(x[0])
# # Mode has no result_error, just num_error which is the
# # deviation that two numbers are considered the same. (Make sense because
# # Mode has no result_error, just num_error which is the
# # deviation that two numbers are considered the same. (Make sense because
# # if some dataset has all different data, mode will be trivial if this is not the case)
# # This value doesn't depend on any scale, but on the dataset itself, and the intention
# # This value doesn't depend on any scale, but on the dataset itself, and the intention
# # the evaluator. For example 0.01 means that data is counted as the same within 1% value range.
# # If wanting the strict definition of Mode, can just put this error to be 0
@@ -156,13 +182,13 @@ def mode_within(data_array: torch.Tensor, error: float) -> torch.Tensor:
# for ele in x[0]
# ], dtype = torch.float32)
# return torch.sum(_result) == size
class Mode(Operation):
@classmethod
def create(cls, x: list[torch.Tensor], error: float) -> 'Mode':
x_1d = to_1d(x[0])
x_1d = x_1d[x_1d!=MagicNumber]
# Here is traditional definition of Mode, can just put this num_error to be 0
result = torch.tensor(mode_within(x_1d, 0))
return cls(result, error)
@@ -170,18 +196,18 @@ class Mode(Operation):
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
# Assume x is [1, n, 1]
x = x[0]
size = x.size()[1]
count_equal = torch.sum((torch.abs(x-self.result)<=torch.abs(self.error*self.result)).float())
_result = torch.tensor([
torch.sum((torch.abs(x-ele[0])<=torch.abs(self.error*ele[0])).float())<= count_equal
for ele in x[0]
], dtype = torch.float32)
return torch.sum(_result) == size
min_x = torch.min(x)
old_size = x.size()[1]
x = torch.where(x==MagicNumber, min_x-1, x)
count_equal = torch.sum((x==self.result).float())
result = torch.tensor([torch.logical_or(torch.sum((x==ele[0]).float())<=count_equal, min_x-1 ==ele[0]) for ele in x[0]])
return torch.sum(result) == old_size
class PStdev(Operation):
def __init__(self, x: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.sqrt(torch.var(x_1d, correction = 0))
super().__init__(result, error)
@@ -192,16 +218,19 @@ class PStdev(Operation):
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
size = x.size()[1]
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
size = torch.sum((x!=MagicNumber).float())
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
return torch.logical_and(
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*size)<=torch.abs(2*self.error*self.result*self.result*size),x_mean_cons
)
class PVariance(Operation):
def __init__(self, x: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.var(x_1d, correction = 0)
super().__init__(result, error)
@@ -212,16 +241,20 @@ class PVariance(Operation):
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
size = x.size()[1]
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
size = torch.sum((x!=MagicNumber).float())
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
return torch.logical_and(
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*size)<=torch.abs(self.error*self.result*size), x_mean_cons
)
class Stdev(Operation):
def __init__(self, x: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.sqrt(torch.var(x_1d, correction = 1))
super().__init__(result, error)
@@ -232,16 +265,19 @@ class Stdev(Operation):
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
size = x.size()[1]
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
size = torch.sum((x!=MagicNumber).float())
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
return torch.logical_and(
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
)
class Variance(Operation):
def __init__(self, x: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
result = torch.var(x_1d, correction = 1)
super().__init__(result, error)
@@ -252,17 +288,23 @@ class Variance(Operation):
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
x = x[0]
size = x.size()[1]
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
size = torch.sum((x!=MagicNumber).float())
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size*(self.data_mean))<=torch.abs(self.error*self.data_mean*size)
x_fil_mean = torch.where(x==MagicNumber, self.data_mean, x)
return torch.logical_and(
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
torch.abs(torch.sum((x_fil_mean-self.data_mean)*(x_fil_mean-self.data_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
)
class Covariance(Operation):
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
y_1d = to_1d(y)
y_1d = y_1d[y_1d!=MagicNumber]
x_1d_list = x_1d.tolist()
y_1d_list = y_1d.tolist()
@@ -278,34 +320,37 @@ class Covariance(Operation):
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
x, y = args[0], args[1]
size_x = x.size()[1]
size_y = y.size()[1]
x_mean_cons = torch.abs(torch.sum(x)-size_x*(self.x_mean))<=torch.abs(self.error*size_x*self.x_mean)
y_mean_cons = torch.abs(torch.sum(y)-size_y*(self.y_mean))<=torch.abs(self.error*size_y*self.y_mean)
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
size_x = torch.sum((x!=MagicNumber).float())
size_y = torch.sum((y!=MagicNumber).float())
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
# only x_fil_mean is enough, no need for y_fil_mean since it will multiply 0 anyway
return torch.logical_and(
torch.logical_and(x_mean_cons,y_mean_cons),
torch.abs(torch.sum((x-self.x_mean)*(y-self.y_mean))-(size_x-1)*self.result)<self.error*(size_x-1)*self.result
torch.logical_and(size_x==size_y,torch.logical_and(x_mean_cons,y_mean_cons)),
torch.abs(torch.sum((x_fil_mean-self.x_mean)*(y-self.y_mean))-(size_x-1)*self.result)<self.error*self.result*(size_x-1)
)
def stdev(x: torch.Tensor, x_std: torch.Tensor, x_mean: torch.Tensor, error: float) -> torch.Tensor:
size_x = x.size()[1]
x_mean_cons = torch.abs(torch.sum(x)-size_x*(x_mean))<=torch.abs(error*size_x*x_mean)
return (torch.logical_and(torch.abs(torch.sum((x-x_mean)*(x-x_mean))-x_std*x_std*(size_x-1))<=torch.abs(2*error*x_std*x_std*(size_x-1)),x_mean_cons),x_std)
def covariance(x: torch.Tensor, y: torch.Tensor, cov: torch.Tensor, x_mean: torch.Tensor, y_mean: torch.Tensor, error: float) -> torch.Tensor:
size_x = x.size()[1]
size_y = y.size()[1]
x_mean_cons = torch.abs(torch.sum(x)-size_x*(x_mean))<=torch.abs(error*size_x*(x_mean))
y_mean_cons = torch.abs(torch.sum(y)-size_y*(y_mean))<=torch.abs(error*size_y*(y_mean))
return (torch.logical_and(torch.logical_and(x_mean_cons,y_mean_cons), torch.abs(torch.sum((x-x_mean)*(y-y_mean))-(size_x-1)*(cov))<error*(size_x-1)*(cov)), cov)
# refer other constraints to correlation function, not put here since will be repetitive
def stdev_for_corr(x_fil_mean:torch.Tensor,size_x:torch.Tensor, x_std: torch.Tensor, x_mean: torch.Tensor, error: float) -> torch.Tensor:
return (
torch.abs(torch.sum((x_fil_mean-x_mean)*(x_fil_mean-x_mean))-x_std*x_std*(size_x - 1))<=torch.abs(2*error*x_std*x_std*(size_x - 1))
, x_std)
# refer other constraints to correlation function, not put here since will be repetitive
def covariance_for_corr(x_fil_mean: torch.Tensor,y_fil_mean: torch.Tensor,size_x:torch.Tensor, size_y:torch.Tensor, cov: torch.Tensor, x_mean: torch.Tensor, y_mean: torch.Tensor, error: float) -> torch.Tensor:
return (
torch.abs(torch.sum((x_fil_mean-x_mean)*(y_fil_mean-y_mean))-(size_x-1)*cov)<error*cov*(size_x-1)
, cov)
class Correlation(Operation):
def __init__(self, x: torch.Tensor, y: torch.Tensor, error: float):
x_1d = to_1d(x)
x_1d = x_1d[x_1d!=MagicNumber]
y_1d = to_1d(y)
y_1d = y_1d[y_1d!=MagicNumber]
x_1d_list = x_1d.tolist()
y_1d_list = y_1d.tolist()
self.x_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
@@ -323,11 +368,21 @@ class Correlation(Operation):
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
x, y = args[0], args[1]
bool1, cov = covariance(x, y, self.cov, self.x_mean, self.y_mean, self.error)
bool2, x_std = stdev(x, self.x_std, self.x_mean, self.error)
bool3, y_std = stdev(y, self.y_std, self.y_mean, self.error)
x_fil_0 = torch.where(x==MagicNumber, 0.0, x)
y_fil_0 = torch.where(y==MagicNumber, 0.0, y)
size_x = torch.sum((x!=MagicNumber).float())
size_y = torch.sum((y!=MagicNumber).float())
x_mean_cons = torch.abs(torch.sum(x_fil_0)-size_x*(self.x_mean))<=torch.abs(self.error*self.x_mean*size_x)
y_mean_cons = torch.abs(torch.sum(y_fil_0)-size_y*(self.y_mean))<=torch.abs(self.error*self.y_mean*size_y)
x_fil_mean = torch.where(x==MagicNumber, self.x_mean, x)
y_fil_mean = torch.where(y==MagicNumber, self.y_mean, y)
miscel_cons = torch.logical_and(size_x==size_y, torch.logical_and(x_mean_cons, y_mean_cons))
bool1, cov = covariance_for_corr(x_fil_mean,y_fil_mean,size_x, size_y, self.cov, self.x_mean, self.y_mean, self.error)
bool2, x_std = stdev_for_corr( x_fil_mean, size_x, self.x_std, self.x_mean, self.error)
bool3, y_std = stdev_for_corr( y_fil_mean, size_y, self.y_std, self.y_mean, self.error)
bool4 = torch.abs(cov - self.result*x_std*y_std)<=self.error*cov
return torch.logical_and(torch.logical_and(bool1, bool2),torch.logical_and(bool3, bool4))
return torch.logical_and(torch.logical_and(torch.logical_and(bool1, bool2),torch.logical_and(bool3, bool4)), miscel_cons)
def stacked_x(args: list[float]):
@@ -336,12 +391,19 @@ def stacked_x(args: list[float]):
class Regression(Operation):
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float):
x_1ds = [to_1d(i).tolist() for i in xs]
y_1d = to_1d(y).tolist()
x_1ds = [to_1d(i) for i in xs]
fil_x_1ds=[]
for x_1 in x_1ds:
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
x_1ds = fil_x_1ds
y_1d = to_1d(y)
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
x_one = stacked_x(x_1ds)
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
print('result: ', result)
super().__init__(result, error)
@classmethod
@@ -353,6 +415,9 @@ class Regression(Operation):
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
# infer y from the last parameter
y = args[-1]
y = torch.where(y==MagicNumber, torch.tensor(0.0), y)
x_one = torch.cat((*args[:-1], torch.ones_like(args[0])), dim=2)
x_one = torch.where((x_one[:,:,0] ==MagicNumber).unsqueeze(-1), torch.tensor([0.0]*x_one.size()[2]), x_one)
x_t = torch.transpose(x_one, 1, 2)
return torch.sum(torch.abs(x_t @ x_one @ self.result - x_t @ y)) <= self.error * torch.sum(torch.abs(x_t @ y))