Compare commits

...

5 Commits

Author SHA1 Message Date
github-actions[bot]
85b543dd01 ci: update version string in docs 2024-04-15 16:01:39 +00:00
dante
88dd83dbe5 fix: default compiled model paths in python (#776) 2024-04-15 12:01:21 -04:00
Ethan Cemer
f05f83481e chore: update eth postgres (#769)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2024-04-13 08:08:09 -04:00
Ethan Cemer
8aaf518b5e fix: fix @ezkljs/verify etherumjs deps (#765) 2024-04-12 18:24:59 -04:00
katsumata
1b7b43e073 fix: Improve EZKL installation script reliability (#774) 2024-04-09 16:07:39 -04:00
14 changed files with 243 additions and 256 deletions

View File

@@ -1,4 +1,4 @@
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
name: Build and Publish EZKL Engine npm package
on:
workflow_dispatch:
@@ -62,7 +62,7 @@ jobs:
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/wasm-bindgen-rayon-7afa899f36665473/src/workerHelpers.js",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
@@ -79,6 +79,10 @@ jobs:
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
run: |
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
- name: Add serialize and deserialize methods to nodejs bundle
run: |
echo '
@@ -174,40 +178,3 @@ jobs:
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
needs: ["publish-wasm-bindings"]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
npm install
npm run build
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -307,8 +307,8 @@ jobs:
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --no-frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
pnpm install --frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --frozen-lockfile
env:
CI: false
NODE_ENV: development
@@ -380,7 +380,7 @@ jobs:
cache: "pnpm"
- name: Install dependencies for js tests
run: |
pnpm install --no-frozen-lockfile
pnpm install --frozen-lockfile
env:
CI: false
NODE_ENV: development
@@ -610,6 +610,24 @@ jobs:
python-integration-tests:
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
@@ -634,6 +652,8 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
@@ -651,5 +671,3 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

54
.github/workflows/verify.yml vendored Normal file
View File

@@ -0,0 +1,54 @@
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
defaults:
run:
working-directory: .
jobs:
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
pnpm install --frozen-lockfile
pnpm run build
pnpm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

2
.gitignore vendored
View File

@@ -50,4 +50,4 @@ timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
docs/python/build
!tests/wasm/vk_aggr.key
!tests/wasm/vk_aggr.key

View File

@@ -1,4 +1,4 @@
ezkl==0.0.0
ezkl==10.3.4
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '10.3.4'
version = release

View File

@@ -7,9 +7,9 @@
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [e2pg](https://github.com/indexsupply/x/tree/main/docs/e2pg), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://postgresapp.com/. \n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
@@ -21,23 +21,81 @@
"source": [
"import os\n",
"import getpass\n",
"\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/main/linux/amd64/e2pg\")\n",
"os.system(\"chmod +x e2pg\")\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgresql://\" + getpass.getuser() + \":@localhost:5432/e2pg\"\n",
"os.environ[\"RLPS_URL\"] = \"https://1.rlps.indexsupply.net\"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"os.system(\"echo $RLPS_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 e2pg\")\n",
"# equivalent of nohup ./e2pg -reset -e $RLPS_URL -pg $PG_URL &\n",
"e2pg_process = os.system(\"nohup ./e2pg -e $RLPS_URL -pg $PG_URL &\")\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(5)\n",
"\n"
]
},
@@ -79,11 +137,13 @@
"import json\n",
"import os\n",
"\n",
"# import logging\n",
"# # # uncomment for more descriptive logging \n",
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"# logging.basicConfig(format=FORMAT)\n",
"# logging.getLogger().setLevel(logging.DEBUG)"
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
@@ -176,6 +236,7 @@
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
@@ -183,9 +244,9 @@
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"e2pg\",\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 5\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
@@ -194,7 +255,7 @@
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_input_file, open(input_filename, 'w' ))\n"
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
@@ -210,9 +271,9 @@
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"e2pg\",\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 20\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
@@ -229,22 +290,6 @@
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -253,10 +298,21 @@
},
"outputs": [],
"source": [
"# setup kzg params\n",
"params_path = os.path.join('kzg.params')\n",
"import subprocess\n",
"import os\n",
"\n",
"res = ezkl.get_srs(params_path, settings_filename)"
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True"
]
},
{
@@ -306,16 +362,13 @@
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"params_path = os.path.join('kzg.params')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path,\n",
" params_path,\n",
" settings_filename,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
@@ -331,11 +384,14 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"assert os.path.isfile(witness_path)"
"# generate the witness\n",
"res = ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
@@ -360,73 +416,14 @@
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" params_path,\n",
" \"single\",\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n",
"# verify\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_filename,\n",
" vk_path,\n",
" params_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W7tAa-DFAtvS"
},
"source": [
"# Part 2 (Using the ZK Computational Graph Onchain!)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8Ym91kaVAIB6"
},
"source": [
"**Now How Do We Do It Onchain?????**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 339
},
"id": "fodkNgwS70FM",
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
},
"outputs": [],
"source": [
"# first we need to create evm verifier\n",
"print(vk_path)\n",
"print(params_path)\n",
"print(settings_filename)\n",
"\n",
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" params_path,\n",
" settings_filename,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
"\n"
]
},
{
@@ -435,51 +432,8 @@
"metadata": {},
"outputs": [],
"source": [
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
"# we use the default anvil node here\n",
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(address_path, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the address from addr_path\n",
"addr = None\n",
"with open(address_path, 'r') as f:\n",
" addr = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.system(\"killall -9 e2pg\");"
"# kill all shovel process \n",
"os.system(\"pkill -f shovel\")"
]
}
],
@@ -501,7 +455,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -20,16 +20,16 @@
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",
"@ethereumjs/evm": "^2.0.0",
"@ethereumjs/statemanager": "^2.0.0",
"@ethereumjs/tx": "^5.0.0",
"@ethereumjs/util": "^9.0.0",
"@ethereumjs/vm": "^7.0.0",
"@ethersproject/abi": "^5.7.0",
"@ethereumjs/common": "4.0.0",
"@ethereumjs/evm": "2.0.0",
"@ethereumjs/statemanager": "2.0.0",
"@ethereumjs/tx": "5.0.0",
"@ethereumjs/util": "9.0.0",
"@ethereumjs/vm": "7.0.0",
"@ethersproject/abi": "5.7.0",
"@ezkljs/engine": "^9.4.4",
"ethers": "^6.7.1",
"json-bigint": "^1.0.0"
"ethers": "6.7.1",
"json-bigint": "1.0.0"
},
"devDependencies": {
"@types/node": "^20.8.3",

View File

@@ -6,34 +6,34 @@ settings:
dependencies:
'@ethereumjs/common':
specifier: ^4.0.0
specifier: 4.0.0
version: 4.0.0
'@ethereumjs/evm':
specifier: ^2.0.0
specifier: 2.0.0
version: 2.0.0
'@ethereumjs/statemanager':
specifier: ^2.0.0
specifier: 2.0.0
version: 2.0.0
'@ethereumjs/tx':
specifier: ^5.0.0
specifier: 5.0.0
version: 5.0.0
'@ethereumjs/util':
specifier: ^9.0.0
specifier: 9.0.0
version: 9.0.0
'@ethereumjs/vm':
specifier: ^7.0.0
specifier: 7.0.0
version: 7.0.0
'@ethersproject/abi':
specifier: ^5.7.0
specifier: 5.7.0
version: 5.7.0
'@ezkljs/engine':
specifier: ^9.4.4
version: 9.4.4
ethers:
specifier: ^6.7.1
specifier: 6.7.1
version: 6.7.1
json-bigint:
specifier: ^1.0.0
specifier: 1.0.0
version: 1.0.0
devDependencies:

View File

@@ -36,7 +36,7 @@ if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; the
exit 1
fi
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
if [[ ":$PATH:" != *":${EZKL_DIR}:"* ]]; then
# Add the ezkl directory to the path and ensure the old PATH variables remain.
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
fi

View File

@@ -196,7 +196,6 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path,
srs_path,
} => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(model, witness),
#[cfg(not(target_arch = "wasm32"))]
@@ -637,7 +636,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn
Ok(String::new())
}
pub(crate) async fn gen_witness(
pub(crate) fn gen_witness(
compiled_circuit_path: PathBuf,
data: PathBuf,
output: Option<PathBuf>,
@@ -660,7 +659,7 @@ pub(crate) async fn gen_witness(
};
#[cfg(not(target_arch = "wasm32"))]
let mut input = circuit.load_graph_input(&data).await?;
let mut input = circuit.load_graph_input(&data)?;
#[cfg(target_arch = "wasm32")]
let mut input = circuit.load_graph_input(&data)?;

View File

@@ -21,8 +21,6 @@ use std::io::BufWriter;
use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
@@ -234,21 +232,15 @@ impl PostgresSource {
)
};
let res: Vec<pg_bigdecimal::PgNumeric> = thread::spawn(move || {
let mut client = Client::connect(&config, NoTls).unwrap();
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[]).unwrap() {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
let mut client = Client::connect(&config, NoTls)?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[])? {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
res
})
.join()
.map_err(|_| "failed to fetch data from postgres")?;
}
Ok(vec![res])
}

