mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
where+geomean, median, mode, reg, stdv
This commit is contained in:
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"x1": [
|
||||
7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 10.0, 6.0, 6.0, 9.1, 1.7, 9.2,
|
||||
10.0, 7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 6.0, 6.0, 9.1, 1.7, 9.2,
|
||||
0.2, 7.8, 3.7, 7.0, 2.5, 2.8, 5.9, 7.3, 2.9, 2.9, 3.5, 1.0, 9.7, 4.8, 0.9,
|
||||
7.1, 3.6, 8.2, 3.0, 7.6, 4.2, 5.2, 8.1, 6.3, 9.3, 8.8, 8.2, 6.7, 4.9, 5.4,
|
||||
9.8, 5.9, 7.1, 3.9, 9.3
|
||||
@@ -12,7 +12,7 @@
|
||||
1.5, 2.1, 0.4, 4.3, 0.2
|
||||
],
|
||||
"y": [
|
||||
18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 20.8, 12.5, 21.5, 32.5,
|
||||
20.8, 18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 12.5, 21.5, 32.5,
|
||||
18.6, 23.9, 7.0, 16.9, 22.9, 31.0, 15.0, 8.5, 8.7, 28.9, 19.7, 12.5, 17.4,
|
||||
7.2, 25.5, 21.4, 15.7, 15.5, 8.2, 28.2, 19.5, 25.5, 12.5, 20.3, 21.7, 22.1,
|
||||
19.6, 32.2, 22.4, 20.6, 19.7, 20.8, 21.1, 21.8, 17.7, 21.1, 19.4
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,308 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: ezkl==7.0.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 1)) (7.0.0)\n",
|
||||
"Requirement already satisfied: torch in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 2)) (2.2.0)\n",
|
||||
"Requirement already satisfied: requests in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 3)) (2.31.0)\n",
|
||||
"Requirement already satisfied: scipy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 4)) (1.12.0)\n",
|
||||
"Requirement already satisfied: numpy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 5)) (1.26.3)\n",
|
||||
"Requirement already satisfied: matplotlib in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 6)) (3.8.2)\n",
|
||||
"Requirement already satisfied: statistics in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 7)) (1.0.3.5)\n",
|
||||
"Requirement already satisfied: onnx in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from -r ../../requirements.txt (line 8)) (1.15.0)\n",
|
||||
"Requirement already satisfied: filelock in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.13.1)\n",
|
||||
"Requirement already satisfied: typing-extensions>=4.8.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (4.9.0)\n",
|
||||
"Requirement already satisfied: sympy in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (1.12)\n",
|
||||
"Requirement already satisfied: networkx in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.2.1)\n",
|
||||
"Requirement already satisfied: jinja2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.1.3)\n",
|
||||
"Requirement already satisfied: fsspec in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from torch->-r ../../requirements.txt (line 2)) (2023.12.2)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (3.3.2)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (3.6)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (2.2.0)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from requests->-r ../../requirements.txt (line 3)) (2024.2.2)\n",
|
||||
"Requirement already satisfied: contourpy>=1.0.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (1.2.0)\n",
|
||||
"Requirement already satisfied: cycler>=0.10 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (0.12.1)\n",
|
||||
"Requirement already satisfied: fonttools>=4.22.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (4.47.2)\n",
|
||||
"Requirement already satisfied: kiwisolver>=1.3.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (1.4.5)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (23.2)\n",
|
||||
"Requirement already satisfied: pillow>=8 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (10.2.0)\n",
|
||||
"Requirement already satisfied: pyparsing>=2.3.1 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (3.1.1)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.7 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (2.8.2)\n",
|
||||
"Requirement already satisfied: docutils>=0.3 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from statistics->-r ../../requirements.txt (line 7)) (0.20.1)\n",
|
||||
"Requirement already satisfied: protobuf>=3.20.2 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from onnx->-r ../../requirements.txt (line 8)) (4.25.2)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib->-r ../../requirements.txt (line 6)) (1.16.0)\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from jinja2->torch->-r ../../requirements.txt (line 2)) (2.1.4)\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages (from sympy->torch->-r ../../requirements.txt (line 2)) (1.3.0)\n",
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pip install -r ../../requirements.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import scipy\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import statistics\n",
|
||||
"import math"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run -i ../../zkstats/core.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# init path\n",
|
||||
"os.makedirs(os.path.dirname('shared/'), exist_ok=True)\n",
|
||||
"os.makedirs(os.path.dirname('prover/'), exist_ok=True)\n",
|
||||
"verifier_model_path = os.path.join('shared/verifier.onnx')\n",
|
||||
"prover_model_path = os.path.join('prover/prover.onnx')\n",
|
||||
"verifier_compiled_model_path = os.path.join('shared/verifier.compiled')\n",
|
||||
"prover_compiled_model_path = os.path.join('prover/prover.compiled')\n",
|
||||
"pk_path = os.path.join('shared/test.pk')\n",
|
||||
"vk_path = os.path.join('shared/test.vk')\n",
|
||||
"proof_path = os.path.join('shared/test.pf')\n",
|
||||
"settings_path = os.path.join('shared/settings.json')\n",
|
||||
"srs_path = os.path.join('shared/kzg.srs')\n",
|
||||
"witness_path = os.path.join('prover/witness.json')\n",
|
||||
"# this is private to prover since it contains actual data\n",
|
||||
"sel_data_path = os.path.join('prover/sel_data.json')\n",
|
||||
"# this is just dummy random value\n",
|
||||
"sel_dummy_data_path = os.path.join('shared/sel_dummy_data.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"======================= ZK-STATS FLOW ======================="
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_path = os.path.join('data.json')\n",
|
||||
"dummy_data_path = os.path.join('shared/dummy_data.json')\n",
|
||||
"\n",
|
||||
"create_dummy(data_path, dummy_data_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"scales = [2]\n",
|
||||
"selected_columns = ['col_name']\n",
|
||||
"commitment_maps = get_data_commitment_maps(data_path, scales)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/mhchia/projects/work/pse/zk-stats-lib/zkstats/computation.py:166: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
|
||||
" is_precise_aggregated = torch.tensor(1.0)\n",
|
||||
"/Users/mhchia/Library/Caches/pypoetry/virtualenvs/zkstats-brXmXluj-py3.12/lib/python3.12/site-packages/torch/onnx/symbolic_opset9.py:2174: FutureWarning: 'torch.onnx.symbolic_opset9._cast_Bool' is deprecated in version 2.0 and will be removed in the future. Please Avoid using this function and create a Cast node instead.\n",
|
||||
" return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Verifier/ data consumer side: send desired calculation\n",
|
||||
"from zkstats.computation import computation_to_model, State\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:\n",
|
||||
" x = data[0]\n",
|
||||
" # FIXME: should be replaced by `s.where` when it's available. Now the result may be incorrect\n",
|
||||
" filter = (x < 50)\n",
|
||||
" min_x = torch.min(x)\n",
|
||||
" filtered_x = torch.where(filter, x, min_x - 1)\n",
|
||||
" return s.median(filtered_x)\n",
|
||||
"\n",
|
||||
"error = 0.01\n",
|
||||
"_, verifier_model = computation_to_model(computation, error)\n",
|
||||
"\n",
|
||||
"verifier_define_calculation(dummy_data_path, selected_columns, sel_dummy_data_path,verifier_model, verifier_model_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Theory_output: tensor(40., dtype=torch.float64)\n",
|
||||
"==== Generate & Calibrate Setting ====\n",
|
||||
"scale: [2]\n",
|
||||
"setting: {\"run_args\":{\"tolerance\":{\"val\":0.0,\"scale\":1.0},\"input_scale\":2,\"param_scale\":2,\"scale_rebase_multiplier\":10,\"lookup_range\":[-582,1208],\"logrows\":14,\"num_inner_cols\":2,\"variables\":[[\"batch_size\",1]],\"input_visibility\":{\"Hashed\":{\"hash_is_public\":true,\"outlets\":[]}},\"output_visibility\":\"Public\",\"param_visibility\":\"Private\"},\"num_rows\":14432,\"total_assignments\":15928,\"total_const_size\":2126,\"model_instance_shapes\":[[1],[1]],\"model_output_scales\":[0,2],\"model_input_scales\":[2],\"module_sizes\":{\"kzg\":[],\"poseidon\":[14432,[1]],\"elgamal\":[0,[0]]},\"required_lookups\":[\"Abs\",{\"Div\":{\"denom\":2.0}},\"ReLU\",{\"Floor\":{\"scale\":4.0}},{\"GreaterThan\":{\"a\":0.0}},\"KroneckerDelta\"],\"check_mode\":\"UNSAFE\",\"version\":\"7.0.0\",\"num_blinding_factors\":null}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Prover/ data owner side\n",
|
||||
"_, prover_model = computation_to_model(computation, error)\n",
|
||||
"\n",
|
||||
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model, prover_model_path, scales, \"resources\", settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"spawning module 0\n",
|
||||
"spawning module 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"==== setting up ezkl ====\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"spawning module 0\n",
|
||||
"spawning module 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Time setup: 4.408194065093994 seconds\n",
|
||||
"=======================================\n",
|
||||
"Theory output: tensor(40., dtype=torch.float64)\n",
|
||||
"==== Generating Witness ====\n",
|
||||
"witness boolean: 1.0\n",
|
||||
"witness result 1 : 40.0\n",
|
||||
"==== Generating Proof ====\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"spawning module 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Here verifier & prover can concurrently call setup since all params are public to get pk.\n",
|
||||
"# Here write as verifier function to emphasize that verifier must calculate its own vk to be sure\n",
|
||||
"setup(verifier_model_path, verifier_compiled_model_path, settings_path,vk_path, pk_path )\n",
|
||||
"\n",
|
||||
"print(\"=======================================\")\n",
|
||||
"# Prover generates proof\n",
|
||||
"print(\"Theory output: \", theory_output)\n",
|
||||
"prover_gen_proof(prover_model_path, sel_data_path, witness_path, prover_compiled_model_path, settings_path, proof_path, pk_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"num_inputs: 1\n",
|
||||
"prf instances: [[[1780239215148830498, 13236513277824664467, 10913529727158264423, 131860697733488968], [12436184717236109307, 3962172157175319849, 7381016538464732718, 1011752739694698287], [12341676197686541490, 2627393525778350065, 16625494184434727973, 1478518078215075360]]]\n",
|
||||
"proof boolean: 1.0\n",
|
||||
"proof result 1 : 40.0\n",
|
||||
"verified\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Verifier verifies\n",
|
||||
"res = verifier_verify(proof_path, settings_path, vk_path, selected_columns, commitment_maps)\n",
|
||||
"print(\"Verifier gets result:\", res)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.1"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
253
examples/where/where+geomean/where+geomean.ipynb
Normal file
253
examples/where/where+geomean/where+geomean.ipynb
Normal file
File diff suppressed because one or more lines are too long
245
examples/where/where+mean/where+mean.ipynb
Normal file
245
examples/where/where+mean/where+mean.ipynb
Normal file
File diff suppressed because one or more lines are too long
245
examples/where/where+median/where+median.ipynb
Normal file
245
examples/where/where+median/where+median.ipynb
Normal file
File diff suppressed because one or more lines are too long
276
examples/where/where+mode/where+mode.ipynb
Normal file
276
examples/where/where+mode/where+mode.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"x1": [
|
||||
7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 10.0, 6.0, 6.0, 9.1, 1.7, 9.2,
|
||||
10.0, 7.1, 3.2, 8.6, 3.5, 0.1, 9.7, 2.3, 5.7, 2.8, 6.0, 6.0, 9.1, 1.7, 9.2,
|
||||
0.2, 7.8, 3.7, 7.0, 2.5, 2.8, 5.9, 7.3, 2.9, 2.9, 3.5, 1.0, 9.7, 4.8, 0.9,
|
||||
7.1, 3.6, 8.2, 3.0, 7.6, 4.2, 5.2, 8.1, 6.3, 9.3, 8.8, 8.2, 6.7, 4.9, 5.4,
|
||||
9.8, 5.9, 7.1, 3.9, 9.3
|
||||
@@ -12,7 +12,7 @@
|
||||
1.5, 2.1, 0.4, 4.3, 0.2
|
||||
],
|
||||
"y": [
|
||||
18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 20.8, 12.5, 21.5, 32.5,
|
||||
20.8, 18.5, 5.5, 18.2, 9.0, 4.0, 19.5, 11.7, 17.9, 15.3, 12.5, 21.5, 32.5,
|
||||
18.6, 23.9, 7.0, 16.9, 22.9, 31.0, 15.0, 8.5, 8.7, 28.9, 19.7, 12.5, 17.4,
|
||||
7.2, 25.5, 21.4, 15.7, 15.5, 8.2, 28.2, 19.5, 25.5, 12.5, 20.3, 21.7, 22.1,
|
||||
19.6, 32.2, 22.4, 20.6, 19.7, 20.8, 21.1, 21.8, 17.7, 21.1, 19.4
|
||||
291
examples/where/where+regression/where+regression.ipynb
Normal file
291
examples/where/where+regression/where+regression.ipynb
Normal file
File diff suppressed because one or more lines are too long
239
examples/where/where+stdev/where+stdev.ipynb
Normal file
239
examples/where/where+stdev/where+stdev.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -129,6 +129,7 @@ class State:
|
||||
Calculate the linear regression of x and y. The behavior should conform to
|
||||
[statistics.linear_regression](https://docs.python.org/3/library/statistics.html#statistics.linear_regression) in Python standard library.
|
||||
"""
|
||||
# hence support only one x for now
|
||||
return self._call_op([x, y], Regression)
|
||||
|
||||
# WHERE operation
|
||||
|
||||
@@ -45,7 +45,9 @@ def create_dummy(data_path: str, dummy_data_path: str) -> None:
|
||||
dummy_data ={}
|
||||
for col in data:
|
||||
# not use same value for every column to prevent something weird, like singular matrix
|
||||
dummy_data[col] = np.round(np.random.uniform(1,30,len(data[col])),1).tolist()
|
||||
min_col = min(data[col])
|
||||
max_col = max(data[col])
|
||||
dummy_data[col] = np.round(np.random.uniform(min_col,max_col,len(data[col])),1).tolist()
|
||||
|
||||
json.dump(dummy_data, open(dummy_data_path, 'w'))
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ class Where(Operation):
|
||||
return cls(torch.where(x[0],x[1], MagicNumber ),error)
|
||||
def ezkl(self, x:list[torch.Tensor]) -> IsResultPrecise:
|
||||
bool_array = torch.logical_or(x[1]==self.result, torch.logical_and(torch.logical_not(x[0]), self.result==MagicNumber))
|
||||
# print('sellll: ', self.result)
|
||||
return torch.sum(bool_array.float())==x[1].size()[1]
|
||||
|
||||
|
||||
@@ -41,8 +42,9 @@ class Mean(Operation):
|
||||
# return cls(torch.mean(x[0]), error)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
size = torch.sum((x[0]!=MagicNumber).float())
|
||||
x = torch.where(x[0]==MagicNumber, 0.0, x[0])
|
||||
x = x[0]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x = torch.where(x==MagicNumber, 0.0, x)
|
||||
return torch.abs(torch.sum(x)-size*self.result)<=torch.abs(self.error*size*self.result)
|
||||
|
||||
|
||||
@@ -64,6 +66,7 @@ class Median(Operation):
|
||||
# we want in our context. However, we tend to have x as a `[1, len(x), 1]`. In this case,
|
||||
# we need to flatten `x` to 1d array to get the correct `lower` and `upper`.
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
super().__init__(torch.tensor(np.median(x_1d)), error)
|
||||
sorted_x = np.sort(x_1d)
|
||||
len_x = len(x_1d)
|
||||
@@ -76,21 +79,25 @@ class Median(Operation):
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
old_size = x.size()[1]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
min_x = torch.min(x)
|
||||
x = torch.where(x==MagicNumber,min_x-1, x)
|
||||
|
||||
# since within 1%, we regard as same value
|
||||
count_less = torch.sum((x < self.result).float())
|
||||
count_less = torch.sum((x < self.result).float())-(old_size-size)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
len = x.size()[1]
|
||||
half_len = torch.floor(torch.div(len, 2))
|
||||
half_size = torch.floor(torch.div(size, 2))
|
||||
|
||||
less_cons = count_less<half_len+len%2
|
||||
more_cons = count_less+count_equal>half_len
|
||||
less_cons = count_less<half_size+size%2
|
||||
more_cons = count_less+count_equal>half_size
|
||||
|
||||
# For count_equal == 0
|
||||
lower_exist = torch.sum((x==self.lower).float())>0
|
||||
lower_cons = torch.sum((x>self.lower).float())==half_len
|
||||
lower_cons = torch.sum((x>self.lower).float())==half_size
|
||||
upper_exist = torch.sum((x==self.upper).float())>0
|
||||
upper_cons = torch.sum((x<self.upper).float())==half_len
|
||||
bound = count_less== half_len
|
||||
upper_cons = torch.sum((x<self.upper).float())==half_size
|
||||
bound = count_less== half_size
|
||||
# 0.02 since 2*0.01
|
||||
bound_avg = (torch.abs(self.lower+self.upper-2*self.result)<=torch.abs(2*self.error*self.result))
|
||||
|
||||
@@ -98,17 +105,20 @@ class Median(Operation):
|
||||
median_out_cons = torch.logical_and(torch.logical_and(bound, bound_avg), torch.logical_and(torch.logical_and(lower_cons, upper_cons), torch.logical_and(lower_exist, upper_exist)))
|
||||
return torch.where(count_equal==0, median_out_cons, median_in_cons)
|
||||
|
||||
|
||||
class GeometricMean(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'GeometricMean':
|
||||
x_1d = to_1d(x[0])
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
result = torch.exp(torch.mean(torch.log(x_1d)))
|
||||
return cls(result, error)
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x = torch.where(x==MagicNumber, 1.0, x)
|
||||
return torch.abs((torch.log(self.result)*size)-torch.sum(torch.log(x)))<=size*torch.log(torch.tensor(1+self.error))
|
||||
|
||||
|
||||
@@ -176,7 +186,7 @@ class Mode(Operation):
|
||||
@classmethod
|
||||
def create(cls, x: list[torch.Tensor], error: float) -> 'Mode':
|
||||
x_1d = to_1d(x[0])
|
||||
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
# Here is traditional definition of Mode, can just put this num_error to be 0
|
||||
result = torch.tensor(mode_within(x_1d, 0))
|
||||
return cls(result, error)
|
||||
@@ -184,6 +194,15 @@ class Mode(Operation):
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# Assume x is [1, n, 1]
|
||||
x = x[0]
|
||||
min_x = torch.min(x)
|
||||
old_size = x.size()[1]
|
||||
x = torch.where(x==MagicNumber, min_x-1, x)
|
||||
count_equal = torch.sum((x==self.result).float())
|
||||
result = torch.tensor([torch.logical_or(torch.sum((x==ele[0]).float())<=count_equal, min_x-1 ==ele[0]) for ele in x[0]])
|
||||
return torch.sum(result) == old_size
|
||||
|
||||
|
||||
|
||||
size = x.size()[1]
|
||||
count_equal = torch.sum((torch.abs(x-self.result)<=torch.abs(self.error*self.result)).float())
|
||||
_result = torch.tensor([
|
||||
@@ -235,6 +254,7 @@ class PVariance(Operation):
|
||||
class Stdev(Operation):
|
||||
def __init__(self, x: torch.Tensor, error: float):
|
||||
x_1d = to_1d(x)
|
||||
x_1d = x_1d[x_1d!=MagicNumber]
|
||||
self.data_mean = torch.nn.Parameter(data=torch.mean(x_1d), requires_grad=False)
|
||||
result = torch.sqrt(torch.var(x_1d, correction = 1))
|
||||
super().__init__(result, error)
|
||||
@@ -245,10 +265,12 @@ class Stdev(Operation):
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
|
||||
x_for_mean = torch.where(x==MagicNumber, 0.0, x)
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x_mean_cons = torch.abs(torch.sum(x_for_mean)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
|
||||
x_for_std = torch.where(x==MagicNumber, self.data_mean, x)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
|
||||
torch.abs(torch.sum((x_for_std-self.data_mean)*(x_for_std-self.data_mean))-self.result*self.result*(size - 1))<=torch.abs(2*self.error*self.result*self.result*(size - 1)), x_mean_cons
|
||||
)
|
||||
|
||||
|
||||
@@ -265,7 +287,7 @@ class Variance(Operation):
|
||||
|
||||
def ezkl(self, x: list[torch.Tensor]) -> IsResultPrecise:
|
||||
x = x[0]
|
||||
size = x.size()[1]
|
||||
size = torch.sum((x!=MagicNumber).float())
|
||||
x_mean_cons = torch.abs(torch.sum(x)-size*(self.data_mean))<=torch.abs(self.error*size*self.data_mean)
|
||||
return torch.logical_and(
|
||||
torch.abs(torch.sum((x-self.data_mean)*(x-self.data_mean))-self.result*(size - 1))<=torch.abs(self.error*self.result*(size - 1)), x_mean_cons
|
||||
@@ -349,12 +371,23 @@ def stacked_x(args: list[float]):
|
||||
|
||||
class Regression(Operation):
|
||||
def __init__(self, xs: list[torch.Tensor], y: torch.Tensor, error: float):
|
||||
x_1ds = [to_1d(i).tolist() for i in xs]
|
||||
y_1d = to_1d(y).tolist()
|
||||
# x_1ds = [to_1d(i).tolist() for i in xs]
|
||||
x_1ds = [to_1d(i) for i in xs]
|
||||
# print('xxxx: ', x_1ds)
|
||||
fil_x_1ds=[]
|
||||
for x_1 in x_1ds:
|
||||
fil_x_1ds.append((x_1[x_1!=MagicNumber]).tolist())
|
||||
x_1ds = fil_x_1ds
|
||||
# print('fil xxx',fil_x_1ds)
|
||||
# y_1d = to_1d(y).tolist()
|
||||
y_1d = to_1d(y)
|
||||
y_1d = (y_1d[y_1d!=MagicNumber]).tolist()
|
||||
# print('yyy: ', y_1d)
|
||||
|
||||
x_one = stacked_x(x_1ds)
|
||||
result_1d = np.matmul(np.matmul(np.linalg.inv(np.matmul(x_one.transpose(), x_one)), x_one.transpose()), y_1d)
|
||||
result = torch.tensor(result_1d, dtype = torch.float32).reshape(1, -1, 1)
|
||||
print('result: ', result)
|
||||
super().__init__(result, error)
|
||||
|
||||
@classmethod
|
||||
@@ -366,7 +399,9 @@ class Regression(Operation):
|
||||
def ezkl(self, args: list[torch.Tensor]) -> IsResultPrecise:
|
||||
# infer y from the last parameter
|
||||
y = args[-1]
|
||||
y = torch.where(y==MagicNumber, torch.tensor(0.0), y)
|
||||
x_one = torch.cat((*args[:-1], torch.ones_like(args[0])), dim=2)
|
||||
x_one = torch.where((x_one[:,:,0] ==MagicNumber).unsqueeze(-1), torch.tensor([0.0]*x_one.size()[2]), x_one)
|
||||
x_t = torch.transpose(x_one, 1, 2)
|
||||
return torch.sum(torch.abs(x_t @ x_one @ self.result - x_t @ y)) <= self.error * torch.sum(torch.abs(x_t @ y))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user