mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
7 Commits
ac/patch-m
...
v18.1.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ddce63684a | ||
|
|
83c4afce3b | ||
|
|
50740a22df | ||
|
|
a2624f6303 | ||
|
|
fc5be4f949 | ||
|
|
d0ba505baa | ||
|
|
f35688917d |
72
.github/workflows/engine.yml
vendored
72
.github/workflows/engine.yml
vendored
@@ -19,6 +19,8 @@ jobs:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
@@ -45,43 +47,39 @@ jobs:
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
echo '{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "${RELEASE_TAG}",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}' > pkg/package.json
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
@@ -195,6 +193,8 @@ jobs:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -202,10 +202,8 @@ jobs:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${RELEASE_TAG}\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
|
||||
|
||||
4
.github/workflows/pypi-gpu.yml
vendored
4
.github/workflows/pypi-gpu.yml
vendored
@@ -25,6 +25,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -49,8 +51,6 @@ jobs:
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
|
||||
109
.github/workflows/pypi.yml
vendored
109
.github/workflows/pypi.yml
vendored
@@ -23,6 +23,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64, universal2-apple-darwin]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -32,10 +34,14 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
@@ -89,6 +95,14 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: ${{ matrix.target }}
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -138,6 +152,14 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -148,7 +170,6 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -187,57 +208,6 @@ jobs:
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
|
||||
# linux-cross:
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
# matrix:
|
||||
# target: [aarch64, armv7]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: 3.12
|
||||
|
||||
# - name: Install cross-compilation tools for aarch64
|
||||
# if: matrix.target == 'aarch64'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
|
||||
|
||||
# - name: Install cross-compilation tools for armv7
|
||||
# if: matrix.target == 'armv7'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
|
||||
|
||||
# - name: Build wheels
|
||||
# uses: PyO3/maturin-action@v1
|
||||
# with:
|
||||
# target: ${{ matrix.target }}
|
||||
# manylinux: auto
|
||||
# args: --release --out dist --features python-bindings
|
||||
|
||||
# - uses: uraimo/run-on-arch-action@v2.5.0
|
||||
# name: Install built wheel
|
||||
# with:
|
||||
# arch: ${{ matrix.target }}
|
||||
# distro: ubuntu20.04
|
||||
# githubToken: ${{ github.token }}
|
||||
# install: |
|
||||
# apt-get update
|
||||
# apt-get install -y --no-install-recommends python3 python3-pip
|
||||
# pip3 install -U pip
|
||||
# run: |
|
||||
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
|
||||
# python3 -c "import ezkl"
|
||||
|
||||
# - name: Upload wheels
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: wheels
|
||||
# path: dist
|
||||
|
||||
musllinux:
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -273,6 +243,7 @@ jobs:
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -323,6 +294,14 @@ jobs:
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -366,8 +345,6 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
if: "startsWith(github.ref, 'refs/tags/')"
|
||||
# TODO: Uncomment if linux-cross is working
|
||||
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
|
||||
needs: [macos, windows, linux, musllinux, musllinux-cross]
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
@@ -375,24 +352,20 @@ jobs:
|
||||
name: wheels
|
||||
- name: List Files
|
||||
run: ls -R
|
||||
|
||||
# Both publish steps will fail if there is no trusted publisher setup
|
||||
# On failure the publish step will then simply continue to the next one
|
||||
|
||||
# # publishes to TestPyPI
|
||||
# - name: Publish package distribution to TestPyPI
|
||||
# uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# packages-dir: ./
|
||||
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
permissions:
|
||||
@@ -409,4 +382,4 @@ jobs:
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -779,6 +779,8 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: Neural bow
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
@@ -798,8 +800,6 @@ jobs:
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Reusable verifier tutorial
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '18.1.6'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -592,9 +592,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// this is 0 if the index is the same as the column index (starting from 1)
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* table
|
||||
* (table
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier =
|
||||
table.selector_constructor.get_selector_val_at_idx(col_idx);
|
||||
@@ -626,6 +626,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.static_lookups
|
||||
.selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
@@ -904,9 +938,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let default_x = range_check.get_first_element(col_idx);
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* range_check
|
||||
* (range_check
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier = range_check
|
||||
.selector_constructor
|
||||
@@ -929,6 +963,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.range_checks
|
||||
.selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
|
||||
@@ -75,7 +75,7 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>,
|
||||
f: impl Fn(&BaseConfig<F>, &mut RegionCtx<F>, &ValTensor<F>) -> Result<ValTensor<F>, CircuitError>,
|
||||
) -> Result<(), CircuitError> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let one = create_constant_tensor(F::from(1), 1);
|
||||
|
||||
let f_x = f(config, region, x)?;
|
||||
@@ -87,22 +87,17 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
let f_x_minus_1 = f(config, region, &x_minus_1)?;
|
||||
|
||||
// because the function is convex, the result should be the minimum of the three
|
||||
// not that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) <= f(x-1)
|
||||
// note that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) < f(x-1)
|
||||
// the result is 1 if the function is optimal solely because of the convexity of the function
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal (or f(x) and f(x-1)).
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal, but if (f(x) = f(x + 1))
|
||||
// f(x+1) is not smaller than f(x + 1 - 1) = f(x) and thus f(x) is unique
|
||||
let f_x_is_opt_rhs = less_equal(config, region, &[f_x.clone(), f_x_plus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less_equal(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
|
||||
let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(())
|
||||
Ok(is_opt)
|
||||
}
|
||||
|
||||
/// Err is less than some constant
|
||||
@@ -290,7 +285,14 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
// we need to add 1 to the points where it is zero to ignore the cvx opt conditions at those points
|
||||
let mut is_opt = optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
is_opt = pairwise(config, region, &[is_opt, equal_zero_mask], BaseOp::Add)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
@@ -362,7 +364,13 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
let is_opt = optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
@@ -132,21 +132,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size given the number of rows and reserved blinding rows
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
(range_len / col_size as IntegerRep) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
@@ -355,16 +350,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
|
||||
@@ -11,6 +11,12 @@ pub enum GraphError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Non scalar power
|
||||
#[error("we only support scalar powers")]
|
||||
NonScalarPower,
|
||||
/// Non scalar base for exponentiation
|
||||
#[error("we only support scalar bases for exponentiation")]
|
||||
NonScalarBase,
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
@@ -143,4 +149,7 @@ pub enum GraphError {
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
/// Only nearest neighbor interpolation is supported
|
||||
#[error("only nearest neighbor interpolation is supported")]
|
||||
InvalidInterpolation,
|
||||
}
|
||||
|
||||
@@ -44,11 +44,10 @@ use tract_onnx::tract_hir::{
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `vec` - the vector to quantize.
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `elem` - the element to quantize.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(
|
||||
@@ -85,7 +84,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
|
||||
f64::powf(2., scale as f64)
|
||||
}
|
||||
|
||||
/// Converts a scale (log base 2) to a fixed point multiplier.
|
||||
/// Converts a fixed point multiplier to a scale (log base 2).
|
||||
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
mult.log2().round() as crate::Scale
|
||||
}
|
||||
@@ -312,6 +311,9 @@ pub fn new_op_from_onnx(
|
||||
let mut deleted_indices = vec![];
|
||||
let node = match node.op().name().as_ref() {
|
||||
"ShiftLeft" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -324,10 +326,13 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
|
||||
}
|
||||
}
|
||||
"ShiftRight" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -340,7 +345,7 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
@@ -363,7 +368,10 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
|
||||
if input_ops.len() != 3 {
|
||||
return Err(GraphError::InvalidDims(idx, "range".to_string()));
|
||||
}
|
||||
|
||||
let input_ops = input_ops
|
||||
.iter()
|
||||
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
|
||||
@@ -419,6 +427,10 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
|
||||
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| {
|
||||
@@ -447,8 +459,17 @@ pub fn new_op_from_onnx(
|
||||
"Topk" => {
|
||||
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
};
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
}
|
||||
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
c.raw_values.map(|x| x as usize)[0]
|
||||
@@ -488,6 +509,10 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -522,6 +547,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -555,6 +583,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -589,6 +620,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -684,7 +718,9 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
assert_eq!(axes.len(), 1, "only support argmax over one axis");
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
|
||||
}
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
|
||||
}
|
||||
@@ -694,7 +730,9 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
assert_eq!(axes.len(), 1, "only support argmin over one axis");
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
|
||||
}
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
|
||||
}
|
||||
@@ -803,6 +841,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
@@ -846,6 +887,9 @@ pub fn new_op_from_onnx(
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
@@ -933,7 +977,9 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
|
||||
let dt = op.to;
|
||||
|
||||
assert_eq!(input_scales.len(), 1);
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
|
||||
};
|
||||
|
||||
match dt {
|
||||
DatumType::Bool
|
||||
@@ -983,6 +1029,11 @@ pub fn new_op_from_onnx(
|
||||
|
||||
if const_idx.len() == 1 {
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if inputs.len() <= const_idx {
|
||||
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
|
||||
// if not divisible by 2 then we need to add a range check
|
||||
@@ -1057,6 +1108,9 @@ pub fn new_op_from_onnx(
|
||||
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
|
||||
}
|
||||
};
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
|
||||
}
|
||||
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
@@ -1096,22 +1150,42 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Ceil" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Floor" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Round" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "round".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"RoundHalfToEven" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
// Extract the slope layer hyperparams from a const
|
||||
@@ -1121,7 +1195,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
return Err(GraphError::NonScalarPower);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
@@ -1138,7 +1214,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[0].decrement_use();
|
||||
deleted_indices.push(0);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar base")
|
||||
return Err(GraphError::NonScalarBase);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
|
||||
let base = c.raw_values[0];
|
||||
@@ -1148,10 +1226,14 @@ pub fn new_op_from_onnx(
|
||||
base: base.into(),
|
||||
})
|
||||
} else {
|
||||
unimplemented!("only support constant base or pow for now")
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
}
|
||||
"Div" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -1159,14 +1241,15 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 {
|
||||
if const_idx.len() > 1 || const_idx.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if const_idx != 1 {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
@@ -1180,10 +1263,14 @@ pub fn new_op_from_onnx(
|
||||
denom: denom.into(),
|
||||
})
|
||||
} else {
|
||||
unimplemented!("only support non zero divisors of size 1")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support non zero divisors of size 1".to_string(),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
|
||||
@@ -1323,7 +1410,7 @@ pub fn new_op_from_onnx(
|
||||
if !resize_node.contains("interpolator: Nearest")
|
||||
&& !resize_node.contains("nearest: Floor")
|
||||
{
|
||||
unimplemented!("Only nearest neighbor interpolation is supported")
|
||||
return Err(GraphError::InvalidInterpolation);
|
||||
}
|
||||
// check if optional scale factor is present
|
||||
if inputs.len() != 2 && inputs.len() != 3 {
|
||||
@@ -1427,6 +1514,10 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Linear(PolyOp::Reshape(output_shape))
|
||||
}
|
||||
"Flatten" => {
|
||||
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
|
||||
};
|
||||
|
||||
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
|
||||
SupportedOp::Linear(PolyOp::Flatten(new_dims))
|
||||
}
|
||||
@@ -1546,6 +1637,7 @@ pub fn homogenize_input_scales(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// tests for the utility module
|
||||
pub mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -72,11 +72,10 @@ mod py_tests {
|
||||
"torchtext==0.17.2",
|
||||
"torchvision==0.17.2",
|
||||
"pandas==2.2.1",
|
||||
"numpy==1.26.4",
|
||||
"seaborn==0.13.2",
|
||||
"notebook==7.1.2",
|
||||
"nbconvert==7.16.3",
|
||||
"onnx==1.16.0",
|
||||
"onnx==1.17.0",
|
||||
"kaggle==1.6.8",
|
||||
"py-solc-x==2.0.3",
|
||||
"web3==7.5.0",
|
||||
@@ -90,12 +89,13 @@ mod py_tests {
|
||||
"xgboost==2.0.3",
|
||||
"hummingbird-ml==0.4.11",
|
||||
"lightgbm==4.3.0",
|
||||
"numpy==1.26.4",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new("pip")
|
||||
.args(["install", "numpy==1.23"])
|
||||
.args(["install", "numpy==1.26.4"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user