Files
ezkl/examples/notebooks/univ3-da.ipynb
dante bf9cf14ab7 refactor!: rpc url should be required (#965)
BREAKING CHANGE: in python the order of arguments for evm related functions has changed
2025-04-22 12:45:36 +01:00

686 lines
30 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# univ3-da-ezkl\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source. For this setup we make a single call to a view function that returns an array of UniV3 historical TWAP price data that we will attest to on-chain. \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--fork-url\", \"https://arb1.arbitrum.io/rpc\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"private\"\n",
"- `output_visibility`: public\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.decomp_legs=5\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"call\": {\n",
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
" \"len\": 12 // The number of data points returned by the view function (the length of the array)\n",
" }\n",
" }\n",
"}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
"import os\n",
"import torch\n",
"import requests\n",
"\n",
"def count_decimal_places(num):\n",
" num_str = str(num)\n",
" if '.' in num_str:\n",
" return len(num_str) - 1 - num_str.index('.')\n",
" else:\n",
" return 0\n",
"\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"def on_chain_data(tensor):\n",
" data = tensor.view(-1).tolist()\n",
" secondsAgo = [len(data) - 1 - i for i in range(len(data))]\n",
"\n",
" contract_source_code = '''\n",
" // SPDX-License-Identifier: MIT\n",
" pragma solidity ^0.8.20;\n",
"\n",
" interface IUniswapV3PoolDerivedState {\n",
" function observe(\n",
" uint32[] calldata secondsAgos\n",
" ) external view returns (\n",
" int56[] memory tickCumulatives,\n",
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
" );\n",
" }\n",
"\n",
" contract UniTickAttestor {\n",
" int256[] private cachedTicks;\n",
"\n",
" function consult(\n",
" IUniswapV3PoolDerivedState pool,\n",
" uint32[] memory secondsAgo\n",
" ) public view returns (int256[] memory tickCumulatives) {\n",
" tickCumulatives = new int256[](secondsAgo.length);\n",
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
" for (uint256 i = 0; i < secondsAgo.length; i++) {\n",
" tickCumulatives[i] = int256(_ticks[i]);\n",
" }\n",
" }\n",
"\n",
" function cache_price(\n",
" IUniswapV3PoolDerivedState pool,\n",
" uint32[] memory secondsAgo\n",
" ) public {\n",
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
" cachedTicks = new int256[](_ticks.length);\n",
" for (uint256 i = 0; i < _ticks.length; i++) {\n",
" cachedTicks[i] = int256(_ticks[i]);\n",
" }\n",
" }\n",
"\n",
" function readPriceCache() public view returns (int256[] memory) {\n",
" return cachedTicks;\n",
" }\n",
" }\n",
" '''\n",
"\n",
" compiled_sol = compile_standard({\n",
" \"language\": \"Solidity\",\n",
" \"sources\": {\"UniTickAttestor.sol\": {\"content\": contract_source_code}},\n",
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
" })\n",
"\n",
" bytecode = compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['evm']['bytecode']['object']\n",
" abi = json.loads(compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['metadata'])['output']['abi']\n",
"\n",
" # Deploy contract\n",
" UniTickAttestor = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
" tx_hash = UniTickAttestor.constructor().transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
"\n",
" # Step 4: Store data via cache_price transaction\n",
" tx_hash = contract.functions.cache_price(\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
"\n",
" # Step 5: Prepare calldata for readPriceCache\n",
" call = contract.functions.readPriceCache().build_transaction()\n",
" calldata = call['data'][2:]\n",
"\n",
" # Get stored data\n",
" result = contract.functions.readPriceCache().call()\n",
" print(f'Cached ticks: {result}')\n",
"\n",
" decimals = [0] * len(data)\n",
"\n",
" call_to_account = {\n",
" 'call_data': calldata,\n",
" 'decimals': decimals,\n",
" 'address': contract.address[2:],\n",
" }\n",
"\n",
" return call_to_account\n",
"\n",
"start_anvil()\n",
"call_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'call': call_to_account })\n",
"json.dump(data, open(\"input.json\", 'w'))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we need to regenerate the witness, prove and then verify all within the same cell. This is because we want to reduce the amount of latency between reading on-chain state and verifying it on-chain. This is because the attest input values read from the oracle are time sensitive (their values are derived from computing on block.timestamp) and can change between the time of reading and the time of verifying.\n",
"\n",
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"# print(res)\n",
"assert os.path.isfile(proof_path)\n",
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}