View File

@@ -918,7 +918,7 @@ impl GraphCircuit {
///
#[cfg(not(target_arch = "wasm32"))]
pub async fn load_graph_input(
pub fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
@@ -928,7 +928,6 @@ impl GraphCircuit {
debug!("input scales: {:?}", scales);
self.process_data_source(&data.input_data, shapes, scales, input_types)
.await
}
#[cfg(target_arch = "wasm32")]
@@ -952,7 +951,7 @@ impl GraphCircuit {
#[cfg(not(target_arch = "wasm32"))]
/// Process the data source for the model
async fn process_data_source(
fn process_data_source(
&mut self,
data: &DataSource,
shapes: Vec<Vec<usize>>,
@@ -965,8 +964,16 @@ impl GraphCircuit {
for (i, shape) in shapes.iter().enumerate() {
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
}
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
// start runtime and fetch data
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(async {
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
})
}
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)

View File

@@ -942,7 +942,7 @@ fn calibrate_settings(
///
#[pyfunction(signature = (
data=PathBuf::from(DEFAULT_DATA),
model=PathBuf::from(DEFAULT_MODEL),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
output=PathBuf::from(DEFAULT_WITNESS),
vk_path=None,
srs_path=None,
@@ -954,12 +954,8 @@ fn gen_witness(
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> PyResult<PyObject> {
let output = Runtime::new()
.unwrap()
.block_on(crate::execute::gen_witness(
model, data, output, vk_path, srs_path,
))
.map_err(|e| {
let output =
crate::execute::gen_witness(model, data, output, vk_path, srs_path).map_err(|e| {
let err_str = format!("Failed to run generate witness: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -982,7 +978,7 @@ fn gen_witness(
///
#[pyfunction(signature = (
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_MODEL),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
))]
fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
crate::execute::mock(model, witness).map_err(|e| {
@@ -1054,7 +1050,7 @@ fn mock_aggregate(
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
vk_path=PathBuf::from(DEFAULT_VK),
pk_path=PathBuf::from(DEFAULT_PK),
srs_path=None,
@@ -1113,7 +1109,7 @@ fn setup(
///
#[pyfunction(signature = (
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_MODEL),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
pk_path=PathBuf::from(DEFAULT_PK),
proof_path=None,
proof_type=ProofType::default(),