Files
zk-stats-lib/examples/median/median.ipynb
2024-02-19 22:17:08 +08:00

402 lines
29 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: ezkl==7.0.0 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from -r ../../requirements.txt (line 1)) (7.0.0)\n",
"Requirement already satisfied: torch in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from -r ../../requirements.txt (line 2)) (2.2.0)\n",
"Requirement already satisfied: requests in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from -r ../../requirements.txt (line 3)) (2.31.0)\n",
"Requirement already satisfied: scipy in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from -r ../../requirements.txt (line 4)) (1.12.0)\n",
"Requirement already satisfied: numpy in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from -r ../../requirements.txt (line 5)) (1.26.3)\n",
"Requirement already satisfied: matplotlib in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from -r ../../requirements.txt (line 6)) (3.8.2)\n",
"Requirement already satisfied: statistics in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from -r ../../requirements.txt (line 7)) (1.0.3.5)\n",
"Requirement already satisfied: onnx in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from -r ../../requirements.txt (line 8)) (1.15.0)\n",
"Requirement already satisfied: filelock in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.13.1)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from torch->-r ../../requirements.txt (line 2)) (4.9.0)\n",
"Requirement already satisfied: sympy in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from torch->-r ../../requirements.txt (line 2)) (1.12)\n",
"Requirement already satisfied: networkx in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from torch->-r ../../requirements.txt (line 2)) (3.1.3)\n",
"Requirement already satisfied: fsspec in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from torch->-r ../../requirements.txt (line 2)) (2023.12.2)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from requests->-r ../../requirements.txt (line 3)) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from requests->-r ../../requirements.txt (line 3)) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from requests->-r ../../requirements.txt (line 3)) (2.2.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from requests->-r ../../requirements.txt (line 3)) (2024.2.2)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (1.2.0)\n",
"Requirement already satisfied: cycler>=0.10 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (4.47.2)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (1.4.5)\n",
"Requirement already satisfied: packaging>=20.0 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (23.2)\n",
"Requirement already satisfied: pillow>=8 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (10.2.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (3.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from matplotlib->-r ../../requirements.txt (line 6)) (2.8.2)\n",
"Requirement already satisfied: docutils>=0.3 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from statistics->-r ../../requirements.txt (line 7)) (0.20.1)\n",
"Requirement already satisfied: protobuf>=3.20.2 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from onnx->-r ../../requirements.txt (line 8)) (4.25.2)\n",
"Requirement already satisfied: six>=1.5 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->-r ../../requirements.txt (line 6)) (1.16.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from jinja2->torch->-r ../../requirements.txt (line 2)) (2.1.4)\n",
"Requirement already satisfied: mpmath>=0.19 in /Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages (from sympy->torch->-r ../../requirements.txt (line 2)) (1.3.0)\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"pip install -r ../../requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"import torch\n",
"from torch import nn\n",
"import json\n",
"import os\n",
"import time\n",
"import scipy\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import statistics\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%run -i ../../zkstats/core.py"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# init path\n",
"os.makedirs(os.path.dirname('shared/'), exist_ok=True)\n",
"os.makedirs(os.path.dirname('prover/'), exist_ok=True)\n",
"verifier_model_path = os.path.join('shared/verifier.onnx')\n",
"prover_model_path = os.path.join('prover/prover.onnx')\n",
"verifier_compiled_model_path = os.path.join('shared/verifier.compiled')\n",
"prover_compiled_model_path = os.path.join('prover/prover.compiled')\n",
"pk_path = os.path.join('shared/test.pk')\n",
"vk_path = os.path.join('shared/test.vk')\n",
"proof_path = os.path.join('shared/test.pf')\n",
"settings_path = os.path.join('shared/settings.json')\n",
"srs_path = os.path.join('shared/kzg.srs')\n",
"witness_path = os.path.join('prover/witness.json')\n",
"# this is private to prover since it contains actual data\n",
"sel_data_path = os.path.join('prover/sel_data.json')\n",
"# this is just dummy random value\n",
"sel_dummy_data_path = os.path.join('shared/sel_dummy_data.json')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"======================= ZK-STATS FLOW ======================="
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"data_path = os.path.join('data.json')\n",
"dummy_data_path = os.path.join('shared/dummy_data.json')\n",
"\n",
"data = json.loads(open(data_path, \"r\").read())['col_name']\n",
"\n",
"create_dummy(data_path, dummy_data_path)\n",
"dummy_data = json.loads(open(dummy_data_path, \"r\").read())['col_name']\n",
"\n",
"dummy_theory_output = torch.tensor(np.median(dummy_data))\n",
"dummy_lower_to_median = torch.tensor(np.sort(dummy_data)[int(len(dummy_data)/2)-1])\n",
"dummy_upper_to_median = torch.tensor(np.sort(dummy_data)[int(len(dummy_data)/2)])\n",
"\n",
"theory_output = torch.tensor(np.median(data))\n",
"lower_to_median = torch.tensor(np.sort(data)[int(len(data)/2)-1])\n",
"upper_to_median = torch.tensor(np.sort(data)[int(len(data)/2)])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"scales = [8]\n",
"selected_columns = ['col_name']\n",
"commitment_maps = get_data_commitment_maps(data_path, scales)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dummy output: tensor(15.8000, dtype=torch.float64)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jernkun/Library/Caches/pypoetry/virtualenvs/zkstats-OJpceffF-py3.11/lib/python3.11/site-packages/torch/onnx/symbolic_opset9.py:2174: FutureWarning: 'torch.onnx.symbolic_opset9._cast_Bool' is deprecated in version 2.0 and will be removed in the future. Please Avoid using this function and create a Cast node instead.\n",
" return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))\n"
]
}
],
"source": [
"print(\"dummy output: \", dummy_theory_output)\n",
"# Verifier/ data consumer side: send desired calculation\n",
"class verifier_model(nn.Module):\n",
" def __init__(self):\n",
" super(verifier_model, self).__init__()\n",
" # w represents mean in this case\n",
" self.w = nn.Parameter(data = dummy_theory_output, requires_grad = False)\n",
" self.lower = nn.Parameter(data = dummy_lower_to_median, requires_grad = False)\n",
" self.upper = nn.Parameter(data = dummy_upper_to_median, requires_grad = False)\n",
" def forward(self,X):\n",
" # since within 1%, we regard as same value\n",
" count_less = torch.sum((X < 0.99*self.w).double())\n",
" count_equal = torch.sum((torch.abs(X-self.w)<=torch.abs(0.01*self.w)).double())\n",
" len = X.size()[1]\n",
" half_len = torch.floor(torch.div(len, 2))\n",
" \n",
" # not support modulo yet\n",
" less_cons = count_less<half_len+2*(len/2 - torch.floor(len/2))\n",
" more_cons = count_less+count_equal>half_len\n",
"\n",
" # For count_equal == 0\n",
" lower_exist = torch.sum((torch.abs(X-self.lower)<=torch.abs(0.01*self.lower)).double())>0\n",
" lower_cons = torch.sum((X>1.01*self.lower).double())==half_len\n",
" upper_exist = torch.sum((torch.abs(X-self.upper)<=torch.abs(0.01*self.upper)).double())>0\n",
" upper_cons = torch.sum((X<0.99*self.upper).double())==half_len\n",
" bound = 2*count_less== 2*half_len\n",
" # 0.02 since 2*0.01\n",
" bound_avg = (torch.abs(self.lower+self.upper-2*self.w)<=torch.abs(0.02*self.w))\n",
"\n",
" median_in_cons = torch.logical_and(less_cons, more_cons)\n",
" median_out_cons = torch.logical_and(torch.logical_and(bound, bound_avg), torch.logical_and(torch.logical_and(lower_cons, upper_cons), torch.logical_and(lower_exist, upper_exist)))\n",
"\n",
" return(torch.where(count_equal==0, median_out_cons, median_in_cons), self.w)\n",
"\n",
" \n",
"verifier_define_calculation(dummy_data_path, selected_columns,sel_dummy_data_path,verifier_model, verifier_model_path)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Theory_output: tensor(49.5500, dtype=torch.float64)\n",
"lower: tensor(49.3000, dtype=torch.float64)\n",
"upper: tensor(49.8000, dtype=torch.float64)\n",
"==== Generate & Calibrate Setting ====\n",
"scale: [8]\n",
"setting: {\"run_args\":{\"tolerance\":{\"val\":0.0,\"scale\":1.0},\"input_scale\":8,\"param_scale\":8,\"scale_rebase_multiplier\":10,\"lookup_range\":[-25112,24986],\"logrows\":16,\"num_inner_cols\":2,\"variables\":[[\"batch_size\",1]],\"input_visibility\":{\"Hashed\":{\"hash_is_public\":true,\"outlets\":[]}},\"output_visibility\":\"Public\",\"param_visibility\":\"Private\"},\"num_rows\":14432,\"total_assignments\":12052,\"total_const_size\":1815,\"model_instance_shapes\":[[1],[1]],\"model_output_scales\":[0,8],\"model_input_scales\":[8],\"module_sizes\":{\"kzg\":[],\"poseidon\":[14432,[1]],\"elgamal\":[0,[0]]},\"required_lookups\":[\"Abs\",{\"GreaterThan\":{\"a\":0.0}},\"KroneckerDelta\"],\"check_mode\":\"UNSAFE\",\"version\":\"7.0.0\",\"num_blinding_factors\":null}\n"
]
}
],
"source": [
"# prover calculates settings, send to verifier\n",
"print(\"Theory_output: \", theory_output)\n",
"print(\"lower: \", lower_to_median)\n",
"print(\"upper: \", upper_to_median)\n",
"class prover_model(nn.Module):\n",
" def __init__(self):\n",
" super(prover_model, self).__init__()\n",
" # w represents mean in this case\n",
" self.w = nn.Parameter(data = theory_output, requires_grad = False)\n",
" self.lower = nn.Parameter(data = lower_to_median, requires_grad = False)\n",
" self.upper = nn.Parameter(data = upper_to_median, requires_grad = False)\n",
" def forward(self,X):\n",
" # since within 1%, we regard as same value\n",
" count_less = torch.sum((X < 0.99*self.w).double())\n",
" count_equal = torch.sum((torch.abs(X-self.w)<=torch.abs(0.01*self.w)).double())\n",
" len = X.size()[1]\n",
" half_len = torch.floor(torch.div(len, 2))\n",
" \n",
" # not support modulo yet\n",
" less_cons = count_less<half_len+2*(len/2 - torch.floor(len/2))\n",
" more_cons = count_less+count_equal>half_len\n",
"\n",
" # For count_equal == 0\n",
" lower_exist = torch.sum((torch.abs(X-self.lower)<=torch.abs(0.01*self.lower)).double())>0\n",
" lower_cons = torch.sum((X>1.01*self.lower).double())==half_len\n",
" upper_exist = torch.sum((torch.abs(X-self.upper)<=torch.abs(0.01*self.upper)).double())>0\n",
" upper_cons = torch.sum((X<0.99*self.upper).double())==half_len\n",
" bound = 2*count_less == 2*half_len\n",
" # 0.02 since 2*0.01\n",
" bound_avg = (torch.abs(self.lower+self.upper-2*self.w)<=torch.abs(0.02*self.w))\n",
"\n",
" median_in_cons = torch.logical_and(less_cons, more_cons)\n",
" median_out_cons = torch.logical_and(torch.logical_and(bound, bound_avg), torch.logical_and(torch.logical_and(lower_cons, upper_cons), torch.logical_and(lower_exist, upper_exist)))\n",
" return(torch.where(count_equal==0, median_out_cons, median_in_cons), self.w)\n",
"\n",
"\n",
" \n",
"prover_gen_settings(data_path, selected_columns, sel_data_path, prover_model,prover_model_path, scales, \"resources\", settings_path)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"spawning module 0\n",
"spawning module 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"==== setting up ezkl ====\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"spawning module 0\n",
"spawning module 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time setup: 6.857371807098389 seconds\n",
"=======================================\n",
"Theory output: tensor(49.5500, dtype=torch.float64)\n",
"==== Generating Witness ====\n",
"witness boolean: 1.0\n",
"witness result 1 : 49.55078125\n",
"==== Generating Proof ====\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"spawning module 0\n",
"spawning module 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"proof: {'instances': [[[3042937791208075219, 8157070662846698822, 3804781648660056856, 172406108020799675], [12436184717236109307, 3962172157175319849, 7381016538464732718, 1011752739694698287], [18341455175509539295, 12796101019039945164, 1607286914885633240, 1929881192315725821]]], 'proof': '', 'transcript_type': 'EVM'}\n",
"Time gen prf: 9.461018085479736 seconds\n"
]
}
],
"source": [
"# Here verifier & prover can concurrently call setup since all params are public to get pk.\n",
"# Here write as verifier function to emphasize that verifier must calculate its own vk to be sure\n",
"setup(verifier_model_path, verifier_compiled_model_path, settings_path,vk_path, pk_path )\n",
"\n",
"print(\"=======================================\")\n",
"# Prover generates proof\n",
"print(\"Theory output: \", theory_output)\n",
"prover_gen_proof(prover_model_path, sel_data_path, witness_path, prover_compiled_model_path, settings_path, proof_path, pk_path)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"49.55078125"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Verifier verifies\n",
"verifier_verify(proof_path, settings_path, vk_path, selected_columns, commitment_maps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}