chore: python integration tests (#350)

This commit is contained in:
dante
2023-07-10 22:24:46 +01:00
committed by GitHub
parent da0755c5fc
commit 50635cb884
11 changed files with 2948 additions and 8 deletions

View File

@@ -220,6 +220,7 @@ jobs:
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
python-integration-tests,
]
steps:
- uses: actions/checkout@v3
@@ -261,6 +262,7 @@ jobs:
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
python-integration-tests,
]
steps:
- uses: actions/checkout@v3
@@ -294,6 +296,7 @@ jobs:
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
python-integration-tests,
]
steps:
- uses: actions/checkout@v3
@@ -325,6 +328,7 @@ jobs:
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
python-integration-tests,
]
steps:
- uses: actions/checkout@v3
@@ -428,3 +432,45 @@ jobs:
run: source .env/bin/activate; maturin develop --features python-bindings --release
- name: Run pytest
run: source .env/bin/activate; pytest
python-integration-tests:
runs-on: ubuntu-latest-32-cores
needs: [build, build-wasm, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --profile local --locked foundry-cli anvil
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
- name: authenticate-kaggle-cli
shell: bash
env:
KAGGLE_API_KEY: ${{ secrets.KAGGLE_API_KEY }}
run: |
mkdir /home/runner/.kaggle
# now dump the contents of the file into a file called kaggle.json
echo $KAGGLE_API_KEY > /home/runner/.kaggle/kaggle.json
chmod 600 /home/runner/.kaggle/kaggle.json
- name: Notebook integration tests
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_

1
.gitignore vendored
View File

@@ -1,5 +1,6 @@
target
data
*.csv
*.ipynb_checkpoints
*.sol
!QuantizeData.sol

View File

@@ -39,10 +39,21 @@ The generated proofs can then be used on-chain to verify computation, only the E
| --- | --- |
| [docs](https://docs.ezkl.xyz ) | the official ezkl docs page |
| [colab notebook demo](https://colab.research.google.com/drive/1XuXNKqH7axOelZXyU3gpoTOCvFetIsKu?usp=sharing) | demo of ezkl python bindings on google's colab
| [tutorial](https://github.com/zkonduit/pyezkl/tree/main/examples/tutorial) | end-to-end tutorial using pytorch and ezkl |
| [notebook](https://github.com/zkonduit/pyezkl/blob/main/examples/ezkl_demo.ipynb) | end-to-end tutorial using pytorch and ezkl in a jupyter notebook |
| `cargo doc --open` | compile and open the docs in your default browser locally |
#### tutorials
You can find a range of python based tutorials in the `examples/notebooks` section. These all assume you have the `ezkl` python library installed. If you want the bleeding edge version of the library, you can install it from the `main` branch with:
```bash
python -m venv .env
source .env/bin/activate
pip install -r requirements.txt
maturin develop --release --features python-bindings
# dependencies specific to tutorials
pip install torch pandas numpy seaborn jupyter onnx kaggle py-solc-x web3 librosa
```
----------------------

View File

@@ -0,0 +1,623 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# data-attest-ezkl\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source.\n",
"\n",
"In this setup:\n",
"- the inputs and outputs are publicly known to the prover and verifier\n",
"- the on chain inputs will be fetched and then fed directly into the circuit\n",
"- the quantization of the on-chain inputs happens within the evm and is replicated at proving time \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually stored on chain and read from accoring to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\",])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw field elements generated from feeding the value of `input.json` into the circuit are stored. In the case of input data from file this would be the raw quantized values from `input.json` converted into field elements. Whereas for output data, these are the raw outputs (field elements) generated from running a forward pass of the model/circuit with the given `input.json` (and any hashes). \n",
"\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL. \n",
"Here is what the schema for an on-chain data source graph input file should look like:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": [\n",
" {\n",
" \"call_data\": [\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to view function that returns on-chain data (only support uint256 returns for now)\n",
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only supports integer values.\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
" 5\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
" 5\n",
" ]\n",
" ],\n",
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
" }\n",
" ]\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
"import os\n",
"import torch\n",
"\n",
"# This function counts the decimal places of a floating point number\n",
"def count_decimal_places(num):\n",
" num_str = str(num)\n",
" if '.' in num_str:\n",
" return len(num_str) - 1 - num_str.index('.')\n",
" else:\n",
" return 0\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) # replace with your provider\n",
"\n",
"def test_on_chain_data(tensor):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = tensor.view(-1).tolist()\n",
"\n",
" # Step 1: Prepare the data\n",
" decimals = [count_decimal_places(x) for x in data]\n",
" scaled_data = [int(x * 10**decimals[i]) for i, x in enumerate(data)]\n",
"\n",
" # Step 2: Prepare and compile the contract.\n",
" # We are using a test contract here but in production you would \n",
" # use whatever contract you are fetching data from.\n",
" contract_source_code = '''\n",
" // SPDX-License-Identifier: UNLICENSED\n",
" pragma solidity ^0.8.17;\n",
"\n",
" contract TestReads {\n",
"\n",
" uint[] public arr;\n",
" constructor(uint256[] memory _numbers) {\n",
" for(uint256 i = 0; i < _numbers.length; i++) {\n",
" arr.push(_numbers[i]);\n",
" }\n",
" }\n",
" }\n",
" '''\n",
"\n",
" compiled_sol = compile_standard({\n",
" \"language\": \"Solidity\",\n",
" \"sources\": {\"testreads.sol\": {\"content\": contract_source_code}},\n",
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
" })\n",
"\n",
" # Get bytecode\n",
" bytecode = compiled_sol['contracts']['testreads.sol']['TestReads']['evm']['bytecode']['object']\n",
"\n",
" # Get ABI\n",
" # In production if you are reading from really large contracts you can just use\n",
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
" abi = json.loads(compiled_sol['contracts']['testreads.sol']['TestReads']['metadata'])['output']['abi']\n",
"\n",
" # Step 3: Deploy the contract\n",
" TestReads = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
" tx_hash = TestReads.constructor(scaled_data).transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
" # passing the address and abi of the contract you are fetching data from.\n",
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
"\n",
" # Step 4: Interact with the contract\n",
" calldata = []\n",
" for i, _ in enumerate(data):\n",
" call = contract.functions.arr(i).build_transaction()\n",
" calldata.append((call['data'][2:], decimals[i])) # In production you would need to manually decide what the optimal decimal place for each value should be. \n",
" # Here we are just using the number of decimal places in a randomly generated tensor.\n",
"\n",
" # Prepare the calls_to_account object\n",
" # If you were calling view functions across multiple contracts,\n",
" # you would have multiple entries in the calls_to_account array,\n",
" # one for each contract.\n",
" calls_to_account = [{\n",
" 'call_data': calldata,\n",
" 'address': contract.address[2:], # remove the '0x' prefix\n",
" }]\n",
"\n",
" print(f'calls_to_account: {calls_to_account}')\n",
"\n",
" return calls_to_account\n",
"\n",
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
"start_anvil()\n",
"\n",
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
"calls_to_account = test_on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"private\"\n",
"- `output_visibility`: public\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"public\"\n",
"\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs(srs_path, settings_path)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, model_path, witness_path, settings_path = settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" model_path,\n",
" vk_path,\n",
" pk_path,\n",
" srs_path,\n",
" settings_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" model_path,\n",
" pk_path,\n",
" proof_path,\n",
" srs_path,\n",
" \"evm\",\n",
" \"single\",\n",
" settings_path,\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" srs_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create an EVM / `.sol` verifier that can be deployed on chain to verify submitted proofs and attest to on-chain EZKL inputs using a view function. Make sure to pass the `input_path` instead of the `witness_path` as the former contains the on-chain data that we are attesting to."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = ezkl.create_evm_data_attestation_verifier(\n",
" vk_path,\n",
" srs_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" input_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path = \"addr.txt\"\n",
"\n",
"res = ezkl.deploy_da_evm(\n",
" addr_path,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the address from addr_path\n",
"addr = None\n",
"with open(addr_path, 'r') as f:\n",
" addr = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
" proof_path,\n",
" addr,\n",
" RPC_URL,\n",
" True\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,480 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# hashed-ezkl\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model, and the model params themselves, are hashed inside a circuit.\n",
"\n",
"In this setup:\n",
"- the hashes are publicly known to the prover and verifier\n",
"- the hashes serve as \"public inputs\" (a.k.a instances) to the circuit\n",
"\n",
"We leave the outputs of the model as public as well (known to the verifier and prover). \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a humble model with but a conv layer and a $ReLU$ non-linearity, but it is a model nonetheless"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"# we got convs, we got relu, \n",
"# What else could one want ????\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
"\n",
" self.conv1 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=5, stride=4)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.relu(x)\n",
"\n",
" return x\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input file.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 8, 8], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This is where the magic happens. We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"There are currently 4 visibility settings:\n",
"- `public`: known to both the verifier and prover (a subtle nuance is that this may not be the case for model parameters but until we have more rigorous theoretical results we don't want to make strong claims as to this). \n",
"- `private`: known only to the prover\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"hashed\"\n",
"- `param_visibility`: \"hashed\"\n",
"- `output_visibility`: public\n",
"\n",
"We encourage you to play around with other setups :) \n",
"\n",
"Shoutouts: \n",
"\n",
"- [summa-solvency](https://github.com/summa-dev/summa-solvency) for their help with the poseidon hashing chip. \n",
"- [timeofey](https://github.com/timoftime) for providing inspiration in our developement of the el-gamal encryption circuit in Halo2. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"hashed\"\n",
"run_args.param_visibility = \"hashed\"\n",
"run_args.output_visibility = \"public\"\n",
"\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 8, 8])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs(srs_path, settings_path)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the (partial) circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, model_path, witness_path, settings_path = settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As a sanity check you can \"mock prove\" (i.e check that all the constraints of the circuit match without generate a full proof). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"res = ezkl.mock(witness_path, model_path, settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" model_path,\n",
" vk_path,\n",
" pk_path,\n",
" srs_path,\n",
" settings_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" model_path,\n",
" pk_path,\n",
" proof_path,\n",
" srs_path,\n",
" \"evm\",\n",
" \"single\",\n",
" settings_path,\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" srs_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create an EVM / `.sol` verifier that can be deployed on chain to verify submitted proofs using a view function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" srs_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify on the evm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
"# we use the default anvil node here\n",
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(address_path, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
" proof_path,\n",
" addr,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,258 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## EZKL Jupyter Notebook Demo \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import torch\n",
"\n",
"\n",
"# Defines the model\n",
"# we got convs, we got relu, we got linear layers\n",
"# What else could one want ????\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
"\n",
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)\n",
" self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)\n",
"\n",
" self.relu = nn.ReLU()\n",
"\n",
" self.d1 = nn.Linear(48, 48)\n",
" self.d2 = nn.Linear(48, 10)\n",
"\n",
" def forward(self, x):\n",
" # 32x1x28x28 => 32x32x26x26\n",
" x = self.conv1(x)\n",
" x = self.relu(x)\n",
" x = self.conv2(x)\n",
" x = self.relu(x)\n",
"\n",
" # flatten => 32 x (32*26*26)\n",
" x = x.flatten(start_dim = 1)\n",
"\n",
" # 32 x (32*26*26) => 32x128\n",
" x = self.d1(x)\n",
" x = self.relu(x)\n",
"\n",
" # logits => 32x10\n",
" logits = self.d2(x)\n",
"\n",
" return logits\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
"x = 0.1*torch.rand(1,*[1, 28, 28], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.gen_srs(srs_path, 17)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, model_path, witness_path, settings_path = settings_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" model_path,\n",
" vk_path,\n",
" pk_path,\n",
" srs_path,\n",
" settings_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" model_path,\n",
" pk_path,\n",
" proof_path,\n",
" srs_path,\n",
" \"evm\",\n",
" \"single\",\n",
" settings_path,\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" srs_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,29 @@
# download tess data
# check if first argument has been set
if [ ! -z "$1" ]; then
DATA_DIR=$1/data
else
DATA_DIR=data
fi
echo "Downloading data to $DATA_DIR"
if [ ! -d "$DATA_DIR/TESS" ]; then
kaggle datasets download ejlok1/toronto-emotional-speech-set-tess -p $DATA_DIR --unzip
mv "$DATA_DIR/TESS Toronto emotional speech set data" $DATA_DIR/TESS
rm -r "$DATA_DIR/TESS/TESS Toronto emotional speech set data"
fi
if [ ! -d "$DATA_DIR/RAVDESS_SONG" ]; then
kaggle datasets download uwrfkaggler/ravdess-emotional-song-audio -p $DATA_DIR/RAVDESS_SONG --unzip
fi
if [ ! -d "$DATA_DIR/RAVDESS_SPEECH" ]; then
kaggle datasets download uwrfkaggler/ravdess-emotional-speech-audio -p $DATA_DIR/RAVDESS_SPEECH --unzip
fi
if [ ! -d "$DATA_DIR/CREMA-D" ]; then
kaggle datasets download ejlok1/cremad -p $DATA_DIR --unzip
mv $DATA_DIR/AudioWAV $DATA_DIR/CREMA-D
fi
if [ ! -d "$DATA_DIR/SAVEE" ]; then
kaggle datasets download ejlok1/surrey-audiovisual-expressed-emotion-savee -p $DATA_DIR --unzip
mv $DATA_DIR/ALL $DATA_DIR/SAVEE
fi

View File

@@ -0,0 +1,869 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Voice judgoor\n",
"\n",
"Here we showcase a full-end-to-end flow of:\n",
"1. training a model for a specific task (judging voices)\n",
"2. creating a proof of judgment\n",
"3. creating and deploying and evm verifier\n",
"4. verifying the proof of judgment using the verifier\n",
"\n",
"First we download a few voice related datasets from kaggle, which are all labelled using the same emotion and tone labelling standard.\n",
"\n",
"We have 8 emotions in both speaking and singing datasets: neutral, calm, happy, sad, angry, feat, disgust, surprise.\n",
"\n",
"To download the dataset make sure you have the kaggle cli installed in your local env `pip install kaggle`. Make sure you set up your `kaggle.json` file as detailed [here](https://www.kaggle.com/docs/api#getting-started-installation-&-authentication).\n",
"Then run the associated `voice_data.sh` data download script: `sh voice_data.sh`.\n",
"\n",
"Make sure you set the `VOICE_DATA_DIR` variables to point to the directory the `voice_data.sh` script has downloaded to. This script also accepts an argument to download to a specific directory: `sh voice_data.sh /path/to/voice/data`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"VOICE_DATA_DIR\"] = \"../..\"\n",
"\n",
"import os\n",
"voice_data_dir = os.environ.get('VOICE_DATA_DIR')\n",
"\n",
"# if is none set to \"\" \n",
"if voice_data_dir is None:\n",
" voice_data_dir = \"\"\n",
" \n",
"print(\"voice_data_dir: \", voice_data_dir)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### TESS Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import pandas as pd\n",
"import logging\n",
"\n",
"# read in VOICE_DATA_DIR from environment variable\n",
"\n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.INFO)\n",
"\n",
"\n",
"Tess = os.path.join(voice_data_dir, \"data/TESS/\")\n",
"\n",
"tess = os.listdir(Tess)\n",
"\n",
"emotions = []\n",
"files = []\n",
"\n",
"for item in tess:\n",
" items = os.listdir(Tess + item)\n",
" for file in items:\n",
" part = file.split('.')[0]\n",
" part = part.split('_')[2]\n",
" if part == 'ps':\n",
" emotions.append('surprise')\n",
" else:\n",
" emotions.append(part)\n",
" files.append(Tess + item + '/' + file)\n",
"\n",
"tess_df = pd.concat([pd.DataFrame(emotions, columns=['Emotions']), pd.DataFrame(files, columns=['Files'])], axis=1)\n",
"tess_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### RAVDESS SONG dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Ravdess = os.path.join(voice_data_dir, \"data/RAVDESS_SONG/audio_song_actors_01-24/\")\n",
"\n",
"ravdess_list = os.listdir(Ravdess)\n",
"\n",
"files = []\n",
"emotions = []\n",
"\n",
"for item in ravdess_list:\n",
" actor = os.listdir(Ravdess + item)\n",
" for file in actor:\n",
" name = file.split('.')[0]\n",
" parts = name.split('-')\n",
" emotions.append(int(parts[2]))\n",
" files.append(Ravdess + item + '/' + file)\n",
"\n",
"emotion_data = pd.DataFrame(emotions, columns=['Emotions'])\n",
"files_data = pd.DataFrame(files, columns=['Files'])\n",
"\n",
"ravdess_song_df = pd.concat([emotion_data, files_data], axis=1)\n",
"\n",
"ravdess_song_df.Emotions.replace({1:'neutral', 2:'calm', 3:'happy', 4:'sad', 5:'angry', 6:'fear', 7:'disgust', 8:'surprise'}, inplace=True)\n",
"\n",
"ravdess_song_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### RAVDESS SPEECH Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Ravdess = os.path.join(voice_data_dir, \"data/RAVDESS_SPEECH/audio_speech_actors_01-24/\")\n",
"\n",
"ravdess_list = os.listdir(Ravdess)\n",
"\n",
"files = []\n",
"emotions = []\n",
"\n",
"for item in ravdess_list:\n",
" actor = os.listdir(Ravdess + item)\n",
" for file in actor:\n",
" name = file.split('.')[0]\n",
" parts = name.split('-')\n",
" emotions.append(int(parts[2]))\n",
" files.append(Ravdess + item + '/' + file)\n",
" \n",
"emotion_data = pd.DataFrame(emotions, columns=['Emotions'])\n",
"files_data = pd.DataFrame(files, columns=['Files'])\n",
"\n",
"ravdess_df = pd.concat([emotion_data, files_data], axis=1)\n",
"\n",
"ravdess_df.Emotions.replace({1:'neutral', 2:'calm', 3:'happy', 4:'sad', 5:'angry', 6:'fear', 7:'disgust', 8:'surprise'}, inplace=True)\n",
"\n",
"ravdess_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CREMA Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Crema = os.path.join(voice_data_dir, \"data/CREMA-D/\")\n",
"\n",
"crema = os.listdir(Crema)\n",
"emotions = []\n",
"files = []\n",
"\n",
"for item in crema:\n",
" files.append(Crema + item)\n",
" \n",
" parts = item.split('_')\n",
" if parts[2] == 'SAD':\n",
" emotions.append('sad')\n",
" elif parts[2] == 'ANG':\n",
" emotions.append('angry')\n",
" elif parts[2] == 'DIS':\n",
" emotions.append('disgust')\n",
" elif parts[2] == 'FEA':\n",
" emotions.append('fear')\n",
" elif parts[2] == 'HAP':\n",
" emotions.append('happy')\n",
" elif parts[2] == 'NEU':\n",
" emotions.append('neutral')\n",
" else :\n",
" emotions.append('unknown')\n",
" \n",
"emotions_data = pd.DataFrame(emotions, columns=['Emotions'])\n",
"files_data = pd.DataFrame(files, columns=['Files'])\n",
"\n",
"crema_df = pd.concat([emotions_data, files_data], axis=1)\n",
"\n",
"crema_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SAVEE Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Savee = os.path.join(voice_data_dir,\"data/SAVEE/\")\n",
"\n",
"savee = os.listdir(Savee)\n",
"\n",
"emotions = []\n",
"files = []\n",
"\n",
"for item in savee:\n",
" files.append(Savee + item)\n",
" part = item.split('_')[1]\n",
" ele = part[:-6]\n",
" if ele == 'a':\n",
" emotions.append('angry')\n",
" elif ele == 'd':\n",
" emotions.append('disgust')\n",
" elif ele == 'f':\n",
" emotions.append('fear')\n",
" elif ele == 'h':\n",
" emotions.append('happy')\n",
" elif ele == 'n':\n",
" emotions.append('neutral')\n",
" elif ele == 'sa':\n",
" emotions.append('sad')\n",
" else:\n",
" emotions.append('surprise')\n",
"\n",
"savee_df = pd.concat([pd.DataFrame(emotions, columns=['Emotions']), pd.DataFrame(files, columns=['Files'])], axis=1)\n",
"savee_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Combining all datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.concat([ravdess_df, ravdess_song_df, crema_df, tess_df, savee_df], axis = 0)\n",
"# relabel indices\n",
"df.index = range(len(df.index))\n",
"df.to_csv(\"df.csv\",index=False)\n",
"df\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"sns.histplot(data=df, x=\"Emotions\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training \n",
"\n",
"Here we convert all audio files into 2D frequency-domain spectrograms so that we can leverage convolutional neural networks, which tend to be more efficient than time-series model like RNNs or LSTMs.\n",
"We thus: \n",
"1. Extract the mel spectrogram from each of the audio recordings. \n",
"2. Rescale each of these to the decibel (DB) scale. \n",
"3. Define the model as the following model: `(x) -> (conv) -> (relu) -> (linear) -> (y)`\n",
"\n",
"\n",
"You may notice that we introduce a second computational graph `(key) -> (key)`. The reasons for this are to do with MEV, and if you are not interested you can skip the following paragraph. \n",
"\n",
"Let's say that obtaining a high score from the judge and then submitting said score to the EVM verifier could result in the issuance of a reward (financial or otherwise). There is an incentive then for MEV bots to scalp any issued valid proof and submit a duplicate transaction with the same proof to the verifier contract in the hopes of obtaining the reward before the original issuer. Here we add `(key) -> (key)` such that the transaction creator's public key / address is both a private input AND a public input to the proof. As such the on-chain verification only succeeds if the key passed in during proof time is also passed in as a public input to the contract. The reward issued by the contract can then be irrevocably tied to that key such that even if the proof is submitted by another actor, the reward would STILL go to the original singer / transaction issuer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"import librosa\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"#stft extraction from augmented data\n",
"def extract_mel_spec(filename):\n",
" x,sr=librosa.load(filename,duration=3,offset=0.5)\n",
" X = librosa.feature.melspectrogram(y=x, sr=sr)\n",
" Xdb = librosa.power_to_db(X, ref=np.max)\n",
" Xdb = Xdb.reshape(1,128,-1)\n",
" return Xdb\n",
"\n",
"Xdb=df.iloc[:,1].apply(lambda x: extract_mel_spec(x))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we convert label to a number between 0 and 1 where 1 is pleasant surprised and 0 is disgust and the rest are floats in between. The model loves pleasantly surprised voices and hates disgust ;) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get max size\n",
"max_size = 0\n",
"for i in range(len(Xdb)):\n",
" if Xdb[i].shape[2] > max_size:\n",
" max_size = Xdb[i].shape[2]\n",
"\n",
"# 0 pad 2nd dim to max size\n",
"Xdb=Xdb.apply(lambda x: np.pad(x,((0,0),(0,0),(0,max_size-x.shape[2]))))\n",
"\n",
"Xdb=pd.DataFrame(Xdb)\n",
"Xdb['label'] = df['Emotions']\n",
"# convert label to a number between 0 and 1 where 1 is pleasant surprised and 0 is disgust and the rest are floats in betwee\n",
"Xdb['label'] = Xdb['label'].apply(lambda x: 1 if x=='surprise' else 0 if x=='disgust' else 0.2 if x=='fear' else 0.4 if x=='happy' else 0.6 if x=='sad' else 0.8)\n",
"\n",
"Xdb.iloc[0,0][0].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"# we got convs, we got relu, we got linear layers\n",
"# What else could one want ????\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
"\n",
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=5, stride=4)\n",
"\n",
" self.d1 = nn.Linear(992, 1)\n",
"\n",
" self.sigmoid = nn.Sigmoid()\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, key, x):\n",
" # 32x1x28x28 => 32x32x26x26\n",
" x = self.conv1(x)\n",
" x = self.relu(x)\n",
" x = x.flatten(start_dim=1)\n",
" x = self.d1(x)\n",
" x = self.sigmoid(x)\n",
"\n",
" return [key, x]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"output = circuit(0, torch.tensor(Xdb.iloc[0,0][0].reshape(1,1,128,130)))\n",
"\n",
"output\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we leverage the classic Adam optimizer, coupled with 0.001 weight decay so as to regularize the model. The weight decay (a.k.a L2 regularization) can also help on the zk-circuit end of things in that it prevents inputs to Halo2 lookup tables from falling out of range (lookup tables are how we represent non-linearities like ReLU and Sigmoid inside our circuits). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"\n",
"# Train the model using pytorch\n",
"n_epochs = 10 # number of epochs to run\n",
"batch_size = 10 # size of each batch\n",
"\n",
"\n",
"loss_fn = nn.MSELoss() #MSE\n",
"# adds l2 regularization\n",
"optimizer = torch.optim.Adam(circuit.parameters(), lr=0.001, weight_decay=0.001)\n",
"\n",
"# randomly shuffle dataset\n",
"Xdb = Xdb.sample(frac=1).reset_index(drop=True)\n",
"\n",
"# split into train and test and validation sets with 80% train, 10% test, 10% validation\n",
"train = Xdb.iloc[:int(len(Xdb)*0.8)]\n",
"test = Xdb.iloc[int(len(Xdb)*0.8):int(len(Xdb)*0.9)]\n",
"val = Xdb.iloc[int(len(Xdb)*0.9):]\n",
"\n",
"batches_per_epoch = len(train)\n",
"\n",
"\n",
"def get_loss(Xbatch, ybatch):\n",
" y_pred = circuit(0, Xbatch)[1]\n",
" loss = loss_fn(y_pred, ybatch)\n",
" return loss\n",
"\n",
"for epoch in range(n_epochs):\n",
" # X is a torch Variable\n",
" permutation = torch.randperm(batches_per_epoch)\n",
"\n",
" with tqdm(range(batches_per_epoch), unit=\"batch\", mininterval=0) as bar:\n",
" bar.set_description(f\"Epoch {epoch}\")\n",
" for i in bar:\n",
" start = i * batch_size\n",
" # take a batch\n",
" indices = np.random.choice(batches_per_epoch, batch_size)\n",
"\n",
" data = np.concatenate(train.iloc[indices.tolist(),0].values)\n",
" labels = train.iloc[indices.tolist(),1].values.astype(np.float32)\n",
"\n",
" data = data.reshape(batch_size,1,128,130)\n",
" labels = labels.reshape(batch_size,1)\n",
"\n",
" # convert to tensors\n",
" Xbatch = torch.tensor(data)\n",
" ybatch = torch.tensor(labels)\n",
"\n",
" # forward pass\n",
" loss = get_loss(Xbatch, ybatch)\n",
" # backward pass\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" # update weights\n",
" optimizer.step()\n",
"\n",
" bar.set_postfix(\n",
" batch_loss=float(loss),\n",
" )\n",
" # get validation loss\n",
" val_data = np.concatenate(val.iloc[:,0].values)\n",
" val_labels = val.iloc[:,1].values.astype(np.float32)\n",
" val_data = val_data.reshape(len(val),1,128,130)\n",
" val_labels = val_labels.reshape(len(val),1)\n",
" val_loss = get_loss(torch.tensor(val_data), torch.tensor(val_labels))\n",
" print(f\"Validation loss: {val_loss}\")\n",
"\n",
"\n",
"\n",
"# get validation loss\n",
"test_data = np.concatenate(test.iloc[:,0].values)\n",
"test_labels = val.iloc[:,1].values.astype(np.float32)\n",
"test_data = val_data.reshape(len(val),1,128,130)\n",
"test_labels = val_labels.reshape(len(val),1)\n",
"test_loss = get_loss(torch.tensor(test_data), torch.tensor(test_labels))\n",
"print(f\"Test loss: {test_loss}\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#\n",
"val_data = {\n",
" \"input_data\": [np.zeros(100).tolist(), np.concatenate(val.iloc[:100,0].values).flatten().tolist()],\n",
"}\n",
"# save as json file\n",
"with open(\"val_data.json\", \"w\") as f:\n",
" json.dump(val_data, f)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[1, 128, 130], requires_grad=True)\n",
"key = torch.rand(1,*[1], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" (key, x), # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'},\n",
" 'input.1' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"key_array = ((key).detach().numpy()).reshape([-1]).tolist()\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [key_array, data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(\"input.json\", 'w' ))\n",
"\n",
"\n",
"# ezkl.export(circuit, input_shape = [[1], [1025, 130]], run_gen_witness=False, run_calibrate_settings=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we set the visibility of the different parts of the circuit, whereby the model params and the outputs of the computational graph (the key and the judgment) are public"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"import os \n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.params')\n",
"data_path = os.path.join('input.json')\n",
"val_data = os.path.join('val_data.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"private\"\n",
"run_args.param_visibility = \"public\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.batch_size = 1\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we use the validation dataset we used during training. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\")\n",
"assert res == True\n",
"print(\"verified\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs(srs_path, settings_path)\n",
"\n",
"assert os.path.exists(srs_path)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the (partial) circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, model_path, witness_path, settings_path = settings_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a sanity check we can run a mock proof. This just checks that all the constraints are valid. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"res = ezkl.mock(witness_path, model_path, settings_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_LOG=trace\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" model_path,\n",
" vk_path,\n",
" pk_path,\n",
" srs_path,\n",
" settings_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" model_path,\n",
" pk_path,\n",
" proof_path,\n",
" srs_path,\n",
" \"evm\",\n",
" \"single\",\n",
" settings_path,\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" srs_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"vk_path = os.path.join('test.vk')\n",
"srs_path = os.path.join('kzg.params')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" srs_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Verify if the Verifier Works Locally"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Deploy The Contract"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
"# we use the default anvil node here\n",
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(address_path, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
" proof_path,\n",
" addr,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1083,10 +1083,7 @@ mod native_tests {
let vk_arg = format!("{}/{}/evm_aggr.vk", test_dir, example_name);
fn build_args<'a>(
base_args: Vec<&'a str>,
sol_arg: &'a str
) -> Vec<&'a str> {
fn build_args<'a>(base_args: Vec<&'a str>, sol_arg: &'a str) -> Vec<&'a str> {
let mut args = base_args;
args.push("--sol-code-path");
@@ -1593,7 +1590,6 @@ mod native_tests {
let vk_arg = format!("{}/{}/key.vk", test_dir, example_name);
let sol_arg = format!("{}/{}/kzg.sol", test_dir, example_name);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
@@ -1606,7 +1602,7 @@ mod native_tests {
"--vk-path",
&vk_arg,
"-D",
test_on_chain_data_path.as_str()
test_on_chain_data_path.as_str(),
])
.status()
.expect("failed to execute process");

View File

@@ -0,0 +1,172 @@
#[cfg(not(target_arch = "wasm32"))]
#[cfg(test)]
mod py_tests {
use lazy_static::lazy_static;
use std::env::var;
use std::process::Command;
use std::sync::Once;
use tempdir::TempDir;
static COMPILE: Once = Once::new();
static START_ANVIL: Once = Once::new();
static ENV_SETUP: Once = Once::new();
static DOWNLOAD_VOICE_DATA: Once = Once::new();
//Sure to run this once
lazy_static! {
static ref CARGO_TARGET_DIR: String =
var("CARGO_TARGET_DIR").unwrap_or_else(|_| "./target".to_string());
static ref TEST_DIR: TempDir = TempDir::new("example").unwrap();
static ref ANVIL_URL: String = "http://localhost:3030".to_string();
}
fn start_anvil() {
START_ANVIL.call_once(|| {
let _ = Command::new("anvil")
.args(["-p", "3030"])
// .stdout(Stdio::piped())
.spawn()
.expect("failed to start anvil process");
std::thread::sleep(std::time::Duration::from_secs(3));
});
}
fn download_voice_data() {
DOWNLOAD_VOICE_DATA.call_once(|| {
let status = Command::new("bash")
.args([
"examples/notebooks/voice_data.sh",
TEST_DIR.path().to_str().unwrap(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
});
// set VOICE_DATA_DIR environment variable
std::env::set_var(
"VOICE_DATA_DIR",
format!("{}", TEST_DIR.path().to_str().unwrap()),
);
}
fn setup_py_env() {
ENV_SETUP.call_once(|| {
// supposes that you have a virtualenv called .env and have run the following
// equivalent of python -m venv .env
// source .env/bin/activate
// pip install -r requirements.txt
// maturin develop --release --features python-bindings
// now install torch, pandas, numpy, seaborn, jupyter
let status = Command::new("pip")
.args([
"install",
"torch",
"pandas",
"numpy",
"seaborn",
"jupyter",
"onnx",
"kaggle",
"py-solc-x",
"web3",
"librosa",
])
.status()
.expect("failed to execute process");
assert!(status.success());
});
}
fn init_binary() {
COMPILE.call_once(|| {
println!("using cargo target dir: {}", *CARGO_TARGET_DIR);
setup_py_env();
});
}
fn mv_test_(test: &str) {
let test_dir = TEST_DIR.path().to_str().unwrap();
let path: std::path::PathBuf = format!("{}/{}", test_dir, test).into();
if !path.exists() {
let status = Command::new("cp")
.args([
"-R",
&format!("./examples/notebooks/{}", test),
&format!("{}/{}", test_dir, test),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
}
const TESTS: [&str; 4] = [
"hashed_vis.ipynb",
"simple_demo.ipynb",
"data_attest.ipynb",
"variance.ipynb",
];
macro_rules! test_func {
() => {
#[cfg(test)]
mod tests {
use seq_macro::seq;
use crate::py_tests::TESTS;
use test_case::test_case;
use super::*;
seq!(N in 0..=3 {
#(#[test_case(TESTS[N])])*
fn run_notebook_(test: &str) {
crate::py_tests::init_binary();
crate::py_tests::start_anvil();
crate::py_tests::mv_test_(test);
run_notebook(test);
}
#[test]
fn voice_notebook_() {
crate::py_tests::init_binary();
crate::py_tests::start_anvil();
crate::py_tests::download_voice_data();
crate::py_tests::mv_test_("voice_judge.ipynb");
run_notebook("voice_judge.ipynb");
}
});
}
};
}
fn run_notebook(test: &str) {
// activate venv
let status = Command::new("bash")
.arg("-c")
.arg("source .env/bin/activate")
.status()
.expect("failed to execute process");
assert!(status.success());
let test_dir = TEST_DIR.path().to_str().unwrap();
let path: std::path::PathBuf = format!("{}/{}", test_dir, test).into();
let status = Command::new("jupyter")
.args([
"nbconvert",
"--to",
"notebook",
"--execute",
&path.to_str().unwrap(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
test_func!();
}