Files
ezkl/examples/notebooks/set_membership.ipynb
2025-06-27 22:58:10 +02:00

524 lines
16 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67"
},
"source": [
"## Hash set membership demo"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {
"id": "95613ee9"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"pytest\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import logging\n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import torch\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
"\n",
" def forward(self, x, y):\n",
" diff = (x - y)\n",
" membership_test = torch.prod(diff, dim=1)\n",
" return (membership_test,y)\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {
"id": "b37637c4"
},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c833f08c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c833f08c",
"outputId": "b5c794e1-c787-4b65-e267-c005e661df1b"
},
"outputs": [],
"source": [
"# print pytorch version\n",
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {
"id": "82db373a"
},
"outputs": [],
"source": [
"\n",
"\n",
"x = torch.zeros(1,*[1], requires_grad=True)\n",
"y = torch.tensor([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], requires_grad=True)\n",
"\n",
"y_input = [0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]\n",
"\n",
"# Create an empty list to store the results\n",
"result = []\n",
"\n",
"# Loop through each element in the y tensor\n",
"for e in y_input:\n",
" # Apply the custom function and append the result to the list\n",
" print(ezkl.float_to_felt(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_felt(e, 7)])[0])\n",
"\n",
"y = y.unsqueeze(0)\n",
"y = y.reshape(1, 9)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" (x,y), # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=14, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array_x = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"data_array_y = result\n",
"print(data_array_y)\n",
"\n",
"data = dict(input_data = [data_array_x, data_array_y])\n",
"\n",
"print(data)\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {
"id": "d5e374a2"
},
"outputs": [],
"source": [
"run_args = ezkl.PyRunArgs()\n",
"# \"hashed/private\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"hashed/private/0\"\n",
"# as the inputs are felts we turn off input range checks\n",
"run_args.ignore_range_check_inputs_outputs = True\n",
"# we set it to fix the set we want to check membership for\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public -- set membership fails if it is not = 0\n",
"run_args.output_visibility = \"fixed\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"# never rebase the scale\n",
"run_args.scale_rebase_multiplier = 1000\n",
"# logrows\n",
"run_args.logrows = 11\n",
"\n",
"# this creates the following sequence of ops:\n",
"# 1. hash the input -> poseidon(x)\n",
"# 2. compute the set difference -> poseidon(x) - set\n",
"# 3. compute the product of the set difference -> prod(poseidon(x) - set)\n",
"\n",
"\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {
"id": "3aa4f090"
},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8b74dcee",
"outputId": "f7b9198c-2b3d-48bb-c67e-8478333cedb5"
},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {
"id": "18c8b7c7"
},
"outputs": [],
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "Y94vCo5Znrim",
"metadata": {
"id": "Y94vCo5Znrim"
},
"outputs": [],
"source": [
"# now generate a faulty input + witness file (x input not in the set)\n",
"\n",
"data_path_faulty = os.path.join('input_faulty.json')\n",
"\n",
"witness_path_faulty = os.path.join('witness_faulty.json')\n",
"\n",
"x = torch.ones(1,*[1], requires_grad=True)\n",
"y = torch.tensor([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], requires_grad=True)\n",
"\n",
"y = y.unsqueeze(0)\n",
"y = y.reshape(1, 9)\n",
"\n",
"data_array_x = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"data_array_y = result\n",
"print(data_array_y)\n",
"\n",
"data = dict(input_data = [data_array_x, data_array_y])\n",
"\n",
"print(data)\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path_faulty, 'w' ))\n",
"\n",
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
"assert os.path.isfile(witness_path_faulty)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "FQfGdcUNpvuK",
"metadata": {
"id": "FQfGdcUNpvuK"
},
"outputs": [],
"source": [
"# now generate a truthy input + witness file (x input not in the set)\n",
"import random\n",
"\n",
"# Generate a random integer between 1 and 8, inclusive\n",
"random_value = random.randint(1, 8)\n",
"\n",
"data_path_truthy = os.path.join('input_truthy.json')\n",
"\n",
"witness_path_truthy = os.path.join('witness_truthy.json')\n",
"\n",
"set = [0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]\n",
"\n",
"x = torch.tensor([set[random_value]])\n",
"y = torch.tensor(set, requires_grad=True)\n",
"\n",
"y = y.unsqueeze(0)\n",
"y = y.reshape(1, 9)\n",
"\n",
"x = x.unsqueeze(0)\n",
"x = x.reshape(1,1)\n",
"\n",
"data_array_x = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"data_array_y = result\n",
"print(data_array_y)\n",
"\n",
"data = dict(input_data = [data_array_x, data_array_y])\n",
"\n",
"print(data)\n",
"\n",
"# Serialize data into file:\n",
"json.dump( data, open(data_path_truthy, 'w' ))\n",
"\n",
"res = ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
"assert os.path.isfile(witness_path_truthy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41fd15a8",
"metadata": {},
"outputs": [],
"source": [
"witness = json.load(open(witness_path, \"r\"))\n",
"witness"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {
"id": "b1c561a8"
},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"# we force the output to be 0 this corresponds to the set membership test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [\"0000000000000000000000000000000000000000000000000000000000000000\"]\n",
"json.dump(witness, open(witness_path, \"w\"))\n",
"\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"print(witness[\"outputs\"][0])\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" witness_path = witness_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {
"id": "c384cbc8"
},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "XAC73EvtpM-W",
"metadata": {
"id": "XAC73EvtpM-W"
},
"outputs": [],
"source": [
"# GENERATE A FAULTY PROOF\n",
"\n",
"\n",
"proof_path_faulty = os.path.join('test_faulty.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path_faulty,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path_faulty,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path_faulty)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "_x19Q4FUrKb6",
"metadata": {
"id": "_x19Q4FUrKb6"
},
"outputs": [],
"source": [
"# GENERATE A TRUTHY PROOF\n",
"\n",
"\n",
"proof_path_truthy = os.path.join('test_truthy.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path_truthy,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path_truthy,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path_truthy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {
"id": "76f00d41"
},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"assert res == True\n",
"\n",
"res = ezkl.verify(\n",
" proof_path_truthy,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4nqEx7-qpciQ",
"metadata": {
"id": "4nqEx7-qpciQ"
},
"outputs": [],
"source": [
"import pytest\n",
"def test_verification():\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
" ezkl.verify(\n",
" proof_path_faulty,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"# Run the test function\n",
"test_verification()"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}