mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
924f7c0420 | ||
|
|
ae03b6515b | ||
|
|
bae2e9e22b | ||
|
|
4a93d31869 | ||
|
|
88dd83dbe5 | ||
|
|
f05f83481e | ||
|
|
8aaf518b5e | ||
|
|
1b7b43e073 | ||
|
|
f78618ec59 | ||
|
|
0943e534ee | ||
|
|
316a9a3b40 | ||
|
|
5389012b68 | ||
|
|
48223cca11 | ||
|
|
32c3a5e159 | ||
|
|
ff563e93a7 | ||
|
|
5639d36097 | ||
|
|
4ec8d13082 | ||
|
|
12735aefd4 | ||
|
|
7fe179b8d4 | ||
|
|
3be988a6a0 | ||
|
|
3abb3aff56 | ||
|
|
338788cb8f | ||
|
|
feb3b1b475 | ||
|
|
e134d86756 | ||
|
|
6819a3acf6 |
@@ -1,4 +1,4 @@
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
name: Build and Publish EZKL Engine npm package
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -22,7 +22,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -30,13 +30,13 @@ jobs:
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
set -e
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/wasm-bindgen-rayon-7afa899f36665473/src/workerHelpers.js",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
@@ -79,6 +79,10 @@ jobs:
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
|
||||
|
||||
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
|
||||
run: |
|
||||
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
|
||||
|
||||
- name: Add serialize and deserialize methods to nodejs bundle
|
||||
run: |
|
||||
echo '
|
||||
@@ -92,7 +96,7 @@ jobs:
|
||||
const jsonObject = JSONBig.parse(string);
|
||||
return jsonObject;
|
||||
}
|
||||
|
||||
|
||||
function serialize(data) { // data is an object // return a Uint8ClampedArray
|
||||
// Step 1: Stringify the Object with BigInt support
|
||||
if (typeof data === "object") {
|
||||
@@ -100,11 +104,11 @@ jobs:
|
||||
}
|
||||
// Step 2: Encode the JSON String
|
||||
const uint8Array = new TextEncoder().encode(data);
|
||||
|
||||
|
||||
// Step 3: Convert to Uint8ClampedArray
|
||||
return new Uint8ClampedArray(uint8Array.buffer);
|
||||
}
|
||||
|
||||
|
||||
module.exports = {
|
||||
deserialize,
|
||||
serialize
|
||||
@@ -123,7 +127,7 @@ jobs:
|
||||
const jsonObject = parse(string);
|
||||
return jsonObject;
|
||||
}
|
||||
|
||||
|
||||
export function serialize(data) { // data is an object // return a Uint8ClampedArray
|
||||
// Step 1: Stringify the Object with BigInt support
|
||||
if (typeof data === "object") {
|
||||
@@ -131,7 +135,7 @@ jobs:
|
||||
}
|
||||
// Step 2: Encode the JSON String
|
||||
const uint8Array = new TextEncoder().encode(data);
|
||||
|
||||
|
||||
// Step 3: Convert to Uint8ClampedArray
|
||||
return new Uint8ClampedArray(uint8Array.buffer);
|
||||
}
|
||||
@@ -174,40 +178,3 @@ jobs:
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: ["publish-wasm-bindings"]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
npm install
|
||||
npm run build
|
||||
npm ci
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
|
||||
2
.github/workflows/pypi-gpu.yml
vendored
2
.github/workflows/pypi-gpu.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
|
||||
43
.github/workflows/pypi.yml
vendored
43
.github/workflows/pypi.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.12
|
||||
architecture: ${{ matrix.target }}
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
@@ -115,7 +115,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
@@ -128,6 +128,7 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -139,6 +140,20 @@ jobs:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: auto
|
||||
args: --release --out dist --features python-bindings
|
||||
before-script-linux: |
|
||||
# If we're running on rhel centos, install needed packages.
|
||||
if command -v yum &> /dev/null; then
|
||||
yum update -y && yum install -y perl-core openssl openssl-devel pkgconfig libatomic
|
||||
|
||||
# If we're running on i686 we need to symlink libatomic
|
||||
# in order to build openssl with -latomic flag.
|
||||
if [[ ! -d "/usr/lib64" ]]; then
|
||||
ln -s /usr/lib/libatomic.so.1 /usr/lib/libatomic.so
|
||||
fi
|
||||
else
|
||||
# If we're running on debian-based system.
|
||||
apt update -y && apt-get install -y libssl-dev openssl pkg-config
|
||||
fi
|
||||
|
||||
- name: Install built wheel
|
||||
if: matrix.target == 'x86_64'
|
||||
@@ -162,7 +177,7 @@ jobs:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: 3.7
|
||||
# python-version: 3.12
|
||||
|
||||
# - name: Install cross-compilation tools for aarch64
|
||||
# if: matrix.target == 'aarch64'
|
||||
@@ -214,7 +229,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
@@ -249,7 +264,7 @@ jobs:
|
||||
apk add py3-pip
|
||||
pip3 install -U pip
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
source .venv/bin/activate
|
||||
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
|
||||
python3 -c "import ezkl"
|
||||
|
||||
@@ -273,7 +288,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.12
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
@@ -345,3 +360,17 @@ jobs:
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
|
||||
16
.github/workflows/release.yml
vendored
16
.github/workflows/release.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
token: ${{ secrets.RELEASE_TOKEN }}
|
||||
tag_name: ${{ env.EZKL_VERSION }}
|
||||
|
||||
build-release-gpu:
|
||||
build-release-gpu:
|
||||
name: build-release-gpu
|
||||
needs: ["create-release"]
|
||||
runs-on: GPU
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
@@ -60,16 +60,15 @@ jobs:
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get update
|
||||
|
||||
- name: Build release binary
|
||||
run: cargo build --release -Z sparse-registry --features icicle
|
||||
@@ -91,7 +90,6 @@ jobs:
|
||||
asset_name: ${{ env.ASSET }}
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
|
||||
build-release:
|
||||
name: build-release
|
||||
needs: ["create-release"]
|
||||
|
||||
143
.github/workflows/rust.yml
vendored
143
.github/workflows/rust.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Docs
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -106,7 +106,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -184,12 +184,12 @@ jobs:
|
||||
|
||||
wasm32-tests:
|
||||
runs-on: ubuntu-latest
|
||||
# needs: [build, library-tests, docs]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -199,7 +199,7 @@ jobs:
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Run wasm verifier tests
|
||||
# on mac:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
@@ -207,12 +207,12 @@ jobs:
|
||||
|
||||
tutorial:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build, library-tests, docs]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -224,12 +224,12 @@ jobs:
|
||||
|
||||
mock-proving-tests:
|
||||
runs-on: non-gpu
|
||||
# needs: [build, library-tests, docs]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -281,12 +281,12 @@ jobs:
|
||||
|
||||
prove-and-verify-evm-tests:
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -303,10 +303,12 @@ jobs:
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: "Add rust-src"
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
@@ -324,7 +326,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
@@ -352,12 +354,12 @@ jobs:
|
||||
|
||||
prove-and-verify-tests:
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests]
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -365,7 +367,7 @@ jobs:
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
@@ -378,7 +380,7 @@ jobs:
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies for js tests
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
pnpm install --frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
@@ -392,12 +394,18 @@ jobs:
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
- name: KZG prove and verify tests (hashed inputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
- name: IPA prove and verify tests
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
- name: IPA prove and verify tests (ipa outputs)
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests single inner col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_single_col
|
||||
- name: KZG prove and verify tests triple inner col
|
||||
@@ -408,12 +416,8 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
@@ -429,11 +433,11 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
@@ -456,15 +460,14 @@ jobs:
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -482,7 +485,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -494,12 +497,12 @@ jobs:
|
||||
|
||||
prove-and-verify-aggr-tests:
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -509,16 +512,14 @@ jobs:
|
||||
- name: KZG tests
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
|
||||
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -528,7 +529,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
@@ -539,15 +540,13 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Download MNIST
|
||||
run: sh data.sh
|
||||
- name: Examples
|
||||
run: cargo nextest run --release tests_examples
|
||||
|
||||
@@ -558,18 +557,20 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install cmake
|
||||
run: sudo apt-get install -y cmake
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
@@ -577,15 +578,15 @@ jobs:
|
||||
|
||||
accuracy-measurement-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
# needs: [build, library-tests, docs]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -593,7 +594,7 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Div rebase
|
||||
@@ -609,14 +610,32 @@ jobs:
|
||||
|
||||
python-integration-tests:
|
||||
runs-on: large-self-hosted
|
||||
services:
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
image: postgres
|
||||
env:
|
||||
POSTGRES_USER: ubuntu
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
-v /var/run/postgresql:/var/run/postgresql
|
||||
ports:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -626,11 +645,17 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
- name: Install pip
|
||||
run: python -m ensurepip --upgrade
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Postgres tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
|
||||
# - name: authenticate-kaggle-cli
|
||||
# shell: bash
|
||||
# env:
|
||||
@@ -646,7 +671,3 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
|
||||
# - name: Postgres tutorials
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1
|
||||
|
||||
36
.github/workflows/tagging.yml
vendored
36
.github/workflows/tagging.yml
vendored
@@ -14,6 +14,40 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@v6.1
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set Cargo.toml version to match github tag for docs
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
run: |
|
||||
mv docs/python/src/conf.py docs/python/src/conf.py.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
|
||||
rm docs/python/src/conf.py.orig
|
||||
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
|
||||
rm docs/python/requirements-docs.txt.orig
|
||||
|
||||
- name: Commit files and create tag
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
run: |
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
git fetch --tags
|
||||
git checkout -b release-$RELEASE_TAG
|
||||
git add .
|
||||
git commit -m "ci: update version string in docs"
|
||||
git tag -d $RELEASE_TAG
|
||||
git tag $RELEASE_TAG
|
||||
|
||||
- name: Push changes
|
||||
uses: ad-m/github-push-action@master
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
with:
|
||||
branch: release-${{ steps.tag_version.outputs.new_tag }}
|
||||
force: true
|
||||
tags: true
|
||||
|
||||
54
.github/workflows/verify.yml
vendored
Normal file
54
.github/workflows/verify.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag to release"
|
||||
required: true
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: .
|
||||
jobs:
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm run build
|
||||
pnpm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,6 +1,5 @@
|
||||
target
|
||||
pkg
|
||||
data
|
||||
*.csv
|
||||
!examples/notebooks/eth_price.csv
|
||||
*.ipynb_checkpoints
|
||||
@@ -48,4 +47,6 @@ node_modules
|
||||
/dist
|
||||
timingData.json
|
||||
!tests/wasm/pk.key
|
||||
!tests/wasm/vk.key
|
||||
!tests/wasm/vk.key
|
||||
docs/python/build
|
||||
!tests/wasm/vk_aggr.key
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12.1
|
||||
26
.readthedocs.yaml
Normal file
26
.readthedocs.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
# .readthedocs.yaml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
version: 2
|
||||
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.12"
|
||||
|
||||
# Build documentation in the "docs/" directory with Sphinx
|
||||
sphinx:
|
||||
configuration: ./docs/python/src/conf.py
|
||||
|
||||
# Optionally build your docs in additional formats such as PDF and ePub
|
||||
# formats:
|
||||
# - pdf
|
||||
# - epub
|
||||
|
||||
# Optional but recommended, declare the Python requirements required
|
||||
# to build your documentation
|
||||
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
||||
python:
|
||||
install:
|
||||
- requirements: ./docs/python/requirements-docs.txt
|
||||
1718
Cargo.lock
generated
1718
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
20
Cargo.toml
20
Cargo.toml
@@ -15,14 +15,14 @@ crate-type = ["cdylib", "rlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "main" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "main" }
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.3.3", features = ["derive"] }
|
||||
clap = { version = "4.5.3", features = ["derive"] }
|
||||
serde = { version = "1.0.126", features = ["derive"], optional = true }
|
||||
serde_json = { version = "1.0.97", default_features = false, features = [
|
||||
"float_roundtrip",
|
||||
@@ -80,7 +80,7 @@ pyo3-asyncio = { version = "0.20.0", features = [
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
|
||||
@@ -95,10 +95,10 @@ getrandom = { version = "0.2.8", features = ["js"] }
|
||||
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
|
||||
|
||||
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
|
||||
wasm-bindgen-rayon = { version = "1.0", optional = true }
|
||||
wasm-bindgen-test = "0.3.34"
|
||||
serde-wasm-bindgen = "0.4"
|
||||
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
|
||||
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
|
||||
wasm-bindgen-test = "0.3.42"
|
||||
serde-wasm-bindgen = "0.6.5"
|
||||
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
|
||||
console_error_panic_hook = "0.1.7"
|
||||
wasm-bindgen-console-logger = "0.1.1"
|
||||
|
||||
@@ -203,5 +203,9 @@ no-banner = []
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
|
||||
|
||||
|
||||
[profile.release]
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
|
||||
@@ -70,8 +70,8 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: [(0, 0); 2],
|
||||
stride: (1, 1),
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -65,9 +65,9 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&mut region,
|
||||
&[self.image.clone()],
|
||||
Box::new(HybridOp::SumPool {
|
||||
padding: [(0, 0); 2],
|
||||
stride: (1, 1),
|
||||
kernel_shape: (2, 2),
|
||||
padding: vec![(0, 0); 2],
|
||||
stride: vec![1, 1],
|
||||
kernel_shape: vec![2, 2],
|
||||
normalized: false,
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use ezkl::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
|
||||
@@ -48,7 +50,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
|
||||
PoseidonChip::new(config);
|
||||
chip.layout(&mut layouter, &[self.image.clone()], 0)?;
|
||||
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
11
data.sh
11
data.sh
@@ -1,11 +0,0 @@
|
||||
#! /bin/bash
|
||||
|
||||
mkdir data
|
||||
cd data
|
||||
|
||||
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
|
||||
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
|
||||
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
|
||||
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
|
||||
|
||||
gzip -d *.gz
|
||||
2
docs/python/build.sh
Executable file
2
docs/python/build.sh
Executable file
@@ -0,0 +1,2 @@
|
||||
#!/bin/sh
|
||||
sphinx-build ./src build
|
||||
4
docs/python/requirements-docs.txt
Normal file
4
docs/python/requirements-docs.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
29
docs/python/src/conf.py
Normal file
29
docs/python/src/conf.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
version = release
|
||||
|
||||
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.inheritance_diagram',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx_rtd_theme',
|
||||
]
|
||||
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
templates_path = ['_templates']
|
||||
exclude_patterns = []
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_static_path = ['_static']
|
||||
11
docs/python/src/index.rst
Normal file
11
docs/python/src/index.rst
Normal file
@@ -0,0 +1,11 @@
|
||||
.. extension documentation master file, created by
|
||||
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
ezkl python bindings
|
||||
================================================
|
||||
|
||||
.. automodule:: ezkl
|
||||
:members:
|
||||
:undoc-members:
|
||||
@@ -203,8 +203,8 @@ where
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
|
||||
let op = PolyOp::Conv {
|
||||
padding: [(PADDING, PADDING); 2],
|
||||
stride: (STRIDE, STRIDE),
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
@@ -308,6 +308,7 @@ pub fn runconv() {
|
||||
tst_lbl: _,
|
||||
..
|
||||
} = MnistBuilder::new()
|
||||
.base_path("examples/data")
|
||||
.label_format_digit()
|
||||
.training_set_length(50_000)
|
||||
.validation_set_length(10_000)
|
||||
|
||||
BIN
examples/data/t10k-images-idx3-ubyte
Normal file
BIN
examples/data/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
examples/data/t10k-labels-idx1-ubyte
Normal file
BIN
examples/data/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
examples/data/train-images-idx3-ubyte
Normal file
BIN
examples/data/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
examples/data/train-labels-idx1-ubyte
Normal file
BIN
examples/data/train-labels-idx1-ubyte
Normal file
Binary file not shown.
@@ -696,10 +696,12 @@
|
||||
"for i, value in enumerate(proof[\"instances\"]):\n",
|
||||
" for j, field_element in enumerate(value):\n",
|
||||
" onchain_input_array.append(ezkl.felt_to_big_endian(field_element))\n",
|
||||
" formatted_output += str(onchain_input_array[-1])\n",
|
||||
" formatted_output += '\"' + str(onchain_input_array[-1]) + '\"'\n",
|
||||
" if j != len(value) - 1:\n",
|
||||
" formatted_output += \", \"\n",
|
||||
" formatted_output += \"]\"\n",
|
||||
" if i != len(proof[\"instances\"]) - 1:\n",
|
||||
" formatted_output += \", \"\n",
|
||||
"formatted_output += \"]\"\n",
|
||||
"\n",
|
||||
"# This will be the values you use onchain\n",
|
||||
"# copy them over to remix and see if they verify\n",
|
||||
|
||||
@@ -67,6 +67,7 @@
|
||||
"model.add(Dense(128, activation='relu'))\n",
|
||||
"model.add(Dropout(0.5))\n",
|
||||
"model.add(Dense(10, activation='softmax'))\n",
|
||||
"model.output_names=['output']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Train the model as you like here (skipped for brevity)\n",
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [e2pg](https://github.com/indexsupply/x/tree/main/docs/e2pg), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://postgresapp.com/. \n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -21,23 +21,81 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/main/linux/amd64/e2pg\")\n",
|
||||
"os.system(\"chmod +x e2pg\")\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgresql://\" + getpass.getuser() + \":@localhost:5432/e2pg\"\n",
|
||||
"os.environ[\"RLPS_URL\"] = \"https://1.rlps.indexsupply.net\"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"os.system(\"echo $RLPS_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 e2pg\")\n",
|
||||
"# equivalent of nohup ./e2pg -reset -e $RLPS_URL -pg $PG_URL &\n",
|
||||
"e2pg_process = os.system(\"nohup ./e2pg -e $RLPS_URL -pg $PG_URL &\")\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(5)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -79,11 +137,13 @@
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# import logging\n",
|
||||
"# # # uncomment for more descriptive logging \n",
|
||||
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"# logging.basicConfig(format=FORMAT)\n",
|
||||
"# logging.getLogger().setLevel(logging.DEBUG)"
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -176,6 +236,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
@@ -183,9 +244,9 @@
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"e2pg\",\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 5\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
@@ -194,7 +255,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_input_file, open(input_filename, 'w' ))\n"
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -210,9 +271,9 @@
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"e2pg\",\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 20\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
@@ -229,22 +290,6 @@
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(\n",
|
||||
" input_filename, onnx_filename, settings_filename, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -253,10 +298,21 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# setup kzg params\n",
|
||||
"params_path = os.path.join('kzg.params')\n",
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"res = ezkl.get_srs(params_path, settings_filename)"
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -306,16 +362,13 @@
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"params_path = os.path.join('kzg.params')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" params_path,\n",
|
||||
" settings_filename,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -331,11 +384,14 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
"# generate the witness\n",
|
||||
"res = ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -360,73 +416,14 @@
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" params_path,\n",
|
||||
" \"single\",\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n",
|
||||
"# verify\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_filename,\n",
|
||||
" vk_path,\n",
|
||||
" params_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "W7tAa-DFAtvS"
|
||||
},
|
||||
"source": [
|
||||
"# Part 2 (Using the ZK Computational Graph Onchain!)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8Ym91kaVAIB6"
|
||||
},
|
||||
"source": [
|
||||
"**Now How Do We Do It Onchain?????**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 339
|
||||
},
|
||||
"id": "fodkNgwS70FM",
|
||||
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# first we need to create evm verifier\n",
|
||||
"print(vk_path)\n",
|
||||
"print(params_path)\n",
|
||||
"print(settings_filename)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" params_path,\n",
|
||||
" settings_filename,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
" )\n",
|
||||
"assert res == True"
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -435,51 +432,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Make sure anvil is running locally first\n",
|
||||
"# run with $ anvil -p 3030\n",
|
||||
"# we use the default anvil node here\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"address_path = os.path.join(\"address.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"with open(address_path, 'r') as file:\n",
|
||||
" addr = file.read().rstrip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read the address from addr_path\n",
|
||||
"addr = None\n",
|
||||
"with open(address_path, 'r') as f:\n",
|
||||
" addr = f.read()\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.system(\"killall -9 e2pg\");"
|
||||
"# kill all shovel process \n",
|
||||
"os.system(\"pkill -f shovel\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -501,7 +455,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"import tensorflow as tf\n",
|
||||
"from tensorflow.keras.optimizers.legacy import Adam\n",
|
||||
"from tensorflow.keras.optimizers import Adam\n",
|
||||
"from tensorflow.keras.layers import *\n",
|
||||
"from tensorflow.keras.models import Model\n",
|
||||
"from tensorflow.keras.datasets import mnist\n",
|
||||
@@ -71,9 +71,11 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"opt = Adam()\n",
|
||||
"ZDIM = 100\n",
|
||||
"\n",
|
||||
"opt = Adam()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# discriminator\n",
|
||||
"# 0 if it's fake, 1 if it's real\n",
|
||||
"x = in1 = Input((28,28))\n",
|
||||
@@ -114,8 +116,11 @@
|
||||
"\n",
|
||||
"gm = Model(in1, x)\n",
|
||||
"gm.compile('adam', 'mse')\n",
|
||||
"gm.output_names=['output']\n",
|
||||
"gm.summary()\n",
|
||||
"\n",
|
||||
"opt = Adam()\n",
|
||||
"\n",
|
||||
"# GAN\n",
|
||||
"dm.trainable = False\n",
|
||||
"x = dm(gm.output)\n",
|
||||
@@ -415,7 +420,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -349,6 +349,8 @@
|
||||
"z_log_var = Dense(ZDIM)(x)\n",
|
||||
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
|
||||
"dec = get_decoder()\n",
|
||||
"dec.output_names=['output']\n",
|
||||
"\n",
|
||||
"out = dec(z)\n",
|
||||
"\n",
|
||||
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",
|
||||
|
||||
@@ -61,11 +61,10 @@
|
||||
"from sklearn.datasets import load_iris\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
|
||||
"import sk2torch\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import os\n",
|
||||
"from torch import nn\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -77,28 +76,12 @@
|
||||
"clr.fit(X_train, y_train)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"trees = []\n",
|
||||
"for tree in clr.estimators_:\n",
|
||||
" trees.append(sk2torch.wrap(tree))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class RandomForest(nn.Module):\n",
|
||||
" def __init__(self, trees):\n",
|
||||
" super(RandomForest, self).__init__()\n",
|
||||
" self.trees = nn.ModuleList(trees)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" out = self.trees[0](x)\n",
|
||||
" for tree in self.trees[1:]:\n",
|
||||
" out += tree(x)\n",
|
||||
" return out / len(self.trees)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"torch_rf = RandomForest(trees)\n",
|
||||
"torch_rf = convert(clr, 'torch')\n",
|
||||
"# assert predictions from torch are = to sklearn \n",
|
||||
"diffs = []\n",
|
||||
"for i in range(len(X_test)):\n",
|
||||
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
|
||||
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
|
||||
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
|
||||
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
|
||||
"\n",
|
||||
@@ -134,14 +117,12 @@
|
||||
"\n",
|
||||
"# export to onnx format\n",
|
||||
"\n",
|
||||
"torch_rf.eval()\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = X_train.shape[1:]\n",
|
||||
"x = torch.rand(1, *shape, requires_grad=False)\n",
|
||||
"torch_out = torch_rf(x)\n",
|
||||
"torch_out = torch_rf.predict(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(torch_rf, # model being run\n",
|
||||
"torch.onnx.export(torch_rf.model, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
" x,\n",
|
||||
" # where to save the model (can be a file or file-like object)\n",
|
||||
@@ -158,7 +139,7 @@
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
|
||||
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n"
|
||||
@@ -321,7 +302,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -57,7 +57,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -119,7 +119,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -163,7 +163,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -217,7 +217,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -227,6 +227,10 @@
|
||||
" self.length = self.compute_length(self.file_good)\n",
|
||||
" self.data = self.load_data(self.file_good)\n",
|
||||
"\n",
|
||||
" def __iter__(self):\n",
|
||||
" for i in range(len(self.data)):\n",
|
||||
" yield self.data[i]\n",
|
||||
"\n",
|
||||
" def parse_json_object(self, line):\n",
|
||||
" try:\n",
|
||||
" return json.loads(line)\n",
|
||||
@@ -749,7 +753,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -209,6 +209,11 @@
|
||||
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
|
||||
" self.data = self.load_data(self.file_good, self.file_bad)\n",
|
||||
"\n",
|
||||
" def __iter__(self):\n",
|
||||
" for i in range(len(self.data)):\n",
|
||||
" yield self.data[i]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def parse_json_object(self, line):\n",
|
||||
" try:\n",
|
||||
" return json.loads(line)\n",
|
||||
@@ -637,7 +642,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
40
examples/onnx/1l_lppool/gen.py
Normal file
40
examples/onnx/1l_lppool/gen.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.layer = nn.LPPool2d(2, 1, (1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)[0]
|
||||
|
||||
|
||||
circuit = Model()
|
||||
|
||||
x = torch.empty(1, 3, 2, 2).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
|
||||
1
examples/onnx/1l_lppool/input.json
Normal file
1
examples/onnx/1l_lppool/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.7549541592597961, 0.990360677242279, 0.9473411440849304, 0.3951031565666199, 0.8500555753707886, 0.9352139830589294, 0.11867779493331909, 0.9493132829666138, 0.6588345766067505, 0.1933223009109497, 0.12139874696731567, 0.8547163605690002]]}
|
||||
BIN
examples/onnx/1l_lppool/network.onnx
Normal file
BIN
examples/onnx/1l_lppool/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/celu/gen.py
Normal file
42
examples/onnx/celu/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.CELU()(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/celu/input.json
Normal file
1
examples/onnx/celu/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.35387128591537476, 0.030473172664642334, 0.08707714080810547, 0.2429301142692566, 0.45228832960128784, 0.496021032333374, 0.13245105743408203, 0.8497090339660645]]}
|
||||
BIN
examples/onnx/celu/network.onnx
Normal file
BIN
examples/onnx/celu/network.onnx
Normal file
Binary file not shown.
41
examples/onnx/clip/gen.py
Normal file
41
examples/onnx/clip/gen.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.clamp(x, min=0.4, max=0.8)
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/clip/input.json
Normal file
1
examples/onnx/clip/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.03297048807144165, 0.46362626552581787, 0.6044231057167053, 0.4949902892112732, 0.48823297023773193, 0.6798646450042725, 0.6824942231178284, 0.03491640090942383, 0.19608813524246216, 0.24129939079284668, 0.9769315123558044, 0.6306831240653992, 0.7690497636795044, 0.252221941947937, 0.9167693853378296, 0.3882681131362915, 0.9307044148445129, 0.33559417724609375, 0.7815426588058472, 0.3435332179069519, 0.7871478796005249, 0.12240773439407349, 0.5295405983924866, 0.4874419569969177, 0.08262640237808228, 0.1124718189239502, 0.5834914445877075, 0.30927878618240356, 0.48899340629577637, 0.9376634955406189, 0.21893149614334106, 0.526070773601532]]}
|
||||
24
examples/onnx/clip/network.onnx
Normal file
24
examples/onnx/clip/network.onnx
Normal file
@@ -0,0 +1,24 @@
|
||||
pytorch2.2.1:±
|
||||
?/Constant_output_0 /Constant"Constant*
|
||||
value*JÍÌÌ>
|
||||
C/Constant_1_output_0/Constant_1"Constant*
|
||||
value*JÍÌL?
|
||||
F
|
||||
input
|
||||
/Constant_output_0
|
||||
/Constant_1_output_0output/Clip"Clip
|
||||
main_graphZ)
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
|
||||
|
||||
b*
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
|
||||
|
||||
B
|
||||
41
examples/onnx/gru/gen.py
Normal file
41
examples/onnx/gru/gen.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import json
|
||||
|
||||
|
||||
model = nn.GRU(3, 3) # Input dim is 3, output dim is 3
|
||||
x = torch.randn(1, 3) # make a sequence of length 5
|
||||
|
||||
print(x)
|
||||
|
||||
# Flips the neural net into inference mode
|
||||
model.eval()
|
||||
model.to('cpu')
|
||||
|
||||
# Export the model
|
||||
torch.onnx.export(model, # model being run
|
||||
# model input (or a tuple for multiple inputs)
|
||||
x,
|
||||
# where to save the model (can be a file or file-like object)
|
||||
"network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=10, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data_json = dict(input_data=[data_array])
|
||||
|
||||
print(data_json)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data_json, open("input.json", 'w'))
|
||||
1
examples/onnx/gru/input.json
Normal file
1
examples/onnx/gru/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.4145222008228302, -0.4043896496295929, 0.7545749545097351]]}
|
||||
BIN
examples/onnx/gru/network.onnx
Normal file
BIN
examples/onnx/gru/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/hard_max/gen.py
Normal file
42
examples/onnx/hard_max/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.argmax(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/hard_max/input.json
Normal file
1
examples/onnx/hard_max/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.5505883693695068, 0.0766521692276001, 0.12006187438964844, 0.9497959017753601, 0.9100563526153564, 0.968717098236084, 0.5978299379348755, 0.9419963359832764]]}
|
||||
BIN
examples/onnx/hard_max/network.onnx
Normal file
BIN
examples/onnx/hard_max/network.onnx
Normal file
Binary file not shown.
@@ -9,7 +9,7 @@ class MyModel(nn.Module):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.Logsoftmax()(x)
|
||||
m = nn.Hardsigmoid()(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"input_data": [[0.2971532940864563, 0.3465197682380676, 0.05381882190704346, 0.058654189109802246, 0.014198064804077148, 0.06088751554489136, 0.1723427176475525, 0.5115123987197876]]}
|
||||
{"input_data": [[0.8326942324638367, 0.2796096205711365, 0.600328266620636, 0.3701696991920471, 0.17832040786743164, 0.6247223019599915, 0.501872718334198, 0.6961578726768494]]}
|
||||
@@ -1,4 +1,4 @@
|
||||
pytorch2.1.0:<3A>
|
||||
pytorch2.2.1:<3A>
|
||||
;
|
||||
inputoutput/HardSigmoid"HardSigmoid*
|
||||
alphaǻ*>
|
||||
|
||||
41
examples/onnx/hard_swish/gen.py
Normal file
41
examples/onnx/hard_swish/gen.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.Hardswish()(x)
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/hard_swish/input.json
Normal file
1
examples/onnx/hard_swish/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.6996762752532959, 0.42992985248565674, 0.5102168321609497, 0.5540387630462646, 0.8489438891410828, 0.8533616065979004, 0.36736780405044556, 0.5859147310256958]]}
|
||||
15
examples/onnx/hard_swish/network.onnx
Normal file
15
examples/onnx/hard_swish/network.onnx
Normal file
@@ -0,0 +1,15 @@
|
||||
pytorch2.2.1:{
|
||||
&
|
||||
inputoutput
|
||||
/HardSwish" HardSwish
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -9,7 +9,7 @@ class MyModel(nn.Module):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.Hardsigmoid()(x)
|
||||
m = nn.LogSoftmax()(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
42
examples/onnx/logsumexp/gen.py
Normal file
42
examples/onnx/logsumexp/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.logsumexp(x, dim=1)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/logsumexp/input.json
Normal file
1
examples/onnx/logsumexp/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.7973018884658813, 0.5245689153671265, 0.34149593114852905, 0.1455438733100891, 0.9482707381248474, 0.4221445322036743, 0.001363217830657959, 0.8736765384674072, 0.42954301834106445, 0.7199509739875793, 0.37641745805740356, 0.5920265316963196, 0.42270803451538086, 0.41761744022369385, 0.603948712348938, 0.7250819802284241, 0.047173500061035156, 0.5115441679954529, 0.3743387460708618, 0.16794061660766602, 0.5352339148521423, 0.037976861000061035, 0.65323406457901, 0.5585184097290039, 0.10559147596359253, 0.07827490568161011, 0.6717077493667603, 0.6480781435966492, 0.9780838489532471, 0.8353415131568909, 0.6491701006889343, 0.6573048233985901]]}
|
||||
BIN
examples/onnx/logsumexp/network.onnx
Normal file
BIN
examples/onnx/logsumexp/network.onnx
Normal file
Binary file not shown.
13
examples/onnx/lstm_large/input.json
Normal file
13
examples/onnx/lstm_large/input.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
0.8894134163856506,
|
||||
0.8894201517105103
|
||||
]
|
||||
],
|
||||
"output_data": [
|
||||
[
|
||||
0.8436377
|
||||
]
|
||||
]
|
||||
}
|
||||
BIN
examples/onnx/lstm_large/network.onnx
Normal file
BIN
examples/onnx/lstm_large/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/mish/gen.py
Normal file
42
examples/onnx/mish/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.Mish()(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/mish/input.json
Normal file
1
examples/onnx/mish/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.18563222885131836, 0.4843214750289917, 0.9991059899330139, 0.02534431219100952, 0.8105666041374207, 0.9658406376838684, 0.681107759475708, 0.5365872979164124]]}
|
||||
19
examples/onnx/mish/network.onnx
Normal file
19
examples/onnx/mish/network.onnx
Normal file
@@ -0,0 +1,19 @@
|
||||
pytorch2.2.1:ä
|
||||
0
|
||||
input/Softplus_output_0 /Softplus"Softplus
|
||||
1
|
||||
/Softplus_output_0/Tanh_output_0/Tanh"Tanh
|
||||
*
|
||||
input
|
||||
/Tanh_output_0output/Mul"Mul
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
42
examples/onnx/reducel1/gen.py
Normal file
42
examples/onnx/reducel1/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.norm(x, p=1, dim=1)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/reducel1/input.json
Normal file
1
examples/onnx/reducel1/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.02284395694732666, 0.7941043376922607, 0.07971876859664917, 0.8898420929908752, 0.8233054280281067, 0.11066079139709473, 0.4424799084663391, 0.4355071783065796, 0.6723723411560059, 0.6818525195121765, 0.8726171851158142, 0.17742449045181274, 0.054257750511169434, 0.5775953531265259, 0.7758923172950745, 0.8431423306465149, 0.7602444887161255, 0.29686522483825684, 0.22489851713180542, 0.0675363540649414, 0.981339693069458, 0.15771394968032837, 0.5801441669464111, 0.9044001698493958, 0.49266451597213745, 0.42621421813964844, 0.35345613956451416, 0.042848050594329834, 0.6908614039421082, 0.5422852039337158, 0.01975083351135254, 0.5772860050201416]]}
|
||||
BIN
examples/onnx/reducel1/network.onnx
Normal file
BIN
examples/onnx/reducel1/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/reducel2/gen.py
Normal file
42
examples/onnx/reducel2/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.norm(x, p=2, dim=1)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/reducel2/input.json
Normal file
1
examples/onnx/reducel2/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.8709188103675842, 0.11553549766540527, 0.27376580238342285, 0.7518517971038818, 0.7879393100738525, 0.8765475749969482, 0.14315760135650635, 0.8982420563697815, 0.7274006605148315, 0.39007169008255005, 0.729040801525116, 0.11306107044219971, 0.658822774887085, 0.666404664516449, 0.3001367449760437, 0.45343858003616333, 0.7460223436355591, 0.7423691749572754, 0.7544230818748474, 0.5674425959587097, 0.8728761672973633, 0.27062875032424927, 0.1595977544784546, 0.22975260019302368, 0.6711723208427429, 0.8265992403030396, 0.48679041862487793, 0.689740777015686, 0.330846905708313, 0.5630669593811035, 0.8058932423591614, 0.5802426338195801]]}
|
||||
BIN
examples/onnx/reducel2/network.onnx
Normal file
BIN
examples/onnx/reducel2/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/tril/gen.py
Normal file
42
examples/onnx/tril/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.triu(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 3, 3).uniform_(0, 5)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/tril/input.json
Normal file
1
examples/onnx/tril/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.4870188236236572, 2.275230646133423, 3.126268148422241, 0.6412187218666077, 0.9967470169067383, 1.9814395904541016, 1.6355383396148682, 0.6397527456283569, 0.7825168967247009]]}
|
||||
BIN
examples/onnx/tril/network.onnx
Normal file
BIN
examples/onnx/tril/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/triu/gen.py
Normal file
42
examples/onnx/triu/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.tril(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 3, 3).uniform_(0, 5)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/triu/input.json
Normal file
1
examples/onnx/triu/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.2898547053337097, 1.8070811033248901, 0.30266255140304565, 3.00581955909729, 0.5379888415336609, 1.7057424783706665, 2.415961265563965, 0.589233934879303, 0.03824889659881592]]}
|
||||
BIN
examples/onnx/triu/network.onnx
Normal file
BIN
examples/onnx/triu/network.onnx
Normal file
Binary file not shown.
@@ -17,19 +17,19 @@
|
||||
"clean": "rm -r dist || true",
|
||||
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
|
||||
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
|
||||
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
|
||||
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ethereumjs/common": "^4.0.0",
|
||||
"@ethereumjs/evm": "^2.0.0",
|
||||
"@ethereumjs/statemanager": "^2.0.0",
|
||||
"@ethereumjs/tx": "^5.0.0",
|
||||
"@ethereumjs/util": "^9.0.0",
|
||||
"@ethereumjs/vm": "^7.0.0",
|
||||
"@ethersproject/abi": "^5.7.0",
|
||||
"@ethereumjs/common": "4.0.0",
|
||||
"@ethereumjs/evm": "2.0.0",
|
||||
"@ethereumjs/statemanager": "2.0.0",
|
||||
"@ethereumjs/tx": "5.0.0",
|
||||
"@ethereumjs/util": "9.0.0",
|
||||
"@ethereumjs/vm": "7.0.0",
|
||||
"@ethersproject/abi": "5.7.0",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"ethers": "^6.7.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
"ethers": "6.7.1",
|
||||
"json-bigint": "1.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.8.3",
|
||||
|
||||
18
in-browser-evm-verifier/pnpm-lock.yaml
generated
18
in-browser-evm-verifier/pnpm-lock.yaml
generated
@@ -6,34 +6,34 @@ settings:
|
||||
|
||||
dependencies:
|
||||
'@ethereumjs/common':
|
||||
specifier: ^4.0.0
|
||||
specifier: 4.0.0
|
||||
version: 4.0.0
|
||||
'@ethereumjs/evm':
|
||||
specifier: ^2.0.0
|
||||
specifier: 2.0.0
|
||||
version: 2.0.0
|
||||
'@ethereumjs/statemanager':
|
||||
specifier: ^2.0.0
|
||||
specifier: 2.0.0
|
||||
version: 2.0.0
|
||||
'@ethereumjs/tx':
|
||||
specifier: ^5.0.0
|
||||
specifier: 5.0.0
|
||||
version: 5.0.0
|
||||
'@ethereumjs/util':
|
||||
specifier: ^9.0.0
|
||||
specifier: 9.0.0
|
||||
version: 9.0.0
|
||||
'@ethereumjs/vm':
|
||||
specifier: ^7.0.0
|
||||
specifier: 7.0.0
|
||||
version: 7.0.0
|
||||
'@ethersproject/abi':
|
||||
specifier: ^5.7.0
|
||||
specifier: 5.7.0
|
||||
version: 5.7.0
|
||||
'@ezkljs/engine':
|
||||
specifier: ^9.4.4
|
||||
version: 9.4.4
|
||||
ethers:
|
||||
specifier: ^6.7.1
|
||||
specifier: 6.7.1
|
||||
version: 6.7.1
|
||||
json-bigint:
|
||||
specifier: ^1.0.0
|
||||
specifier: 1.0.0
|
||||
version: 1.0.0
|
||||
|
||||
devDependencies:
|
||||
|
||||
@@ -36,7 +36,7 @@ if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; the
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
|
||||
if [[ ":$PATH:" != *":${EZKL_DIR}:"* ]]; then
|
||||
# Add the ezkl directory to the path and ensure the old PATH variables remain.
|
||||
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
|
||||
fi
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=0.14,<0.15"]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
attrs==22.2.0
|
||||
exceptiongroup==1.1.1
|
||||
importlib-metadata==6.1.0
|
||||
attrs==23.2.0
|
||||
exceptiongroup==1.2.0
|
||||
importlib-metadata==7.1.0
|
||||
iniconfig==2.0.0
|
||||
maturin==1.0.1
|
||||
packaging==23.0
|
||||
pluggy==1.0.0
|
||||
pytest==7.2.2
|
||||
maturin==1.5.1
|
||||
packaging==24.0
|
||||
pluggy==1.4.0
|
||||
pytest==8.1.1
|
||||
tomli==2.0.1
|
||||
typing-extensions==4.5.0
|
||||
zipp==3.15.0
|
||||
onnx==1.14.1
|
||||
onnxruntime==1.14.1
|
||||
numpy==1.21.6
|
||||
typing-extensions==4.10.0
|
||||
zipp==3.18.1
|
||||
onnx==1.15.0
|
||||
onnxruntime==1.17.1
|
||||
numpy==1.26.4
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "nightly-2023-08-24"
|
||||
channel = "nightly-2024-02-06"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
||||
@@ -15,6 +15,8 @@ pub use planner::*;
|
||||
|
||||
use crate::tensor::{TensorType, ValTensor};
|
||||
|
||||
use super::region::ConstantsMap;
|
||||
|
||||
/// Module trait used to extend ezkl functionality
|
||||
pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Config
|
||||
@@ -39,6 +41,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
&self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
input: &[ValTensor<F>],
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<Self::InputAssignments, Error>;
|
||||
/// Layout
|
||||
fn layout(
|
||||
@@ -46,6 +49,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
layouter: &mut impl Layouter<F>,
|
||||
input: &[ValTensor<F>],
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, Error>;
|
||||
/// Number of instance values the module uses every time it is applied
|
||||
fn instance_increment_input(&self) -> Vec<usize>;
|
||||
|
||||
@@ -4,6 +4,8 @@ is already implemented in halo2_gadgets, there is no wrapper chip that makes it
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
*/
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
use halo2_proofs::halo2curves::bn256::Fr as Fp;
|
||||
use halo2_proofs::poly::commitment::{Blind, CommitmentScheme, Params};
|
||||
@@ -13,6 +15,7 @@ use halo2curves::group::prime::PrimeCurveAffine;
|
||||
use halo2curves::group::Curve;
|
||||
use halo2curves::CurveAffine;
|
||||
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
|
||||
|
||||
use super::Module;
|
||||
@@ -41,12 +44,11 @@ impl PolyCommitChip {
|
||||
/// Commit to the message using the KZG commitment scheme
|
||||
pub fn commit<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
|
||||
message: Vec<Scheme::Scalar>,
|
||||
degree: u32,
|
||||
num_unusable_rows: u32,
|
||||
params: &Scheme::ParamsProver,
|
||||
) -> Vec<G1Affine> {
|
||||
let k = params.k();
|
||||
let domain = halo2_proofs::poly::EvaluationDomain::new(degree, k);
|
||||
let domain = halo2_proofs::poly::EvaluationDomain::new(2, k);
|
||||
let n = 2_u64.pow(k) - num_unusable_rows as u64;
|
||||
let num_poly = (message.len() / n as usize) + 1;
|
||||
let mut poly = vec![domain.empty_lagrange(); num_poly];
|
||||
@@ -107,6 +109,7 @@ impl Module<Fp> for PolyCommitChip {
|
||||
&self,
|
||||
_: &mut impl Layouter<Fp>,
|
||||
_: &[ValTensor<Fp>],
|
||||
_: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, Error> {
|
||||
Ok(())
|
||||
}
|
||||
@@ -119,11 +122,24 @@ impl Module<Fp> for PolyCommitChip {
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
input: &[ValTensor<Fp>],
|
||||
_: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, Error> {
|
||||
assert_eq!(input.len(), 1);
|
||||
|
||||
let local_constants = constants.clone();
|
||||
layouter.assign_region(
|
||||
|| "PolyCommit",
|
||||
|mut region| self.config.inputs.assign(&mut region, 0, &input[0]),
|
||||
|mut region| {
|
||||
let mut local_inner_constants = local_constants.clone();
|
||||
let res = self.config.inputs.assign(
|
||||
&mut region,
|
||||
0,
|
||||
&input[0],
|
||||
&mut local_inner_constants,
|
||||
)?;
|
||||
*constants = local_inner_constants;
|
||||
Ok(res)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -184,7 +200,12 @@ mod tests {
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
let polycommit_chip = PolyCommitChip::new(config);
|
||||
polycommit_chip.layout(&mut layouter, &[self.message.clone()], 0);
|
||||
polycommit_chip.layout(
|
||||
&mut layouter,
|
||||
&[self.message.clone()],
|
||||
0,
|
||||
&mut HashMap::new(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ use maybe_rayon::slice::ParallelSlice;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType};
|
||||
|
||||
use super::Module;
|
||||
@@ -172,12 +173,15 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
&self,
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, Error> {
|
||||
assert_eq!(message.len(), 1);
|
||||
let message = message[0].clone();
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let local_constants = constants.clone();
|
||||
|
||||
let res = layouter.assign_region(
|
||||
|| "load message",
|
||||
|mut region| {
|
||||
@@ -199,12 +203,26 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
Ok(v.clone())
|
||||
}
|
||||
ValType::Constant(f) => region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
),
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
|
||||
log::error!("constant not previously assigned");
|
||||
Error::Synthesis
|
||||
})?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
constants
|
||||
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => {
|
||||
log::error!(
|
||||
"wrong input type {:?}, must be previously assigned",
|
||||
@@ -270,8 +288,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
input: &[ValTensor<Fp>],
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, Error> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input)?;
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
input_cells.iter().map(|e| ValType::from(e.clone())).into();
|
||||
@@ -434,7 +453,7 @@ mod tests {
|
||||
*,
|
||||
};
|
||||
|
||||
use std::marker::PhantomData;
|
||||
use std::{collections::HashMap, marker::PhantomData};
|
||||
|
||||
use halo2_gadgets::poseidon::primitives::Spec;
|
||||
use halo2_proofs::{
|
||||
@@ -477,7 +496,12 @@ mod tests {
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
|
||||
chip.layout(&mut layouter, &[self.message.clone()], 0)?;
|
||||
chip.layout(
|
||||
&mut layouter,
|
||||
&[self.message.clone()],
|
||||
0,
|
||||
&mut HashMap::new(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -345,7 +345,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
Self {
|
||||
@@ -956,20 +956,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
values: &[ValTensor<F>],
|
||||
op: Box<dyn Op<F>>,
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
let res = op.layout(self, region, values)?;
|
||||
|
||||
if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() {
|
||||
if let Some(claimed_output) = &res {
|
||||
// during key generation this will be unknown vals so we use this as a flag to check
|
||||
let mut is_assigned = !claimed_output.any_unknowns()?;
|
||||
for val in values.iter() {
|
||||
is_assigned = is_assigned && !val.any_unknowns()?;
|
||||
}
|
||||
if is_assigned {
|
||||
op.safe_mode_check(claimed_output, values)?;
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
op.layout(self, region, values)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
fieldutils::i128_to_felt,
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -29,15 +29,15 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
},
|
||||
SumPool {
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
kernel_shape: (usize, usize),
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
kernel_shape: Vec<usize>,
|
||||
normalized: bool,
|
||||
},
|
||||
MaxPool2d {
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
pool_dims: (usize, usize),
|
||||
MaxPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
pool_dims: Vec<usize>,
|
||||
},
|
||||
ReduceMin {
|
||||
axes: Vec<usize>,
|
||||
@@ -46,7 +46,8 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
},
|
||||
Softmax {
|
||||
scale: utils::F32,
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
RangeCheck(Tolerance),
|
||||
@@ -70,7 +71,7 @@ pub enum HybridOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
@@ -84,86 +85,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = inputs[0].clone().map(|x| felt_to_i128(x));
|
||||
|
||||
let res = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
|
||||
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
|
||||
}
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
..
|
||||
} => crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
),
|
||||
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
|
||||
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather(&x, idx, *dim)?
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
|
||||
}
|
||||
|
||||
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
..
|
||||
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
|
||||
HybridOp::Softmax { scale, axes } => {
|
||||
tensor::ops::nonlinearities::softmax_axes(&x, scale.into(), axes)
|
||||
}
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
|
||||
}
|
||||
HybridOp::Greater => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::greater(&x, &y)?
|
||||
}
|
||||
HybridOp::GreaterEqual => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::greater_equal(&x, &y)?
|
||||
}
|
||||
HybridOp::Less => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::less(&x, &y)?
|
||||
}
|
||||
HybridOp::LessEqual => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::less_equal(&x, &y)?
|
||||
}
|
||||
HybridOp::Equals => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::equals(&x, &y)?
|
||||
}
|
||||
};
|
||||
|
||||
// convert back to felt
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
@@ -193,18 +114,25 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
),
|
||||
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
|
||||
HybridOp::MaxPool2d {
|
||||
HybridOp::MaxPool {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
} => format!(
|
||||
"MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
padding, stride, pool_dims
|
||||
),
|
||||
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
|
||||
HybridOp::Softmax { scale, axes } => {
|
||||
format!("SOFTMAX (scale={}, axes={:?})", scale, axes)
|
||||
HybridOp::Softmax {
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
} => {
|
||||
format!(
|
||||
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
|
||||
input_scale, output_scale, axes
|
||||
)
|
||||
}
|
||||
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
|
||||
HybridOp::Greater => "GREATER".into(),
|
||||
@@ -238,9 +166,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*padding,
|
||||
*stride,
|
||||
*kernel_shape,
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
*normalized,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
@@ -300,17 +228,17 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
}
|
||||
}
|
||||
|
||||
HybridOp::MaxPool2d {
|
||||
HybridOp::MaxPool {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
} => layouts::max_pool2d(
|
||||
} => layouts::max_pool(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*padding,
|
||||
*stride,
|
||||
*pool_dims,
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
)?,
|
||||
HybridOp::ReduceMax { axes } => {
|
||||
layouts::max_axes(config, region, values[..].try_into()?, axes)?
|
||||
@@ -324,9 +252,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
HybridOp::ReduceArgMin { dim } => {
|
||||
layouts::argmin_axes(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
HybridOp::Softmax { scale, axes } => {
|
||||
layouts::softmax_axes(config, region, values[..].try_into()?, *scale, axes)?
|
||||
}
|
||||
HybridOp::Softmax {
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
} => layouts::softmax_axes(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
axes,
|
||||
)?,
|
||||
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
|
||||
config,
|
||||
region,
|
||||
@@ -359,8 +296,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
| HybridOp::ReduceArgMax { .. }
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
HybridOp::Softmax { .. } => 2 * in_scales[0],
|
||||
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
|
||||
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
|
||||
multiplier_to_scale(output_scale.0 as f64)
|
||||
}
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -123,6 +123,9 @@ pub enum LookupOp {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
HardSwish {
|
||||
scale: utils::F32,
|
||||
},
|
||||
}
|
||||
|
||||
impl LookupOp {
|
||||
@@ -132,15 +135,12 @@ impl LookupOp {
|
||||
let range = range as i128;
|
||||
(-range, range)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_i128(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
@@ -223,12 +223,22 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())),
|
||||
LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())),
|
||||
LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())),
|
||||
LookupOp::HardSwish { scale } => {
|
||||
Ok(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
@@ -276,6 +286,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
LookupOp::ASin { scale } => format!("ASIN(scale={})", scale),
|
||||
LookupOp::Sinh { scale } => format!("SINH(scale={})", scale),
|
||||
LookupOp::ASinh { scale } => format!("ASINH(scale={})", scale),
|
||||
LookupOp::HardSwish { scale } => format!("HARDSWISH(scale={})", scale),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -27,14 +27,14 @@ pub mod region;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Returns a string representation of the operation.
|
||||
fn as_string(&self) -> String;
|
||||
|
||||
@@ -69,36 +69,9 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
|
||||
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Safe mode output checl
|
||||
fn safe_mode_check(
|
||||
&self,
|
||||
claimed_output: &ValTensor<F>,
|
||||
original_values: &[ValTensor<F>],
|
||||
) -> Result<(), TensorError> {
|
||||
let felt_evals = original_values
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?;
|
||||
evals.reshape(v.dims())?;
|
||||
Ok(evals)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let ref_op: Tensor<F> = self.f(&felt_evals)?.output;
|
||||
|
||||
let mut output = claimed_output
|
||||
.get_felt_evals()
|
||||
.map_err(|_| TensorError::FeltError)?;
|
||||
output.reshape(claimed_output.dims())?;
|
||||
|
||||
assert_eq!(output, ref_op);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Clone for Box<dyn Op<F>> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
fn clone(&self) -> Self {
|
||||
self.clone_dyn()
|
||||
}
|
||||
@@ -165,7 +138,7 @@ pub struct Input {
|
||||
pub datum_type: InputType,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
Ok(self.scale)
|
||||
}
|
||||
@@ -174,12 +147,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
self
|
||||
}
|
||||
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Ok(ForwardResult {
|
||||
output: x[0].clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
"Input".into()
|
||||
}
|
||||
@@ -226,16 +193,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Unknown;
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
Ok(0)
|
||||
}
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Err(TensorError::WrongMethod)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
"Unknown".into()
|
||||
@@ -256,7 +220,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
///
|
||||
pub quantized_values: Tensor<F>,
|
||||
///
|
||||
@@ -266,7 +230,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
Self {
|
||||
@@ -293,17 +257,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
|
||||
for Constant<F>
|
||||
impl<
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let output = self.quantized_values.clone();
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("CONST (scale={})", self.quantized_values.scale().unwrap())
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
fieldutils::felt_to_i128,
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -32,8 +31,8 @@ pub enum PolyOp {
|
||||
equation: String,
|
||||
},
|
||||
Conv {
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -41,9 +40,9 @@ pub enum PolyOp {
|
||||
modulo: usize,
|
||||
},
|
||||
DeConv {
|
||||
padding: [(usize, usize); 2],
|
||||
output_padding: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
padding: Vec<(usize, usize)>,
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -58,10 +57,13 @@ pub enum PolyOp {
|
||||
destination: usize,
|
||||
},
|
||||
Flatten(Vec<usize>),
|
||||
Pad([(usize, usize); 2]),
|
||||
Pad(Vec<(usize, usize)>),
|
||||
Sum {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
MeanOfSquares {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Prod {
|
||||
axes: Vec<usize>,
|
||||
len_prod: usize,
|
||||
@@ -83,10 +85,20 @@ pub enum PolyOp {
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Trilu {
|
||||
upper: bool,
|
||||
k: i32,
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
|
||||
for PolyOp
|
||||
impl<
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
@@ -95,10 +107,28 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterND { .. } => "SCATTERND".into(),
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
constant_idx.is_some()
|
||||
),
|
||||
PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices,
|
||||
} => format!(
|
||||
"GATHERND (batch_dims={}, constant_idx{})",
|
||||
batch_dims,
|
||||
indices.is_some()
|
||||
),
|
||||
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
|
||||
PolyOp::ScatterElements { dim, constant_idx } => format!(
|
||||
"SCATTERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
constant_idx.is_some()
|
||||
),
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
|
||||
}
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -110,15 +140,26 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
}
|
||||
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
|
||||
PolyOp::Flatten(_) => "FLATTEN".into(),
|
||||
PolyOp::Pad(_) => "PAD".into(),
|
||||
PolyOp::Pad(pads) => format!("PAD (pads={:?})", pads),
|
||||
PolyOp::Add => "ADD".into(),
|
||||
PolyOp::Mult => "MULT".into(),
|
||||
PolyOp::Sub => "SUB".into(),
|
||||
PolyOp::Sum { .. } => "SUM".into(),
|
||||
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Conv { .. } => "CONV".into(),
|
||||
PolyOp::DeConv { .. } => "DECONV".into(),
|
||||
PolyOp::Conv { stride, padding } => {
|
||||
format!("CONV (stride={:?}, padding={:?})", stride, padding)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
|
||||
stride, padding, output_padding
|
||||
)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
format!("SLICE (axis={}, start={}, end={})", axis, start, end)
|
||||
@@ -128,148 +169,10 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::And => "AND".into(),
|
||||
PolyOp::Or => "OR".into(),
|
||||
PolyOp::Xor => "XOR".into(),
|
||||
PolyOp::Trilu { upper, k } => format!("TRILU (upper={}, k={})", upper, k),
|
||||
}
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let mut inputs = inputs.to_vec();
|
||||
let res = match &self {
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch(
|
||||
"multibroadcastto inputs".to_string(),
|
||||
));
|
||||
}
|
||||
inputs[0].expand(shape)
|
||||
}
|
||||
PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]),
|
||||
PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]),
|
||||
PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]),
|
||||
PolyOp::Not => tensor::ops::not(&inputs[0]),
|
||||
PolyOp::Downsample {
|
||||
axis,
|
||||
stride,
|
||||
modulo,
|
||||
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
|
||||
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
|
||||
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
|
||||
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
|
||||
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
|
||||
PolyOp::Reshape(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::MoveAxis {
|
||||
source,
|
||||
destination,
|
||||
} => inputs[0].move_axis(*source, *destination),
|
||||
PolyOp::Flatten(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::Pad(p) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pad inputs".to_string()));
|
||||
}
|
||||
tensor::ops::pad(&inputs[0], *p)
|
||||
}
|
||||
PolyOp::Add => tensor::ops::add(&inputs),
|
||||
PolyOp::Neg => tensor::ops::neg(&inputs[0]),
|
||||
PolyOp::Sub => tensor::ops::sub(&inputs),
|
||||
PolyOp::Mult => tensor::ops::mult(&inputs),
|
||||
PolyOp::Conv { padding, stride } => tensor::ops::conv(&inputs, *padding, *stride),
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
|
||||
PolyOp::Pow(u) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pow inputs".to_string()));
|
||||
}
|
||||
inputs[0].pow(*u)
|
||||
}
|
||||
PolyOp::Sum { axes } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("sum inputs".to_string()));
|
||||
}
|
||||
tensor::ops::sum_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::Prod { axes, .. } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("prod inputs".to_string()));
|
||||
}
|
||||
tensor::ops::prod_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::Concat { axis } => {
|
||||
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
|
||||
}
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("slice inputs".to_string()));
|
||||
}
|
||||
tensor::ops::slice(&inputs[0], axis, start, end)
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
indices,
|
||||
batch_dims,
|
||||
} => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = indices {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_nd(&x, &y, *batch_dims)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
}
|
||||
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter_nd(&x, &idx, &src)
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
@@ -280,6 +183,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
PolyOp::MeanOfSquares { axes } => {
|
||||
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
|
||||
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
|
||||
@@ -306,7 +212,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
|
||||
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
@@ -358,9 +264,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*padding,
|
||||
*output_padding,
|
||||
*stride,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
@@ -376,7 +282,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
)));
|
||||
}
|
||||
let mut input = values[0].clone();
|
||||
input.pad(*p)?;
|
||||
input.pad(p.clone(), 0)?;
|
||||
input
|
||||
}
|
||||
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
|
||||
@@ -384,11 +290,15 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
|
||||
}
|
||||
PolyOp::Trilu { upper, k } => {
|
||||
layouts::trilu(config, region, values[..].try_into()?, k, upper)?
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
let scale = match self {
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Iff => in_scales[1],
|
||||
PolyOp::Einsum { .. } => {
|
||||
|
||||
@@ -2,24 +2,28 @@ use crate::{
|
||||
circuit::table::Range,
|
||||
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
use halo2_proofs::{
|
||||
circuit::Region,
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use portable_atomic::AtomicI128 as AtomicInt;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashSet,
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use portable_atomic::AtomicI128 as AtomicInt;
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
|
||||
/// Constants map
|
||||
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
|
||||
|
||||
/// Dynamic lookup index
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookupIndex {
|
||||
@@ -120,12 +124,11 @@ impl From<Box<dyn std::error::Error>> for RegionError {
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A context for a region
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
total_constants: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
@@ -133,13 +136,34 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
max_lookup_inputs: i128,
|
||||
min_lookup_inputs: i128,
|
||||
max_range_size: i128,
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
///
|
||||
pub fn increment_total_constants(&mut self, n: usize) {
|
||||
self.total_constants += n;
|
||||
pub fn debug_report(&self) {
|
||||
log::debug!(
|
||||
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
|
||||
self.row().to_string().blue(),
|
||||
self.linear_coord().to_string().yellow(),
|
||||
self.total_constants().to_string().red(),
|
||||
self.max_lookup_inputs().to_string().green(),
|
||||
self.min_lookup_inputs().to_string().green(),
|
||||
self.max_range_size().to_string().green(),
|
||||
self.dynamic_lookup_col_coord().to_string().green(),
|
||||
self.shuffle_col_coord().to_string().green());
|
||||
}
|
||||
|
||||
///
|
||||
pub fn assigned_constants(&self) -> &ConstantsMap<F> {
|
||||
&self.assigned_constants
|
||||
}
|
||||
|
||||
///
|
||||
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
|
||||
self.assigned_constants.extend(constants);
|
||||
}
|
||||
|
||||
///
|
||||
@@ -163,8 +187,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn throw_range_check_error(&self) -> bool {
|
||||
self.throw_range_check_error
|
||||
pub fn witness_gen(&self) -> bool {
|
||||
self.witness_gen
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
@@ -177,7 +201,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
row,
|
||||
linear_coord,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
@@ -185,9 +208,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error: false,
|
||||
witness_gen: true,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_with_constants(
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
constants: ConstantsMap<F>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let mut new_self = Self::new(region, row, num_inner_cols);
|
||||
new_self.assigned_constants = constants;
|
||||
new_self
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
pub fn from_wrapped_region(
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
@@ -202,7 +238,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index,
|
||||
shuffle_index,
|
||||
used_lookups: HashSet::new(),
|
||||
@@ -210,16 +245,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error: false,
|
||||
witness_gen: false,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
throw_range_check_error: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
let linear_coord = row * num_inner_cols;
|
||||
|
||||
@@ -228,7 +260,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
@@ -236,17 +267,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error,
|
||||
witness_gen,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy_with_constants(
|
||||
pub fn new_dummy_with_linear_coord(
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
total_constants: usize,
|
||||
num_inner_cols: usize,
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -254,7 +285,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
@@ -262,7 +292,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error,
|
||||
witness_gen,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,29 +343,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
) -> Result<(), RegionError> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let constants = AtomicUsize::new(self.total_constants());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output
|
||||
.par_enum_map(|idx, _| {
|
||||
// we kick off the loop with the current offset
|
||||
let starting_offset = row.load(Ordering::SeqCst);
|
||||
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
|
||||
let starting_constants = constants.load(Ordering::SeqCst);
|
||||
// get inner value of the locked lookups
|
||||
|
||||
// we need to make sure that the region is not shared between threads
|
||||
let mut local_reg = Self::new_dummy_with_constants(
|
||||
let mut local_reg = Self::new_dummy_with_linear_coord(
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
starting_constants,
|
||||
self.num_inner_cols,
|
||||
self.throw_range_check_error,
|
||||
self.witness_gen,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -343,10 +372,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
local_reg.linear_coord() - starting_linear_coord,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
constants.fetch_add(
|
||||
local_reg.total_constants() - starting_constants,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
@@ -362,11 +387,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
|
||||
res
|
||||
})
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
|
||||
self.total_constants = constants.into_inner();
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
@@ -410,6 +437,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
|
||||
})?;
|
||||
self.assigned_constants = Arc::try_unwrap(constants)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
|
||||
})?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -435,7 +470,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if range.0 > range.1 {
|
||||
return Err("update_max_min_lookup_range: invalid range".into());
|
||||
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
|
||||
}
|
||||
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
@@ -477,7 +512,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
|
||||
/// Get the total number of constants
|
||||
pub fn total_constants(&self) -> usize {
|
||||
self.total_constants
|
||||
self.assigned_constants.len()
|
||||
}
|
||||
|
||||
/// Get the dynamic lookup index
|
||||
@@ -525,26 +560,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.max_range_size
|
||||
}
|
||||
|
||||
/// Assign a constant value
|
||||
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
|
||||
self.total_constants += 1;
|
||||
if let Some(region) = &self.region {
|
||||
let cell = var.assign_constant(&mut region.borrow_mut(), self.linear_coord, value)?;
|
||||
Ok(cell.into())
|
||||
} else {
|
||||
Ok(value.into())
|
||||
}
|
||||
}
|
||||
/// Assign a valtensor to a vartensor
|
||||
pub fn assign(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
self.total_constants += values.num_constants();
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(&mut region.borrow_mut(), self.linear_coord, values)
|
||||
var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.linear_coord,
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -560,14 +593,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
self.total_constants += values.num_constants();
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.combined_dynamic_shuffle_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -594,13 +631,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.linear_coord,
|
||||
values,
|
||||
ommissions,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
self.total_constants += values.num_constants();
|
||||
let inner_tensor = values.get_inner_tensor().unwrap();
|
||||
let mut values_map = values.create_constants_map();
|
||||
|
||||
for o in ommissions {
|
||||
self.total_constants -= inner_tensor.get_flat_index(**o).is_constant() as usize;
|
||||
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
|
||||
values_map.remove(&value);
|
||||
}
|
||||
}
|
||||
|
||||
self.assigned_constants.extend(values_map);
|
||||
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -615,24 +659,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len, total_assigned_constants) = var.assign_with_duplication(
|
||||
let (res, len) = var.assign_with_duplication(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
check_mode,
|
||||
single_inner_col,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
self.total_constants += total_assigned_constants;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication(
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
single_inner_col,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
self.total_constants += total_assigned_constants;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
@@ -699,9 +743,4 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// increment constants
|
||||
pub fn increment_constants(&mut self, n: usize) {
|
||||
self.total_constants += n
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,6 @@ use crate::{
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
use super::Op;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i128, i128);
|
||||
|
||||
@@ -98,7 +96,7 @@ pub struct Table<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
@@ -113,11 +111,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
let chunk = chunk as i128;
|
||||
// we index from 1 to prevent soundness issues
|
||||
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
|
||||
let op_f = Op::<F>::f(
|
||||
&self.nonlinearity,
|
||||
&[Tensor::from(vec![first_element].into_iter())],
|
||||
)
|
||||
.unwrap();
|
||||
let op_f = self
|
||||
.nonlinearity
|
||||
.f(&[Tensor::from(vec![first_element].into_iter())])
|
||||
.unwrap();
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
@@ -138,7 +135,7 @@ pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
|
||||
(range_len / (col_size as i128)) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -152,7 +149,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
// number of cols needed to store the range
|
||||
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
|
||||
|
||||
log::debug!("table range: {:?}", range);
|
||||
debug!("table range: {:?}", range);
|
||||
|
||||
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
|
||||
let mut cols = vec![];
|
||||
@@ -165,7 +162,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
let num_cols = table_inputs.len();
|
||||
|
||||
if num_cols > 1 {
|
||||
debug!("Using {} columns for non-linearity table.", num_cols);
|
||||
warn!("Using {} columns for non-linearity table.", num_cols);
|
||||
}
|
||||
|
||||
let table_outputs = table_inputs
|
||||
@@ -205,8 +202,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
@@ -275,7 +272,7 @@ pub struct RangeCheck<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> F {
|
||||
let chunk = chunk as i128;
|
||||
@@ -303,7 +300,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
|
||||
@@ -1048,8 +1048,8 @@ mod conv {
|
||||
&mut region,
|
||||
&self.inputs,
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1198,8 +1198,8 @@ mod conv_col_ultra_overflow {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1343,8 +1343,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
@@ -1911,6 +1911,8 @@ mod add_with_overflow {
|
||||
|
||||
#[cfg(test)]
|
||||
mod add_with_overflow_and_poseidon {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
use crate::circuit::modules::{
|
||||
@@ -1969,8 +1971,10 @@ mod add_with_overflow_and_poseidon {
|
||||
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
|
||||
PoseidonChip::new(config.poseidon.clone());
|
||||
|
||||
let assigned_inputs_a = poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0)?;
|
||||
let assigned_inputs_b = poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1)?;
|
||||
let assigned_inputs_a =
|
||||
poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0, &mut HashMap::new())?;
|
||||
let assigned_inputs_b =
|
||||
poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1, &mut HashMap::new())?;
|
||||
|
||||
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
|
||||
|
||||
|
||||
@@ -444,7 +444,7 @@ pub enum Commands {
|
||||
disable_selector_compression: bool,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Commitments,
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Aggregates proofs :)
|
||||
Aggregate {
|
||||
@@ -479,7 +479,7 @@ pub enum Commands {
|
||||
split_proofs: bool,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Commitments,
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
|
||||
CompileCircuit {
|
||||
@@ -726,7 +726,7 @@ pub enum Commands {
|
||||
logrows: u32,
|
||||
/// commitment
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Commitments,
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
|
||||
150
src/execute.rs
150
src/execute.rs
@@ -24,6 +24,8 @@ use crate::pfsys::{
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
use crate::{Commitments, RunArgs};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
#[cfg(unix)]
|
||||
use gag::Gag;
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
@@ -194,7 +196,6 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path,
|
||||
srs_path,
|
||||
} => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path)
|
||||
.await
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(model, witness),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -337,7 +338,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
logrows,
|
||||
split_proofs,
|
||||
disable_selector_compression,
|
||||
commitment,
|
||||
commitment.into(),
|
||||
),
|
||||
Commands::Aggregate {
|
||||
proof_path,
|
||||
@@ -358,7 +359,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
logrows,
|
||||
check_mode,
|
||||
split_proofs,
|
||||
commitment,
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Verify {
|
||||
@@ -382,7 +383,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
logrows,
|
||||
reduced_srs,
|
||||
commitment,
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -538,7 +539,7 @@ fn check_srs_hash(
|
||||
let path = get_srs_path(logrows, srs_path, commitment);
|
||||
let hash = get_file_hash(&path)?;
|
||||
|
||||
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
|
||||
let predefined_hash = match crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) {
|
||||
Some(h) => h,
|
||||
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
|
||||
};
|
||||
@@ -584,7 +585,7 @@ pub(crate) async fn get_srs_cmd(
|
||||
} else if let Some(settings_p) = settings_path {
|
||||
if settings_p.exists() {
|
||||
let settings = GraphSettings::load(&settings_p)?;
|
||||
settings.run_args.commitment
|
||||
settings.run_args.commitment.into()
|
||||
} else {
|
||||
return Err(err_string.into());
|
||||
}
|
||||
@@ -635,7 +636,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) async fn gen_witness(
|
||||
pub(crate) fn gen_witness(
|
||||
compiled_circuit_path: PathBuf,
|
||||
data: PathBuf,
|
||||
output: Option<PathBuf>,
|
||||
@@ -658,33 +659,29 @@ pub(crate) async fn gen_witness(
|
||||
};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
// if any of the settings have kzg visibility then we need to load the srs
|
||||
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
|
||||
let start_time = Instant::now();
|
||||
let witness = if settings.module_requires_polycommit() {
|
||||
if get_srs_path(
|
||||
settings.run_args.logrows,
|
||||
srs_path.clone(),
|
||||
settings.run_args.commitment,
|
||||
)
|
||||
.exists()
|
||||
{
|
||||
match settings.run_args.commitment {
|
||||
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
|
||||
match Commitments::from(settings.run_args.commitment) {
|
||||
Commitments::KZG => {
|
||||
let srs: ParamsKZG<Bn256> = load_params_prover::<KZGCommitmentScheme<Bn256>>(
|
||||
srs_path.clone(),
|
||||
settings.run_args.logrows,
|
||||
settings.run_args.commitment,
|
||||
commitment,
|
||||
)?;
|
||||
circuit.forward::<KZGCommitmentScheme<_>>(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
false,
|
||||
true,
|
||||
)?
|
||||
}
|
||||
Commitments::IPA => {
|
||||
@@ -692,22 +689,22 @@ pub(crate) async fn gen_witness(
|
||||
load_params_prover::<IPACommitmentScheme<G1Affine>>(
|
||||
srs_path.clone(),
|
||||
settings.run_args.logrows,
|
||||
settings.run_args.commitment,
|
||||
commitment,
|
||||
)?;
|
||||
circuit.forward::<IPACommitmentScheme<_>>(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
false,
|
||||
true,
|
||||
)?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("SRS for poly commit does not exist (will be ignored)");
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
|
||||
}
|
||||
} else {
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
|
||||
};
|
||||
|
||||
// print each variable tuple (symbol, value) as symbol=value
|
||||
@@ -819,7 +816,15 @@ impl AccuracyResults {
|
||||
let error = (original.clone() - calibrated.clone())?;
|
||||
let abs_error = error.map(|x| x.abs());
|
||||
let squared_error = error.map(|x| x.powi(2));
|
||||
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
|
||||
let percentage_error = error.enum_map(|i, x| {
|
||||
// if everything is 0 then we can't divide by 0 so we just return 0
|
||||
let res = if original[i] == 0.0 && x == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
x / original[i]
|
||||
};
|
||||
Ok::<f32, TensorError>(res)
|
||||
})?;
|
||||
let abs_percentage_error = percentage_error.map(|x| x.abs());
|
||||
|
||||
errors.extend(error);
|
||||
@@ -888,6 +893,7 @@ pub(crate) fn calibrate(
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
use log::error;
|
||||
use std::collections::HashMap;
|
||||
use tabled::Table;
|
||||
|
||||
@@ -900,9 +906,9 @@ pub(crate) fn calibrate(
|
||||
let model = Model::from_run_args(&settings.run_args, &model_path)?;
|
||||
|
||||
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
|
||||
info!("num of calibration batches: {}", chunks.len());
|
||||
info!("num calibration batches: {}", chunks.len());
|
||||
|
||||
info!("running onnx predictions...");
|
||||
debug!("running onnx predictions...");
|
||||
let original_predictions = Model::run_onnx_predictions(
|
||||
&settings.run_args,
|
||||
&model_path,
|
||||
@@ -970,10 +976,18 @@ pub(crate) fn calibrate(
|
||||
let pb = init_bar(range_grid.len() as u64);
|
||||
pb.set_message("calibrating...");
|
||||
|
||||
let mut num_failed = 0;
|
||||
let mut num_passed = 0;
|
||||
|
||||
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
|
||||
pb.set_message(format!(
|
||||
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
|
||||
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
|
||||
input_scale.to_string().blue(),
|
||||
param_scale.to_string().blue(),
|
||||
scale_rebase_multiplier.to_string().blue(),
|
||||
div_rebasing.to_string().yellow(),
|
||||
num_failed.to_string().red(),
|
||||
num_passed.to_string().green()
|
||||
));
|
||||
|
||||
let key = (
|
||||
@@ -1007,7 +1021,9 @@ pub(crate) fn calibrate(
|
||||
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
debug!("circuit creation from run args failed: {:?}", e);
|
||||
error!("circuit creation from run args failed: {:?}", e);
|
||||
pb.inc(1);
|
||||
num_failed += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -1039,7 +1055,9 @@ pub(crate) fn calibrate(
|
||||
Ok(_) => (),
|
||||
// typically errors will be due to the circuit overflowing the i128 limit
|
||||
Err(e) => {
|
||||
debug!("forward pass failed: {:?}", e);
|
||||
error!("forward pass failed: {:?}", e);
|
||||
pb.inc(1);
|
||||
num_failed += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -1104,8 +1122,10 @@ pub(crate) fn calibrate(
|
||||
"found settings: \n {}",
|
||||
found_settings.as_json()?.to_colored_json_auto()?
|
||||
);
|
||||
num_passed += 1;
|
||||
} else {
|
||||
debug!("calibration failed {}", res.err().unwrap());
|
||||
error!("calibration failed {}", res.err().unwrap());
|
||||
num_failed += 1;
|
||||
}
|
||||
|
||||
pb.inc(1);
|
||||
@@ -1208,22 +1228,14 @@ pub(crate) fn calibrate(
|
||||
);
|
||||
|
||||
if matches!(target, CalibrationTarget::Resources { col_overflow: true }) {
|
||||
let lookup_log_rows = ((best_params.run_args.lookup_range.1
|
||||
- best_params.run_args.lookup_range.0) as f32)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
+ 1;
|
||||
let mut reduction = std::cmp::max(
|
||||
(best_params
|
||||
.model_instance_shapes
|
||||
.iter()
|
||||
.map(|x| x.iter().product::<usize>())
|
||||
.sum::<usize>() as f32)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
+ 1,
|
||||
lookup_log_rows,
|
||||
);
|
||||
let lookup_log_rows = best_params.lookup_log_rows_with_blinding();
|
||||
let module_log_row = best_params.module_constraint_logrows_with_blinding();
|
||||
let instance_logrows = best_params.log2_total_instances_with_blinding();
|
||||
let dynamic_lookup_logrows = best_params.dynamic_lookup_and_shuffle_logrows_with_blinding();
|
||||
|
||||
let mut reduction = std::cmp::max(lookup_log_rows, module_log_row);
|
||||
reduction = std::cmp::max(reduction, instance_logrows);
|
||||
reduction = std::cmp::max(reduction, dynamic_lookup_logrows);
|
||||
reduction = std::cmp::max(reduction, crate::graph::MIN_LOGROWS);
|
||||
|
||||
info!(
|
||||
@@ -1278,17 +1290,19 @@ pub(crate) fn create_evm_verifier(
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
|
||||
srs_path,
|
||||
circuit_settings.run_args.logrows,
|
||||
circuit_settings.run_args.commitment,
|
||||
settings.run_args.logrows,
|
||||
commitment,
|
||||
)?;
|
||||
|
||||
let num_instance = circuit_settings.total_instances();
|
||||
let num_instance = settings.total_instances();
|
||||
let num_instance: usize = num_instance.iter().sum::<usize>();
|
||||
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
|
||||
trace!("params computed");
|
||||
|
||||
let generator = halo2_solidity_verifier::SolidityGenerator::new(
|
||||
@@ -1322,17 +1336,18 @@ pub(crate) fn create_evm_vk(
|
||||
abi_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
|
||||
srs_path,
|
||||
circuit_settings.run_args.logrows,
|
||||
circuit_settings.run_args.commitment,
|
||||
settings.run_args.logrows,
|
||||
commitment,
|
||||
)?;
|
||||
|
||||
let num_instance = circuit_settings.total_instances();
|
||||
let num_instance = settings.total_instances();
|
||||
let num_instance: usize = num_instance.iter().sum::<usize>();
|
||||
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
|
||||
trace!("params computed");
|
||||
|
||||
let generator = halo2_solidity_verifier::SolidityGenerator::new(
|
||||
@@ -1601,8 +1616,9 @@ pub(crate) fn setup(
|
||||
}
|
||||
|
||||
let logrows = circuit.settings().run_args.logrows;
|
||||
let commitment: Commitments = circuit.settings().run_args.commitment.into();
|
||||
|
||||
let pk = match circuit.settings().run_args.commitment {
|
||||
let pk = match commitment {
|
||||
Commitments::KZG => {
|
||||
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
|
||||
srs_path,
|
||||
@@ -1711,7 +1727,8 @@ pub(crate) fn prove(
|
||||
let transcript: TranscriptType = proof_type.into();
|
||||
let proof_split_commits: Option<ProofSplitCommit> = data.into();
|
||||
|
||||
let commitment = circuit_settings.run_args.commitment;
|
||||
let commitment = circuit_settings.run_args.commitment.into();
|
||||
let logrows = circuit_settings.run_args.logrows;
|
||||
// creates and verifies the proof
|
||||
let mut snark = match commitment {
|
||||
Commitments::KZG => {
|
||||
@@ -1720,7 +1737,7 @@ pub(crate) fn prove(
|
||||
|
||||
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
|
||||
srs_path,
|
||||
circuit_settings.run_args.logrows,
|
||||
logrows,
|
||||
Commitments::KZG,
|
||||
)?;
|
||||
match strategy {
|
||||
@@ -1879,7 +1896,9 @@ pub(crate) fn mock_aggregate(
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(
|
||||
format!("invalid sample commitment type for aggregation, must be KZG").into(),
|
||||
"invalid sample commitment type for aggregation, must be KZG"
|
||||
.to_string()
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1922,7 +1941,9 @@ pub(crate) fn setup_aggregate(
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(
|
||||
format!("invalid sample commitment type for aggregation, must be KZG",).into(),
|
||||
"invalid sample commitment type for aggregation, must be KZG"
|
||||
.to_string()
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1983,7 +2004,9 @@ pub(crate) fn aggregate(
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(
|
||||
format!("invalid sample commitment type for aggregation, must be KZG").into(),
|
||||
"invalid sample commitment type for aggregation, must be KZG"
|
||||
.to_string()
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -2156,8 +2179,9 @@ pub(crate) fn verify(
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
|
||||
let logrows = circuit_settings.run_args.logrows;
|
||||
let commitment = circuit_settings.run_args.commitment.into();
|
||||
|
||||
match circuit_settings.run_args.commitment {
|
||||
match commitment {
|
||||
Commitments::KZG => {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
let params: ParamsKZG<Bn256> = if reduced_srs {
|
||||
|
||||
@@ -21,8 +21,6 @@ use std::io::BufWriter;
|
||||
use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::thread;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
@@ -234,21 +232,15 @@ impl PostgresSource {
|
||||
)
|
||||
};
|
||||
|
||||
let res: Vec<pg_bigdecimal::PgNumeric> = thread::spawn(move || {
|
||||
let mut client = Client::connect(&config, NoTls).unwrap();
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[]).unwrap() {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
let mut client = Client::connect(&config, NoTls)?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[])? {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
res
|
||||
})
|
||||
.join()
|
||||
.map_err(|_| "failed to fetch data from postgres")?;
|
||||
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
|
||||
100
src/graph/mod.rs
100
src/graph/mod.rs
@@ -26,6 +26,7 @@ use self::input::{FileSource, GraphData};
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
@@ -38,7 +39,7 @@ use halo2_proofs::{
|
||||
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
|
||||
};
|
||||
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
use log::{debug, error, trace, warn};
|
||||
@@ -155,7 +156,7 @@ use std::cell::RefCell;
|
||||
thread_local!(
|
||||
/// This is a global variable that holds the settings for the graph
|
||||
/// This is used to pass settings to the layouter and other parts of the circuit without needing to heavily modify the Halo2 API in a new fork
|
||||
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = RefCell::new(None)
|
||||
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = const { RefCell::new(None) }
|
||||
);
|
||||
|
||||
/// Result from a forward pass
|
||||
@@ -482,7 +483,22 @@ pub struct GraphSettings {
|
||||
}
|
||||
|
||||
impl GraphSettings {
|
||||
fn model_constraint_logrows(&self) -> u32 {
|
||||
/// Calc the number of rows required for lookup tables
|
||||
pub fn lookup_log_rows(&self) -> u32 {
|
||||
((self.run_args.lookup_range.1 - self.run_args.lookup_range.0) as f32)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// Calc the number of rows required for lookup tables
|
||||
pub fn lookup_log_rows_with_blinding(&self) -> u32 {
|
||||
((self.run_args.lookup_range.1 - self.run_args.lookup_range.0) as f32
|
||||
+ RESERVED_BLINDING_ROWS as f32)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn model_constraint_logrows_with_blinding(&self) -> u32 {
|
||||
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
@@ -494,16 +510,35 @@ impl GraphSettings {
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the number of rows required for the dynamic lookup and shuffle
|
||||
pub fn dynamic_lookup_and_shuffle_logrows_with_blinding(&self) -> u32 {
|
||||
(self.total_dynamic_col_size as f64
|
||||
+ self.total_shuffle_col_size as f64
|
||||
+ RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
|
||||
self.total_dynamic_col_size + self.total_shuffle_col_size
|
||||
}
|
||||
|
||||
fn module_constraint_logrows(&self) -> u32 {
|
||||
/// calculate the number of rows required for the module constraints
|
||||
pub fn module_constraint_logrows(&self) -> u32 {
|
||||
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the number of rows required for the module constraints
|
||||
pub fn module_constraint_logrows_with_blinding(&self) -> u32 {
|
||||
(self.module_sizes.max_constraints() as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn constants_logrows(&self) -> u32 {
|
||||
(self.total_const_size as f64).log2().ceil() as u32
|
||||
(self.total_const_size as f64 / self.run_args.num_inner_cols as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the total number of instances
|
||||
@@ -526,6 +561,14 @@ impl GraphSettings {
|
||||
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
|
||||
}
|
||||
|
||||
/// calculate the log2 of the total number of instances
|
||||
pub fn log2_total_instances_with_blinding(&self) -> u32 {
|
||||
let sum = self.total_instances().iter().sum::<usize>() + RESERVED_BLINDING_ROWS;
|
||||
|
||||
// max between 1 and the log2 of the sums
|
||||
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
|
||||
}
|
||||
|
||||
/// save params to file
|
||||
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
|
||||
// buf writer
|
||||
@@ -915,7 +958,7 @@ impl GraphCircuit {
|
||||
|
||||
///
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub async fn load_graph_input(
|
||||
pub fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
@@ -925,7 +968,6 @@ impl GraphCircuit {
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
@@ -949,7 +991,7 @@ impl GraphCircuit {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Process the data source for the model
|
||||
async fn process_data_source(
|
||||
fn process_data_source(
|
||||
&mut self,
|
||||
data: &DataSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
@@ -962,8 +1004,16 @@ impl GraphCircuit {
|
||||
for (i, shape) in shapes.iter().enumerate() {
|
||||
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
|
||||
}
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
|
||||
// start runtime and fetch data
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
})
|
||||
}
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
@@ -1049,16 +1099,10 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
|
||||
let mut margin = (
|
||||
(
|
||||
lookup_safety_margin * min_max_lookup.0,
|
||||
lookup_safety_margin * min_max_lookup.1,
|
||||
);
|
||||
if lookup_safety_margin == 1 {
|
||||
margin.0 += 4;
|
||||
margin.1 += 4;
|
||||
}
|
||||
|
||||
margin
|
||||
)
|
||||
}
|
||||
|
||||
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
|
||||
@@ -1129,7 +1173,7 @@ impl GraphCircuit {
|
||||
);
|
||||
|
||||
// These are upper limits, going above these is wasteful, but they are not hard limits
|
||||
let model_constraint_logrows = self.settings().model_constraint_logrows();
|
||||
let model_constraint_logrows = self.settings().model_constraint_logrows_with_blinding();
|
||||
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
|
||||
let constants_logrows = self.settings().constants_logrows();
|
||||
max_logrows = std::cmp::min(
|
||||
@@ -1242,7 +1286,7 @@ impl GraphCircuit {
|
||||
inputs: &mut [Tensor<Fp>],
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
@@ -1291,7 +1335,7 @@ impl GraphCircuit {
|
||||
|
||||
let mut model_results =
|
||||
self.model()
|
||||
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
|
||||
.forward(inputs, &self.settings().run_args, witness_gen)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
@@ -1454,7 +1498,8 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
struct CircuitSize {
|
||||
/// The configuration for the graph circuit
|
||||
pub struct CircuitSize {
|
||||
num_instances: usize,
|
||||
num_advice_columns: usize,
|
||||
num_fixed: usize,
|
||||
@@ -1464,7 +1509,8 @@ struct CircuitSize {
|
||||
}
|
||||
|
||||
impl CircuitSize {
|
||||
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
|
||||
///
|
||||
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
|
||||
CircuitSize {
|
||||
num_instances: cs.num_instance_columns(),
|
||||
num_advice_columns: cs.num_advice_columns(),
|
||||
@@ -1606,6 +1652,8 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
let output_vis = &self.settings().run_args.output_visibility;
|
||||
let mut graph_modules = GraphModules::new();
|
||||
|
||||
let mut constants = ConstantsMap::new();
|
||||
|
||||
let mut config = config.clone();
|
||||
|
||||
let mut inputs = self
|
||||
@@ -1651,6 +1699,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut input_outlets,
|
||||
input_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
// replace inputs with the outlets
|
||||
for (i, outlet) in outlets.iter().enumerate() {
|
||||
@@ -1663,6 +1712,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut inputs,
|
||||
input_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
}
|
||||
|
||||
@@ -1699,6 +1749,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut flattened_params,
|
||||
param_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
|
||||
let shapes = self.model().const_shapes();
|
||||
@@ -1727,6 +1778,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&inputs,
|
||||
&mut vars,
|
||||
&outputs,
|
||||
&mut constants,
|
||||
)
|
||||
.map_err(|e| {
|
||||
log::error!("{}", e);
|
||||
@@ -1751,6 +1803,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut output_outlets,
|
||||
&self.settings().run_args.output_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
|
||||
// replace outputs with the outlets
|
||||
@@ -1764,6 +1817,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut outputs,
|
||||
&self.settings().run_args.output_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ use super::vars::*;
|
||||
use super::GraphError;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
@@ -404,7 +405,7 @@ impl ParsedNodes {
|
||||
.get(input)
|
||||
.ok_or(GraphError::MissingNode(*input))?;
|
||||
let input_dims = node.out_dims();
|
||||
let input_dim = input_dims.get(0).ok_or(GraphError::MissingNode(*input))?;
|
||||
let input_dim = input_dims.first().ok_or(GraphError::MissingNode(*input))?;
|
||||
inputs.push(input_dim.clone());
|
||||
}
|
||||
|
||||
@@ -514,21 +515,24 @@ impl Model {
|
||||
instance_shapes.len().to_string().blue(),
|
||||
"instances".blue()
|
||||
);
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let default_value = if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let inputs: Vec<ValTensor<Fp>> = self
|
||||
.graph
|
||||
.input_shapes()?
|
||||
.iter()
|
||||
.map(|shape| {
|
||||
let mut t: ValTensor<Fp> =
|
||||
vec![default_value.clone(); shape.iter().product()].into();
|
||||
let len = shape.iter().product();
|
||||
let mut t: ValTensor<Fp> = (0..len)
|
||||
.map(|_| {
|
||||
if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::random(&mut rand::thread_rng()))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
@@ -577,13 +581,13 @@ impl Model {
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -799,13 +803,18 @@ impl Model {
|
||||
let input_state_idx = input_state_idx(&input_mappings);
|
||||
|
||||
let mut output_mappings = vec![];
|
||||
for mapping in b.output_mapping.iter() {
|
||||
for (i, mapping) in b.output_mapping.iter().enumerate() {
|
||||
let mut mappings = vec![];
|
||||
if let Some(outlet) = mapping.last_value_slot {
|
||||
mappings.push(OutputMapping::Single {
|
||||
outlet,
|
||||
is_state: mapping.state,
|
||||
});
|
||||
} else if mapping.state {
|
||||
mappings.push(OutputMapping::Single {
|
||||
outlet: i,
|
||||
is_state: mapping.state,
|
||||
});
|
||||
}
|
||||
if let Some(last) = mapping.scan {
|
||||
mappings.push(OutputMapping::Stacked {
|
||||
@@ -814,6 +823,7 @@ impl Model {
|
||||
is_state: false,
|
||||
});
|
||||
}
|
||||
|
||||
output_mappings.push(mappings);
|
||||
}
|
||||
|
||||
@@ -1071,6 +1081,8 @@ impl Model {
|
||||
/// * `layouter` - Halo2 Layouter.
|
||||
/// * `inputs` - The values to feed into the circuit.
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `witnessed_outputs` - The values to compare against.
|
||||
/// * `constants` - The constants for the circuit.
|
||||
pub fn layout(
|
||||
&self,
|
||||
mut config: ModelConfig,
|
||||
@@ -1079,6 +1091,7 @@ impl Model {
|
||||
inputs: &[ValTensor<Fp>],
|
||||
vars: &mut ModelVars<Fp>,
|
||||
witnessed_outputs: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
|
||||
info!("model layout...");
|
||||
|
||||
@@ -1104,14 +1117,12 @@ impl Model {
|
||||
config.base.layout_tables(layouter)?;
|
||||
config.base.layout_range_checks(layouter)?;
|
||||
|
||||
let mut num_rows = 0;
|
||||
let mut linear_coord = 0;
|
||||
let mut total_const_size = 0;
|
||||
let original_constants = constants.clone();
|
||||
|
||||
let outputs = layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut thread_safe_region = RegionCtx::new(region, 0, run_args.num_inner_cols);
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
|
||||
// we need to do this as this loop is called multiple times
|
||||
vars.set_instance_idx(instance_idx);
|
||||
|
||||
@@ -1157,29 +1168,17 @@ impl Model {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
} else if !run_args.output_visibility.is_private() {
|
||||
for output in &outputs {
|
||||
thread_safe_region.increment_total_constants(output.num_constants());
|
||||
}
|
||||
}
|
||||
num_rows = thread_safe_region.row();
|
||||
linear_coord = thread_safe_region.linear_coord();
|
||||
total_const_size = thread_safe_region.total_constants();
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
thread_safe_region.debug_report();
|
||||
|
||||
*constants = thread_safe_region.assigned_constants().clone();
|
||||
|
||||
Ok(outputs)
|
||||
},
|
||||
)?;
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
num_rows.to_string().blue(),
|
||||
"rows".blue(),
|
||||
linear_coord.to_string().yellow(),
|
||||
total_const_size.to_string().red()
|
||||
);
|
||||
)?;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
trace!("model layout took: {:?}", duration);
|
||||
@@ -1201,6 +1200,20 @@ impl Model {
|
||||
.collect();
|
||||
|
||||
for (idx, node) in self.graph.nodes.iter() {
|
||||
debug!("laying out {}: {}", idx, node.as_str(),);
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
region.debug_report();
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
|
||||
node.inputs()
|
||||
.iter()
|
||||
@@ -1212,31 +1225,11 @@ impl Model {
|
||||
// we re-assign inputs, always from the 0 outlet
|
||||
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
|
||||
};
|
||||
|
||||
debug!("output dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
|
||||
idx,
|
||||
node.as_str(),
|
||||
region.row(),
|
||||
region.linear_coord(),
|
||||
region.total_constants(),
|
||||
region.max_lookup_inputs(),
|
||||
region.min_lookup_inputs()
|
||||
);
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input_dims {:?}",
|
||||
"input dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
@@ -1277,8 +1270,8 @@ impl Model {
|
||||
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
|
||||
|
||||
debug!(
|
||||
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
|
||||
num_iter, inputs, model.graph.inputs
|
||||
"{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}",
|
||||
num_iter, inputs, model.graph.inputs, model.graph.outputs
|
||||
);
|
||||
|
||||
let mut full_results: Vec<ValTensor<Fp>> = vec![];
|
||||
@@ -1310,6 +1303,7 @@ impl Model {
|
||||
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
|
||||
|
||||
let mut outlets = BTreeMap::new();
|
||||
let mut stacked_outlets = BTreeMap::new();
|
||||
|
||||
for (mappings, outlet_res) in output_mappings.iter().zip(res) {
|
||||
for mapping in mappings {
|
||||
@@ -1322,25 +1316,42 @@ impl Model {
|
||||
let stacked_res = full_results[*outlet]
|
||||
.clone()
|
||||
.concat_axis(outlet_res.clone(), axis)?;
|
||||
|
||||
outlets.insert(outlet, stacked_res);
|
||||
} else {
|
||||
outlets.insert(outlet, outlet_res.clone());
|
||||
stacked_outlets.insert(outlet, stacked_res);
|
||||
}
|
||||
outlets.insert(outlet, outlet_res.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
full_results = outlets.into_values().collect_vec();
|
||||
// now extend with stacked elements
|
||||
let mut pre_stacked_outlets = outlets.clone();
|
||||
pre_stacked_outlets.extend(stacked_outlets);
|
||||
|
||||
let outlets = outlets.into_values().collect_vec();
|
||||
|
||||
full_results = pre_stacked_outlets.into_values().collect_vec();
|
||||
|
||||
let output_states = output_state_idx(output_mappings);
|
||||
let input_states = input_state_idx(&input_mappings);
|
||||
|
||||
assert_eq!(input_states.len(), output_states.len());
|
||||
assert_eq!(
|
||||
input_states.len(),
|
||||
output_states.len(),
|
||||
"input and output states must be the same length, got {:?} and {:?}",
|
||||
input_mappings,
|
||||
output_mappings
|
||||
);
|
||||
|
||||
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
|
||||
values[*input_idx] = full_results[output_idx].clone();
|
||||
assert_eq!(
|
||||
values[*input_idx].dims(),
|
||||
outlets[output_idx].dims(),
|
||||
"input and output dims must be the same, got {:?} and {:?}",
|
||||
values[*input_idx].dims(),
|
||||
outlets[output_idx].dims()
|
||||
);
|
||||
values[*input_idx] = outlets[output_idx].clone();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1380,7 +1391,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
@@ -1399,29 +1410,31 @@ impl Model {
|
||||
vars: ModelVars::new_dummy(),
|
||||
};
|
||||
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen);
|
||||
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
|
||||
let default_value = if !self.visibility.output.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let output_scales = self.graph.get_output_scales()?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut comparator: ValTensor<Fp> = (0..output.len())
|
||||
.map(|_| {
|
||||
if !self.visibility.output.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::random(&mut rand::thread_rng()))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let mut comparator: ValTensor<Fp> =
|
||||
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
@@ -1432,7 +1445,7 @@ impl Model {
|
||||
res?;
|
||||
} else if !self.visibility.output.is_private() {
|
||||
for output in &outputs {
|
||||
region.increment_total_constants(output.num_constants());
|
||||
region.update_constants(output.create_constants_map());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1441,14 +1454,7 @@ impl Model {
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
region.row().to_string().blue(),
|
||||
"rows".blue(),
|
||||
region.linear_coord().to_string().yellow(),
|
||||
region.total_constants().to_string().red()
|
||||
);
|
||||
region.debug_report();
|
||||
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user