chore: Move to the mono repo layout
1
frontends/concrete-python/.dockerignore
Normal file
@@ -0,0 +1 @@
|
||||
**
|
||||
15
frontends/concrete-python/.editorconfig
Normal file
@@ -0,0 +1,15 @@
|
||||
# EditorConfig is awesome: https://EditorConfig.org
|
||||
|
||||
# top-most EditorConfig file
|
||||
root = true
|
||||
|
||||
# Unix-style newlines with a newline ending every file
|
||||
[*]
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
|
||||
# 4 space indentation
|
||||
[*.py]
|
||||
charset = utf-8
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
1
frontends/concrete-python/.gitbook.yaml
Normal file
@@ -0,0 +1 @@
|
||||
root: ./docs
|
||||
28
frontends/concrete-python/LICENSE
Normal file
@@ -0,0 +1,28 @@
|
||||
BSD 3-Clause Clear License
|
||||
|
||||
Copyright © 2022 ZAMA.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this
|
||||
list of conditions and the following disclaimer in the documentation and/or other
|
||||
materials provided with the distribution.
|
||||
|
||||
3. Neither the name of ZAMA nor the names of its contributors may be used to endorse
|
||||
or promote products derived from this software without specific prior written permission.
|
||||
|
||||
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE.
|
||||
THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
|
||||
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
|
||||
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
||||
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
313
frontends/concrete-python/Makefile
Normal file
@@ -0,0 +1,313 @@
|
||||
SHELL:=/bin/bash
|
||||
|
||||
DEV_DOCKER_IMG:=concrete-numpy-dev
|
||||
DEV_DOCKERFILE:=docker/Dockerfile.dev
|
||||
DEV_CONTAINER_VENV_VOLUME:=concrete-numpy-internal-venv
|
||||
DEV_CONTAINER_CACHE_VOLUME:=concrete-numpy-internal-cache
|
||||
SRC_DIR:=concrete
|
||||
|
||||
.PHONY: setup_env # Set up the environment
|
||||
setup_env:
|
||||
poetry run python -m pip install -U pip wheel
|
||||
poetry run python -m pip install -U --force-reinstall setuptools
|
||||
if [[ $$(uname) != "Linux" ]] && [[ $$(uname) != "Darwin" ]]; then \
|
||||
poetry install --only dev; \
|
||||
else \
|
||||
poetry install; \
|
||||
fi
|
||||
|
||||
.PHONY: sync_env # Synchronise the environment
|
||||
sync_env:
|
||||
if [[ $$(uname) != "Linux" ]] && [[ $$(uname) != "Darwin" ]]; then \
|
||||
poetry install --remove-untracked --only dev; \
|
||||
else \
|
||||
poetry install --remove-untracked; \
|
||||
fi
|
||||
$(MAKE) setup_env
|
||||
|
||||
.PHONY: python_format # Apply python formatting
|
||||
python_format:
|
||||
poetry run env bash ./script/source_format/format_python.sh \
|
||||
--dir $(SRC_DIR) --dir tests --dir script
|
||||
|
||||
.PHONY: check_python_format # Check python format
|
||||
check_python_format:
|
||||
poetry run env bash ./script/source_format/format_python.sh \
|
||||
--dir $(SRC_DIR) --dir tests --dir script --check
|
||||
|
||||
.PHONY: check_finalize_nb # Sanitize notebooks
|
||||
check_finalize_nb:
|
||||
poetry run python ./script/nbmake_utils/notebook_finalize.py docs --check
|
||||
|
||||
.PHONY: pylint # Run pylint
|
||||
pylint:
|
||||
$(MAKE) --keep-going pylint_src pylint_tests pylint_script
|
||||
|
||||
.PHONY: pylint_src # Run pylint on sources
|
||||
pylint_src:
|
||||
poetry run pylint --rcfile=pylintrc $(SRC_DIR)
|
||||
|
||||
.PHONY: pylint_tests # Run pylint on tests
|
||||
pylint_tests:
|
||||
@# Disable duplicate code detection (R0801) in tests
|
||||
@# Disable unnecessary lambda (W0108) for tests
|
||||
find ./tests/ -type f -name "*.py" | xargs poetry run pylint --disable=R0801,W0108 --rcfile=pylintrc
|
||||
|
||||
.PHONY: pylint_script # Run pylint on scripts
|
||||
pylint_script:
|
||||
find ./script/ -type f -name "*.py" | xargs poetry run pylint --rcfile=pylintrc
|
||||
|
||||
.PHONY: flake8 # Run flake8
|
||||
flake8:
|
||||
poetry run flake8 --max-line-length 100 --per-file-ignores="__init__.py:F401" \
|
||||
$(SRC_DIR)/ tests/ script/
|
||||
|
||||
.PHONY: ruff
|
||||
ruff:
|
||||
poetry run ruff $(SRC_DIR)/ tests/ script/
|
||||
|
||||
.PHONY: python_linting # Run python linters
|
||||
python_linting: pylint flake8 ruff
|
||||
|
||||
.PHONY: conformance # Run command to fix some conformance issues automatically
|
||||
conformance: finalize_nb python_format supported_functions licenses
|
||||
|
||||
.PHONY: pcc # Run pre-commit checks
|
||||
pcc:
|
||||
@$(MAKE) --keep-going --jobs $$(./script/make_utils/ncpus.sh) --output-sync=recurse \
|
||||
--no-print-directory pcc_internal
|
||||
|
||||
PCC_DEPS := check_python_format check_finalize_nb python_linting mypy_ci pydocstyle shell_lint
|
||||
PCC_DEPS += check_supported_functions # check_licenses
|
||||
|
||||
# Not commented on purpose for make help, since internal
|
||||
.PHONY: pcc_internal
|
||||
pcc_internal: $(PCC_DEPS)
|
||||
|
||||
# One can reproduce pytest thanks to the --randomly-seed which is given by
|
||||
# pytest-randomly
|
||||
.PHONY: pytest # Run pytest
|
||||
pytest:
|
||||
poetry run pytest -svv \
|
||||
--global-coverage=.global-coverage.json \
|
||||
-n $$(./script/make_utils/ncpus.sh) \
|
||||
--cov=$(SRC_DIR) --cov-fail-under=100 \
|
||||
--randomly-dont-reorganize \
|
||||
--cov-report=term-missing:skip-covered tests/
|
||||
|
||||
# Not a huge fan of ignoring missing imports, but some packages do not have typing stubs
|
||||
.PHONY: mypy # Run mypy
|
||||
mypy:
|
||||
poetry run mypy -p $(SRC_DIR) --ignore-missing-imports
|
||||
|
||||
# Friendly target to run mypy without ignoring missing stubs and still have errors messages
|
||||
# Allows to see which stubs we are missing
|
||||
.PHONY: mypy_ns # Run mypy (without ignoring missing stubs)
|
||||
mypy_ns:
|
||||
poetry run mypy -p $(SRC_DIR)
|
||||
|
||||
.PHONY: mypy_test # Run mypy on test files
|
||||
mypy_test:
|
||||
find ./tests/ -name "*.py" | xargs poetry run mypy --ignore-missing-imports
|
||||
|
||||
.PHONY: mypy_script # Run mypy on scripts
|
||||
mypy_script:
|
||||
find ./script/ -name "*.py" | xargs poetry run mypy --ignore-missing-imports
|
||||
|
||||
# The plus indicates that make will be called by the command and allows to share the context with
|
||||
# the parent make execution. We serialize calls to these targets as they may overwrite each others
|
||||
# cache which can cause issues.
|
||||
.PHONY: mypy_ci # Run all mypy checks for CI
|
||||
mypy_ci:
|
||||
$(MAKE) --keep-going mypy mypy_test mypy_script
|
||||
|
||||
.PHONY: docker_build # Build dev docker
|
||||
docker_build:
|
||||
BUILD_ARGS=; \
|
||||
if [[ $$(uname) == "Linux" ]]; then \
|
||||
BUILD_ARGS="--build-arg BUILD_UID=$$(id -u) --build-arg BUILD_GID=$$(id -g)"; \
|
||||
fi; \
|
||||
DOCKER_BUILDKIT=1 docker build $${BUILD_ARGS:+$$BUILD_ARGS} \
|
||||
--pull -t $(DEV_DOCKER_IMG) -f $(DEV_DOCKERFILE) .
|
||||
|
||||
.PHONY: docker_rebuild # Rebuild docker
|
||||
docker_rebuild: docker_clean_volumes
|
||||
BUILD_ARGS=; \
|
||||
if [[ $$(uname) == "Linux" ]]; then \
|
||||
BUILD_ARGS="--build-arg BUILD_UID=$$(id -u) --build-arg BUILD_GID=$$(id -g)"; \
|
||||
fi; \
|
||||
DOCKER_BUILDKIT=1 docker build $${BUILD_ARGS:+$$BUILD_ARGS} \
|
||||
--pull --no-cache -t $(DEV_DOCKER_IMG) -f $(DEV_DOCKERFILE) .
|
||||
|
||||
.PHONY: docker_start # Launch docker
|
||||
docker_start:
|
||||
@# the slash before pwd is for Windows
|
||||
docker run --rm -it \
|
||||
-p 8888:8888 \
|
||||
--env DISPLAY=host.docker.internal:0 \
|
||||
--volume /"$$(pwd)":/src \
|
||||
--volume $(DEV_CONTAINER_VENV_VOLUME):/home/dev_user/dev_venv \
|
||||
--volume $(DEV_CONTAINER_CACHE_VOLUME):/home/dev_user/.cache \
|
||||
$(DEV_DOCKER_IMG)
|
||||
|
||||
.PHONY: docker_build_and_start # Docker build and start
|
||||
docker_build_and_start: docker_build docker_start
|
||||
|
||||
.PHONY: docker_bas # Docker build and start
|
||||
docker_bas: docker_build_and_start
|
||||
|
||||
.PHONY: docker_clean_volumes # Docker clean volumes
|
||||
docker_clean_volumes:
|
||||
docker volume rm -f $(DEV_CONTAINER_VENV_VOLUME)
|
||||
docker volume rm -f $(DEV_CONTAINER_CACHE_VOLUME)
|
||||
|
||||
.PHONY: docker_cv # Docker clean volumes
|
||||
docker_cv: docker_clean_volumes
|
||||
|
||||
.PHONY: pydocstyle # Launch syntax checker on source code documentation
|
||||
pydocstyle:
|
||||
@# From http://www.pydocstyle.org/en/stable/error_codes.html
|
||||
poetry run pydocstyle $(SRC_DIR) --convention google --add-ignore=D1,D200,D202,D212,D402,D417 --add-select=D401
|
||||
|
||||
.PHONY: finalize_nb # Sanitize notebooks
|
||||
finalize_nb:
|
||||
poetry run python ./script/nbmake_utils/notebook_finalize.py docs
|
||||
|
||||
# A warning in a package unrelated to the project made pytest fail with notebooks
|
||||
# Run notebook tests without warnings as sources are already tested with warnings treated as errors
|
||||
.PHONY: pytest_nb # Launch notebook tests
|
||||
pytest_nb:
|
||||
find docs -name "*.ipynb" | grep -v _build | grep -v .ipynb_checkpoints | xargs poetry run pytest -Wignore --nbmake
|
||||
|
||||
.PHONY: jupyter # Launch jupyter notebook
|
||||
jupyter:
|
||||
poetry run jupyter notebook --allow-root --no-browser --ip=0.0.0.0
|
||||
|
||||
.PHONY: release_docker # Build a docker release image
|
||||
release_docker:
|
||||
./docker/build_release_image.sh
|
||||
|
||||
.PHONY: upgrade_py_deps # Upgrade python dependencies
|
||||
upgrade_py_deps:
|
||||
./script/make_utils/upgrade_deps.sh
|
||||
|
||||
# Keeping this target as it proved useful before we had a proper package, allowed to run code that
|
||||
# pytest-codeblocks was failing to execute if not installed as a pip package.
|
||||
# This is done by hand as pytest-codeblocks was failing with our native extensions.
|
||||
# See refused PR on the project here: https://github.com/nschloe/pytest-codeblocks/pull/58
|
||||
# Test code blocks using a custom python script in the documentation
|
||||
.PHONY: test_codeblocks
|
||||
test_codeblocks:
|
||||
poetry run python ./script/make_utils/test_md_python_code.py --md_dir docs/
|
||||
|
||||
.PHONY: pytest_codeblocks # Test code blocks using pytest in the documentation
|
||||
pytest_codeblocks:
|
||||
poetry run pytest --codeblocks -svv -n $$(./script/make_utils/ncpus.sh) \
|
||||
--randomly-dont-reorganize docs/
|
||||
|
||||
# From https://stackoverflow.com/a/63523300 for the find command
|
||||
.PHONY: shell_lint # Lint all bash scripts
|
||||
shell_lint:
|
||||
find \( -path "./.venv" -o -path "./.docker_venv" \) -prune -o -type f -name "*.sh" -print | \
|
||||
xargs shellcheck
|
||||
|
||||
.PHONY: set_version_no_commit # Dry run for set_version
|
||||
set_version_no_commit:
|
||||
@if [[ "$$VERSION" == "" ]]; then \
|
||||
echo "VERSION env variable is empty. Please set to desired version."; \
|
||||
exit 1; \
|
||||
fi && \
|
||||
poetry run python ./script/make_utils/version_utils.py set-version --version "$${VERSION}"
|
||||
|
||||
.PHONY: set_version # Generate a new version number and update all files with it accordingly
|
||||
set_version:
|
||||
@if [[ "$$VERSION" == "" ]]; then \
|
||||
echo "VERSION env variable is empty. Please set to desired version."; \
|
||||
exit 1; \
|
||||
fi && \
|
||||
STASH_COUNT="$$(git stash list | wc -l)" && \
|
||||
git stash && \
|
||||
poetry run python ./script/make_utils/version_utils.py set-version --version "$${VERSION}" && \
|
||||
git add -u && \
|
||||
git commit -m "chore: bump version to $${VERSION}" && \
|
||||
NEW_STASH_COUNT="$$(git stash list | wc -l)" && \
|
||||
if [[ "$$NEW_STASH_COUNT" != "$$STASH_COUNT" ]]; then \
|
||||
git stash pop; \
|
||||
fi
|
||||
|
||||
.PHONY: changelog # Generate a changelog
|
||||
changelog:
|
||||
PROJECT_VER=($$(poetry version)) && \
|
||||
PROJECT_VER="$${PROJECT_VER[1]}" && \
|
||||
poetry run python ./script/make_utils/changelog_helper.py > "CHANGELOG_$${PROJECT_VER}.md"
|
||||
|
||||
.PHONY: release # Create a new release
|
||||
release:
|
||||
@PROJECT_VER=($$(poetry version)) && \
|
||||
PROJECT_VER="$${PROJECT_VER[1]}" && \
|
||||
TAG_NAME="v$${PROJECT_VER}" && \
|
||||
git fetch --tags --force && \
|
||||
git tag -s -a -m "$${TAG_NAME} release" "$${TAG_NAME}" && \
|
||||
git push origin "refs/tags/$${TAG_NAME}"
|
||||
|
||||
|
||||
.PHONY: show_scope # Show the accepted types and optional scopes (for git conventional commits)
|
||||
show_scope:
|
||||
@echo "Accepted types and optional scopes:"
|
||||
@cat .github/workflows/continuous-integration.yaml | grep feat | grep pattern | cut -f 2- -d ":" | cut -f 2- -d " "
|
||||
|
||||
.PHONY: show_type # Show the accepted types and optional scopes (for git conventional commits)
|
||||
show_type:show_scope
|
||||
|
||||
# grep recursively, ignore binary files, print file line, print file name
|
||||
# exclude dot dirs, exclude pylintrc (would match the notes)
|
||||
# exclude notebooks (sometimes matches in svg text), match the notes in this directory
|
||||
.PHONY: todo # List all todo left in the code
|
||||
todo:
|
||||
@NOTES_ARGS=$$(poetry run python ./script/make_utils/get_pylintrc_notes.py \
|
||||
--pylintrc-path pylintrc) && \
|
||||
grep -rInH --exclude-dir='.[^.]*' --exclude=pylintrc --exclude='*.ipynb' "$${NOTES_ARGS}" .
|
||||
|
||||
|
||||
.PHONY: supported_functions # Update docs with supported functions
|
||||
supported_functions:
|
||||
poetry run python script/doc_utils/gen_supported_ufuncs.py docs/getting-started/compatibility.md
|
||||
|
||||
.PHONY: check_supported_functions # Check supported functions (for the doc)
|
||||
check_supported_functions:
|
||||
poetry run python script/doc_utils/gen_supported_ufuncs.py docs/getting-started/compatibility.md --check
|
||||
|
||||
.PHONY: licenses # Generate the list of licenses of dependencies
|
||||
licenses:
|
||||
@./script/make_utils/licenses.sh
|
||||
|
||||
.PHONY: check_licenses # Check if the licenses of dependencies have changed
|
||||
check_licenses:
|
||||
@TMP_OUT="$$(mktemp)" && \
|
||||
if ! poetry run env bash ./script/make_utils/licenses.sh --check > "$${TMP_OUT}"; then \
|
||||
cat "$${TMP_OUT}"; \
|
||||
rm -f "$${TMP_OUT}"; \
|
||||
echo "Error while checking licenses, see log above."; \
|
||||
echo "Consider re-running 'make licenses'"; \
|
||||
exit 1; \
|
||||
else \
|
||||
echo "Licenses check OK"; \
|
||||
fi
|
||||
.PHONY: check_licenses
|
||||
|
||||
.PHONY: help # Generate list of targets with descriptions
|
||||
help:
|
||||
@grep '^.PHONY: .* #' Makefile | sed 's/\.PHONY: \(.*\) # \(.*\)/\1\t\2/' | expand -t30 | sort
|
||||
|
||||
.PHONY: pip_audit # Run pip-audit and check if there are known vulnerabilities in our dependencies
|
||||
pip_audit:
|
||||
poetry run pip-audit
|
||||
|
||||
.PHONY: clean_local_git # Tell the user how to delete local git branches, except main
|
||||
clean_local_git:
|
||||
@git fetch --all --prune
|
||||
@echo "Consider doing: "
|
||||
@echo
|
||||
@# Don't consider deleting `main` or current branches
|
||||
@git branch | grep -v "^*" | grep -v main | xargs echo "git branch -D "
|
||||
@echo
|
||||
145
frontends/concrete-python/README.md
Normal file
@@ -0,0 +1,145 @@
|
||||
<p align="center">
|
||||
<!-- product name logo -->
|
||||
<img width=600 src="https://user-images.githubusercontent.com/5758427/193612313-6b1124c7-8e3e-4e23-8b8c-57fd43b17d4f.png">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<!-- Version badge using shields.io -->
|
||||
<a href="https://github.com/zama-ai/concrete-numpy/releases">
|
||||
<img src="https://img.shields.io/github/v/release/zama-ai/concrete-numpy?style=flat-square">
|
||||
</a>
|
||||
<!-- Link to docs badge using shields.io -->
|
||||
<a href="https://docs.zama.ai/concrete-numpy/">
|
||||
<img src="https://img.shields.io/badge/read-documentation-yellow?style=flat-square">
|
||||
</a>
|
||||
<!-- Community forum badge using shields.io -->
|
||||
<a href="https://community.zama.ai/c/concrete-numpy">
|
||||
<img src="https://img.shields.io/badge/community%20forum-online-brightgreen?style=flat-square">
|
||||
</a>
|
||||
<!-- Open source badge using shields.io -->
|
||||
<a href="https://docs.zama.ai/concrete-numpy/developer/contributing">
|
||||
<img src="https://img.shields.io/badge/we're%20open%20source-contributing.md-blue?style=flat-square">
|
||||
</a>
|
||||
<!-- Follow on twitter badge using shields.io -->
|
||||
<a href="https://twitter.com/zama_fhe">
|
||||
<img src="https://img.shields.io/badge/follow-zama_fhe-blue?logo=twitter&style=flat-square">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
**Concrete Numpy** is an open-source library which simplifies the use of fully homomorphic encryption (FHE) in Python.
|
||||
|
||||
FHE is a powerful cryptographic tool, which allows computation to be performed directly on encrypted data without needing to decrypt it first.
|
||||
|
||||
With FHE, you can build services that preserve the privacy of the users. FHE is also great against data breaches as everything is done on encrypted data. Even if the server is compromised, in the end no sensitive data is leaked.
|
||||
|
||||
## Main features
|
||||
|
||||
- Ability to compile Python functions (that may use NumPy within) to their FHE equivalents, to operate on encrypted data
|
||||
- Support for [large collection of operators](https://docs.zama.ai/concrete-numpy/getting-started/compatibility)
|
||||
- Partial support for floating points
|
||||
- Support for table lookups on integers
|
||||
- Support for integration with Client / Server architectures
|
||||
|
||||
## Installation
|
||||
|
||||
| OS / HW | Available on Docker | Available on PyPI |
|
||||
| :----------------------------------: | :-----------------: | :--------------: |
|
||||
| Linux | Yes | Yes |
|
||||
| Windows | Yes | Coming soon |
|
||||
| Windows Subsystem for Linux | Yes | Yes |
|
||||
| macOS (Intel) | Yes | Yes |
|
||||
| macOS (Apple Silicon, ie M1, M2 etc) | Yes (Rosetta) | Coming soon |
|
||||
|
||||
|
||||
The preferred way to install Concrete Numpy is through PyPI:
|
||||
|
||||
```shell
|
||||
pip install concrete-numpy
|
||||
```
|
||||
|
||||
You can get the concrete-numpy docker image by pulling the latest docker image:
|
||||
|
||||
```shell
|
||||
docker pull zamafhe/concrete-numpy:v0.10.0
|
||||
```
|
||||
|
||||
You can find more detailed installation instructions in [installing.md](docs/getting-started/installing.md)
|
||||
|
||||
## Getting started
|
||||
|
||||
```python
|
||||
import concrete.numpy as cnp
|
||||
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
compiler = cnp.Compiler(add, {"x": "encrypted", "y": "encrypted"})
|
||||
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1), (3, 2), (6, 1), (1, 7), (4, 5), (5, 4)]
|
||||
|
||||
print(f"Compiling...")
|
||||
circuit = compiler.compile(inputset)
|
||||
|
||||
print(f"Generating keys...")
|
||||
circuit.keygen()
|
||||
|
||||
examples = [(3, 4), (1, 2), (7, 7), (0, 0)]
|
||||
for example in examples:
|
||||
encrypted_example = circuit.encrypt(*example)
|
||||
encrypted_result = circuit.run(encrypted_example)
|
||||
result = circuit.decrypt(encrypted_result)
|
||||
print(f"Evaluation of {' + '.join(map(str, example))} homomorphically = {result}")
|
||||
```
|
||||
|
||||
or if you have a simple function that you can decorate, and you don't care about explicit steps of key generation, encryption, evaluation and decryption:
|
||||
|
||||
```python
|
||||
import concrete.numpy as cnp
|
||||
|
||||
@cnp.compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1), (3, 2), (6, 1), (1, 7), (4, 5), (5, 4)]
|
||||
|
||||
print(f"Compiling...")
|
||||
circuit = add.compile(inputset)
|
||||
|
||||
examples = [(3, 4), (1, 2), (7, 7), (0, 0)]
|
||||
for example in examples:
|
||||
result = circuit.encrypt_run_decrypt(*example)
|
||||
print(f"Evaluation of {' + '.join(map(str, example))} homomorphically = {result}")
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
Full, comprehensive documentation is available at [https://docs.zama.ai/concrete-numpy](https://docs.zama.ai/concrete-numpy).
|
||||
|
||||
## Target users
|
||||
|
||||
Concrete Numpy is a generic library that supports a variety of use cases. Because of this flexibility,
|
||||
it doesn't provide primitives for specific use cases.
|
||||
|
||||
If you have a specific use case, or a specific field of computation, you may want to build abstractions on top of Concrete Numpy.
|
||||
|
||||
One such example is [Concrete ML](https://github.com/zama-ai/concrete-ml), which is built on top of Concrete Numpy to simplify Machine Learning oriented use cases.
|
||||
|
||||
## Tutorials
|
||||
|
||||
Various tutorials are proposed in the documentation to help you start writing homomorphic programs:
|
||||
|
||||
- How to use Concrete Numpy with [Decorators](https://docs.zama.ai/concrete-numpy/tutorials/decorator)
|
||||
- Partial support of [Floating Points](https://docs.zama.ai/concrete-numpy/tutorials/floating_points)
|
||||
- How to perform [Table Lookup](https://docs.zama.ai/concrete-numpy/tutorials/table_lookup)
|
||||
|
||||
More generally, if you have built awesome projects using Concrete Numpy, feel free to let us know and we'll link to it!
|
||||
|
||||
## Need support?
|
||||
|
||||
<a target="_blank" href="https://community.zama.ai">
|
||||
<img src="https://user-images.githubusercontent.com/5758427/191792238-b132e413-05f9-4fee-bee3-1371f3d81c28.png">
|
||||
</a>
|
||||
|
||||
## License
|
||||
|
||||
This software is distributed under the BSD-3-Clause-Clear license. If you have any questions, please contact us at hello@zama.ai.
|
||||
7
frontends/concrete-python/concrete/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Setup concrete module to be enlarged with numpy module.
|
||||
"""
|
||||
|
||||
# Do not modify, this is to have a compatible namespace package
|
||||
# https://packaging.python.org/en/latest/guides/packaging-namespace-packages/#pkg-resources-style-namespace-packages
|
||||
__import__("pkg_resources").declare_namespace(__name__) # pragma: no cover
|
||||
166
frontends/concrete-python/concrete/numpy/__init__.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Export everything that users might need.
|
||||
"""
|
||||
|
||||
from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult
|
||||
|
||||
from .compilation import (
|
||||
DEFAULT_GLOBAL_P_ERROR,
|
||||
DEFAULT_P_ERROR,
|
||||
Circuit,
|
||||
Client,
|
||||
ClientSpecs,
|
||||
Compiler,
|
||||
Configuration,
|
||||
DebugArtifacts,
|
||||
EncryptionStatus,
|
||||
Server,
|
||||
)
|
||||
from .compilation.decorators import circuit, compiler
|
||||
from .extensions import (
|
||||
AutoRounder,
|
||||
LookupTable,
|
||||
array,
|
||||
one,
|
||||
ones,
|
||||
round_bit_pattern,
|
||||
tag,
|
||||
univariate,
|
||||
zero,
|
||||
zeros,
|
||||
)
|
||||
from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
from .tracing.typing import (
|
||||
f32,
|
||||
f64,
|
||||
int1,
|
||||
int2,
|
||||
int3,
|
||||
int4,
|
||||
int5,
|
||||
int6,
|
||||
int7,
|
||||
int8,
|
||||
int9,
|
||||
int10,
|
||||
int11,
|
||||
int12,
|
||||
int13,
|
||||
int14,
|
||||
int15,
|
||||
int16,
|
||||
int17,
|
||||
int18,
|
||||
int19,
|
||||
int20,
|
||||
int21,
|
||||
int22,
|
||||
int23,
|
||||
int24,
|
||||
int25,
|
||||
int26,
|
||||
int27,
|
||||
int28,
|
||||
int29,
|
||||
int30,
|
||||
int31,
|
||||
int32,
|
||||
int33,
|
||||
int34,
|
||||
int35,
|
||||
int36,
|
||||
int37,
|
||||
int38,
|
||||
int39,
|
||||
int40,
|
||||
int41,
|
||||
int42,
|
||||
int43,
|
||||
int44,
|
||||
int45,
|
||||
int46,
|
||||
int47,
|
||||
int48,
|
||||
int49,
|
||||
int50,
|
||||
int51,
|
||||
int52,
|
||||
int53,
|
||||
int54,
|
||||
int55,
|
||||
int56,
|
||||
int57,
|
||||
int58,
|
||||
int59,
|
||||
int60,
|
||||
int61,
|
||||
int62,
|
||||
int63,
|
||||
int64,
|
||||
tensor,
|
||||
uint1,
|
||||
uint2,
|
||||
uint3,
|
||||
uint4,
|
||||
uint5,
|
||||
uint6,
|
||||
uint7,
|
||||
uint8,
|
||||
uint9,
|
||||
uint10,
|
||||
uint11,
|
||||
uint12,
|
||||
uint13,
|
||||
uint14,
|
||||
uint15,
|
||||
uint16,
|
||||
uint17,
|
||||
uint18,
|
||||
uint19,
|
||||
uint20,
|
||||
uint21,
|
||||
uint22,
|
||||
uint23,
|
||||
uint24,
|
||||
uint25,
|
||||
uint26,
|
||||
uint27,
|
||||
uint28,
|
||||
uint29,
|
||||
uint30,
|
||||
uint31,
|
||||
uint32,
|
||||
uint33,
|
||||
uint34,
|
||||
uint35,
|
||||
uint36,
|
||||
uint37,
|
||||
uint38,
|
||||
uint39,
|
||||
uint40,
|
||||
uint41,
|
||||
uint42,
|
||||
uint43,
|
||||
uint44,
|
||||
uint45,
|
||||
uint46,
|
||||
uint47,
|
||||
uint48,
|
||||
uint49,
|
||||
uint50,
|
||||
uint51,
|
||||
uint52,
|
||||
uint53,
|
||||
uint54,
|
||||
uint55,
|
||||
uint56,
|
||||
uint57,
|
||||
uint58,
|
||||
uint59,
|
||||
uint60,
|
||||
uint61,
|
||||
uint62,
|
||||
uint63,
|
||||
uint64,
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Glue the compilation process together.
|
||||
"""
|
||||
|
||||
from .artifacts import DebugArtifacts
|
||||
from .circuit import Circuit
|
||||
from .client import Client
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .configuration import DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, Configuration
|
||||
from .server import Server
|
||||
from .specs import ClientSpecs
|
||||
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Declaration of `DebugArtifacts` class.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
from ..representation import Graph
|
||||
|
||||
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
|
||||
|
||||
|
||||
class DebugArtifacts:
|
||||
"""
|
||||
DebugArtifacts class, to export information about the compilation process.
|
||||
"""
|
||||
|
||||
output_directory: Path
|
||||
|
||||
source_code: Optional[str]
|
||||
parameter_encryption_statuses: Dict[str, str]
|
||||
textual_representations_of_graphs: Dict[str, List[str]]
|
||||
final_graph: Optional[Graph]
|
||||
mlir_to_compile: Optional[str]
|
||||
client_parameters: Optional[bytes]
|
||||
|
||||
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
|
||||
self.output_directory = Path(output_directory)
|
||||
|
||||
self.source_code = None
|
||||
self.parameter_encryption_statuses = {}
|
||||
self.textual_representations_of_graphs = {}
|
||||
self.final_graph = None
|
||||
self.mlir_to_compile = None
|
||||
self.client_parameters = None
|
||||
|
||||
def add_source_code(self, function: Union[str, Callable]):
|
||||
"""
|
||||
Add source code of the function being compiled.
|
||||
|
||||
Args:
|
||||
function (Union[str, Callable]):
|
||||
either the source code of the function or the function itself
|
||||
"""
|
||||
|
||||
try:
|
||||
self.source_code = (
|
||||
function if isinstance(function, str) else inspect.getsource(function)
|
||||
)
|
||||
except OSError: # pragma: no cover
|
||||
self.source_code = "unavailable"
|
||||
|
||||
def add_parameter_encryption_status(self, name: str, encryption_status: str):
|
||||
"""
|
||||
Add parameter encryption status of a parameter of the function being compiled.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the parameter
|
||||
|
||||
encryption_status (str):
|
||||
encryption status of the parameter
|
||||
"""
|
||||
|
||||
self.parameter_encryption_statuses[name] = encryption_status
|
||||
|
||||
def add_graph(self, name: str, graph: Graph):
|
||||
"""
|
||||
Add a representation of the function being compiled.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the graph (e.g., initial, optimized, final)
|
||||
|
||||
graph (Graph):
|
||||
a representation of the function being compiled
|
||||
"""
|
||||
|
||||
if name not in self.textual_representations_of_graphs:
|
||||
self.textual_representations_of_graphs[name] = []
|
||||
|
||||
textual_representation = graph.format()
|
||||
self.textual_representations_of_graphs[name].append(textual_representation)
|
||||
|
||||
self.final_graph = graph
|
||||
|
||||
def add_mlir_to_compile(self, mlir: str):
|
||||
"""
|
||||
Add textual representation of the resulting MLIR.
|
||||
|
||||
Args:
|
||||
mlir (str):
|
||||
textual representation of the resulting MLIR
|
||||
"""
|
||||
|
||||
self.mlir_to_compile = mlir
|
||||
|
||||
def add_client_parameters(self, client_parameters: bytes):
|
||||
"""
|
||||
Add client parameters used.
|
||||
|
||||
Args:
|
||||
client_parameters (bytes): client parameters
|
||||
"""
|
||||
|
||||
self.client_parameters = client_parameters
|
||||
|
||||
def export(self):
|
||||
"""
|
||||
Export the collected information to `self.output_directory`.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
output_directory = self.output_directory
|
||||
if output_directory.exists():
|
||||
shutil.rmtree(output_directory)
|
||||
output_directory.mkdir(parents=True)
|
||||
|
||||
with open(output_directory.joinpath("environment.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{platform.platform()} {platform.version()}\n")
|
||||
f.write(f"Python {platform.python_version()}\n")
|
||||
|
||||
with open(output_directory.joinpath("requirements.txt"), "w", encoding="utf-8") as f:
|
||||
# example `pip list` output
|
||||
|
||||
# Package Version
|
||||
# ----------------------------- ---------
|
||||
# alabaster 0.7.12
|
||||
# appdirs 1.4.4
|
||||
# ... ...
|
||||
# ... ...
|
||||
# wrapt 1.12.1
|
||||
# zipp 3.5.0
|
||||
|
||||
pip_process = subprocess.run(
|
||||
["pip", "--disable-pip-version-check", "list"], stdout=subprocess.PIPE, check=True
|
||||
)
|
||||
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
|
||||
|
||||
# skip 'Package ... Version' line
|
||||
next(dependencies)
|
||||
|
||||
# skip '------- ... -------' line
|
||||
next(dependencies)
|
||||
|
||||
for dependency in dependencies:
|
||||
tokens = [token for token in dependency.split(" ") if token != ""] # noqa: S105
|
||||
if len(tokens) == 0:
|
||||
continue
|
||||
|
||||
name = tokens[0]
|
||||
version = tokens[1]
|
||||
|
||||
f.write(f"{name}=={version}\n")
|
||||
|
||||
if self.source_code is not None:
|
||||
with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(self.source_code)
|
||||
|
||||
if len(self.parameter_encryption_statuses) > 0:
|
||||
with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f:
|
||||
for name, parameter in self.parameter_encryption_statuses.items():
|
||||
f.write(f"{name} :: {parameter}\n")
|
||||
|
||||
identifier = 0
|
||||
|
||||
textual_representations = self.textual_representations_of_graphs.items()
|
||||
for name, representations in textual_representations:
|
||||
for representation in representations:
|
||||
identifier += 1
|
||||
output_path = output_directory.joinpath(f"{identifier}.{name}.graph.txt")
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"{representation}\n")
|
||||
|
||||
if self.mlir_to_compile is not None:
|
||||
assert self.final_graph is not None
|
||||
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{self.mlir_to_compile}\n")
|
||||
|
||||
if self.client_parameters is not None:
|
||||
with open(output_directory.joinpath("client_parameters.json"), "wb") as f:
|
||||
f.write(self.client_parameters)
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
218
frontends/concrete-python/concrete/numpy/compilation/circuit.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Declaration of `Circuit` class.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import PublicArguments, PublicResult
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph
|
||||
from .client import Client
|
||||
from .configuration import Configuration
|
||||
from .server import Server
|
||||
|
||||
|
||||
class Circuit:
|
||||
"""
|
||||
Circuit class, to combine computation graph, mlir, client and server into a single object.
|
||||
"""
|
||||
|
||||
configuration: Configuration
|
||||
|
||||
graph: Graph
|
||||
mlir: str
|
||||
|
||||
client: Client
|
||||
server: Server
|
||||
|
||||
def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None):
|
||||
self.configuration = configuration if configuration is not None else Configuration()
|
||||
|
||||
self.graph = graph
|
||||
self.mlir = mlir
|
||||
|
||||
self._initialize_client_and_server()
|
||||
|
||||
def _initialize_client_and_server(self):
|
||||
input_signs = []
|
||||
for i in range(len(self.graph.input_nodes)): # pylint: disable=consider-using-enumerate
|
||||
input_value = self.graph.input_nodes[i].output
|
||||
assert_that(isinstance(input_value.dtype, Integer))
|
||||
input_dtype = cast(Integer, input_value.dtype)
|
||||
input_signs.append(input_dtype.is_signed)
|
||||
|
||||
output_signs = []
|
||||
for i in range(len(self.graph.output_nodes)): # pylint: disable=consider-using-enumerate
|
||||
output_value = self.graph.output_nodes[i].output
|
||||
assert_that(isinstance(output_value.dtype, Integer))
|
||||
output_dtype = cast(Integer, output_value.dtype)
|
||||
output_signs.append(output_dtype.is_signed)
|
||||
|
||||
self.server = Server.create(self.mlir, input_signs, output_signs, self.configuration)
|
||||
|
||||
keyset_cache_directory = None
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
assert_that(self.configuration.insecure_key_cache_location is not None)
|
||||
keyset_cache_directory = self.configuration.insecure_key_cache_location
|
||||
|
||||
self.client = Client(self.server.client_specs, keyset_cache_directory)
|
||||
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
|
||||
def simulate(self, *args: Any) -> Any:
|
||||
"""
|
||||
Simulate execution of the circuit.
|
||||
|
||||
Args:
|
||||
*args (Any):
|
||||
inputs to the circuit
|
||||
|
||||
Returns:
|
||||
Any:
|
||||
result of the simulation
|
||||
"""
|
||||
|
||||
return self.graph(*args, p_error=self.p_error)
|
||||
|
||||
def keygen(self, force: bool = False):
|
||||
"""
|
||||
Generate keys required for homomorphic evaluation.
|
||||
|
||||
Args:
|
||||
force (bool, default = False):
|
||||
whether to generate new keys even if keys are already generated
|
||||
"""
|
||||
|
||||
self.client.keygen(force)
|
||||
|
||||
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
|
||||
"""
|
||||
Prepare inputs to be run on the circuit.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]):
|
||||
inputs to the circuit
|
||||
|
||||
Returns:
|
||||
PublicArguments:
|
||||
encrypted and plain arguments as well as public keys
|
||||
"""
|
||||
|
||||
return self.client.encrypt(*args)
|
||||
|
||||
def run(self, args: PublicArguments) -> PublicResult:
|
||||
"""
|
||||
Evaluate circuit using encrypted arguments.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
arguments to the circuit (can be obtained with `encrypt` method of `Circuit`)
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
self.keygen(force=False)
|
||||
return self.server.run(args, self.client.evaluation_keys)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
result: PublicResult,
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
"""
|
||||
Decrypt result of homomorphic evaluaton.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
encrypted result of homomorphic evaluaton
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]:
|
||||
clear result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
return self.client.decrypt(result)
|
||||
|
||||
def encrypt_run_decrypt(self, *args: Any) -> Any:
|
||||
"""
|
||||
Encrypt inputs, run the circuit, and decrypt the outputs in one go.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]):
|
||||
inputs to the circuit
|
||||
|
||||
Returns:
|
||||
Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
clear result of homomorphic evaluation
|
||||
"""
|
||||
|
||||
return self.decrypt(self.run(self.encrypt(*args)))
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Cleanup the temporary library output directory.
|
||||
"""
|
||||
|
||||
self.server.cleanup()
|
||||
|
||||
@property
|
||||
def complexity(self) -> float:
|
||||
"""
|
||||
Get complexity of the circuit.
|
||||
"""
|
||||
return self.server.complexity
|
||||
|
||||
@property
|
||||
def size_of_secret_keys(self) -> int:
|
||||
"""
|
||||
Get size of the secret keys of the circuit.
|
||||
"""
|
||||
return self.server.size_of_secret_keys
|
||||
|
||||
@property
|
||||
def size_of_bootstrap_keys(self) -> int:
|
||||
"""
|
||||
Get size of the bootstrap keys of the circuit.
|
||||
"""
|
||||
return self.server.size_of_bootstrap_keys
|
||||
|
||||
@property
|
||||
def size_of_keyswitch_keys(self) -> int:
|
||||
"""
|
||||
Get size of the key switch keys of the circuit.
|
||||
"""
|
||||
return self.server.size_of_keyswitch_keys
|
||||
|
||||
@property
|
||||
def size_of_inputs(self) -> int:
|
||||
"""
|
||||
Get size of the inputs of the circuit.
|
||||
"""
|
||||
return self.server.size_of_inputs
|
||||
|
||||
@property
|
||||
def size_of_outputs(self) -> int:
|
||||
"""
|
||||
Get size of the outputs of the circuit.
|
||||
"""
|
||||
return self.server.size_of_outputs
|
||||
|
||||
@property
|
||||
def p_error(self) -> int:
|
||||
"""
|
||||
Get probability of error for each simple TLU (on a scalar).
|
||||
"""
|
||||
return self.server.p_error
|
||||
|
||||
@property
|
||||
def global_p_error(self) -> int:
|
||||
"""
|
||||
Get the probability of having at least one simple TLU error during the entire execution.
|
||||
"""
|
||||
return self.server.global_p_error
|
||||
271
frontends/concrete-python/concrete/numpy/compilation/client.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Declaration of `Client` class.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
ClientSupport,
|
||||
EvaluationKeys,
|
||||
KeySet,
|
||||
KeySetCache,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
|
||||
from ..dtypes.integer import SignedInteger, UnsignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..values.value import Value
|
||||
from .specs import ClientSpecs
|
||||
|
||||
|
||||
class Client:
|
||||
"""
|
||||
Client class, which can be used to manage keys, encrypt arguments and decrypt results.
|
||||
"""
|
||||
|
||||
specs: ClientSpecs
|
||||
|
||||
_keyset: Optional[KeySet]
|
||||
_keyset_cache: Optional[KeySetCache]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_specs: ClientSpecs,
|
||||
keyset_cache_directory: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
self.specs = client_specs
|
||||
|
||||
self._keyset = None
|
||||
self._keyset_cache = None
|
||||
|
||||
if keyset_cache_directory is not None:
|
||||
self._keyset_cache = KeySetCache.new(str(keyset_cache_directory))
|
||||
|
||||
def save(self, path: Union[str, Path]):
|
||||
"""
|
||||
Save the client into the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to save the client
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with open(Path(tmp_dir) / "client.specs.json", "w", encoding="utf-8") as f:
|
||||
f.write(self.specs.serialize())
|
||||
|
||||
path = str(path)
|
||||
if path.endswith(".zip"):
|
||||
path = path[: len(path) - 4]
|
||||
|
||||
shutil.make_archive(path, "zip", tmp_dir)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: Union[str, Path],
|
||||
keyset_cache_directory: Optional[Union[str, Path]] = None,
|
||||
) -> "Client":
|
||||
"""
|
||||
Load the client from the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to load the client from
|
||||
|
||||
keyset_cache_directory (Optional[Union[str, Path]], default = None):
|
||||
keyset cache directory to use
|
||||
|
||||
Returns:
|
||||
Client:
|
||||
client loaded from the filesystem
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
shutil.unpack_archive(path, tmp_dir, "zip")
|
||||
with open(Path(tmp_dir) / "client.specs.json", "r", encoding="utf-8") as f:
|
||||
client_specs = ClientSpecs.unserialize(f.read())
|
||||
|
||||
return Client(client_specs, keyset_cache_directory)
|
||||
|
||||
def keygen(self, force: bool = False):
|
||||
"""
|
||||
Generate keys required for homomorphic evaluation.
|
||||
|
||||
Args:
|
||||
force (bool, default = False):
|
||||
whether to generate new keys even if keys are already generated
|
||||
"""
|
||||
|
||||
if self._keyset is None or force:
|
||||
self._keyset = ClientSupport.key_set(self.specs.client_parameters, self._keyset_cache)
|
||||
|
||||
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
|
||||
"""
|
||||
Prepare inputs to be run on the circuit.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]):
|
||||
inputs to the circuit
|
||||
|
||||
Returns:
|
||||
PublicArguments:
|
||||
encrypted and plain arguments as well as public keys
|
||||
"""
|
||||
|
||||
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
|
||||
assert_that("inputs" in client_parameters_json)
|
||||
input_specs = client_parameters_json["inputs"]
|
||||
|
||||
if len(args) != len(input_specs):
|
||||
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
|
||||
raise ValueError(message)
|
||||
|
||||
sanitized_args: Dict[int, Union[int, np.ndarray]] = {}
|
||||
for index, spec in enumerate(input_specs):
|
||||
arg = args[index]
|
||||
if isinstance(arg, list):
|
||||
arg = np.array(arg)
|
||||
|
||||
is_valid = isinstance(arg, (int, np.integer)) or (
|
||||
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
|
||||
)
|
||||
|
||||
width = spec["shape"]["width"]
|
||||
shape = tuple(spec["shape"]["dimensions"])
|
||||
is_encrypted = spec["encryption"] is not None
|
||||
|
||||
expected_dtype = (
|
||||
SignedInteger(width) if self.specs.input_signs[index] else UnsignedInteger(width)
|
||||
)
|
||||
expected_value = Value(expected_dtype, shape, is_encrypted)
|
||||
if is_valid:
|
||||
expected_min = expected_dtype.min()
|
||||
expected_max = expected_dtype.max()
|
||||
|
||||
actual_min = arg if isinstance(arg, int) else arg.min()
|
||||
actual_max = arg if isinstance(arg, int) else arg.max()
|
||||
actual_shape = () if isinstance(arg, int) else arg.shape
|
||||
|
||||
is_valid = (
|
||||
actual_min >= expected_min
|
||||
and actual_max <= expected_max
|
||||
and actual_shape == expected_value.shape
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
is_signed = self.specs.input_signs[index]
|
||||
sanitizer = 0 if not is_signed else 2 ** (width - 1)
|
||||
|
||||
if isinstance(arg, int):
|
||||
sanitized_args[index] = arg + sanitizer
|
||||
else:
|
||||
sanitized_args[index] = (arg + sanitizer).astype(np.uint64)
|
||||
|
||||
if not is_valid:
|
||||
actual_value = Value.of(arg, is_encrypted=is_encrypted)
|
||||
message = (
|
||||
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
self.keygen(force=False)
|
||||
return ClientSupport.encrypt_arguments(
|
||||
self.specs.client_parameters,
|
||||
self._keyset,
|
||||
[sanitized_args[i] for i in range(len(sanitized_args))],
|
||||
)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
result: PublicResult,
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
"""
|
||||
Decrypt result of homomorphic evaluaton.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
encrypted result of homomorphic evaluaton
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]:
|
||||
clear result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
self.keygen(force=False)
|
||||
outputs = ClientSupport.decrypt_result(self.specs.client_parameters, self._keyset, result)
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs,)
|
||||
|
||||
sanitized_outputs: List[Union[int, np.ndarray]] = []
|
||||
|
||||
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
|
||||
assert_that("outputs" in client_parameters_json)
|
||||
output_specs = client_parameters_json["outputs"]
|
||||
|
||||
for index, output in enumerate(outputs):
|
||||
is_signed = self.specs.output_signs[index]
|
||||
crt_decomposition = (
|
||||
output_specs[index].get("encryption", {}).get("encoding", {}).get("crt", [])
|
||||
)
|
||||
|
||||
if is_signed:
|
||||
if crt_decomposition:
|
||||
if isinstance(output, int):
|
||||
sanititzed_output = (
|
||||
output
|
||||
if output < (int(np.prod(crt_decomposition)) // 2)
|
||||
else -int(np.prod(crt_decomposition)) + output
|
||||
)
|
||||
else:
|
||||
output = output.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_output = np.where(
|
||||
output < (np.prod(crt_decomposition) // 2),
|
||||
output,
|
||||
-np.prod(crt_decomposition) + output,
|
||||
).astype(
|
||||
np.int64
|
||||
) # type: ignore
|
||||
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
|
||||
else:
|
||||
n = output_specs[index]["shape"]["width"]
|
||||
output %= 2**n
|
||||
if isinstance(output, int):
|
||||
sanititzed_output = output if output < (2 ** (n - 1)) else output - (2**n)
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
else:
|
||||
output = output.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_output = np.where(
|
||||
output < (2 ** (n - 1)), output, output - (2**n)
|
||||
).astype(
|
||||
np.int64
|
||||
) # type: ignore
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
else:
|
||||
sanitized_outputs.append(
|
||||
output if isinstance(output, int) else output.astype(np.uint64)
|
||||
)
|
||||
|
||||
return sanitized_outputs[0] if len(sanitized_outputs) == 1 else tuple(sanitized_outputs)
|
||||
|
||||
@property
|
||||
def evaluation_keys(self) -> EvaluationKeys:
|
||||
"""
|
||||
Get evaluation keys for encrypted computation.
|
||||
|
||||
Returns:
|
||||
EvaluationKeys
|
||||
evaluation keys for encrypted computation
|
||||
"""
|
||||
|
||||
self.keygen(force=False)
|
||||
|
||||
assert self._keyset is not None
|
||||
return self._keyset.get_evaluation_keys()
|
||||
550
frontends/concrete-python/concrete/numpy/compilation/compiler.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""
|
||||
Declaration of `Compiler` class.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..extensions import AutoRounder
|
||||
from ..mlir import GraphConverter
|
||||
from ..representation import Graph
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
from .artifacts import DebugArtifacts
|
||||
from .circuit import Circuit
|
||||
from .configuration import Configuration
|
||||
from .utils import fuse
|
||||
|
||||
|
||||
@unique
|
||||
class EncryptionStatus(str, Enum):
|
||||
"""
|
||||
EncryptionStatus enum, to represent encryption status of parameters.
|
||||
"""
|
||||
|
||||
CLEAR = "clear"
|
||||
ENCRYPTED = "encrypted"
|
||||
|
||||
|
||||
class Compiler:
|
||||
"""
|
||||
Compiler class, to glue the compilation pipeline.
|
||||
"""
|
||||
|
||||
function: Callable
|
||||
parameter_encryption_statuses: Dict[str, EncryptionStatus]
|
||||
|
||||
configuration: Configuration
|
||||
artifacts: Optional[DebugArtifacts]
|
||||
|
||||
inputset: List[Any]
|
||||
graph: Optional[Graph]
|
||||
|
||||
_is_direct: bool
|
||||
_parameter_values: Dict[str, Value]
|
||||
|
||||
@staticmethod
|
||||
def assemble(
|
||||
function: Callable,
|
||||
parameter_values: Dict[str, Value],
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Assemble a circuit from the raw parameter values, used in direct circuit definition.
|
||||
|
||||
Args:
|
||||
function (Callable):
|
||||
function to convert to a circuit
|
||||
|
||||
parameter_values (Dict[str, Value]):
|
||||
parameter values of the function
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
assembled circuit
|
||||
"""
|
||||
|
||||
compiler = Compiler(
|
||||
function,
|
||||
{
|
||||
name: "encrypted" if value.is_encrypted else "clear"
|
||||
for name, value in parameter_values.items()
|
||||
},
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
compiler._is_direct = True
|
||||
compiler._parameter_values = parameter_values
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return compiler.compile(None, configuration, artifacts, **kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: Callable,
|
||||
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
|
||||
):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
missing_args = list(signature.parameters)
|
||||
for arg in parameter_encryption_statuses.keys():
|
||||
if arg in signature.parameters:
|
||||
missing_args.remove(arg)
|
||||
|
||||
if len(missing_args) != 0:
|
||||
parameter_str = repr(missing_args[0])
|
||||
for arg in missing_args[1:-1]:
|
||||
parameter_str += f", {repr(arg)}"
|
||||
if len(missing_args) != 1:
|
||||
parameter_str += f" and {repr(missing_args[-1])}"
|
||||
|
||||
message = (
|
||||
f"Encryption status{'es' if len(missing_args) > 1 else ''} "
|
||||
f"of parameter{'s' if len(missing_args) > 1 else ''} "
|
||||
f"{parameter_str} of function '{function.__name__}' "
|
||||
f"{'are' if len(missing_args) > 1 else 'is'} not provided"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
additional_args = list(parameter_encryption_statuses)
|
||||
for arg in signature.parameters.keys():
|
||||
if arg in parameter_encryption_statuses:
|
||||
additional_args.remove(arg)
|
||||
|
||||
if len(additional_args) != 0:
|
||||
parameter_str = repr(additional_args[0])
|
||||
for arg in additional_args[1:-1]:
|
||||
parameter_str += f", {repr(arg)}"
|
||||
if len(additional_args) != 1:
|
||||
parameter_str += f" and {repr(additional_args[-1])}"
|
||||
|
||||
message = (
|
||||
f"Encryption status{'es' if len(additional_args) > 1 else ''} "
|
||||
f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but "
|
||||
f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter "
|
||||
f"of function '{function.__name__}'"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
self.function = function # type: ignore
|
||||
self.parameter_encryption_statuses = {
|
||||
param: EncryptionStatus(status.lower())
|
||||
for param, status in parameter_encryption_statuses.items()
|
||||
}
|
||||
|
||||
self.configuration = Configuration()
|
||||
self.artifacts = None
|
||||
|
||||
self.inputset = []
|
||||
self.graph = None
|
||||
|
||||
self._is_direct = False
|
||||
self._parameter_values = {}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
np.bool_,
|
||||
np.integer,
|
||||
np.floating,
|
||||
np.ndarray,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
||||
]:
|
||||
if len(kwargs) != 0:
|
||||
message = f"Calling function '{self.function.__name__}' with kwargs is not supported"
|
||||
raise RuntimeError(message)
|
||||
|
||||
sample = args[0] if len(args) == 1 else args
|
||||
|
||||
if self.graph is None:
|
||||
self._trace(sample)
|
||||
assert self.graph is not None
|
||||
|
||||
self.inputset.append(sample)
|
||||
return self.graph(*args)
|
||||
|
||||
def _trace(self, sample: Union[Any, Tuple[Any, ...]]):
|
||||
"""
|
||||
Trace the function and fuse the resulting graph with a sample input.
|
||||
|
||||
Args:
|
||||
sample (Union[Any, Tuple[Any, ...]]):
|
||||
sample to use for tracing
|
||||
"""
|
||||
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_source_code(self.function)
|
||||
for param, encryption_status in self.parameter_encryption_statuses.items():
|
||||
self.artifacts.add_parameter_encryption_status(param, encryption_status)
|
||||
|
||||
parameters = {
|
||||
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
|
||||
for arg, (param, status) in zip(
|
||||
sample if len(self.parameter_encryption_statuses) > 1 else (sample,),
|
||||
self.parameter_encryption_statuses.items(),
|
||||
)
|
||||
}
|
||||
|
||||
self.graph = Tracer.trace(self.function, parameters)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("initial", self.graph)
|
||||
|
||||
fuse(self.graph, self.artifacts)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
action: str,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]],
|
||||
):
|
||||
"""
|
||||
Trace, fuse, measure bounds, and update values in the resulting graph in one go.
|
||||
|
||||
Args:
|
||||
action (str):
|
||||
action being performed (e.g., "trace", "compile")
|
||||
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
"""
|
||||
|
||||
if self._is_direct:
|
||||
|
||||
self.graph = Tracer.trace(self.function, self._parameter_values, is_direct=True)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("initial", self.graph) # pragma: no cover
|
||||
|
||||
fuse(self.graph, self.artifacts)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("final", self.graph) # pragma: no cover
|
||||
|
||||
return
|
||||
|
||||
if inputset is not None:
|
||||
previous_inputset_length = len(self.inputset)
|
||||
for index, sample in enumerate(iter(inputset)):
|
||||
self.inputset.append(sample)
|
||||
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
if len(sample) != len(self.parameter_encryption_statuses):
|
||||
self.inputset = self.inputset[:previous_inputset_length]
|
||||
|
||||
expected = (
|
||||
"a single value"
|
||||
if len(self.parameter_encryption_statuses) == 1
|
||||
else f"a tuple of {len(self.parameter_encryption_statuses)} values"
|
||||
)
|
||||
actual = (
|
||||
"a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values"
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Input #{index} of your inputset is not well formed "
|
||||
f"(expected {expected} got {actual})"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if self.configuration.auto_adjust_rounders:
|
||||
AutoRounder.adjust(self.function, self.inputset)
|
||||
|
||||
if self.graph is None:
|
||||
try:
|
||||
first_sample = next(iter(self.inputset))
|
||||
except StopIteration as error:
|
||||
message = (
|
||||
f"{action} function '{self.function.__name__}' "
|
||||
f"without an inputset is not supported"
|
||||
)
|
||||
raise RuntimeError(message) from error
|
||||
|
||||
self._trace(first_sample)
|
||||
assert self.graph is not None
|
||||
|
||||
bounds = self.graph.measure_bounds(self.inputset)
|
||||
self.graph.update_with_bounds(bounds)
|
||||
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("final", self.graph)
|
||||
|
||||
def trace(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Graph:
|
||||
"""
|
||||
Trace the function using an inputset.
|
||||
|
||||
Args:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph representing the function prior to MLIR conversion
|
||||
"""
|
||||
|
||||
old_configuration = deepcopy(self.configuration)
|
||||
old_artifacts = deepcopy(self.artifacts)
|
||||
|
||||
if configuration is not None:
|
||||
self.configuration = configuration
|
||||
|
||||
if len(kwargs) != 0:
|
||||
self.configuration = self.configuration.fork(**kwargs)
|
||||
|
||||
self.artifacts = (
|
||||
artifacts
|
||||
if artifacts is not None
|
||||
else DebugArtifacts()
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
self._evaluate("Tracing", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
if self.configuration.verbose or self.configuration.show_graph:
|
||||
graph = self.graph.format()
|
||||
longest_line = max([len(line) for line in graph.split("\n")])
|
||||
|
||||
try: # pragma: no cover
|
||||
|
||||
# this branch cannot be covered
|
||||
# because `os.get_terminal_size()`
|
||||
# raises an exception during tests
|
||||
|
||||
columns, _ = os.get_terminal_size()
|
||||
if columns == 0:
|
||||
columns = min(longest_line, 80)
|
||||
else:
|
||||
columns = min(longest_line, columns)
|
||||
except OSError: # pragma: no cover
|
||||
columns = min(longest_line, 80)
|
||||
|
||||
print()
|
||||
|
||||
print("Computation Graph")
|
||||
print("-" * columns)
|
||||
print(graph)
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
return self.graph
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
# this branch is reserved for unexpected issues and hence it shouldn't be tested
|
||||
# if it could be tested, we would have fixed the underlying issue
|
||||
|
||||
# if the user desires so,
|
||||
# we need to export all the information we have about the compilation
|
||||
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures:
|
||||
assert self.artifacts is not None
|
||||
self.artifacts.export()
|
||||
|
||||
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
self.configuration = old_configuration
|
||||
self.artifacts = old_artifacts
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
|
||||
def compile(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Compile the function using an inputset.
|
||||
|
||||
Args:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
old_configuration = deepcopy(self.configuration)
|
||||
old_artifacts = deepcopy(self.artifacts)
|
||||
|
||||
if configuration is not None:
|
||||
self.configuration = configuration
|
||||
|
||||
if len(kwargs) != 0:
|
||||
self.configuration = self.configuration.fork(**kwargs)
|
||||
|
||||
self.artifacts = (
|
||||
artifacts
|
||||
if artifacts is not None
|
||||
else DebugArtifacts()
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
self._evaluate("Compiling", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
mlir = GraphConverter.convert(self.graph)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_mlir_to_compile(mlir)
|
||||
|
||||
show_graph = (
|
||||
self.configuration.show_graph
|
||||
if self.configuration.show_graph is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
show_mlir = (
|
||||
self.configuration.show_mlir
|
||||
if self.configuration.show_mlir is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
show_optimizer = (
|
||||
self.configuration.show_optimizer
|
||||
if self.configuration.show_optimizer is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
columns = 0
|
||||
if show_graph or show_mlir or show_optimizer:
|
||||
|
||||
graph = (
|
||||
self.graph.format()
|
||||
if self.configuration.verbose or self.configuration.show_graph
|
||||
else ""
|
||||
)
|
||||
|
||||
longest_graph_line = max([len(line) for line in graph.split("\n")])
|
||||
longest_mlir_line = max([len(line) for line in mlir.split("\n")])
|
||||
longest_line = max(longest_graph_line, longest_mlir_line)
|
||||
|
||||
try: # pragma: no cover
|
||||
|
||||
# this branch cannot be covered
|
||||
# because `os.get_terminal_size()`
|
||||
# raises an exception during tests
|
||||
|
||||
columns, _ = os.get_terminal_size()
|
||||
if columns == 0:
|
||||
columns = min(longest_line, 80)
|
||||
else:
|
||||
columns = min(longest_line, columns)
|
||||
except OSError: # pragma: no cover
|
||||
columns = min(longest_line, 80)
|
||||
|
||||
if show_graph:
|
||||
print()
|
||||
|
||||
print("Computation Graph")
|
||||
print("-" * columns)
|
||||
print(graph)
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
if show_mlir:
|
||||
print("\n" if not show_graph else "", end="")
|
||||
|
||||
print("MLIR")
|
||||
print("-" * columns)
|
||||
print(mlir)
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
if show_optimizer:
|
||||
print("\n" if not (show_graph or show_mlir) else "", end="")
|
||||
|
||||
print("Optimizer")
|
||||
print("-" * columns)
|
||||
|
||||
circuit = Circuit(self.graph, mlir, self.configuration)
|
||||
|
||||
client_parameters = circuit.client.specs.client_parameters
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_client_parameters(client_parameters.serialize())
|
||||
|
||||
if show_optimizer:
|
||||
print("-" * columns)
|
||||
print()
|
||||
|
||||
return circuit
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
# this branch is reserved for unexpected issues and hence it shouldn't be tested
|
||||
# if it could be tested, we would have fixed the underlying issue
|
||||
|
||||
# if the user desires so,
|
||||
# we need to export all the information we have about the compilation
|
||||
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures:
|
||||
assert self.artifacts is not None
|
||||
self.artifacts.export()
|
||||
|
||||
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
self.configuration = old_configuration
|
||||
self.artifacts = old_artifacts
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-statements
|
||||
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Declaration of `Configuration` class.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, get_type_hints
|
||||
|
||||
DEFAULT_P_ERROR = None
|
||||
DEFAULT_GLOBAL_P_ERROR = 1 / 100_000
|
||||
|
||||
|
||||
class Configuration:
|
||||
"""
|
||||
Configuration class, to allow the compilation process to be customized.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
verbose: bool
|
||||
show_graph: Optional[bool]
|
||||
show_mlir: Optional[bool]
|
||||
show_optimizer: Optional[bool]
|
||||
dump_artifacts_on_unexpected_failures: bool
|
||||
enable_unsafe_features: bool
|
||||
use_insecure_key_cache: bool
|
||||
loop_parallelize: bool
|
||||
dataflow_parallelize: bool
|
||||
auto_parallelize: bool
|
||||
jit: bool
|
||||
p_error: Optional[float]
|
||||
global_p_error: Optional[float]
|
||||
insecure_key_cache_location: Optional[str]
|
||||
auto_adjust_rounders: bool
|
||||
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Validate configuration.
|
||||
"""
|
||||
|
||||
if not self.enable_unsafe_features:
|
||||
|
||||
if self.use_insecure_key_cache:
|
||||
message = "Insecure key cache cannot be used without enabling unsafe features"
|
||||
raise RuntimeError(message)
|
||||
|
||||
if self.use_insecure_key_cache and self.insecure_key_cache_location is None:
|
||||
message = "Insecure key cache cannot be enabled without specifying its location"
|
||||
raise RuntimeError(message)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose: bool = False,
|
||||
show_graph: Optional[bool] = None,
|
||||
show_mlir: Optional[bool] = None,
|
||||
show_optimizer: Optional[bool] = None,
|
||||
dump_artifacts_on_unexpected_failures: bool = True,
|
||||
enable_unsafe_features: bool = False,
|
||||
use_insecure_key_cache: bool = False,
|
||||
insecure_key_cache_location: Optional[Union[Path, str]] = None,
|
||||
loop_parallelize: bool = True,
|
||||
dataflow_parallelize: bool = True,
|
||||
auto_parallelize: bool = False,
|
||||
jit: bool = False,
|
||||
p_error: Optional[float] = None,
|
||||
global_p_error: Optional[float] = None,
|
||||
auto_adjust_rounders: bool = False,
|
||||
):
|
||||
self.verbose = verbose
|
||||
self.show_graph = show_graph
|
||||
self.show_mlir = show_mlir
|
||||
self.show_optimizer = show_optimizer
|
||||
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
|
||||
self.enable_unsafe_features = enable_unsafe_features
|
||||
self.use_insecure_key_cache = use_insecure_key_cache
|
||||
self.insecure_key_cache_location = (
|
||||
str(insecure_key_cache_location) if insecure_key_cache_location is not None else None
|
||||
)
|
||||
self.loop_parallelize = loop_parallelize
|
||||
self.dataflow_parallelize = dataflow_parallelize
|
||||
self.auto_parallelize = auto_parallelize
|
||||
self.jit = jit
|
||||
self.p_error = p_error
|
||||
self.global_p_error = global_p_error
|
||||
self.auto_adjust_rounders = auto_adjust_rounders
|
||||
|
||||
self._validate()
|
||||
|
||||
# pylint: enable=too-many-arguments
|
||||
|
||||
def fork(self, **kwargs) -> "Configuration":
|
||||
"""
|
||||
Get a new configuration from another one specified changes.
|
||||
|
||||
Args:
|
||||
**kwargs:
|
||||
changes to make
|
||||
|
||||
Returns:
|
||||
Configuration:
|
||||
configuration that is forked from self and updated using kwargs
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
result = deepcopy(self)
|
||||
|
||||
hints = get_type_hints(Configuration)
|
||||
for name, value in kwargs.items():
|
||||
if name not in hints:
|
||||
message = f"Unexpected keyword argument '{name}'"
|
||||
raise TypeError(message)
|
||||
|
||||
hint = hints[name]
|
||||
expected = None
|
||||
is_correctly_typed = True
|
||||
|
||||
if name == "insecure_key_cache_location":
|
||||
if not (value is None or isinstance(value, str)):
|
||||
is_correctly_typed = False
|
||||
expected = "Optional[str]"
|
||||
|
||||
elif name == "p_error":
|
||||
if not (value is None or isinstance(value, float)):
|
||||
is_correctly_typed = False
|
||||
expected = "Optional[float]"
|
||||
|
||||
elif name == "global_p_error":
|
||||
if not (value is None or isinstance(value, float)):
|
||||
is_correctly_typed = False
|
||||
expected = "Optional[float]"
|
||||
|
||||
elif name in ["show_graph", "show_mlir", "show_optimizer"]:
|
||||
if not (value is None or isinstance(value, bool)):
|
||||
is_correctly_typed = False
|
||||
expected = "Optional[bool]"
|
||||
|
||||
elif not isinstance(value, hint): # type: ignore
|
||||
is_correctly_typed = False
|
||||
|
||||
if not is_correctly_typed:
|
||||
if expected is None:
|
||||
expected = hint.__name__ if hasattr(hint, "__name__") else str(hint)
|
||||
message = (
|
||||
f"Unexpected type for keyword argument '{name}' "
|
||||
f"(expected '{expected}', got '{type(value).__name__}')"
|
||||
)
|
||||
raise TypeError(message)
|
||||
|
||||
setattr(result, name, value)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
result._validate()
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return result
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Declaration of `circuit` and `compiler` decorators.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
|
||||
|
||||
from ..representation import Graph
|
||||
from ..tracing.typing import ScalarAnnotation
|
||||
from ..values import Value
|
||||
from .artifacts import DebugArtifacts
|
||||
from .circuit import Circuit
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .configuration import Configuration
|
||||
|
||||
|
||||
def circuit(
|
||||
parameters: Mapping[str, Union[str, EncryptionStatus]],
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Provide a direct interface for compilation.
|
||||
|
||||
Args:
|
||||
parameters (Mapping[str, Union[str, EncryptionStatus]]):
|
||||
encryption statuses of the parameters of the function to compile
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
"""
|
||||
|
||||
def decoration(function: Callable):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
parameter_values: Dict[str, Value] = {}
|
||||
for name, details in signature.parameters.items():
|
||||
if name not in parameters:
|
||||
continue
|
||||
|
||||
annotation = details.annotation
|
||||
|
||||
is_value = isinstance(annotation, Value)
|
||||
is_scalar_annotation = isinstance(annotation, type) and issubclass(
|
||||
annotation, ScalarAnnotation
|
||||
)
|
||||
|
||||
if not (is_value or is_scalar_annotation):
|
||||
message = (
|
||||
f"Annotation {annotation} for argument '{name}' is not valid "
|
||||
f"(please use a cnp type such as "
|
||||
f"`cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
parameter_values[name] = (
|
||||
annotation if is_value else Value(annotation.dtype, shape=(), is_encrypted=False)
|
||||
)
|
||||
|
||||
status = EncryptionStatus(parameters[name].lower())
|
||||
parameter_values[name].is_encrypted = status == "encrypted"
|
||||
|
||||
return Compiler.assemble(function, parameter_values, configuration, artifacts, **kwargs)
|
||||
|
||||
return decoration
|
||||
|
||||
|
||||
def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]):
|
||||
"""
|
||||
Provide an easy interface for compilation.
|
||||
|
||||
Args:
|
||||
parameters (Mapping[str, Union[str, EncryptionStatus]]):
|
||||
encryption statuses of the parameters of the function to compile
|
||||
"""
|
||||
|
||||
def decoration(function: Callable):
|
||||
class Compilable:
|
||||
"""
|
||||
Compilable class, to wrap a function and provide methods to trace and compile it.
|
||||
"""
|
||||
|
||||
function: Callable
|
||||
compiler: Compiler
|
||||
|
||||
def __init__(self, function: Callable):
|
||||
self.function = function # type: ignore
|
||||
self.compiler = Compiler(self.function, dict(parameters))
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
self.compiler(*args, **kwargs)
|
||||
return self.function(*args, **kwargs)
|
||||
|
||||
def trace(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Graph:
|
||||
"""
|
||||
Trace the function into computation graph.
|
||||
|
||||
Args:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph representing the function prior to MLIR conversion
|
||||
"""
|
||||
|
||||
return self.compiler.trace(inputset, configuration, artifacts, **kwargs)
|
||||
|
||||
def compile(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Compile the function into a circuit.
|
||||
|
||||
Args:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
return self.compiler.compile(inputset, configuration, artifacts, **kwargs)
|
||||
|
||||
return Compilable(function)
|
||||
|
||||
return decoration
|
||||
350
frontends/concrete-python/concrete/numpy/compilation/server.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Declaration of `Server` class.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import concrete.compiler
|
||||
from concrete.compiler import (
|
||||
CompilationFeedback,
|
||||
CompilationOptions,
|
||||
EvaluationKeys,
|
||||
JITCompilationResult,
|
||||
JITLambda,
|
||||
JITSupport,
|
||||
LibraryCompilationResult,
|
||||
LibraryLambda,
|
||||
LibrarySupport,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from .configuration import DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, Configuration
|
||||
from .specs import ClientSpecs
|
||||
|
||||
|
||||
class Server:
|
||||
"""
|
||||
Server class, which can be used to perform homomorphic computation.
|
||||
"""
|
||||
|
||||
client_specs: ClientSpecs
|
||||
|
||||
_output_dir: Optional[tempfile.TemporaryDirectory]
|
||||
_support: Union[JITSupport, LibrarySupport]
|
||||
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
|
||||
_compilation_feedback: CompilationFeedback
|
||||
_server_lambda: Union[JITLambda, LibraryLambda]
|
||||
|
||||
_mlir: Optional[str]
|
||||
_configuration: Optional[Configuration]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_specs: ClientSpecs,
|
||||
output_dir: Optional[tempfile.TemporaryDirectory],
|
||||
support: Union[JITSupport, LibrarySupport],
|
||||
compilation_result: Union[JITCompilationResult, LibraryCompilationResult],
|
||||
server_lambda: Union[JITLambda, LibraryLambda],
|
||||
):
|
||||
self.client_specs = client_specs
|
||||
|
||||
self._output_dir = output_dir
|
||||
self._support = support
|
||||
self._compilation_result = compilation_result
|
||||
self._compilation_feedback = self._support.load_compilation_feedback(compilation_result)
|
||||
self._server_lambda = server_lambda
|
||||
self._mlir = None
|
||||
|
||||
assert_that(
|
||||
support.load_client_parameters(compilation_result).serialize()
|
||||
== client_specs.client_parameters.serialize()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
mlir: str,
|
||||
input_signs: List[bool],
|
||||
output_signs: List[bool],
|
||||
configuration: Configuration,
|
||||
) -> "Server":
|
||||
"""
|
||||
Create a server using MLIR and output sign information.
|
||||
|
||||
Args:
|
||||
mlir (str):
|
||||
mlir to compile
|
||||
|
||||
input_signs (List[bool]):
|
||||
sign status of the inputs
|
||||
|
||||
output_signs (List[bool]):
|
||||
sign status of the outputs
|
||||
|
||||
configuration (Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
"""
|
||||
|
||||
options = CompilationOptions.new("main")
|
||||
|
||||
options.set_loop_parallelize(configuration.loop_parallelize)
|
||||
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
|
||||
options.set_auto_parallelize(configuration.auto_parallelize)
|
||||
|
||||
if configuration.auto_parallelize or configuration.dataflow_parallelize:
|
||||
concrete.compiler.init_dfr()
|
||||
|
||||
global_p_error_is_set = configuration.global_p_error is not None
|
||||
p_error_is_set = configuration.p_error is not None
|
||||
|
||||
if global_p_error_is_set and p_error_is_set: # pragma: no cover
|
||||
options.set_global_p_error(configuration.global_p_error)
|
||||
options.set_p_error(configuration.p_error)
|
||||
|
||||
elif global_p_error_is_set: # pragma: no cover
|
||||
options.set_global_p_error(configuration.global_p_error)
|
||||
options.set_p_error(1.0)
|
||||
|
||||
elif p_error_is_set: # pragma: no cover
|
||||
options.set_global_p_error(1.0)
|
||||
options.set_p_error(configuration.p_error)
|
||||
|
||||
else: # pragma: no cover
|
||||
if DEFAULT_GLOBAL_P_ERROR is not None:
|
||||
options.set_global_p_error(DEFAULT_GLOBAL_P_ERROR)
|
||||
else:
|
||||
options.set_global_p_error(1.0)
|
||||
|
||||
if DEFAULT_P_ERROR is not None:
|
||||
options.set_p_error(DEFAULT_P_ERROR)
|
||||
else:
|
||||
options.set_p_error(1.0)
|
||||
|
||||
show_optimizer = (
|
||||
configuration.show_optimizer
|
||||
if configuration.show_optimizer is not None
|
||||
else configuration.verbose
|
||||
)
|
||||
options.set_display_optimizer_choice(show_optimizer)
|
||||
|
||||
if configuration.jit:
|
||||
|
||||
output_dir = None
|
||||
|
||||
support = JITSupport.new()
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
else:
|
||||
|
||||
# pylint: disable=consider-using-with
|
||||
output_dir = tempfile.TemporaryDirectory()
|
||||
output_dir_path = Path(output_dir.name)
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
support = LibrarySupport.new(
|
||||
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
|
||||
)
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
client_specs = ClientSpecs(input_signs, client_parameters, output_signs)
|
||||
|
||||
result = Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
result._mlir = mlir
|
||||
result._configuration = configuration
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return result
|
||||
|
||||
def save(self, path: Union[str, Path], via_mlir: bool = False):
|
||||
"""
|
||||
Save the server into the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to save the server
|
||||
|
||||
via_mlir (bool, default = False)
|
||||
export using the MLIR code of the program,
|
||||
this will make the export cross-platform
|
||||
"""
|
||||
|
||||
path = str(path)
|
||||
if path.endswith(".zip"):
|
||||
path = path[: len(path) - 4]
|
||||
|
||||
if via_mlir:
|
||||
if self._mlir is None or self._configuration is None:
|
||||
message = "Loaded server objects cannot be saved again via MLIR"
|
||||
raise RuntimeError(message)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
|
||||
with open(Path(tmp) / "circuit.mlir", "w", encoding="utf-8") as f:
|
||||
f.write(self._mlir)
|
||||
|
||||
with open(Path(tmp) / "input_signs.json", "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.client_specs.input_signs))
|
||||
|
||||
with open(Path(tmp) / "output_signs.json", "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.client_specs.output_signs))
|
||||
|
||||
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self._configuration.__dict__))
|
||||
|
||||
shutil.make_archive(path, "zip", tmp)
|
||||
|
||||
return
|
||||
|
||||
if self._output_dir is None:
|
||||
message = "Just-in-Time compilation cannot be saved"
|
||||
raise RuntimeError(message)
|
||||
|
||||
with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f:
|
||||
f.write(self.client_specs.serialize())
|
||||
|
||||
shutil.make_archive(path, "zip", self._output_dir.name)
|
||||
|
||||
@staticmethod
|
||||
def load(path: Union[str, Path]) -> "Server":
|
||||
"""
|
||||
Load the server from the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to load the server from
|
||||
|
||||
Returns:
|
||||
Server:
|
||||
server loaded from the filesystem
|
||||
"""
|
||||
|
||||
# pylint: disable=consider-using-with
|
||||
output_dir = tempfile.TemporaryDirectory()
|
||||
output_dir_path = Path(output_dir.name)
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
shutil.unpack_archive(path, str(output_dir_path), "zip")
|
||||
|
||||
if (output_dir_path / "circuit.mlir").exists():
|
||||
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
|
||||
mlir = f.read()
|
||||
|
||||
with open(output_dir_path / "input_signs.json", "r", encoding="utf-8") as f:
|
||||
input_signs = json.load(f)
|
||||
assert_that(isinstance(input_signs, list))
|
||||
assert_that(all(isinstance(sign, bool) for sign in input_signs))
|
||||
|
||||
with open(output_dir_path / "output_signs.json", "r", encoding="utf-8") as f:
|
||||
output_signs = json.load(f)
|
||||
assert_that(isinstance(output_signs, list))
|
||||
assert_that(all(isinstance(sign, bool) for sign in output_signs))
|
||||
|
||||
with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
|
||||
configuration = Configuration().fork(**json.load(f))
|
||||
|
||||
return Server.create(mlir, input_signs, output_signs, configuration)
|
||||
|
||||
with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f:
|
||||
client_specs = ClientSpecs.unserialize(f.read())
|
||||
|
||||
support = LibrarySupport.new(
|
||||
str(output_dir_path),
|
||||
generateCppHeader=False,
|
||||
generateStaticLib=False,
|
||||
)
|
||||
compilation_result = support.reload("main")
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
def run(self, args: PublicArguments, evaluation_keys: EvaluationKeys) -> PublicResult:
|
||||
"""
|
||||
Evaluate using encrypted arguments.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
encrypted arguments of the computation
|
||||
|
||||
evaluation_keys (EvaluationKeys):
|
||||
evaluation keys for encrypted computation
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of the computation
|
||||
"""
|
||||
|
||||
return self._support.server_call(self._server_lambda, args, evaluation_keys)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Cleanup the temporary library output directory.
|
||||
"""
|
||||
|
||||
if self._output_dir is not None:
|
||||
self._output_dir.cleanup()
|
||||
|
||||
@property
|
||||
def complexity(self) -> float:
|
||||
"""
|
||||
Get complexity of the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.complexity
|
||||
|
||||
@property
|
||||
def size_of_secret_keys(self) -> int:
|
||||
"""
|
||||
Get size of the secret keys of the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_secret_keys_size
|
||||
|
||||
@property
|
||||
def size_of_bootstrap_keys(self) -> int:
|
||||
"""
|
||||
Get size of the bootstrap keys of the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_bootstrap_keys_size
|
||||
|
||||
@property
|
||||
def size_of_keyswitch_keys(self) -> int:
|
||||
"""
|
||||
Get size of the key switch keys of the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_keyswitch_keys_size
|
||||
|
||||
@property
|
||||
def size_of_inputs(self) -> int:
|
||||
"""
|
||||
Get size of the inputs of the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_inputs_size
|
||||
|
||||
@property
|
||||
def size_of_outputs(self) -> int:
|
||||
"""
|
||||
Get size of the outputs of the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_output_size
|
||||
|
||||
@property
|
||||
def p_error(self) -> int:
|
||||
"""
|
||||
Get the probability of error for each simple TLU (on a scalar).
|
||||
"""
|
||||
return self._compilation_feedback.p_error
|
||||
|
||||
@property
|
||||
def global_p_error(self) -> int:
|
||||
"""
|
||||
Get the probability of having at least one simple TLU error during the entire execution.
|
||||
"""
|
||||
return self._compilation_feedback.global_p_error
|
||||
127
frontends/concrete-python/concrete/numpy/compilation/specs.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Declaration of `ClientSpecs` class.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from concrete.compiler import ClientParameters, PublicArguments, PublicResult
|
||||
|
||||
|
||||
class ClientSpecs:
|
||||
"""
|
||||
ClientSpecs class, to create Client objects.
|
||||
"""
|
||||
|
||||
input_signs: List[bool]
|
||||
client_parameters: ClientParameters
|
||||
output_signs: List[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_signs: List[bool],
|
||||
client_parameters: ClientParameters,
|
||||
output_signs: List[bool],
|
||||
):
|
||||
self.input_signs = input_signs
|
||||
self.client_parameters = client_parameters
|
||||
self.output_signs = output_signs
|
||||
|
||||
def serialize(self) -> str:
|
||||
"""
|
||||
Serialize client specs into a string representation.
|
||||
|
||||
Returns:
|
||||
str:
|
||||
string representation of the client specs
|
||||
"""
|
||||
|
||||
client_parameters_json = json.loads(self.client_parameters.serialize())
|
||||
return json.dumps(
|
||||
{
|
||||
"input_signs": self.input_signs,
|
||||
"client_parameters": client_parameters_json,
|
||||
"output_signs": self.output_signs,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unserialize(serialized_client_specs: str) -> "ClientSpecs":
|
||||
"""
|
||||
Create client specs from its string representation.
|
||||
|
||||
Args:
|
||||
serialized_client_specs (str):
|
||||
client specs to unserialize
|
||||
|
||||
Returns:
|
||||
ClientSpecs:
|
||||
unserialized client specs
|
||||
"""
|
||||
|
||||
raw_specs = json.loads(serialized_client_specs)
|
||||
|
||||
client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8")
|
||||
client_parameters = ClientParameters.unserialize(client_parameters_bytes)
|
||||
|
||||
return ClientSpecs(raw_specs["input_signs"], client_parameters, raw_specs["output_signs"])
|
||||
|
||||
def serialize_public_args(self, args: PublicArguments) -> bytes: # pylint: disable=no-self-use
|
||||
"""
|
||||
Serialize public arguments to bytes.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
public arguments to serialize
|
||||
|
||||
Returns:
|
||||
bytes:
|
||||
serialized public arguments
|
||||
"""
|
||||
|
||||
return args.serialize()
|
||||
|
||||
def unserialize_public_args(self, serialized_args: bytes) -> PublicArguments:
|
||||
"""
|
||||
Unserialize public arguments from bytes.
|
||||
|
||||
Args:
|
||||
serialized_args (bytes):
|
||||
serialized public arguments
|
||||
|
||||
Returns:
|
||||
PublicArguments:
|
||||
unserialized public arguments
|
||||
"""
|
||||
|
||||
return PublicArguments.unserialize(self.client_parameters, serialized_args)
|
||||
|
||||
def serialize_public_result(self, result: PublicResult) -> bytes: # pylint: disable=no-self-use
|
||||
"""
|
||||
Serialize public result to bytes.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
public result to serialize
|
||||
|
||||
Returns:
|
||||
bytes:
|
||||
serialized public result
|
||||
"""
|
||||
|
||||
return result.serialize()
|
||||
|
||||
def unserialize_public_result(self, serialized_result: bytes) -> PublicResult:
|
||||
"""
|
||||
Unserialize public result from bytes.
|
||||
|
||||
Args:
|
||||
serialized_result (bytes):
|
||||
serialized public result
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
unserialized public result
|
||||
"""
|
||||
|
||||
return PublicResult.unserialize(self.client_parameters, serialized_result)
|
||||
685
frontends/concrete-python/concrete/numpy/compilation/utils.py
Normal file
@@ -0,0 +1,685 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to compilation.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..dtypes import Float, Integer
|
||||
from ..representation import Graph, Node, Operation
|
||||
from .artifacts import DebugArtifacts
|
||||
|
||||
# ruff: noqa: ERA001
|
||||
|
||||
|
||||
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
|
||||
"""
|
||||
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to search and update
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
compilation artifacts to store information about the fusing process
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if there is a subgraph which needs to be fused cannot be fused
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
processed_terminal_nodes: Set[Node] = set()
|
||||
|
||||
fusing_floats = True
|
||||
while True:
|
||||
subgraph_to_fuse = (
|
||||
find_float_subgraph_with_unique_terminal_node(
|
||||
graph,
|
||||
processed_terminal_nodes,
|
||||
)
|
||||
if fusing_floats
|
||||
else find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
|
||||
graph,
|
||||
processed_terminal_nodes,
|
||||
)
|
||||
)
|
||||
|
||||
if subgraph_to_fuse is None:
|
||||
if fusing_floats:
|
||||
fusing_floats = False
|
||||
processed_terminal_nodes.clear()
|
||||
continue
|
||||
break
|
||||
|
||||
all_nodes, start_nodes, terminal_node = subgraph_to_fuse
|
||||
processed_terminal_nodes.add(terminal_node)
|
||||
|
||||
fused_node, node_before_subgraph = convert_subgraph_to_subgraph_node(
|
||||
graph,
|
||||
all_nodes,
|
||||
start_nodes,
|
||||
terminal_node,
|
||||
)
|
||||
nx_graph.add_node(fused_node)
|
||||
|
||||
if terminal_node in graph.output_nodes.values():
|
||||
output_node_to_idx: Dict[Node, List[int]] = {
|
||||
out_node: [] for out_node in graph.output_nodes.values()
|
||||
}
|
||||
for output_idx, output_node in graph.output_nodes.items():
|
||||
output_node_to_idx[output_node].append(output_idx)
|
||||
|
||||
for output_idx in output_node_to_idx.get(terminal_node, []):
|
||||
graph.output_nodes[output_idx] = fused_node
|
||||
|
||||
terminal_node_succ = list(nx_graph.successors(terminal_node))
|
||||
for succ in terminal_node_succ:
|
||||
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
|
||||
for edge_key, edge_data in succ_edge_data.items():
|
||||
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
|
||||
new_edge_data = deepcopy(edge_data)
|
||||
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
|
||||
|
||||
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
|
||||
|
||||
graph.prune_useless_nodes()
|
||||
if artifacts is not None:
|
||||
artifacts.add_graph("after-fusing", graph)
|
||||
|
||||
|
||||
def find_float_subgraph_with_unique_terminal_node(
|
||||
graph: Graph,
|
||||
processed_terminal_nodes: Set[Node],
|
||||
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
||||
"""
|
||||
Find a subgraph with float computations that end with an integer output.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to search
|
||||
|
||||
processed_terminal_nodes (Set[Node]):
|
||||
set of terminal nodes which have already been searched for float subgraphs
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
||||
None if there are no such subgraphs,
|
||||
tuple containing all nodes in the subgraph, start nodes of the subgraph,
|
||||
and terminal node of the subgraph otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
terminal_nodes = (
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
if (
|
||||
node not in processed_terminal_nodes
|
||||
and any(isinstance(input.dtype, Float) for input in node.inputs)
|
||||
and isinstance(node.output.dtype, Integer)
|
||||
)
|
||||
)
|
||||
try:
|
||||
terminal_node = next(terminal_nodes)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
all_nodes: Dict[Node, None] = {}
|
||||
|
||||
start_single_int_output_nodes_search_from = terminal_node
|
||||
while True:
|
||||
all_nodes, start_nodes = find_closest_integer_output_nodes(
|
||||
graph,
|
||||
[start_single_int_output_nodes_search_from],
|
||||
all_nodes,
|
||||
)
|
||||
|
||||
variable_start_nodes = [
|
||||
start_node for start_node in start_nodes if start_node.operation != Operation.Constant
|
||||
]
|
||||
if len(variable_start_nodes) == 1:
|
||||
break
|
||||
|
||||
# find a common ancestor as we need a single variable input node
|
||||
# lca == lowest common ancestor
|
||||
lca = find_single_lca(graph, variable_start_nodes)
|
||||
|
||||
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
||||
if lca is None:
|
||||
break
|
||||
|
||||
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
|
||||
all_nodes = add_nodes_from_to(graph, start_nodes, {lca: None}, all_nodes)
|
||||
|
||||
# if `lca` is a valid starting node for fusing break
|
||||
if isinstance(lca.output.dtype, Integer):
|
||||
# `lca` is the new start node
|
||||
start_nodes = {lca: None}
|
||||
break
|
||||
|
||||
# otherwise, push a little further
|
||||
# (e.g., if there is a node just before, which has an integer output)
|
||||
start_single_int_output_nodes_search_from = lca
|
||||
|
||||
return all_nodes, start_nodes, terminal_node
|
||||
|
||||
|
||||
def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
|
||||
graph: Graph,
|
||||
processed_terminal_nodes: Set[Node],
|
||||
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
||||
"""
|
||||
Find a subgraph with a tlu computation that has multiple variable inputs \
|
||||
where all variable inputs share a common ancestor.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to search
|
||||
|
||||
processed_terminal_nodes (Set[Node]):
|
||||
set of terminal nodes which have already been searched for tlu subgraphs
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
||||
None if there are no such subgraphs,
|
||||
tuple containing all nodes in the subgraph, start nodes of the subgraph,
|
||||
and terminal node of the subgraph otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
terminal_nodes = (
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
if (
|
||||
node not in processed_terminal_nodes
|
||||
and node.converted_to_table_lookup
|
||||
and all(isinstance(input.dtype, Integer) for input in node.inputs)
|
||||
and isinstance(node.output.dtype, Integer)
|
||||
and len(
|
||||
[
|
||||
pred
|
||||
for pred in nx_graph.predecessors(node)
|
||||
if pred.operation != Operation.Constant
|
||||
]
|
||||
)
|
||||
> 1
|
||||
)
|
||||
)
|
||||
try:
|
||||
terminal_node = next(terminal_nodes)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
all_nodes: Dict[Node, None] = {}
|
||||
|
||||
while True:
|
||||
variable_start_nodes = list(nx_graph.predecessors(terminal_node))
|
||||
|
||||
# find a common ancestor as we need a single variable input node
|
||||
# lca == lowest common ancestor
|
||||
lca = find_single_lca(graph, variable_start_nodes)
|
||||
|
||||
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
||||
if lca is None:
|
||||
start_nodes = {node: None for node in variable_start_nodes}
|
||||
all_nodes = {node: None for node in variable_start_nodes + [terminal_node]}
|
||||
break
|
||||
|
||||
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
|
||||
all_nodes = add_nodes_from_to(
|
||||
graph,
|
||||
list(nx_graph.predecessors(terminal_node)),
|
||||
{lca: None},
|
||||
all_nodes,
|
||||
)
|
||||
all_nodes[terminal_node] = None
|
||||
|
||||
# if `lca` is a valid starting node for fusing break
|
||||
if isinstance(lca.output.dtype, Integer):
|
||||
# `lca` is the new start node
|
||||
start_nodes = {lca: None}
|
||||
break
|
||||
|
||||
return all_nodes, start_nodes, terminal_node
|
||||
|
||||
|
||||
def find_single_lca(graph: Graph, nodes: List[Node]) -> Optional[Node]:
|
||||
"""
|
||||
Find the single lowest common ancestor of a list of nodes.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to search for single lca
|
||||
|
||||
nodes (List[Node]):
|
||||
nodes to find the single lca of
|
||||
|
||||
Returns
|
||||
Optional[Node]:
|
||||
single lca if it exists, None otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
# find all ancestors of `nodes`
|
||||
# nodes themselves need to be in this set because the single lca can be within `nodes`
|
||||
all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes]
|
||||
|
||||
# find common ancestors among `nodes`
|
||||
# if the single lca exists, it's in this set
|
||||
common_ancestors = {
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
if node.operation != Operation.Constant
|
||||
and all(node in ancestors for ancestors in all_ancestors)
|
||||
}
|
||||
|
||||
# iterate over every node in the graph reversed topological order
|
||||
# this is to ensure result, if found, is the single "lowest" common ancestor
|
||||
for candidate in reversed(list(nx.topological_sort(nx_graph))):
|
||||
# check if node is a common ancestor of all `nodes`
|
||||
if candidate not in common_ancestors:
|
||||
# if not, it cannot be the single lca
|
||||
continue
|
||||
|
||||
# check if node is a single common ancestor of `nodes`
|
||||
if is_single_common_ancestor(graph, candidate, nodes):
|
||||
# if so, it's the single lca of `nodes`
|
||||
# so return it
|
||||
return candidate
|
||||
|
||||
# if none of the nodes in `common_ancestors` is the single lca
|
||||
# there is no single lca of this set of nodes, so return None
|
||||
return None
|
||||
|
||||
|
||||
def is_single_common_ancestor(
|
||||
graph: Graph,
|
||||
candidate: Node,
|
||||
nodes: List[Node],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a node is the single common ancestor of a list of nodes.
|
||||
|
||||
Note that this function doesn't care about `lowest` property of `lca`.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to perform the check
|
||||
|
||||
candidate (Node):
|
||||
node to determine single common ancestor status
|
||||
|
||||
nodes (List[Node]):
|
||||
nodes to determine single common ancestor status against
|
||||
|
||||
Returns
|
||||
bool:
|
||||
True if `candidate` is a single common ancestor of `nodes`, False otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
# create a subgraph with `candidate` node
|
||||
subgraph = nx.DiGraph()
|
||||
subgraph.add_node(candidate)
|
||||
|
||||
# iterate over `nodes` to add them to the subgraph
|
||||
# along with every path from `candidate` to them
|
||||
for node in nodes:
|
||||
subgraph.add_node(node)
|
||||
for path in nx.all_simple_paths(nx_graph, source=candidate, target=node):
|
||||
nx.add_path(subgraph, path)
|
||||
|
||||
# iterate over the nodes of the subgraph
|
||||
for node in subgraph.nodes():
|
||||
# the condition below doesn't apply to `candidate`
|
||||
# as its predecessors are not in the subgraph
|
||||
if node == candidate:
|
||||
continue
|
||||
|
||||
# find number of predecessors in the subgraph and in the original graph
|
||||
# except constant nodes in the original graph as
|
||||
# - they are not in the subgraph
|
||||
# - they don't affect fusability status
|
||||
predecessor_count_in_subgraph = len(list(subgraph.predecessors(node)))
|
||||
predecessor_count_in_nx_graph = len(
|
||||
[pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant]
|
||||
)
|
||||
|
||||
# see if number of predecessors are different
|
||||
if predecessor_count_in_subgraph != predecessor_count_in_nx_graph:
|
||||
# if so, `candidate` cannot be a single common ancestor
|
||||
# reasoning for is explained below
|
||||
return False
|
||||
|
||||
# if every node in the subgraph has the same number of predecessors
|
||||
# as in the original graph `candidate` is in fact a single common ancestor
|
||||
return True
|
||||
|
||||
# Here is why this function works.
|
||||
#
|
||||
# Legend:
|
||||
# - /|\- = Edge
|
||||
# - (...) = Intermediate Node
|
||||
# - {...} = Candidate Node
|
||||
# - [...] = Node of which single common ancestor is searched
|
||||
# - {[...]} = Both Candidate Node and Node of which single common ancestor is searched
|
||||
#
|
||||
# Consider the folowing graph:
|
||||
#
|
||||
# (3) (x) (2)
|
||||
# \ / \ /
|
||||
# [{*}] (/)
|
||||
# \ /
|
||||
# [+]
|
||||
#
|
||||
# - Operation: (x * 3) + (x / 2)
|
||||
# - Candidate: {*}
|
||||
# - Nodes: [*] and [+]
|
||||
#
|
||||
# So we want to know if multiplication node is a single common ancestor of
|
||||
# multiplication and addition nodes. The result is no in this case for our purposes.
|
||||
#
|
||||
# Once you apply the subgraph creation above, you'll get the following graph:
|
||||
#
|
||||
# (*)
|
||||
# |
|
||||
# (+)
|
||||
#
|
||||
# In this subgraph, addition node only have a single predecessor,
|
||||
# which means there is path leading to the addition node and that path doesn't include
|
||||
# the multiplication node, so we conclude multiplication node is not a single common ancestor
|
||||
#
|
||||
# Now, consider the folowing graph:
|
||||
#
|
||||
# (3) {x} (2)
|
||||
# \ / \ /
|
||||
# [*] (/)
|
||||
# \ /
|
||||
# [+]
|
||||
#
|
||||
# - Operation: (x * 3) + (x / 2)
|
||||
# - Candidate: {x}
|
||||
# - Nodes: [*] and [+]
|
||||
#
|
||||
# So we want to know if the input node 'x' is the single common ancestor of
|
||||
# multiplication and addition nodes. The result is yes in this case.
|
||||
#
|
||||
# Once you apply the subgraph creation above, you'll get the following graph:
|
||||
#
|
||||
# {x}
|
||||
# / \
|
||||
# [*] (/)
|
||||
# \ /
|
||||
# [+]
|
||||
#
|
||||
# In this subgraph, every node except the candidate node
|
||||
# will keep all of their non-constant predecessors,
|
||||
# which means all of their non-constant predecessors originated
|
||||
# from the `candidate`, so it's a single common anscestor.
|
||||
#
|
||||
# When you think about it, this implementation makes a lot of sense for our purposes
|
||||
# It basically determines if `nodes` "solely" depend on the `candidate`,
|
||||
# which is the condition for fusing.
|
||||
|
||||
|
||||
def find_closest_integer_output_nodes(
|
||||
graph: Graph,
|
||||
start_nodes: List[Node],
|
||||
all_nodes: Dict[Node, None],
|
||||
) -> Tuple[Dict[Node, None], Dict[Node, None]]:
|
||||
"""
|
||||
Find the closest upstream integer output nodes to a set of start nodes in a graph.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to search
|
||||
|
||||
start_nodes (List[Node]):
|
||||
nodes from which to start the search
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
set of nodes to be extended with visited nodes during the search
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[Node, None], Dict[Node, None]]:
|
||||
tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes`
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
closest_integer_output_nodes: Dict[Node, None] = {}
|
||||
visited_nodes: Set[Node] = set()
|
||||
|
||||
current_nodes = {start_node: None for start_node in start_nodes}
|
||||
while current_nodes:
|
||||
next_nodes: Dict[Node, None] = {}
|
||||
for node in current_nodes:
|
||||
if node not in visited_nodes:
|
||||
visited_nodes.add(node)
|
||||
|
||||
all_nodes.update({node: None})
|
||||
for pred in nx_graph.predecessors(node):
|
||||
if isinstance(pred.output.dtype, Integer):
|
||||
closest_integer_output_nodes.update({pred: None})
|
||||
all_nodes.update({pred: None})
|
||||
else:
|
||||
next_nodes.update({pred: None})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return all_nodes, closest_integer_output_nodes
|
||||
|
||||
|
||||
def add_nodes_from_to(
|
||||
graph: Graph,
|
||||
from_nodes: Iterable[Node],
|
||||
to_nodes: Dict[Node, None],
|
||||
all_nodes: Dict[Node, None],
|
||||
) -> Dict[Node, None]:
|
||||
"""
|
||||
Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to traverse
|
||||
|
||||
from_nodes (Iterable[Node]):
|
||||
nodes from which extending `all_nodes` start
|
||||
|
||||
to_nodes (Dict[Node, None]):
|
||||
nodes to which extending `all_nodes` stop
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
nodes to be extended
|
||||
|
||||
Returns:
|
||||
Dict[Node, None]:
|
||||
extended `all_nodes`
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
all_nodes.update(to_nodes)
|
||||
visited_nodes: Set[Node] = set()
|
||||
|
||||
current_nodes = {from_node: None for from_node in from_nodes}
|
||||
while current_nodes:
|
||||
next_nodes: Dict[Node, None] = {}
|
||||
for node in current_nodes:
|
||||
if node not in visited_nodes:
|
||||
visited_nodes.add(node)
|
||||
|
||||
all_nodes.update({node: None})
|
||||
if node not in to_nodes:
|
||||
predecessors = nx_graph.predecessors(node)
|
||||
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return all_nodes
|
||||
|
||||
|
||||
def convert_subgraph_to_subgraph_node(
|
||||
graph: Graph,
|
||||
all_nodes: Dict[Node, None],
|
||||
start_nodes: Dict[Node, None],
|
||||
terminal_node: Node,
|
||||
) -> Tuple[Node, Node]:
|
||||
"""
|
||||
Convert a subgraph to Operation.Generic node.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
orginal graph
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
all nodes in the subgraph
|
||||
|
||||
start_nodes (Dict[Node, None]):
|
||||
start nodes of the subgraph
|
||||
|
||||
terminal_node (Node):
|
||||
terminal node of the subgraph
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if subgraph is not fusable
|
||||
|
||||
Returns:
|
||||
Tuple[Node, Node]:
|
||||
None if the subgraph cannot be fused,
|
||||
subgraph node and its predecessor otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
|
||||
if len(variable_input_nodes) != 1:
|
||||
base_highlighted_nodes = {
|
||||
node: ["within this subgraph", node.location] for node in all_nodes
|
||||
}
|
||||
for variable_input_node in variable_input_nodes:
|
||||
base_highlighted_nodes[variable_input_node] = [
|
||||
"this is one of the input nodes",
|
||||
variable_input_node.location,
|
||||
]
|
||||
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because it has multiple input nodes\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
|
||||
)
|
||||
|
||||
variable_input_node = variable_input_nodes[0]
|
||||
check_subgraph_fusability(graph, all_nodes, variable_input_node)
|
||||
|
||||
nx_subgraph = nx.MultiDiGraph(nx_graph)
|
||||
nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes]
|
||||
nx_subgraph.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
subgraph_variable_input_node = Node.input("input", deepcopy(variable_input_node.output))
|
||||
nx_subgraph.add_node(subgraph_variable_input_node)
|
||||
|
||||
subgraph_variable_input_node.location = variable_input_node.location
|
||||
subgraph_variable_input_node.tag = variable_input_node.tag
|
||||
subgraph_variable_input_node.created_at = variable_input_node.created_at
|
||||
|
||||
variable_input_node_successors = {
|
||||
node: None for node in all_nodes if node in nx_graph.succ[variable_input_node]
|
||||
}
|
||||
for successor in variable_input_node_successors:
|
||||
edges = deepcopy(nx_subgraph.get_edge_data(variable_input_node, successor))
|
||||
for edge_key, edge_data in edges.items():
|
||||
nx_subgraph.remove_edge(variable_input_node, successor, key=edge_key)
|
||||
new_edge_data = deepcopy(edge_data)
|
||||
nx_subgraph.add_edge(
|
||||
subgraph_variable_input_node,
|
||||
successor,
|
||||
key=edge_key,
|
||||
**new_edge_data,
|
||||
)
|
||||
|
||||
original_location = terminal_node.location
|
||||
original_tag = terminal_node.tag
|
||||
original_created_at = terminal_node.created_at
|
||||
|
||||
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
|
||||
subgraph_node = Node.generic(
|
||||
"subgraph",
|
||||
subgraph_variable_input_node.inputs,
|
||||
terminal_node.output,
|
||||
lambda x, subgraph, terminal_node: subgraph.evaluate(x)[terminal_node],
|
||||
kwargs={
|
||||
"subgraph": subgraph,
|
||||
"terminal_node": terminal_node,
|
||||
},
|
||||
)
|
||||
|
||||
subgraph_node.location = original_location
|
||||
subgraph_node.tag = original_tag
|
||||
subgraph_node.created_at = original_created_at
|
||||
|
||||
return subgraph_node, variable_input_node
|
||||
|
||||
|
||||
def check_subgraph_fusability(
|
||||
graph: Graph,
|
||||
all_nodes: Dict[Node, None],
|
||||
variable_input_node: Node,
|
||||
):
|
||||
"""
|
||||
Determine if a subgraph can be fused.
|
||||
|
||||
e.g.,
|
||||
|
||||
shuffling or reshaping a tensor make fusing impossible as there should be a one-to-one mapping
|
||||
between each cell of the input and each cell of the output for table lookups
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
original graph
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
all nodes in the subgraph
|
||||
|
||||
variable_input_node (Node):
|
||||
variable input node to the subgraph
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if subgraph is not fusable
|
||||
"""
|
||||
|
||||
base_highlighted_nodes = {node: ["within this subgraph", node.location] for node in all_nodes}
|
||||
base_highlighted_nodes[variable_input_node] = [
|
||||
"with this input node",
|
||||
variable_input_node.location,
|
||||
]
|
||||
|
||||
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
|
||||
for node in non_constant_nodes:
|
||||
if node == variable_input_node:
|
||||
continue
|
||||
|
||||
if not node.is_fusable:
|
||||
base_highlighted_nodes[node] = ["this node is not fusable", node.location]
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because of a node, which is marked explicitly as non-fusable\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
|
||||
)
|
||||
|
||||
if node.output.shape != variable_input_node.output.shape:
|
||||
base_highlighted_nodes[node] = [
|
||||
"this node has a different shape than the input node",
|
||||
node.location,
|
||||
]
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because of a node, which is has a different shape than the input node\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Define available data types and their semantics.
|
||||
"""
|
||||
|
||||
from .base import BaseDataType
|
||||
from .float import Float
|
||||
from .integer import Integer, SignedInteger, UnsignedInteger
|
||||
17
frontends/concrete-python/concrete/numpy/dtypes/base.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Declaration of `BaseDataType` abstract class.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseDataType(ABC):
|
||||
"""BaseDataType abstract class, to form a basis for data types."""
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other: object) -> bool:
|
||||
pass # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
pass # pragma: no cover
|
||||
31
frontends/concrete-python/concrete/numpy/dtypes/float.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Declaration of `Float` class.
|
||||
"""
|
||||
|
||||
from .base import BaseDataType
|
||||
|
||||
|
||||
class Float(BaseDataType):
|
||||
"""
|
||||
Float class, to represent floating point numbers.
|
||||
"""
|
||||
|
||||
bit_width: int
|
||||
|
||||
def __init__(self, bit_width: int):
|
||||
super().__init__()
|
||||
|
||||
if bit_width not in [16, 32, 64]:
|
||||
message = (
|
||||
f"Float({repr(bit_width)}) is not supported "
|
||||
f"(bit width must be one of 16, 32 or 64)"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
self.bit_width = bit_width
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.bit_width == other.bit_width
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"float{self.bit_width}"
|
||||
155
frontends/concrete-python/concrete/numpy/dtypes/integer.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Declaration of `Integer` class.
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseDataType
|
||||
|
||||
|
||||
class Integer(BaseDataType):
|
||||
"""
|
||||
Integer class, to represent integers.
|
||||
"""
|
||||
|
||||
is_signed: bool
|
||||
bit_width: int
|
||||
|
||||
@staticmethod
|
||||
def that_can_represent(value: Any, force_signed: bool = False) -> "Integer":
|
||||
"""
|
||||
Get the minimal `Integer` that can represent `value`.
|
||||
|
||||
Args:
|
||||
value (Any):
|
||||
value that needs to be represented
|
||||
|
||||
force_signed (bool, default = False):
|
||||
whether to force signed integers or not
|
||||
|
||||
Returns:
|
||||
Integer:
|
||||
minimal `Integer` that can represent `value`
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if `value` cannot be represented by `Integer`
|
||||
"""
|
||||
|
||||
lower_bound: int
|
||||
upper_bound: int
|
||||
|
||||
if isinstance(value, list):
|
||||
try:
|
||||
value = np.array(value)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to convert the list to np.ndarray
|
||||
# if it fails we raise the exception at the else branch below
|
||||
pass
|
||||
|
||||
if isinstance(value, (int, np.integer)):
|
||||
lower_bound = int(value)
|
||||
upper_bound = int(value)
|
||||
elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.integer):
|
||||
lower_bound = int(value.min())
|
||||
upper_bound = int(value.max())
|
||||
else:
|
||||
message = f"Integer cannot represent {repr(value)}"
|
||||
raise ValueError(message)
|
||||
|
||||
def bits_to_represent_int(value: int, force_signed: bool) -> int:
|
||||
bits: int
|
||||
|
||||
if value == 0:
|
||||
return 1
|
||||
|
||||
if value < 0:
|
||||
bits = int(math.ceil(math.log2(abs(value)))) + 1
|
||||
else:
|
||||
bits = int(math.ceil(math.log2(value + 1)))
|
||||
if force_signed:
|
||||
bits += 1
|
||||
|
||||
return bits
|
||||
|
||||
is_signed = force_signed or lower_bound < 0
|
||||
bit_width = (
|
||||
bits_to_represent_int(lower_bound, is_signed)
|
||||
if lower_bound == upper_bound
|
||||
else max(
|
||||
bits_to_represent_int(lower_bound, is_signed),
|
||||
bits_to_represent_int(upper_bound, is_signed),
|
||||
)
|
||||
)
|
||||
|
||||
return Integer(is_signed, bit_width)
|
||||
|
||||
def __init__(self, is_signed: bool, bit_width: int):
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(bit_width, int) or bit_width <= 0:
|
||||
integer_str = "SignedInteger" if is_signed else "UnsignedInteger"
|
||||
message = (
|
||||
f"{integer_str}({repr(bit_width)}) is not supported "
|
||||
f"(bit width must be a positive integer)"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
self.is_signed = is_signed
|
||||
self.bit_width = bit_width
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.is_signed == other.is_signed
|
||||
and self.bit_width == other.bit_width
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{('int' if self.is_signed else 'uint')}{self.bit_width}"
|
||||
|
||||
def min(self) -> int:
|
||||
"""
|
||||
Get the minumum value that can be represented by the `Integer`.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
minumum value that can be represented by the `Integer`
|
||||
"""
|
||||
|
||||
return 0 if not self.is_signed else -(2 ** (self.bit_width - 1))
|
||||
|
||||
def max(self) -> int:
|
||||
"""
|
||||
Get the maximum value that can be represented by the `Integer`.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
maximum value that can be represented by the `Integer`
|
||||
"""
|
||||
|
||||
return (2**self.bit_width) - 1 if not self.is_signed else (2 ** (self.bit_width - 1)) - 1
|
||||
|
||||
def can_represent(self, value: int) -> bool:
|
||||
"""
|
||||
Get whether `value` can be represented by the `Integer` or not.
|
||||
|
||||
Args:
|
||||
value (int):
|
||||
value to check representability
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if `value` is representable by the `integer`, False otherwise
|
||||
"""
|
||||
|
||||
return self.min() <= value <= self.max()
|
||||
|
||||
|
||||
SignedInteger = partial(Integer, True)
|
||||
|
||||
UnsignedInteger = partial(Integer, False)
|
||||
74
frontends/concrete-python/concrete/numpy/dtypes/utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to data types.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from .base import BaseDataType
|
||||
from .float import Float
|
||||
from .integer import Integer, SignedInteger, UnsignedInteger
|
||||
|
||||
|
||||
def combine_dtypes(dtypes: List[BaseDataType]) -> BaseDataType:
|
||||
"""
|
||||
Get the 'BaseDataType' that can represent a set of 'BaseDataType's.
|
||||
|
||||
Args:
|
||||
dtypes (List[BaseDataType]):
|
||||
dtypes to combine
|
||||
|
||||
Returns:
|
||||
BaseDataType:
|
||||
dtype that can hold all the given dtypes (potentially lossy)
|
||||
"""
|
||||
|
||||
assert_that(len(dtypes) != 0)
|
||||
assert_that(all(isinstance(dtype, (Integer, Float)) for dtype in dtypes))
|
||||
|
||||
def combine_2_dtypes(dtype1: BaseDataType, dtype2: BaseDataType) -> BaseDataType:
|
||||
result: BaseDataType = dtype1
|
||||
|
||||
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
|
||||
if dtype1.is_signed and dtype2.is_signed:
|
||||
result = SignedInteger(max_bits)
|
||||
|
||||
elif not dtype1.is_signed and not dtype2.is_signed:
|
||||
result = UnsignedInteger(max_bits)
|
||||
|
||||
elif dtype1.is_signed and not dtype2.is_signed:
|
||||
# if dtype2 has the bigger bit_width,
|
||||
# we need a signed integer that can hold
|
||||
# it, so add 1 bit of sign to its bit_width
|
||||
if dtype2.bit_width >= dtype1.bit_width:
|
||||
new_bit_width = dtype2.bit_width + 1
|
||||
result = SignedInteger(new_bit_width)
|
||||
else:
|
||||
result = SignedInteger(dtype1.bit_width)
|
||||
|
||||
elif not dtype1.is_signed and dtype2.is_signed:
|
||||
# Same as above, with dtype1 and dtype2 switched around
|
||||
if dtype1.bit_width >= dtype2.bit_width:
|
||||
new_bit_width = dtype1.bit_width + 1
|
||||
result = SignedInteger(new_bit_width)
|
||||
else:
|
||||
result = SignedInteger(dtype2.bit_width)
|
||||
|
||||
elif isinstance(dtype1, Float) and isinstance(dtype2, Float):
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
result = Float(max_bits)
|
||||
|
||||
elif isinstance(dtype1, Float):
|
||||
result = dtype1
|
||||
|
||||
elif isinstance(dtype2, Float):
|
||||
result = dtype2
|
||||
|
||||
return result
|
||||
|
||||
result = dtypes[0]
|
||||
for other in dtypes[1:]:
|
||||
result = combine_2_dtypes(result, other)
|
||||
return result
|
||||
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Provide additional features that are not present in numpy.
|
||||
"""
|
||||
|
||||
from .array import array
|
||||
from .ones import one, ones
|
||||
from .round_bit_pattern import AutoRounder, round_bit_pattern
|
||||
from .table import LookupTable
|
||||
from .tag import tag
|
||||
from .univariate import univariate
|
||||
from .zeros import zero, zeros
|
||||
59
frontends/concrete-python/concrete/numpy/extensions/array.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Declaration of `array` function, to simplify creation of encrypted arrays.
|
||||
"""
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes.utils import combine_dtypes
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def array(values: Any) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted array from either encrypted or clear values.
|
||||
|
||||
Args:
|
||||
values (Any):
|
||||
array like object compatible with numpy to construct the resulting encrypted array
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray with values otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
is_tracing = Tracer._is_tracing
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if not isinstance(values, np.ndarray):
|
||||
values = np.array(values)
|
||||
|
||||
if not is_tracing:
|
||||
return values
|
||||
|
||||
shape = values.shape
|
||||
values = values.flatten()
|
||||
|
||||
for i, value in enumerate(values):
|
||||
if not isinstance(value, Tracer):
|
||||
values[i] = Tracer.sanitize(value)
|
||||
|
||||
if not values[i].output.is_scalar:
|
||||
message = "Encrypted arrays can only be created from scalars"
|
||||
raise ValueError(message)
|
||||
|
||||
dtype = combine_dtypes([value.output.dtype for value in values])
|
||||
is_encrypted = True
|
||||
|
||||
computation = Node.generic(
|
||||
"array",
|
||||
[value.output for value in values],
|
||||
Value(dtype, shape, is_encrypted),
|
||||
lambda *args: np.array(args).reshape(shape),
|
||||
)
|
||||
return Tracer(computation, values)
|
||||
56
frontends/concrete-python/concrete/numpy/extensions/ones.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Declaration of `ones` and `one` functions, to simplify creation of encrypted ones.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def ones(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted array of ones.
|
||||
|
||||
Args:
|
||||
shape (Tuple[int, ...]):
|
||||
shape of the array
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray filled with ones otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
is_tracing = Tracer._is_tracing
|
||||
# pylint: enable=protected-access
|
||||
|
||||
numpy_ones = np.ones(shape, dtype=np.int64)
|
||||
|
||||
if is_tracing:
|
||||
computation = Node.generic(
|
||||
"ones",
|
||||
[],
|
||||
Value.of(numpy_ones, is_encrypted=True),
|
||||
lambda: np.ones(shape, dtype=np.int64),
|
||||
)
|
||||
return Tracer(computation, [])
|
||||
|
||||
return numpy_ones
|
||||
|
||||
|
||||
def one() -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted scalar with the value of one.
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray with one otherwise
|
||||
"""
|
||||
|
||||
return ones(())
|
||||
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Declaration of `round_bit_pattern` function, to provide an interface for rounded table lookups.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterable, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
local = threading.local()
|
||||
|
||||
# pylint: disable=protected-access
|
||||
local._is_adjusting = False
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class Adjusting(BaseException):
|
||||
"""
|
||||
Adjusting class, to be used as early stop signal during adjustment.
|
||||
"""
|
||||
|
||||
rounder: "AutoRounder"
|
||||
input_min: int
|
||||
input_max: int
|
||||
|
||||
def __init__(self, rounder: "AutoRounder", input_min: int, input_max: int):
|
||||
super().__init__()
|
||||
self.rounder = rounder
|
||||
self.input_min = input_min
|
||||
self.input_max = input_max
|
||||
|
||||
|
||||
class AutoRounder:
|
||||
"""
|
||||
AutoRounder class, to optimize for number of msbs to keep druing round bit pattern operation.
|
||||
"""
|
||||
|
||||
target_msbs: int
|
||||
|
||||
is_adjusted: bool
|
||||
input_min: int
|
||||
input_max: int
|
||||
input_bit_width: int
|
||||
lsbs_to_remove: int
|
||||
|
||||
def __init__(self, target_msbs: int = MAXIMUM_TLU_BIT_WIDTH):
|
||||
# pylint: disable=protected-access
|
||||
if local._is_adjusting:
|
||||
message = (
|
||||
"AutoRounders cannot be constructed during adjustment, "
|
||||
"please construct AutoRounders outside the function and reference it"
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
self.target_msbs = target_msbs
|
||||
|
||||
self.is_adjusted = False
|
||||
self.input_min = 0
|
||||
self.input_max = 0
|
||||
self.input_bit_width = 0
|
||||
self.lsbs_to_remove = 0
|
||||
|
||||
@staticmethod
|
||||
def adjust(function: Callable, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
|
||||
"""
|
||||
Adjust AutoRounders in a function using an inputset.
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access,too-many-branches
|
||||
|
||||
try: # extract underlying function for decorators
|
||||
function = function.function # type: ignore
|
||||
assert callable(function)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if local._is_adjusting:
|
||||
message = "AutoRounders cannot be adjusted recursively"
|
||||
raise RuntimeError(message)
|
||||
|
||||
try:
|
||||
local._is_adjusting = True
|
||||
while True:
|
||||
rounder = None
|
||||
|
||||
for sample in inputset:
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
try:
|
||||
function(*sample)
|
||||
except Adjusting as adjuster:
|
||||
rounder = adjuster.rounder
|
||||
|
||||
rounder.input_min = min(rounder.input_min, adjuster.input_min)
|
||||
rounder.input_max = max(rounder.input_max, adjuster.input_max)
|
||||
|
||||
input_value = Value.of([rounder.input_min, rounder.input_max])
|
||||
assert isinstance(input_value.dtype, Integer)
|
||||
rounder.input_bit_width = input_value.dtype.bit_width
|
||||
|
||||
if rounder.input_bit_width - rounder.lsbs_to_remove > rounder.target_msbs:
|
||||
rounder.lsbs_to_remove = rounder.input_bit_width - rounder.target_msbs
|
||||
else:
|
||||
return
|
||||
|
||||
if rounder is None:
|
||||
message = "AutoRounders cannot be adjusted with an empty inputset"
|
||||
raise ValueError(message)
|
||||
|
||||
rounder.is_adjusted = True
|
||||
|
||||
finally:
|
||||
local._is_adjusting = False
|
||||
|
||||
# pylint: enable=protected-access,too-many-branches
|
||||
|
||||
|
||||
def round_bit_pattern(
|
||||
x: Union[int, np.integer, List, np.ndarray, Tracer],
|
||||
lsbs_to_remove: Union[int, AutoRounder],
|
||||
) -> Union[int, np.integer, List, np.ndarray, Tracer]:
|
||||
"""
|
||||
Round the bit pattern of an integer.
|
||||
|
||||
If `lsbs_to_remove` is an `AutoRounder`:
|
||||
corresponding integer value will be determined by adjustment process.
|
||||
|
||||
x = 0b_0000_0000 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0001 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0010 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0011 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0100 , lsbs_to_remove = 3 => 0b_0000_1000
|
||||
x = 0b_0000_0101 , lsbs_to_remove = 3 => 0b_0000_1000
|
||||
x = 0b_0000_0110 , lsbs_to_remove = 3 => 0b_0000_1000
|
||||
x = 0b_0000_0111 , lsbs_to_remove = 3 => 0b_0000_1000
|
||||
|
||||
x = 0b_1010_0000 , lsbs_to_remove = 3 => 0b_1010_0000
|
||||
x = 0b_1010_0001 , lsbs_to_remove = 3 => 0b_1010_0000
|
||||
x = 0b_1010_0010 , lsbs_to_remove = 3 => 0b_1010_0000
|
||||
x = 0b_1010_0011 , lsbs_to_remove = 3 => 0b_1010_0000
|
||||
x = 0b_1010_0100 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_0101 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_0110 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_0111 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
|
||||
x = 0b_1010_1000 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_1001 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_1010 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_1011 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_1100 , lsbs_to_remove = 3 => 0b_1011_0000
|
||||
x = 0b_1010_1101 , lsbs_to_remove = 3 => 0b_1011_0000
|
||||
x = 0b_1010_1110 , lsbs_to_remove = 3 => 0b_1011_0000
|
||||
x = 0b_1010_1111 , lsbs_to_remove = 3 => 0b_1011_0000
|
||||
|
||||
x = 0b_1011_1000 , lsbs_to_remove = 3 => 0b_1011_1000
|
||||
x = 0b_1011_1001 , lsbs_to_remove = 3 => 0b_1011_1000
|
||||
x = 0b_1011_1010 , lsbs_to_remove = 3 => 0b_1011_1000
|
||||
x = 0b_1011_1011 , lsbs_to_remove = 3 => 0b_1011_1000
|
||||
x = 0b_1011_1100 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
x = 0b_1011_1101 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
x = 0b_1011_1110 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
x = 0b_1011_1111 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
|
||||
Args:
|
||||
x (Union[int, np.integer, np.ndarray, Tracer]):
|
||||
input to round
|
||||
|
||||
lsbs_to_remove (Union[int, AutoRounder]):
|
||||
number of the least significant bits to remove
|
||||
or an auto rounder object which will be used to determine the integer value
|
||||
|
||||
Returns:
|
||||
Union[int, np.integer, np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
rounded value(s) otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access,too-many-branches
|
||||
|
||||
if isinstance(lsbs_to_remove, AutoRounder):
|
||||
if local._is_adjusting:
|
||||
if not lsbs_to_remove.is_adjusted:
|
||||
raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x))) # type: ignore
|
||||
|
||||
elif not lsbs_to_remove.is_adjusted:
|
||||
message = (
|
||||
"AutoRounders cannot be used before adjustment, "
|
||||
"please call AutoRounder.adjust with the function that will be compiled "
|
||||
"and provide the exact inputset that will be used for compilation"
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
|
||||
lsbs_to_remove = lsbs_to_remove.lsbs_to_remove
|
||||
|
||||
assert isinstance(lsbs_to_remove, int)
|
||||
|
||||
def evaluator(
|
||||
x: Union[int, np.integer, np.ndarray],
|
||||
lsbs_to_remove: int,
|
||||
) -> Union[int, np.integer, np.ndarray]:
|
||||
if lsbs_to_remove == 0:
|
||||
return x
|
||||
|
||||
unit = 1 << lsbs_to_remove
|
||||
half = 1 << lsbs_to_remove - 1
|
||||
rounded = (x + half) // unit
|
||||
return rounded * unit
|
||||
|
||||
if isinstance(x, Tracer):
|
||||
computation = Node.generic(
|
||||
"round_bit_pattern",
|
||||
[x.output],
|
||||
deepcopy(x.output),
|
||||
evaluator,
|
||||
kwargs={"lsbs_to_remove": lsbs_to_remove},
|
||||
)
|
||||
return Tracer(computation, [x])
|
||||
|
||||
if isinstance(x, list): # pragma: no cover
|
||||
try:
|
||||
x = np.array(x)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if isinstance(x, np.ndarray):
|
||||
if not np.issubdtype(x.dtype, np.integer):
|
||||
message = (
|
||||
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
|
||||
)
|
||||
raise TypeError(message)
|
||||
elif not isinstance(x, (int, np.integer)):
|
||||
message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}"
|
||||
raise TypeError(message)
|
||||
|
||||
return evaluator(x, lsbs_to_remove)
|
||||
|
||||
# pylint: enable=protected-access,too-many-branches
|
||||
134
frontends/concrete-python/concrete/numpy/extensions/table.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Declaration of `LookupTable` class.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import BaseDataType, Integer
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
|
||||
|
||||
class LookupTable:
|
||||
"""
|
||||
LookupTable class, to provide a way to do direct table lookups.
|
||||
"""
|
||||
|
||||
table: np.ndarray
|
||||
output_dtype: BaseDataType
|
||||
|
||||
def __init__(self, table: Any):
|
||||
is_valid = True
|
||||
try:
|
||||
self.table = table if isinstance(table, np.ndarray) else np.array(table)
|
||||
except Exception: # pragma: no cover # pylint: disable=broad-except
|
||||
# here we try our best to convert the table to np.ndarray
|
||||
# if it fails we raise the exception at the end of the function
|
||||
is_valid = False
|
||||
|
||||
if is_valid:
|
||||
is_valid = self.table.size > 0
|
||||
|
||||
if is_valid:
|
||||
minimum: int = 0
|
||||
maximum: int = 0
|
||||
|
||||
if np.issubdtype(self.table.dtype, np.integer):
|
||||
minimum = int(self.table.min())
|
||||
maximum = int(self.table.max())
|
||||
if self.table.ndim != 1:
|
||||
is_valid = False
|
||||
else:
|
||||
is_valid = all(isinstance(item, LookupTable) for item in self.table.flat)
|
||||
if is_valid:
|
||||
minimum = int(self.table.flat[0].table.min())
|
||||
maximum = int(self.table.flat[0].table.max())
|
||||
for item in self.table.flat:
|
||||
minimum = min(minimum, item.table.min())
|
||||
maximum = max(maximum, item.table.max())
|
||||
|
||||
self.output_dtype = Integer.that_can_represent([minimum, maximum])
|
||||
|
||||
if not is_valid:
|
||||
message = f"LookupTable cannot be constructed with {repr(table)}"
|
||||
raise ValueError(message)
|
||||
|
||||
def __repr__(self):
|
||||
return str(list(self.table))
|
||||
|
||||
def __getitem__(self, key: Union[int, np.integer, np.ndarray, Tracer]):
|
||||
if not isinstance(key, Tracer):
|
||||
return LookupTable.apply(key, self.table)
|
||||
|
||||
if not isinstance(key.output.dtype, Integer):
|
||||
message = f"LookupTable cannot be looked up with {key.output}"
|
||||
raise ValueError(message)
|
||||
|
||||
table = self.table
|
||||
if not np.issubdtype(self.table.dtype, np.integer):
|
||||
try:
|
||||
table = np.broadcast_to(table, key.output.shape)
|
||||
except Exception as error:
|
||||
message = (
|
||||
f"LookupTable of shape {self.table.shape} "
|
||||
f"cannot be looked up with {key.output}"
|
||||
)
|
||||
raise ValueError(message) from error
|
||||
|
||||
output = deepcopy(key.output)
|
||||
output.dtype = self.output_dtype
|
||||
|
||||
computation = Node.generic(
|
||||
"tlu",
|
||||
[key.output],
|
||||
output,
|
||||
LookupTable.apply,
|
||||
kwargs={"table": table},
|
||||
)
|
||||
return Tracer(computation, [key])
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
key: Union[int, np.integer, np.ndarray],
|
||||
table: np.ndarray,
|
||||
) -> Union[int, np.integer, np.ndarray]:
|
||||
"""
|
||||
Apply lookup table.
|
||||
|
||||
Args:
|
||||
key (Union[int, np.integer, np.ndarray]):
|
||||
lookup key
|
||||
|
||||
table (np.ndarray):
|
||||
lookup table
|
||||
|
||||
Returns:
|
||||
Union[int, np.integer, np.ndarray]:
|
||||
lookup result
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if `table` cannot be looked up with `key`
|
||||
"""
|
||||
|
||||
if not isinstance(key, (int, np.integer, np.ndarray)) or (
|
||||
isinstance(key, np.ndarray) and not np.issubdtype(key.dtype, np.integer)
|
||||
):
|
||||
message = f"LookupTable cannot be looked up with {key}"
|
||||
raise ValueError(message)
|
||||
|
||||
if np.issubdtype(table.dtype, np.integer):
|
||||
return table[key]
|
||||
|
||||
if not isinstance(key, np.ndarray) or key.shape != table.shape:
|
||||
message = f"LookupTable of shape {table.shape} cannot be looked up with {key}"
|
||||
raise ValueError(message)
|
||||
|
||||
flat_result = np.fromiter(
|
||||
(lt.table[k] for lt, k in zip(table.flat, key.flat)),
|
||||
dtype=np.longlong,
|
||||
)
|
||||
return flat_result.reshape(table.shape)
|
||||
24
frontends/concrete-python/concrete/numpy/extensions/tag.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Declaration of `tag` context manager, to allow tagging certain nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
tag_context = threading.local()
|
||||
tag_context.stack = []
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tag(name: str):
|
||||
"""
|
||||
Introduce a new tag to the tag stack.
|
||||
|
||||
Can be nested, and the resulting tag will be `tag1.tag2`.
|
||||
"""
|
||||
|
||||
tag_context.stack.append(name)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tag_context.stack.pop()
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Declaration of `univariate` function.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import BaseDataType, Float
|
||||
from ..representation import Node
|
||||
from ..tracing import ScalarAnnotation, Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def univariate(
|
||||
function: Callable[[Any], Any],
|
||||
outputs: Optional[Union[BaseDataType, Type[ScalarAnnotation]]] = None,
|
||||
) -> Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
|
||||
"""
|
||||
Wrap a univariate function so that it is traced into a single generic node.
|
||||
|
||||
Args:
|
||||
function (Callable[[Any], Any]):
|
||||
univariate function to wrap
|
||||
|
||||
outputs (Optional[Union[BaseDataType, Type[ScalarAnnotation]]], default = None):
|
||||
data type of the result, unused during compilation, required for direct definition
|
||||
|
||||
Returns:
|
||||
Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
|
||||
another univariate function that can be called with a Tracer as well
|
||||
"""
|
||||
|
||||
def wrapper(x: Union[Tracer, Any]) -> Union[Tracer, Any]:
|
||||
"""
|
||||
Evaluate or trace wrapped univariate function.
|
||||
|
||||
Args:
|
||||
x (Union[Tracer, Any]):
|
||||
input of the function
|
||||
|
||||
Returns:
|
||||
Union[Tracer, Any]:
|
||||
result of tracing or evaluation
|
||||
"""
|
||||
|
||||
if isinstance(x, Tracer):
|
||||
dtype = (
|
||||
{64: np.float64, 32: np.float32, 16: np.float16}[x.output.dtype.bit_width]
|
||||
if isinstance(x.output.dtype, Float)
|
||||
else np.int64
|
||||
)
|
||||
|
||||
if x.output.shape == ():
|
||||
sample = dtype(1) # type: ignore
|
||||
else:
|
||||
sample = np.ones(x.output.shape, dtype=dtype)
|
||||
evaluation = function(sample)
|
||||
|
||||
output_value = Value.of(evaluation, is_encrypted=x.output.is_encrypted)
|
||||
if output_value.shape != x.output.shape:
|
||||
message = f"Function {function.__name__} cannot be used with cnp.univariate"
|
||||
raise ValueError(message)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
is_direct = Tracer._is_direct
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if is_direct:
|
||||
if outputs is None:
|
||||
message = (
|
||||
"Univariate extension requires "
|
||||
"`outputs` argument for direct circuit definition "
|
||||
"(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))"
|
||||
)
|
||||
raise ValueError(message)
|
||||
output_value.dtype = outputs if isinstance(outputs, BaseDataType) else outputs.dtype
|
||||
|
||||
computation = Node.generic(
|
||||
function.__name__,
|
||||
[x.output],
|
||||
output_value,
|
||||
lambda x: function(x), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
return Tracer(computation, [x])
|
||||
|
||||
return function(x)
|
||||
|
||||
return wrapper
|
||||
56
frontends/concrete-python/concrete/numpy/extensions/zeros.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Declaration of `zeros` and `zero` functions, to simplify creation of encrypted zeros.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def zeros(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted array of zeros.
|
||||
|
||||
Args:
|
||||
shape (Tuple[int, ...]):
|
||||
shape of the array
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray filled with zeros otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
is_tracing = Tracer._is_tracing
|
||||
# pylint: enable=protected-access
|
||||
|
||||
numpy_zeros = np.zeros(shape, dtype=np.int64)
|
||||
|
||||
if is_tracing:
|
||||
computation = Node.generic(
|
||||
"zeros",
|
||||
[],
|
||||
Value.of(numpy_zeros, is_encrypted=True),
|
||||
lambda: np.zeros(shape, dtype=np.int64),
|
||||
)
|
||||
return Tracer(computation, [])
|
||||
|
||||
return numpy_zeros
|
||||
|
||||
|
||||
def zero() -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted scalar with the value of zero.
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray with zero otherwise
|
||||
"""
|
||||
|
||||
return zeros(())
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Export functions that are used internally by other modules for common things (e.g., assertions).
|
||||
"""
|
||||
32
frontends/concrete-python/concrete/numpy/internal/utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to the entire project.
|
||||
"""
|
||||
|
||||
|
||||
def assert_that(condition: bool, message: str = ""):
|
||||
"""
|
||||
Assert a condition.
|
||||
|
||||
Args:
|
||||
condition (bool):
|
||||
condition to assert
|
||||
|
||||
message (str):
|
||||
message to give to `AssertionError` if the condition does not hold
|
||||
|
||||
Raises:
|
||||
AssertionError:
|
||||
if the condition does not hold
|
||||
"""
|
||||
|
||||
if not condition:
|
||||
raise AssertionError(message)
|
||||
|
||||
|
||||
def unreachable():
|
||||
"""
|
||||
Raise a RuntimeError to indicate unreachable code is entered.
|
||||
"""
|
||||
|
||||
message = "Entered unreachable code"
|
||||
raise RuntimeError(message)
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Provide `computation graph` to `mlir` functionality.
|
||||
"""
|
||||
|
||||
from .graph_converter import GraphConverter
|
||||
from .node_converter import NodeConverter
|
||||
739
frontends/concrete-python/concrete/numpy/mlir/graph_converter.py
Normal file
@@ -0,0 +1,739 @@
|
||||
"""
|
||||
Declaration of `GraphConverter` class.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member,no-name-in-module
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import concrete.lang as concretelang
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from concrete.lang.dialects import fhe, fhelinalg
|
||||
from mlir.dialects import arith, func
|
||||
from mlir.ir import (
|
||||
Attribute,
|
||||
Context,
|
||||
InsertionPoint,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
Location,
|
||||
Module,
|
||||
OpResult,
|
||||
RankedTensorType,
|
||||
)
|
||||
|
||||
from ..dtypes import Integer, SignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import ClearScalar, EncryptedScalar
|
||||
from .node_converter import NodeConverter
|
||||
from .utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
|
||||
# pylint: enable=no-member,no-name-in-module
|
||||
|
||||
|
||||
class GraphConverter:
|
||||
"""
|
||||
GraphConverter class, to convert computation graphs to their MLIR equivalent.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _check_node_convertibility(graph: Graph, node: Node) -> Optional[str]:
|
||||
"""
|
||||
Check node convertibility to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph of the node
|
||||
|
||||
node (Node):
|
||||
node to be checked
|
||||
|
||||
Returns:
|
||||
Optional[str]:
|
||||
None if node is convertible to MLIR, the reason for inconvertibility otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-return-statements,too-many-statements
|
||||
|
||||
inputs = node.inputs
|
||||
output = node.output
|
||||
|
||||
if node.operation == Operation.Constant:
|
||||
assert_that(len(inputs) == 0)
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer constants are supported"
|
||||
|
||||
elif node.operation == Operation.Input:
|
||||
assert_that(len(inputs) == 1)
|
||||
assert_that(inputs[0] == output)
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer inputs are supported"
|
||||
if output.dtype.is_signed and output.is_clear:
|
||||
return "only encrypted signed integer inputs are supported"
|
||||
|
||||
else:
|
||||
assert_that(node.operation == Operation.Generic)
|
||||
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer operations are supported"
|
||||
|
||||
name = node.properties["name"]
|
||||
|
||||
if name == "add":
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "array":
|
||||
assert_that(len(inputs) > 0)
|
||||
assert_that(all(input.is_scalar for input in inputs))
|
||||
|
||||
elif name == "assign.static":
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only assignment to encrypted tensors are supported"
|
||||
|
||||
elif name in ["bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift"]:
|
||||
assert_that(len(inputs) == 2)
|
||||
if all(value.is_encrypted for value in node.inputs):
|
||||
pred_nodes = graph.ordered_preds_of(node)
|
||||
if (
|
||||
name in ["left_shift", "right_shift"]
|
||||
and cast(Integer, pred_nodes[1].output.dtype).bit_width > 4
|
||||
):
|
||||
return "only up to 4-bit shifts are supported"
|
||||
|
||||
for pred_node in pred_nodes:
|
||||
assert isinstance(pred_node.output.dtype, Integer)
|
||||
if pred_node.output.dtype.is_signed:
|
||||
return "only unsigned bitwise operations are supported"
|
||||
|
||||
elif name == "broadcast_to":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted broadcasting is supported"
|
||||
|
||||
elif name == "concatenate":
|
||||
if not all(input.is_encrypted for input in inputs):
|
||||
return "only all encrypted concatenate is supported"
|
||||
|
||||
elif name in ["conv1d", "conv2d", "conv3d"]:
|
||||
assert_that(len(inputs) == 2 or len(inputs) == 3)
|
||||
if not (inputs[0].is_encrypted and inputs[1].is_clear):
|
||||
return f"only {name} with encrypted input and clear weight is supported"
|
||||
|
||||
elif name == "dot":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only dot product between encrypted and clear is supported"
|
||||
|
||||
elif name in ["equal", "greater", "greater_equal", "less", "less_equal", "not_equal"]:
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "expand_dims":
|
||||
assert_that(len(inputs) == 1)
|
||||
|
||||
elif name == "index.static":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted indexing supported"
|
||||
|
||||
elif name == "matmul":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only matrix multiplication between encrypted and clear is supported"
|
||||
|
||||
elif name == "maxpool":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted maxpool is supported"
|
||||
|
||||
elif name == "multiply":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only multiplication between encrypted and clear is supported"
|
||||
|
||||
elif name == "negative":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted negation is supported"
|
||||
|
||||
elif name == "ones":
|
||||
assert_that(len(inputs) == 0)
|
||||
|
||||
elif name == "reshape":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted reshape is supported"
|
||||
|
||||
elif name == "squeeze":
|
||||
assert_that(len(inputs) == 1)
|
||||
|
||||
elif name == "subtract":
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "sum":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted sum is supported"
|
||||
|
||||
elif name == "transpose":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted transpose is supported"
|
||||
|
||||
elif name == "zeros":
|
||||
assert_that(len(inputs) == 0)
|
||||
|
||||
else:
|
||||
assert_that(node.converted_to_table_lookup)
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, pred in enumerate(graph.ordered_preds_of(node))
|
||||
if not pred.operation == Operation.Constant
|
||||
]
|
||||
assert_that(len(variable_input_indices) == 1)
|
||||
|
||||
if len(inputs) > 0 and all(input.is_clear for input in inputs):
|
||||
return "one of the operands must be encrypted"
|
||||
|
||||
return None
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements,too-many-statements
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_convertibility(graph: Graph):
|
||||
"""
|
||||
Check graph convertibility to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be checked
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if `graph` is not convertible to MLIR
|
||||
"""
|
||||
|
||||
offending_nodes = {}
|
||||
|
||||
if len(graph.output_nodes) > 1:
|
||||
offending_nodes.update(
|
||||
{
|
||||
node: ["only a single output is supported", node.location]
|
||||
for node in graph.output_nodes.values()
|
||||
}
|
||||
)
|
||||
|
||||
if len(offending_nodes) == 0:
|
||||
for node in graph.graph.nodes:
|
||||
reason = GraphConverter._check_node_convertibility(graph, node)
|
||||
if reason is not None:
|
||||
offending_nodes[node] = [reason, node.location]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
message = (
|
||||
"Function you are trying to compile cannot be converted to MLIR\n\n"
|
||||
+ graph.format(highlighted_nodes=offending_nodes)
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
|
||||
@staticmethod
|
||||
def _update_bit_widths(graph: Graph):
|
||||
"""
|
||||
Update bit-widths in a computation graph to be convertible to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be updated
|
||||
"""
|
||||
|
||||
offending_nodes: Dict[Node, List[str]] = {}
|
||||
|
||||
max_bit_width = 0
|
||||
max_bit_width_node = None
|
||||
|
||||
first_tlu_node = None
|
||||
first_signed_node = None
|
||||
|
||||
for node in nx.lexicographical_topological_sort(graph.graph):
|
||||
dtype = node.output.dtype
|
||||
assert_that(isinstance(dtype, Integer))
|
||||
|
||||
current_node_bit_width = (
|
||||
dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width
|
||||
)
|
||||
if (
|
||||
all(value.is_encrypted for value in node.inputs)
|
||||
and node.operation == Operation.Generic
|
||||
and node.properties["name"]
|
||||
in [
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"less",
|
||||
"less_equal",
|
||||
]
|
||||
):
|
||||
# implementation of these operators require at least 4 bits
|
||||
current_node_bit_width = max(current_node_bit_width, 4)
|
||||
|
||||
if max_bit_width < current_node_bit_width:
|
||||
max_bit_width = current_node_bit_width
|
||||
max_bit_width_node = node
|
||||
|
||||
if node.converted_to_table_lookup and first_tlu_node is None:
|
||||
first_tlu_node = node
|
||||
|
||||
if dtype.is_signed and first_signed_node is None:
|
||||
first_signed_node = node
|
||||
|
||||
if first_tlu_node is not None:
|
||||
if max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
assert max_bit_width_node is not None
|
||||
offending_nodes[max_bit_width_node] = [
|
||||
(
|
||||
{
|
||||
Operation.Input: f"this input is {max_bit_width}-bits",
|
||||
Operation.Constant: f"this constant is {max_bit_width}-bits",
|
||||
Operation.Generic: f"this operation results in {max_bit_width}-bits",
|
||||
}[max_bit_width_node.operation]
|
||||
),
|
||||
max_bit_width_node.location,
|
||||
]
|
||||
offending_nodes[first_tlu_node] = [
|
||||
f"table lookups are only supported on circuits with "
|
||||
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bits",
|
||||
first_tlu_node.location,
|
||||
]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
"Function you are trying to compile cannot be converted to MLIR:\n\n"
|
||||
+ graph.format(highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
for node in nx.topological_sort(graph.graph):
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
node.properties["original_bit_width"] = node.output.dtype.bit_width
|
||||
|
||||
for value in node.inputs + [node.output]:
|
||||
dtype = value.dtype
|
||||
assert_that(isinstance(dtype, Integer))
|
||||
dtype.bit_width = max_bit_width + 1 if value.is_clear else max_bit_width
|
||||
|
||||
@staticmethod
|
||||
def _offset_negative_lookup_table_inputs(graph: Graph):
|
||||
"""
|
||||
Offset negative table lookup inputs to be convertible to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to apply offset
|
||||
"""
|
||||
|
||||
# ugly hack to add an offset before entering a TLU
|
||||
# if its variable input node has a signed output.
|
||||
# this makes hardcoded assumptions about the way bit widths are handled in MLIR.
|
||||
# this does not update the TLU input values to allow for proper table generation.
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic:
|
||||
if not node.converted_to_table_lookup:
|
||||
continue
|
||||
|
||||
variable_input_index = -1
|
||||
|
||||
preds = graph.ordered_preds_of(node)
|
||||
for index, pred in enumerate(preds):
|
||||
if pred.operation != Operation.Constant:
|
||||
variable_input_index = index
|
||||
break
|
||||
|
||||
variable_input_node = preds[variable_input_index]
|
||||
|
||||
variable_input_value = variable_input_node.output
|
||||
variable_input_dtype = variable_input_value.dtype
|
||||
|
||||
assert_that(isinstance(variable_input_dtype, Integer))
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
|
||||
if not variable_input_dtype.is_signed:
|
||||
continue
|
||||
|
||||
variable_input_bit_width = variable_input_dtype.bit_width
|
||||
offset_constant_dtype = SignedInteger(variable_input_bit_width + 1)
|
||||
|
||||
offset_constant_value = abs(variable_input_dtype.min())
|
||||
|
||||
offset_constant = Node.constant(offset_constant_value)
|
||||
offset_constant.output.dtype = offset_constant_dtype
|
||||
|
||||
original_bit_width = Integer.that_can_represent(offset_constant_value).bit_width
|
||||
offset_constant.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
add_offset = Node.generic(
|
||||
"add",
|
||||
[variable_input_value, ClearScalar(offset_constant_dtype)],
|
||||
variable_input_value,
|
||||
np.add,
|
||||
)
|
||||
|
||||
original_bit_width = variable_input_node.properties["original_bit_width"]
|
||||
add_offset.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
nx_graph.remove_edge(variable_input_node, node)
|
||||
|
||||
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0)
|
||||
nx_graph.add_edge(offset_constant, add_offset, input_idx=1)
|
||||
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_assignments(graph: Graph):
|
||||
"""
|
||||
Broadcast assignments.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to transform
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
|
||||
shape = node.inputs[0].shape
|
||||
index = node.properties["kwargs"]["index"]
|
||||
|
||||
assert_that(isinstance(index, tuple))
|
||||
while len(index) < len(shape):
|
||||
index = (*index, slice(None, None, None))
|
||||
|
||||
required_value_shape_list = []
|
||||
|
||||
for i, indexing_element in enumerate(index):
|
||||
if isinstance(indexing_element, slice):
|
||||
n = len(np.zeros(shape[i])[indexing_element])
|
||||
required_value_shape_list.append(n)
|
||||
else:
|
||||
required_value_shape_list.append(1)
|
||||
|
||||
required_value_shape = tuple(required_value_shape_list)
|
||||
actual_value_shape = node.inputs[1].shape
|
||||
|
||||
if required_value_shape != actual_value_shape:
|
||||
preds = graph.ordered_preds_of(node)
|
||||
pred_to_modify = preds[1]
|
||||
|
||||
modified_value = deepcopy(pred_to_modify.output)
|
||||
modified_value.shape = required_value_shape
|
||||
|
||||
try:
|
||||
np.broadcast_to(np.zeros(actual_value_shape), required_value_shape)
|
||||
modified_value.is_encrypted = True
|
||||
modified_value.dtype = node.output.dtype
|
||||
modified_pred = Node.generic(
|
||||
"broadcast_to",
|
||||
[pred_to_modify.output],
|
||||
modified_value,
|
||||
np.broadcast_to,
|
||||
kwargs={"shape": required_value_shape},
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
np.reshape(np.zeros(actual_value_shape), required_value_shape)
|
||||
modified_pred = Node.generic(
|
||||
"reshape",
|
||||
[pred_to_modify.output],
|
||||
modified_value,
|
||||
np.reshape,
|
||||
kwargs={"newshape": required_value_shape},
|
||||
)
|
||||
|
||||
modified_pred.properties["original_bit_width"] = pred_to_modify.properties[
|
||||
"original_bit_width"
|
||||
]
|
||||
|
||||
nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0)
|
||||
|
||||
nx_graph.remove_edge(pred_to_modify, node)
|
||||
nx_graph.add_edge(modified_pred, node, input_idx=1)
|
||||
|
||||
node.inputs[1] = modified_value
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_clear_assignments(graph: Graph):
|
||||
"""
|
||||
Encrypt clear assignments.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to transform
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
|
||||
assigned_value = node.inputs[1]
|
||||
if assigned_value.is_clear:
|
||||
preds = graph.ordered_preds_of(node)
|
||||
assigned_pred = preds[1]
|
||||
|
||||
new_assigned_pred_value = deepcopy(assigned_value)
|
||||
new_assigned_pred_value.is_encrypted = True
|
||||
new_assigned_pred_value.dtype = preds[0].output.dtype
|
||||
|
||||
zero = Node.generic(
|
||||
"zeros",
|
||||
[],
|
||||
EncryptedScalar(new_assigned_pred_value.dtype),
|
||||
lambda: np.zeros((), dtype=np.int64),
|
||||
)
|
||||
|
||||
original_bit_width = 1
|
||||
zero.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
new_assigned_pred = Node.generic(
|
||||
"add",
|
||||
[assigned_pred.output, zero.output],
|
||||
new_assigned_pred_value,
|
||||
np.add,
|
||||
)
|
||||
|
||||
original_bit_width = assigned_pred.properties["original_bit_width"]
|
||||
new_assigned_pred.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
nx_graph.remove_edge(preds[1], node)
|
||||
|
||||
nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0)
|
||||
nx_graph.add_edge(zero, new_assigned_pred, input_idx=1)
|
||||
|
||||
nx_graph.add_edge(new_assigned_pred, node, input_idx=1)
|
||||
|
||||
@staticmethod
|
||||
def _tensorize_scalars_for_fhelinalg(graph: Graph):
|
||||
"""
|
||||
Tensorize scalars if they are used within fhelinalg operations.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to update
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
OPS_TO_TENSORIZE = [
|
||||
"add",
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"broadcast_to",
|
||||
"dot",
|
||||
"equal",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"left_shift",
|
||||
"less",
|
||||
"less_equal",
|
||||
"multiply",
|
||||
"not_equal",
|
||||
"right_shift",
|
||||
"subtract",
|
||||
]
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
tensorized_scalars: Dict[Node, Node] = {}
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE:
|
||||
assert len(node.inputs) in {1, 2}
|
||||
|
||||
if len(node.inputs) == 2:
|
||||
if {inp.is_scalar for inp in node.inputs} != {True, False}:
|
||||
continue
|
||||
else:
|
||||
if not node.inputs[0].is_scalar:
|
||||
continue
|
||||
|
||||
# for bitwise and comparison operators that can have constants
|
||||
# we don't need broadcasting here
|
||||
if node.converted_to_table_lookup:
|
||||
continue
|
||||
|
||||
pred_to_tensorize: Optional[Node] = None
|
||||
pred_to_tensorize_index = 0
|
||||
|
||||
preds = graph.ordered_preds_of(node)
|
||||
for index, pred in enumerate(preds):
|
||||
if pred.output.is_scalar:
|
||||
pred_to_tensorize = pred
|
||||
pred_to_tensorize_index = index
|
||||
break
|
||||
|
||||
assert pred_to_tensorize is not None
|
||||
|
||||
tensorized_pred = tensorized_scalars.get(pred_to_tensorize)
|
||||
if tensorized_pred is None:
|
||||
tensorized_value = deepcopy(pred_to_tensorize.output)
|
||||
tensorized_value.shape = (1,)
|
||||
|
||||
tensorized_pred = Node.generic(
|
||||
"array",
|
||||
[pred_to_tensorize.output],
|
||||
tensorized_value,
|
||||
lambda *args: np.array(args),
|
||||
)
|
||||
|
||||
original_bit_width = pred_to_tensorize.properties["original_bit_width"]
|
||||
tensorized_pred.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
original_shape = ()
|
||||
tensorized_pred.properties["original_shape"] = original_shape
|
||||
|
||||
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
|
||||
tensorized_scalars[pred_to_tensorize] = tensorized_pred
|
||||
|
||||
assert tensorized_pred is not None
|
||||
|
||||
nx_graph.remove_edge(pred_to_tensorize, node)
|
||||
nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index)
|
||||
|
||||
new_input_value = deepcopy(node.inputs[pred_to_tensorize_index])
|
||||
new_input_value.shape = (1,)
|
||||
node.inputs[pred_to_tensorize_index] = new_input_value
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]:
|
||||
"""
|
||||
Use subtraction to sanitize signed inputs.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph being converted
|
||||
|
||||
args (List[Any]):
|
||||
list of arguments from mlir main
|
||||
|
||||
ctx (Context):
|
||||
mlir context where the conversion is being performed
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[Any]]:
|
||||
sanitized args and name of the sanitized variables in MLIR
|
||||
"""
|
||||
|
||||
sanitized_args = []
|
||||
for i, arg in enumerate(args):
|
||||
input_node = graph.input_nodes[i]
|
||||
input_value = input_node.output
|
||||
|
||||
assert_that(isinstance(input_value.dtype, Integer))
|
||||
input_dtype = cast(Integer, input_value.dtype)
|
||||
|
||||
if input_dtype.is_signed:
|
||||
assert_that(input_value.is_encrypted)
|
||||
n = input_dtype.bit_width
|
||||
|
||||
sanitizer_type = IntegerType.get_signless(n + 1)
|
||||
sanitizer = 2 ** (n - 1)
|
||||
|
||||
if input_value.is_scalar:
|
||||
sanitizer_attr = IntegerAttr.get(sanitizer_type, sanitizer)
|
||||
else:
|
||||
sanitizer_type = RankedTensorType.get((1,), sanitizer_type)
|
||||
sanitizer_attr = Attribute.parse(f"dense<[{sanitizer}]> : {sanitizer_type}")
|
||||
|
||||
# pylint: disable=too-many-function-args
|
||||
sanitizer_cst = arith.ConstantOp(sanitizer_type, sanitizer_attr)
|
||||
# pylint: enable=too-many-function-args
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value)
|
||||
if input_value.is_scalar:
|
||||
sanitized = fhe.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
|
||||
else:
|
||||
sanitized = fhelinalg.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
|
||||
|
||||
sanitized_args.append(sanitized)
|
||||
else:
|
||||
sanitized_args.append(arg)
|
||||
|
||||
return sanitized_args
|
||||
|
||||
@staticmethod
|
||||
def convert(graph: Graph) -> str:
|
||||
"""
|
||||
Convert a computation graph to its corresponding MLIR representation.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be converted
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual MLIR representation corresponding to `graph`
|
||||
"""
|
||||
|
||||
graph = deepcopy(graph)
|
||||
|
||||
GraphConverter._check_graph_convertibility(graph)
|
||||
GraphConverter._update_bit_widths(graph)
|
||||
GraphConverter._offset_negative_lookup_table_inputs(graph)
|
||||
GraphConverter._broadcast_assignments(graph)
|
||||
GraphConverter._encrypt_clear_assignments(graph)
|
||||
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
|
||||
|
||||
from_elements_operations: Dict[OpResult, List[OpResult]] = {}
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
parameters = [
|
||||
NodeConverter.value_to_mlir_type(ctx, input_node.output)
|
||||
for input_node in graph.ordered_inputs()
|
||||
]
|
||||
|
||||
@func.FuncOp.from_py_func(*parameters)
|
||||
def main(*args):
|
||||
sanitized_args = GraphConverter._sanitize_signed_inputs(graph, args, ctx)
|
||||
|
||||
ir_to_mlir = {}
|
||||
for arg_num, node in graph.input_nodes.items():
|
||||
ir_to_mlir[node] = sanitized_args[arg_num]
|
||||
|
||||
constant_cache = {}
|
||||
for node in nx.topological_sort(graph.graph):
|
||||
if node.operation == Operation.Input:
|
||||
continue
|
||||
|
||||
preds = [ir_to_mlir[pred] for pred in graph.ordered_preds_of(node)]
|
||||
node_converter = NodeConverter(
|
||||
ctx,
|
||||
graph,
|
||||
node,
|
||||
preds,
|
||||
constant_cache,
|
||||
from_elements_operations,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert()
|
||||
|
||||
results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs())
|
||||
return results
|
||||
|
||||
direct_replacements = {}
|
||||
for placeholder, elements in from_elements_operations.items():
|
||||
element_names = [NodeConverter.mlir_name(element) for element in elements]
|
||||
actual_value = f"tensor.from_elements {', '.join(element_names)} : {placeholder.type}"
|
||||
direct_replacements[NodeConverter.mlir_name(placeholder)] = actual_value
|
||||
|
||||
module_lines_after_hacks_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
if mlir_name not in direct_replacements:
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
continue
|
||||
|
||||
new_value = direct_replacements[mlir_name]
|
||||
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
|
||||
|
||||
return "\n".join(module_lines_after_hacks_are_applied).strip()
|
||||
1866
frontends/concrete-python/concrete/numpy/mlir/node_converter.py
Normal file
171
frontends/concrete-python/concrete/numpy/mlir/utils.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to MLIR conversion.
|
||||
"""
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, DefaultDict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Node, Operation
|
||||
|
||||
MAXIMUM_TLU_BIT_WIDTH = 16
|
||||
|
||||
|
||||
class HashableNdarray:
|
||||
"""
|
||||
HashableNdarray class, to use numpy arrays in dictionaries.
|
||||
"""
|
||||
|
||||
array: np.ndarray
|
||||
|
||||
def __init__(self, array: np.ndarray):
|
||||
self.array = array
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, HashableNdarray) and np.array_equal(self.array, other.array)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.array.tobytes())
|
||||
|
||||
|
||||
def flood_replace_none_values(table: list):
|
||||
"""
|
||||
Use flooding algorithm to replace `None` values.
|
||||
|
||||
Args:
|
||||
table (list):
|
||||
the list in which there are `None` values that need to be replaced
|
||||
with copies of the closest non `None` data from the list
|
||||
"""
|
||||
|
||||
assert_that(any(value is not None for value in table))
|
||||
|
||||
not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None)
|
||||
while not_none_values_idx:
|
||||
|
||||
current_idx = not_none_values_idx.popleft()
|
||||
current_value = table[current_idx]
|
||||
|
||||
previous_idx = current_idx - 1
|
||||
next_idx = current_idx + 1
|
||||
|
||||
if previous_idx >= 0 and table[previous_idx] is None:
|
||||
table[previous_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(previous_idx)
|
||||
|
||||
if next_idx < len(table) and table[next_idx] is None:
|
||||
table[next_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(next_idx)
|
||||
|
||||
assert_that(all(value is not None for value in table))
|
||||
|
||||
|
||||
def construct_table(node: Node, preds: List[Node]) -> List[Any]:
|
||||
"""
|
||||
Construct the lookup table for an Operation.Generic node.
|
||||
|
||||
Args:
|
||||
node (Node):
|
||||
Operation.Generic to construct the table
|
||||
|
||||
preds (List[Node]):
|
||||
ordered predecessors to `node`
|
||||
|
||||
Returns:
|
||||
List[Any]:
|
||||
lookup table corresponding to `node` and its input value
|
||||
"""
|
||||
|
||||
variable_input_index = -1
|
||||
for index, pred in enumerate(preds):
|
||||
if pred.operation != Operation.Constant:
|
||||
variable_input_index = index
|
||||
break
|
||||
assert_that(variable_input_index != -1)
|
||||
|
||||
variable_input_dtype = node.inputs[variable_input_index].dtype
|
||||
variable_input_shape = node.inputs[variable_input_index].shape
|
||||
|
||||
assert_that(isinstance(variable_input_dtype, Integer))
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
|
||||
inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
|
||||
|
||||
table: List[Optional[Union[np.bool_, np.integer, np.floating, np.ndarray]]] = []
|
||||
for value in range(variable_input_dtype.min(), variable_input_dtype.max() + 1):
|
||||
try:
|
||||
inputs[variable_input_index] = np.ones(variable_input_shape, dtype=np.int64) * value
|
||||
table.append(node(*inputs))
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to fill the table
|
||||
# if it fails, we append None and let flooding algoritm replace None values below
|
||||
table.append(None)
|
||||
|
||||
flood_replace_none_values(table)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def construct_deduplicated_tables(
|
||||
node: Node,
|
||||
preds: List[Node],
|
||||
) -> Tuple[Tuple[np.ndarray, List[Tuple[int, ...]]], ...]:
|
||||
"""
|
||||
Construct lookup tables for each cell of the input for an Operation.Generic node.
|
||||
|
||||
Args:
|
||||
node (Node):
|
||||
Operation.Generic to construct the table
|
||||
|
||||
preds (List[Node]):
|
||||
ordered predecessors to `node`
|
||||
|
||||
Returns:
|
||||
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
|
||||
tuple containing tuples of 2 for
|
||||
- constructed table
|
||||
- list of indices of the input that use the constructed table
|
||||
|
||||
e.g.,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
(
|
||||
(np.array([3, 1, 2, 4]), [(1, 0), (2, 1)]),
|
||||
(np.array([5, 8, 6, 7]), [(0, 0), (0, 1), (1, 1), (2, 0)]),
|
||||
)
|
||||
|
||||
means the lookup on 3x2 input will result in
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[ [5, 8, 6, 7][input[0, 0]] , [5, 8, 6, 7][input[0, 1]] ]
|
||||
[ [3, 1, 2, 4][input[1, 0]] , [5, 8, 6, 7][input[1, 1]] ]
|
||||
[ [5, 8, 6, 7][input[2, 0]] , [3, 1, 2, 4][input[2, 1]] ]
|
||||
"""
|
||||
|
||||
node_complete_table = np.concatenate(
|
||||
tuple(np.expand_dims(array, -1) for array in construct_table(node, preds)),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
|
||||
tables_to_cell_idx: DefaultDict[HashableNdarray, List[Tuple[int, ...]]] = defaultdict(list)
|
||||
|
||||
idx: Tuple[int, ...]
|
||||
all_idx_set = set()
|
||||
for idx in all_cells_idx:
|
||||
hashable_array = HashableNdarray(node_complete_table[idx])
|
||||
tables_to_cell_idx[hashable_array].append(idx)
|
||||
all_idx_set.add(idx)
|
||||
|
||||
assert_that(len(all_idx_set) == np.prod(node_complete_table.shape[:-1]))
|
||||
|
||||
return tuple(
|
||||
(hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items()
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Define structures used to represent computation.
|
||||
"""
|
||||
|
||||
from .graph import Graph
|
||||
from .node import Node
|
||||
from .operation import Operation
|
||||
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Declaration of various `Evaluator` classes, to make graphs picklable.
|
||||
"""
|
||||
|
||||
# ruff: noqa: ARG002
|
||||
|
||||
|
||||
class ConstantEvaluator:
|
||||
"""
|
||||
ConstantEvaluator class, to evaluate Operation.Constant nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, properties):
|
||||
self.properties = properties
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.properties["constant"]
|
||||
|
||||
|
||||
class InputEvaluator:
|
||||
"""
|
||||
InputEvaluator class, to evaluate Operation.Input nodes.
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return args[0]
|
||||
|
||||
|
||||
class GenericEvaluator:
|
||||
"""
|
||||
GenericEvaluator class, to evaluate Operation.Generic nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, operation, properties):
|
||||
self.operation = operation
|
||||
self.properties = properties
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.operation(*args, *self.properties["args"], **self.properties["kwargs"])
|
||||
|
||||
|
||||
class GenericTupleEvaluator:
|
||||
"""
|
||||
GenericEvaluator class, to evaluate Operation.Generic nodes where args are packed in a tuple.
|
||||
"""
|
||||
|
||||
def __init__(self, operation, properties):
|
||||
self.operation = operation
|
||||
self.properties = properties
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.operation(tuple(args), *self.properties["args"], **self.properties["kwargs"])
|
||||
692
frontends/concrete-python/concrete/numpy/representation/graph.py
Normal file
@@ -0,0 +1,692 @@
|
||||
"""
|
||||
Declaration of `Graph` class.
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import scipy.special
|
||||
|
||||
from ..dtypes import Float, Integer, UnsignedInteger
|
||||
from .node import Node
|
||||
from .operation import Operation
|
||||
|
||||
P_ERROR_PER_ERROR_SIZE_CACHE: Dict[float, Dict[int, float]] = {}
|
||||
|
||||
|
||||
class Graph:
|
||||
"""
|
||||
Graph class, to represent computation graphs.
|
||||
"""
|
||||
|
||||
graph: nx.MultiDiGraph
|
||||
|
||||
input_nodes: Dict[int, Node]
|
||||
output_nodes: Dict[int, Node]
|
||||
|
||||
input_indices: Dict[Node, int]
|
||||
|
||||
is_direct: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Dict[int, Node],
|
||||
output_nodes: Dict[int, Node],
|
||||
is_direct: bool = False,
|
||||
):
|
||||
self.graph = graph
|
||||
|
||||
self.input_nodes = input_nodes
|
||||
self.output_nodes = output_nodes
|
||||
|
||||
self.input_indices = {node: index for index, node in input_nodes.items()}
|
||||
|
||||
self.is_direct = is_direct
|
||||
|
||||
self.prune_useless_nodes()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
p_error: Optional[float] = None,
|
||||
) -> Union[
|
||||
np.bool_,
|
||||
np.integer,
|
||||
np.floating,
|
||||
np.ndarray,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
||||
]:
|
||||
evaluation = self.evaluate(*args, p_error=p_error)
|
||||
result = tuple(evaluation[node] for node in self.ordered_outputs())
|
||||
return result if len(result) > 1 else result[0]
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
*args: Any,
|
||||
p_error: Optional[float] = None,
|
||||
) -> Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]]:
|
||||
r"""
|
||||
Perform the computation `Graph` represents and get resulting values for all nodes.
|
||||
|
||||
Args:
|
||||
*args (List[Any]):
|
||||
inputs to the computation
|
||||
|
||||
p_error (Optional[float]):
|
||||
probability of error for table lookups
|
||||
|
||||
Returns:
|
||||
Dict[Node, Union[np.bool\_, np.integer, np.floating, np.ndarray]]:
|
||||
nodes and their values during computation
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member,too-many-nested-blocks,too-many-branches,too-many-statements
|
||||
|
||||
if p_error is None:
|
||||
p_error = 0.0
|
||||
|
||||
assert isinstance(p_error, float)
|
||||
|
||||
node_results: Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]] = {}
|
||||
for node in nx.topological_sort(self.graph):
|
||||
if node.operation == Operation.Input:
|
||||
node_results[node] = node(args[self.input_indices[node]])
|
||||
continue
|
||||
|
||||
pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)]
|
||||
|
||||
if p_error > 0.0 and node.converted_to_table_lookup:
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, pred in enumerate(self.ordered_preds_of(node))
|
||||
if not pred.operation == Operation.Constant
|
||||
]
|
||||
|
||||
for index in variable_input_indices:
|
||||
pred_node = self.ordered_preds_of(node)[index]
|
||||
if pred_node.operation != Operation.Input:
|
||||
dtype = node.inputs[index].dtype
|
||||
if isinstance(dtype, Integer):
|
||||
# see https://github.com/zama-ai/concrete-numpy/blob/main/docs/_static/p_error_simulation.pdf # noqa: E501 # pylint: disable=line-too-long
|
||||
# to learn more about the distribution of error
|
||||
|
||||
if p_error not in P_ERROR_PER_ERROR_SIZE_CACHE:
|
||||
std_score = math.sqrt(2) * scipy.special.erfcinv(p_error)
|
||||
p_error_per_error_size = {}
|
||||
|
||||
error_size = 1
|
||||
last_p = 1 - p_error
|
||||
while last_p != 1.0 or error_size == 1:
|
||||
new_std_score = (2 * error_size + 1) * std_score
|
||||
new_p = scipy.special.erf(new_std_score / math.sqrt(2))
|
||||
|
||||
p_error_per_error_size[error_size] = new_p - last_p
|
||||
|
||||
last_p = new_p
|
||||
error_size += 1
|
||||
|
||||
# ordering of `p_error_per_error_size` is relied on
|
||||
# during the introduction of the error below
|
||||
# thus we explicitly sort it to make sure it's ordered
|
||||
p_error_per_error_size = dict(
|
||||
sorted(p_error_per_error_size.items())
|
||||
)
|
||||
|
||||
P_ERROR_PER_ERROR_SIZE_CACHE[p_error] = p_error_per_error_size
|
||||
else: # pragma: no cover
|
||||
p_error_per_error_size = P_ERROR_PER_ERROR_SIZE_CACHE[p_error]
|
||||
|
||||
error = np.random.rand(*pred_results[index].shape)
|
||||
|
||||
accumulated_p_error = 0.0
|
||||
for error_size, p_error_for_size in p_error_per_error_size.items():
|
||||
accumulated_p_error += p_error_for_size
|
||||
error = np.where(error < accumulated_p_error, error_size, error)
|
||||
|
||||
error = np.where(error < 1, 0, error).astype(np.int64)
|
||||
|
||||
error_sign = np.random.rand(*pred_results[index].shape)
|
||||
error_sign = np.where(error_sign < 0.5, 1, -1).astype(np.int64)
|
||||
|
||||
new_result = pred_results[index] + (error * error_sign)
|
||||
|
||||
if new_result.shape == (): # pragma: no cover
|
||||
if new_result < dtype.min():
|
||||
new_result = dtype.max() - (dtype.min() - new_result) + 1
|
||||
elif new_result > dtype.max():
|
||||
new_result = dtype.min() - (new_result - dtype.max()) - 1
|
||||
|
||||
else:
|
||||
underflow_indices = np.where(new_result < dtype.min())
|
||||
new_result[underflow_indices] = (
|
||||
dtype.max() - (dtype.min() - new_result[underflow_indices]) + 1
|
||||
)
|
||||
|
||||
overflow_indices = np.where(new_result > dtype.max())
|
||||
new_result[overflow_indices] = (
|
||||
dtype.min() + (new_result[overflow_indices] - dtype.max()) - 1
|
||||
)
|
||||
|
||||
pred_results[index] = new_result
|
||||
|
||||
try:
|
||||
node_results[node] = node(*pred_results)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
"Evaluation of the graph failed\n\n"
|
||||
+ self.format(
|
||||
highlighted_nodes={node: ["evaluation of this node failed"]},
|
||||
show_bounds=False,
|
||||
)
|
||||
) from error
|
||||
|
||||
return node_results
|
||||
|
||||
def format(
|
||||
self,
|
||||
maximum_constant_length: int = 25,
|
||||
highlighted_nodes: Optional[Dict[Node, List[str]]] = None,
|
||||
show_types: bool = True,
|
||||
show_bounds: bool = True,
|
||||
show_tags: bool = True,
|
||||
show_locations: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Get the textual representation of the `Graph`.
|
||||
|
||||
Args:
|
||||
maximum_constant_length (int, default = 25):
|
||||
maximum length of formatted constants
|
||||
|
||||
highlighted_nodes (Optional[Dict[Node, List[str]]], default = None):
|
||||
nodes to be highlighted and their corresponding messages
|
||||
|
||||
show_types (bool, default = True):
|
||||
whether to show types of nodes
|
||||
|
||||
show_bounds (bool, default = True):
|
||||
whether to show bounds of nodes
|
||||
|
||||
show_tags (bool, default = True):
|
||||
whether to show tags of nodes
|
||||
|
||||
show_locations (bool, default = False):
|
||||
whether to show line information of nodes
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of the `Graph`
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
||||
# ruff: noqa: ERA001
|
||||
|
||||
if self.is_direct:
|
||||
show_bounds = False
|
||||
|
||||
# node -> identifier
|
||||
# e.g., id_map[node1] = 2
|
||||
# means line for node1 is in this form %2 = node1.format(...)
|
||||
id_map: Dict[Node, int] = {}
|
||||
|
||||
# lines that will be merged at the end
|
||||
lines: List[str] = []
|
||||
|
||||
# metadata to add to each line
|
||||
# (for alignment, this is done after lines are determined)
|
||||
line_metadata: List[Dict[str, str]] = []
|
||||
|
||||
# default highlighted nodes is empty
|
||||
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
|
||||
|
||||
# highlight information for lines, this is required because highlights are added to lines
|
||||
# after their type information is added, and we only have line numbers, not nodes
|
||||
highlighted_lines: Dict[int, List[str]] = {}
|
||||
|
||||
# subgraphs to format after the main graph is formatted
|
||||
subgraphs: Dict[str, Graph] = {}
|
||||
|
||||
# format nodes
|
||||
for node in nx.lexicographical_topological_sort(self.graph):
|
||||
# assign a unique id to outputs of node
|
||||
id_map[node] = len(id_map)
|
||||
|
||||
# remember highlights of the node
|
||||
if node in highlighted_nodes:
|
||||
highlighted_lines[len(lines)] = highlighted_nodes[node]
|
||||
|
||||
# extract predecessors and their ids
|
||||
predecessors = []
|
||||
for predecessor in self.ordered_preds_of(node):
|
||||
predecessors.append(f"%{id_map[predecessor]}")
|
||||
|
||||
# start the build the line for the node
|
||||
line = ""
|
||||
|
||||
# add output information to the line
|
||||
line += f"%{id_map[node]}"
|
||||
|
||||
# add node information to the line
|
||||
line += " = "
|
||||
line += node.format(predecessors, maximum_constant_length)
|
||||
|
||||
# append line to list of lines
|
||||
lines.append(line)
|
||||
|
||||
# if exists, save the subgraph
|
||||
if node.operation == Operation.Generic and "subgraph" in node.properties["kwargs"]:
|
||||
subgraphs[line] = node.properties["kwargs"]["subgraph"]
|
||||
|
||||
# get formatted bounds
|
||||
bounds = ""
|
||||
if node.bounds is not None:
|
||||
bounds += "∈ ["
|
||||
|
||||
lower, upper = node.bounds
|
||||
assert type(lower) == type(upper) # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
if isinstance(lower, (float, np.float32, np.float64)):
|
||||
bounds += f"{round(lower, 6)}, {round(upper, 6)}"
|
||||
else:
|
||||
bounds += f"{int(lower)}, {int(upper)}"
|
||||
|
||||
bounds += "]"
|
||||
|
||||
# remember metadata of the node
|
||||
line_metadata.append(
|
||||
{
|
||||
"type": f"# {node.output}",
|
||||
"bounds": bounds,
|
||||
"tag": (f"@ {node.tag}" if node.tag != "" else ""),
|
||||
"location": node.location,
|
||||
},
|
||||
)
|
||||
|
||||
# align = signs
|
||||
#
|
||||
# e.g.,
|
||||
#
|
||||
# %1 = ...
|
||||
# %2 = ...
|
||||
# ...
|
||||
# %8 = ...
|
||||
# %9 = ...
|
||||
# %10 = ...
|
||||
# %11 = ...
|
||||
# ...
|
||||
longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines)
|
||||
for i, line in enumerate(lines):
|
||||
length_before_equals_sign = len(line.split("=")[0])
|
||||
lines[i] = (
|
||||
" " * (longest_length_before_equals_sign - length_before_equals_sign)
|
||||
) + line
|
||||
|
||||
# determine which metadata to show
|
||||
shown_metadata_keys = []
|
||||
if show_types:
|
||||
shown_metadata_keys.append("type")
|
||||
if show_bounds:
|
||||
shown_metadata_keys.append("bounds")
|
||||
if show_tags:
|
||||
shown_metadata_keys.append("tag")
|
||||
if show_locations:
|
||||
shown_metadata_keys.append("location")
|
||||
|
||||
# show requested metadata
|
||||
indent = 8
|
||||
for metadata_key in shown_metadata_keys:
|
||||
longest_line_length = max(len(line) for line in lines)
|
||||
lines = [
|
||||
line + (" " * ((longest_line_length - len(line)) + indent)) + metadata[metadata_key]
|
||||
for line, metadata in zip(lines, line_metadata)
|
||||
]
|
||||
|
||||
# strip whitespaces
|
||||
lines = [line.rstrip() for line in lines]
|
||||
|
||||
# add highlights (this is done in reverse to keep indices consistent)
|
||||
for i in reversed(range(len(lines))):
|
||||
if i in highlighted_lines:
|
||||
for j, message in enumerate(highlighted_lines[i]):
|
||||
highlight = "^" if j == 0 else " "
|
||||
lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}")
|
||||
|
||||
# add return information
|
||||
# (if there is a single return, it's in the form `return %id`
|
||||
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
|
||||
returns: List[str] = []
|
||||
for node in self.output_nodes.values():
|
||||
returns.append(f"%{id_map[node]}")
|
||||
lines.append("return " + (returns[0] if len(returns) == 1 else f"({', '.join(returns)})"))
|
||||
|
||||
# format subgraphs after the actual graph
|
||||
result = "\n".join(lines)
|
||||
if len(subgraphs) > 0:
|
||||
result += "\n\n"
|
||||
result += "Subgraphs:"
|
||||
for line, subgraph in subgraphs.items():
|
||||
subgraph_lines = subgraph.format(
|
||||
maximum_constant_length=maximum_constant_length,
|
||||
highlighted_nodes={},
|
||||
show_types=show_types,
|
||||
show_bounds=False, # doesn't make sense as we don't measure bounds in subgraphs
|
||||
show_tags=show_tags,
|
||||
show_locations=show_locations,
|
||||
).split("\n")
|
||||
|
||||
result += "\n\n"
|
||||
result += f" {line}:\n\n"
|
||||
result += "\n".join(f" {line}" for line in subgraph_lines)
|
||||
|
||||
return result
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-locals,too-many-statements
|
||||
|
||||
def measure_bounds(
|
||||
self,
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
) -> Dict[Node, Dict[str, Union[np.integer, np.floating]]]:
|
||||
"""
|
||||
Evaluate the `Graph` using an inputset and measure bounds.
|
||||
|
||||
inputset is either an iterable of anything
|
||||
for a single parameter
|
||||
|
||||
or
|
||||
|
||||
an iterable of tuples of anything (of rank number of parameters)
|
||||
for multiple parameters
|
||||
|
||||
e.g.,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
inputset = [1, 3, 5, 2, 4]
|
||||
def f(x):
|
||||
...
|
||||
|
||||
inputset = [(1, 2), (2, 4), (3, 1), (2, 2)]
|
||||
def g(x, y):
|
||||
...
|
||||
|
||||
Args:
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
|
||||
inputset to use
|
||||
|
||||
Returns:
|
||||
Dict[Node, Dict[str, Union[np.integer, np.floating]]]:
|
||||
bounds of each node in the `Graph`
|
||||
"""
|
||||
|
||||
bounds = {}
|
||||
|
||||
inputset_iterator = iter(inputset)
|
||||
|
||||
sample = next(inputset_iterator)
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
index = 0
|
||||
try:
|
||||
evaluation = self.evaluate(*sample)
|
||||
for node, value in evaluation.items():
|
||||
bounds[node] = {
|
||||
"min": value.min(),
|
||||
"max": value.max(),
|
||||
}
|
||||
|
||||
for sample in inputset_iterator:
|
||||
index += 1
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
evaluation = self.evaluate(*sample)
|
||||
for node, value in evaluation.items():
|
||||
bounds[node] = {
|
||||
"min": np.minimum(bounds[node]["min"], value.min()),
|
||||
"max": np.maximum(bounds[node]["max"], value.max()),
|
||||
}
|
||||
|
||||
except Exception as error:
|
||||
message = f"Bound measurement using inputset[{index}] failed"
|
||||
raise RuntimeError(message) from error
|
||||
|
||||
return bounds
|
||||
|
||||
def update_with_bounds(self, bounds: Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
|
||||
"""
|
||||
Update `Value`s within the `Graph` according to measured bounds.
|
||||
|
||||
Args:
|
||||
bounds (Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
|
||||
bounds of each node in the `Graph`
|
||||
"""
|
||||
|
||||
for node in self.graph.nodes():
|
||||
if node in bounds:
|
||||
min_bound = bounds[node]["min"]
|
||||
max_bound = bounds[node]["max"]
|
||||
|
||||
node.bounds = (min_bound, max_bound)
|
||||
|
||||
new_value = deepcopy(node.output)
|
||||
|
||||
if isinstance(min_bound, np.integer):
|
||||
new_value.dtype = Integer.that_can_represent(np.array([min_bound, max_bound]))
|
||||
else:
|
||||
new_value.dtype = {
|
||||
np.bool_: UnsignedInteger(1),
|
||||
np.float64: Float(64),
|
||||
np.float32: Float(32),
|
||||
np.float16: Float(16),
|
||||
}[type(min_bound)]
|
||||
|
||||
node.output = new_value
|
||||
|
||||
if node.operation == Operation.Input:
|
||||
node.inputs[0] = new_value
|
||||
|
||||
for successor in self.graph.successors(node):
|
||||
edge_data = self.graph.get_edge_data(node, successor)
|
||||
for edge in edge_data.values():
|
||||
input_idx = edge["input_idx"]
|
||||
successor.inputs[input_idx] = node.output
|
||||
|
||||
def ordered_inputs(self) -> List[Node]:
|
||||
"""
|
||||
Get the input nodes of the `Graph`, ordered by their indices.
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
ordered input nodes
|
||||
"""
|
||||
|
||||
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
|
||||
|
||||
def ordered_outputs(self) -> List[Node]:
|
||||
"""
|
||||
Get the output nodes of the `Graph`, ordered by their indices.
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
ordered output nodes
|
||||
"""
|
||||
|
||||
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
|
||||
|
||||
def ordered_preds_of(self, node: Node) -> List[Node]:
|
||||
"""
|
||||
Get predecessors of `node`, ordered by their indices.
|
||||
|
||||
Args:
|
||||
node (Node):
|
||||
node whose predecessors are requested
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
ordered predecessors of `node`.
|
||||
"""
|
||||
|
||||
idx_to_pred: Dict[int, Node] = {}
|
||||
for pred in self.graph.predecessors(node):
|
||||
edge_data = self.graph.get_edge_data(pred, node)
|
||||
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
|
||||
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
|
||||
def prune_useless_nodes(self):
|
||||
"""
|
||||
Remove unreachable nodes from the graph.
|
||||
"""
|
||||
|
||||
useful_nodes: Dict[Node, None] = {}
|
||||
|
||||
current_nodes = {node: None for node in self.ordered_outputs()}
|
||||
while current_nodes:
|
||||
useful_nodes.update(current_nodes)
|
||||
|
||||
next_nodes: Dict[Node, None] = {}
|
||||
for node in current_nodes:
|
||||
next_nodes.update({node: None for node in self.graph.predecessors(node)})
|
||||
|
||||
current_nodes = next_nodes
|
||||
|
||||
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
|
||||
self.graph.remove_nodes_from(useless_nodes)
|
||||
|
||||
def query_nodes(
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
) -> List[Node]:
|
||||
"""
|
||||
Query nodes within the graph.
|
||||
|
||||
Filters work like so:
|
||||
str -> nodes without exact match is skipped
|
||||
List[str] -> nodes without exact match with one of the strings in the list is skipped
|
||||
re.Pattern -> nodes without pattern match is skipped
|
||||
|
||||
Args:
|
||||
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for tags
|
||||
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
Returns:
|
||||
List[Node]:
|
||||
filtered nodes
|
||||
"""
|
||||
|
||||
def match_text_filter(text_filter, text):
|
||||
if text_filter is None:
|
||||
return True
|
||||
|
||||
if isinstance(text_filter, str):
|
||||
return text == text_filter
|
||||
|
||||
if isinstance(text_filter, re.Pattern):
|
||||
return text_filter.match(text)
|
||||
|
||||
return any(text == alternative for alternative in text_filter)
|
||||
|
||||
def get_operation_name(node):
|
||||
result: str
|
||||
|
||||
if node.operation == Operation.Input:
|
||||
result = "input"
|
||||
elif node.operation == Operation.Constant:
|
||||
result = "constant"
|
||||
else:
|
||||
result = node.properties["name"]
|
||||
|
||||
return result
|
||||
|
||||
return [
|
||||
node
|
||||
for node in self.graph.nodes()
|
||||
if (
|
||||
match_text_filter(tag_filter, node.tag)
|
||||
and match_text_filter(operation_filter, get_operation_name(node))
|
||||
)
|
||||
]
|
||||
|
||||
def maximum_integer_bit_width(
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get maximum integer bit-width within the graph.
|
||||
|
||||
Only nodes after filtering will be used to calculate the result.
|
||||
|
||||
Args:
|
||||
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for tags
|
||||
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
Returns:
|
||||
int:
|
||||
maximum integer bit-width within the graph
|
||||
if there are no integer nodes matching the query, result is -1
|
||||
"""
|
||||
|
||||
filtered_bit_widths = (
|
||||
node.output.dtype.bit_width
|
||||
for node in self.query_nodes(tag_filter, operation_filter)
|
||||
if isinstance(node.output.dtype, Integer)
|
||||
)
|
||||
return max(filtered_bit_widths, default=-1)
|
||||
|
||||
def integer_range(
|
||||
self,
|
||||
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Get integer range of the graph.
|
||||
|
||||
Only nodes after filtering will be used to calculate the result.
|
||||
|
||||
Args:
|
||||
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for tags
|
||||
|
||||
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
|
||||
filter for operations
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]:
|
||||
minimum and maximum integer value observed during inputset evaluation
|
||||
if there are no integer nodes matching the query, result is None
|
||||
"""
|
||||
|
||||
result: Optional[Tuple[int, int]] = None
|
||||
|
||||
if not self.is_direct:
|
||||
filtered_bounds = (
|
||||
node.bounds
|
||||
for node in self.query_nodes(tag_filter, operation_filter)
|
||||
if isinstance(node.output.dtype, Integer) and node.bounds is not None
|
||||
)
|
||||
for min_bound, max_bound in filtered_bounds:
|
||||
assert isinstance(min_bound, np.integer) and isinstance(max_bound, np.integer)
|
||||
|
||||
if result is None:
|
||||
result = (int(min_bound), int(max_bound))
|
||||
else:
|
||||
old_min_bound, old_max_bound = result # pylint: disable=unpacking-non-sequence
|
||||
result = (
|
||||
min(old_min_bound, int(min_bound)),
|
||||
max(old_max_bound, int(max_bound)),
|
||||
)
|
||||
|
||||
return result
|
||||
434
frontends/concrete-python/concrete/numpy/representation/node.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Declaration of `Node` class.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from ..values import Value
|
||||
from .evaluator import ConstantEvaluator, GenericEvaluator, GenericTupleEvaluator, InputEvaluator
|
||||
from .operation import Operation
|
||||
from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
Node class, to represent computation in a computation graph.
|
||||
"""
|
||||
|
||||
inputs: List[Value]
|
||||
output: Value
|
||||
|
||||
operation: Operation
|
||||
evaluator: Callable
|
||||
|
||||
bounds: Optional[Tuple[Union[int, float], Union[int, float]]]
|
||||
properties: Dict[str, Any]
|
||||
|
||||
location: str
|
||||
tag: str
|
||||
created_at: float
|
||||
|
||||
@staticmethod
|
||||
def constant(constant: Any) -> "Node":
|
||||
"""
|
||||
Create an Operation.Constant node.
|
||||
|
||||
Args:
|
||||
constant (Any):
|
||||
constant to represent
|
||||
|
||||
Returns:
|
||||
Node:
|
||||
node representing constant
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if the constant is not representable
|
||||
"""
|
||||
|
||||
try:
|
||||
value = Value.of(constant)
|
||||
except Exception as error:
|
||||
message = f"Constant {repr(constant)} is not supported"
|
||||
raise ValueError(message) from error
|
||||
|
||||
properties = {"constant": np.array(constant)}
|
||||
return Node([], value, Operation.Constant, ConstantEvaluator(properties), properties)
|
||||
|
||||
@staticmethod
|
||||
def generic(
|
||||
name: str,
|
||||
inputs: List[Value],
|
||||
output: Value,
|
||||
operation: Callable,
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
attributes: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Create an Operation.Generic node.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the operation
|
||||
|
||||
inputs (List[Value]):
|
||||
inputs to the operation
|
||||
|
||||
output (Value):
|
||||
output of the operation
|
||||
|
||||
operation (Callable):
|
||||
operation itself
|
||||
|
||||
args (Optional[Tuple[Any, ...]]):
|
||||
args to pass to operation during evaluation
|
||||
|
||||
kwargs (Optional[Dict[str, Any]]):
|
||||
kwargs to pass to operation during evaluation
|
||||
|
||||
attributes (Optional[Dict[str, Any]]):
|
||||
attributes of the operation
|
||||
|
||||
Returns:
|
||||
Node:
|
||||
node representing operation
|
||||
"""
|
||||
|
||||
properties = {
|
||||
"name": name,
|
||||
"args": args if args is not None else (),
|
||||
"kwargs": kwargs if kwargs is not None else {},
|
||||
"attributes": attributes if attributes is not None else {},
|
||||
}
|
||||
|
||||
return Node(
|
||||
inputs,
|
||||
output,
|
||||
Operation.Generic,
|
||||
(
|
||||
GenericTupleEvaluator(operation, properties) # type: ignore
|
||||
if name in ["concatenate"]
|
||||
else GenericEvaluator(operation, properties) # type: ignore
|
||||
),
|
||||
properties,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def input(name: str, value: Value) -> "Node":
|
||||
"""
|
||||
Create an Operation.Input node.
|
||||
|
||||
Args:
|
||||
name (Any):
|
||||
name of the input
|
||||
|
||||
value (Any):
|
||||
value of the input
|
||||
|
||||
Returns:
|
||||
Node:
|
||||
node representing input
|
||||
"""
|
||||
|
||||
return Node([value], value, Operation.Input, InputEvaluator(), {"name": name})
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: List[Value],
|
||||
output: Value,
|
||||
operation: Operation,
|
||||
evaluator: Callable,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.inputs = inputs
|
||||
self.output = output
|
||||
|
||||
self.operation = operation
|
||||
self.evaluator = evaluator # type: ignore
|
||||
|
||||
self.bounds = None
|
||||
self.properties = properties if properties is not None else {}
|
||||
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
cnp_directory = os.path.dirname(cnp.__file__)
|
||||
|
||||
import concrete.onnx as coonx
|
||||
|
||||
coonx_directory = os.path.dirname(coonx.__file__)
|
||||
|
||||
# pylint: enable=cyclic-import,import-outside-toplevel
|
||||
|
||||
for frame in reversed(traceback.extract_stack()):
|
||||
if frame.filename == "<__array_function__ internals>":
|
||||
continue
|
||||
|
||||
if frame.filename.startswith(cnp_directory):
|
||||
continue
|
||||
|
||||
if frame.filename.startswith(coonx_directory):
|
||||
continue
|
||||
|
||||
self.location = f"{frame.filename}:{frame.lineno}"
|
||||
break
|
||||
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
|
||||
from ..extensions.tag import tag_context
|
||||
|
||||
self.tag = ".".join(tag_context.stack)
|
||||
|
||||
# pylint: enable=cyclic-import,import-outside-toplevel
|
||||
|
||||
self.created_at = time.time()
|
||||
|
||||
def __call__(self, *args: List[Any]) -> Union[np.bool_, np.integer, np.floating, np.ndarray]:
|
||||
def generic_error_message() -> str:
|
||||
result = f"Evaluation of {self.operation.value} '{self.label()}' node"
|
||||
if len(args) != 0:
|
||||
result += f" using {', '.join(repr(arg) for arg in args)}"
|
||||
return result
|
||||
|
||||
if len(args) != len(self.inputs):
|
||||
message = f"{generic_error_message()} failed because of invalid number of arguments"
|
||||
raise ValueError(message)
|
||||
|
||||
for arg, input_ in zip(args, self.inputs):
|
||||
try:
|
||||
arg_value = Value.of(arg)
|
||||
except Exception as error:
|
||||
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
|
||||
message = f"{generic_error_message()} failed because {arg_str} is not valid"
|
||||
raise ValueError(message) from error
|
||||
|
||||
if input_.shape != arg_value.shape:
|
||||
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
|
||||
message = (
|
||||
f"{generic_error_message()} failed because "
|
||||
f"{arg_str} does not have the expected "
|
||||
f"shape of {input_.shape}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
result = self.evaluator(*args)
|
||||
|
||||
if isinstance(result, int) and -(2**63) < result < (2**63) - 1:
|
||||
result = np.int64(result)
|
||||
if isinstance(result, float):
|
||||
result = np.float64(result)
|
||||
|
||||
if isinstance(result, list):
|
||||
try:
|
||||
np_result = np.array(result)
|
||||
result = np_result
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to convert the list to np.ndarray
|
||||
# if it fails we raise the exception below
|
||||
pass
|
||||
|
||||
if not isinstance(result, (np.bool_, np.integer, np.floating, np.ndarray)):
|
||||
message = (
|
||||
f"{generic_error_message()} resulted in {repr(result)} "
|
||||
f"of type {result.__class__.__name__} "
|
||||
f"which is not acceptable either because of the type or because of overflow"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if isinstance(result, np.ndarray):
|
||||
dtype = result.dtype
|
||||
if (
|
||||
not np.issubdtype(dtype, np.integer)
|
||||
and not np.issubdtype(dtype, np.floating)
|
||||
and not np.issubdtype(dtype, np.bool_)
|
||||
):
|
||||
message = (
|
||||
f"{generic_error_message()} resulted in {repr(result)} "
|
||||
f"of type np.ndarray and of underlying type '{type(dtype).__name__}' "
|
||||
f"which is not acceptable because of the underlying type"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if result.shape != self.output.shape:
|
||||
message = (
|
||||
f"{generic_error_message()} resulted in {repr(result)} "
|
||||
f"which does not have the expected "
|
||||
f"shape of {self.output.shape}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return result
|
||||
|
||||
def format(self, predecessors: List[str], maximum_constant_length: int = 45) -> str:
|
||||
"""
|
||||
Get the textual representation of the `Node` (dependent to preds).
|
||||
|
||||
Args:
|
||||
predecessors (List[str]):
|
||||
predecessor names to this node
|
||||
|
||||
maximum_constant_length (int, default = 45):
|
||||
maximum length of formatted constants
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of the `Node` (dependent to preds)
|
||||
"""
|
||||
|
||||
if self.operation == Operation.Constant:
|
||||
return format_constant(self(), maximum_constant_length)
|
||||
|
||||
if self.operation == Operation.Input:
|
||||
return self.properties["name"]
|
||||
|
||||
assert_that(self.operation == Operation.Generic)
|
||||
|
||||
name = self.properties["name"]
|
||||
|
||||
if name == "index.static":
|
||||
index = self.properties["kwargs"]["index"]
|
||||
elements = [format_indexing_element(element) for element in index]
|
||||
return f"{predecessors[0]}[{', '.join(elements)}]"
|
||||
|
||||
if name == "assign.static":
|
||||
index = self.properties["kwargs"]["index"]
|
||||
elements = [format_indexing_element(element) for element in index]
|
||||
return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[1]})"
|
||||
|
||||
if name == "concatenate":
|
||||
args = [f"({', '.join(predecessors)})"]
|
||||
else:
|
||||
args = deepcopy(predecessors)
|
||||
|
||||
if name == "array":
|
||||
values = str(np.array(predecessors).reshape(self.output.shape).tolist()).replace(
|
||||
"'", ""
|
||||
)
|
||||
return f"array({format_constant(values, maximum_constant_length)})"
|
||||
|
||||
args.extend(
|
||||
format_constant(value, maximum_constant_length) for value in self.properties["args"]
|
||||
)
|
||||
args.extend(
|
||||
f"{name}={format_constant(value, maximum_constant_length)}"
|
||||
for name, value in self.properties["kwargs"].items()
|
||||
if name not in KWARGS_IGNORED_IN_FORMATTING
|
||||
)
|
||||
|
||||
return f"{name}({', '.join(args)})"
|
||||
|
||||
def label(self) -> str:
|
||||
"""
|
||||
Get the textual representation of the `Node` (independent of preds).
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of the `Node` (independent of preds).
|
||||
"""
|
||||
|
||||
if self.operation == Operation.Constant:
|
||||
return format_constant(self(), maximum_length=45, keep_newlines=True)
|
||||
|
||||
if self.operation == Operation.Input:
|
||||
return self.properties["name"]
|
||||
|
||||
assert_that(self.operation == Operation.Generic)
|
||||
|
||||
name = self.properties["name"]
|
||||
|
||||
if name == "index.static":
|
||||
name = self.format(["□"])
|
||||
|
||||
if name == "assign.static":
|
||||
name = self.format(["□", "□"])[1:-1]
|
||||
|
||||
return name
|
||||
|
||||
@property
|
||||
def converted_to_table_lookup(self) -> bool:
|
||||
"""
|
||||
Get whether the node is converted to a table lookup during MLIR conversion.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if the node is converted to a table lookup, False otherwise
|
||||
"""
|
||||
|
||||
if (
|
||||
all(value.is_encrypted for value in self.inputs)
|
||||
and self.operation == Operation.Generic
|
||||
and self.properties["name"]
|
||||
in [
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"equal",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"left_shift",
|
||||
"less",
|
||||
"less_equal",
|
||||
"not_equal",
|
||||
"right_shift",
|
||||
]
|
||||
):
|
||||
return False
|
||||
|
||||
return self.operation == Operation.Generic and self.properties["name"] not in [
|
||||
"add",
|
||||
"array",
|
||||
"assign.static",
|
||||
"broadcast_to",
|
||||
"concatenate",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"dot",
|
||||
"expand_dims",
|
||||
"index.static",
|
||||
"matmul",
|
||||
"maxpool",
|
||||
"multiply",
|
||||
"negative",
|
||||
"ones",
|
||||
"reshape",
|
||||
"squeeze",
|
||||
"subtract",
|
||||
"sum",
|
||||
"transpose",
|
||||
"zeros",
|
||||
]
|
||||
|
||||
@property
|
||||
def is_fusable(self) -> bool:
|
||||
"""
|
||||
Get whether the node is can be fused into a table lookup.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if the node can be fused into a table lookup, False otherwise
|
||||
"""
|
||||
|
||||
if self.converted_to_table_lookup:
|
||||
return True
|
||||
|
||||
return self.operation != Operation.Generic or self.properties["name"] in [
|
||||
"add",
|
||||
"multiply",
|
||||
"negative",
|
||||
"ones",
|
||||
"subtract",
|
||||
"zeros",
|
||||
]
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
return self.created_at < other.created_at
|
||||
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Declaration of `Operation` enum.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Operation(Enum):
|
||||
"""
|
||||
Operation enum, to distinguish nodes within a computation graph.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
Constant = "constant"
|
||||
Generic = "generic"
|
||||
Input = "input"
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
# https://graphviz.org/doc/info/colors.html#svg
|
||||
|
||||
OPERATION_COLOR_MAPPING = {
|
||||
Operation.Constant: "grey",
|
||||
Operation.Generic: "black",
|
||||
Operation.Input: "crimson",
|
||||
"output": "gold",
|
||||
}
|
||||
114
frontends/concrete-python/concrete/numpy/representation/utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to representation of computation.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Hashable, Set, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
|
||||
KWARGS_IGNORED_IN_FORMATTING: Set[str] = {
|
||||
"subgraph",
|
||||
"terminal_node",
|
||||
}
|
||||
|
||||
SPECIAL_OBJECT_MAPPING: Dict[Any, str] = {
|
||||
np.float16: "float16",
|
||||
np.float32: "float32",
|
||||
np.float64: "float64",
|
||||
np.int8: "int8",
|
||||
np.int16: "int16",
|
||||
np.int32: "int32",
|
||||
np.int64: "int64",
|
||||
np.uint8: "uint8",
|
||||
np.uint16: "uint16",
|
||||
np.uint32: "uint32",
|
||||
np.uint64: "uint64",
|
||||
np.byte: "byte",
|
||||
np.short: "short",
|
||||
np.intc: "intc",
|
||||
np.int_: "int_",
|
||||
np.longlong: "longlong",
|
||||
np.ubyte: "ubyte",
|
||||
np.ushort: "ushort",
|
||||
np.uintc: "uintc",
|
||||
np.uint: "uint",
|
||||
np.ulonglong: "ulonglong",
|
||||
}
|
||||
|
||||
|
||||
def format_constant(constant: Any, maximum_length: int = 45, keep_newlines: bool = False) -> str:
|
||||
"""
|
||||
Get the textual representation of a constant.
|
||||
|
||||
Args:
|
||||
constant (Any):
|
||||
constant to format
|
||||
|
||||
maximum_length (int, default = 45):
|
||||
maximum length of the resulting string
|
||||
|
||||
keep_newlines (bool, default = False):
|
||||
whether to keep newlines or not
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of `constant`
|
||||
"""
|
||||
|
||||
if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING:
|
||||
return SPECIAL_OBJECT_MAPPING[constant]
|
||||
|
||||
# maximum_length should not be smaller than 7 characters because
|
||||
# the constant will be formatted to `x ... y`
|
||||
# where x and y are part of the constant, and they are at least 1 character
|
||||
assert_that(maximum_length >= 7)
|
||||
|
||||
result = str(constant)
|
||||
if not keep_newlines:
|
||||
result = result.replace("\n", "")
|
||||
|
||||
if len(result) > maximum_length:
|
||||
from_start = (maximum_length - 5) // 2
|
||||
from_end = (maximum_length - 5) - from_start
|
||||
|
||||
if keep_newlines and "\n" in result:
|
||||
result = f"{result[:from_start]}\n...\n{result[-from_end:]}"
|
||||
else:
|
||||
result = f"{result[:from_start]} ... {result[-from_end:]}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def format_indexing_element(indexing_element: Union[int, np.integer, slice]):
|
||||
"""
|
||||
Format an indexing element.
|
||||
|
||||
This is required mainly for slices. The reason is that string representation of slices
|
||||
are very long and verbose. To give an example, `x[:, 2:]` will have the following index
|
||||
`[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper,
|
||||
it will be formatted as `[:, 2:]`.
|
||||
|
||||
Args:
|
||||
indexing_element (Union[int, np.integer, slice]):
|
||||
indexing element to format
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual representation of `indexing_element`
|
||||
"""
|
||||
|
||||
result = ""
|
||||
if isinstance(indexing_element, slice):
|
||||
if indexing_element.start is not None:
|
||||
result += str(indexing_element.start)
|
||||
result += ":"
|
||||
if indexing_element.stop is not None:
|
||||
result += str(indexing_element.stop)
|
||||
if indexing_element.step is not None:
|
||||
result += ":"
|
||||
result += str(indexing_element.step)
|
||||
else:
|
||||
result += str(indexing_element)
|
||||
return result.replace("\n", " ")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Provide `function` to `computation graph` functionality.
|
||||
"""
|
||||
|
||||
from .tracer import ScalarAnnotation, TensorAnnotation, Tracer
|
||||
867
frontends/concrete-python/concrete/numpy/tracing/tracer.py
Normal file
@@ -0,0 +1,867 @@
|
||||
"""
|
||||
Declaration of `Tracer` class.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, cast
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..dtypes import BaseDataType, Float, Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..representation.utils import format_indexing_element
|
||||
from ..values import Value
|
||||
|
||||
|
||||
class Tracer:
|
||||
"""
|
||||
Tracer class, to create computation graphs from python functions.
|
||||
"""
|
||||
|
||||
computation: Node
|
||||
input_tracers: List["Tracer"]
|
||||
output: Value
|
||||
|
||||
# property to keep track of assignments
|
||||
last_version: Optional["Tracer"] = None
|
||||
|
||||
# variables to control the behavior of certain functions
|
||||
_is_tracing: bool = False
|
||||
_is_direct: bool = False
|
||||
|
||||
@staticmethod
|
||||
def trace(function: Callable, parameters: Dict[str, Value], is_direct: bool = False) -> Graph:
|
||||
"""
|
||||
Trace `function` and create the `Graph` that represents it.
|
||||
|
||||
Args:
|
||||
function (Callable):
|
||||
function to trace
|
||||
|
||||
parameters (Dict[str, Value]):
|
||||
parameters of function to trace
|
||||
e.g. parameter x is an EncryptedScalar holding a 7-bit UnsignedInteger
|
||||
|
||||
is_direct (bool, default = False):
|
||||
whether the tracing is done on actual parameters or placeholders
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph corresponding to `function`
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
|
||||
signature = inspect.signature(function)
|
||||
|
||||
missing_args = list(signature.parameters)
|
||||
for arg in parameters.keys():
|
||||
missing_args.remove(arg)
|
||||
assert_that(len(missing_args) == 0)
|
||||
|
||||
arguments = {}
|
||||
input_indices = {}
|
||||
|
||||
for index, param in enumerate(signature.parameters.keys()):
|
||||
node = Node.input(param, parameters[param])
|
||||
arguments[param] = Tracer(node, [])
|
||||
input_indices[node] = index
|
||||
|
||||
Tracer._is_direct = is_direct
|
||||
|
||||
Tracer._is_tracing = True
|
||||
output_tracers: Any = function(**arguments)
|
||||
Tracer._is_tracing = False
|
||||
|
||||
if not isinstance(output_tracers, tuple):
|
||||
output_tracers = (output_tracers,)
|
||||
|
||||
output_tracer_list = list(output_tracers)
|
||||
for i, output_tracer in enumerate(output_tracer_list):
|
||||
if isinstance(output_tracer, Tracer) and output_tracer.last_version is not None:
|
||||
output_tracer_list[i] = output_tracer.last_version
|
||||
output_tracers = tuple(output_tracer_list)
|
||||
|
||||
sanitized_tracers = []
|
||||
for tracer in output_tracers:
|
||||
if isinstance(tracer, Tracer):
|
||||
sanitized_tracers.append(tracer)
|
||||
continue
|
||||
|
||||
try:
|
||||
sanitized_tracers.append(Tracer.sanitize(tracer))
|
||||
except Exception as error:
|
||||
message = (
|
||||
f"Function '{function.__name__}' "
|
||||
f"returned '{tracer}', "
|
||||
f"which is not supported"
|
||||
)
|
||||
raise ValueError(message) from error
|
||||
|
||||
output_tracers = tuple(sanitized_tracers)
|
||||
|
||||
def create_graph_from_output_tracers(
|
||||
arguments: Dict[str, Tracer],
|
||||
output_tracers: Tuple[Tracer, ...],
|
||||
) -> nx.MultiDiGraph:
|
||||
graph = nx.MultiDiGraph()
|
||||
|
||||
visited_tracers: Set[Tracer] = set()
|
||||
current_tracers = {tracer: None for tracer in output_tracers}
|
||||
|
||||
while current_tracers:
|
||||
next_tracers: Dict[Tracer, None] = {}
|
||||
for tracer in current_tracers:
|
||||
if tracer not in visited_tracers:
|
||||
current_node = tracer.computation
|
||||
graph.add_node(current_node)
|
||||
|
||||
for input_idx, input_tracer in enumerate(tracer.input_tracers):
|
||||
pred_node = input_tracer.computation
|
||||
|
||||
graph.add_node(pred_node)
|
||||
graph.add_edge(
|
||||
pred_node,
|
||||
current_node,
|
||||
input_idx=input_idx,
|
||||
)
|
||||
|
||||
if input_tracer not in visited_tracers:
|
||||
next_tracers.update({input_tracer: None})
|
||||
|
||||
visited_tracers.add(tracer)
|
||||
|
||||
current_tracers = next_tracers
|
||||
|
||||
assert_that(nx.algorithms.dag.is_directed_acyclic_graph(graph))
|
||||
|
||||
unique_edges = {
|
||||
(pred, succ, tuple((k, v) for k, v in edge_data.items()))
|
||||
for pred, succ, edge_data in graph.edges(data=True)
|
||||
}
|
||||
assert_that(len(unique_edges) == len(graph.edges))
|
||||
|
||||
for tracer in arguments.values():
|
||||
graph.add_node(tracer.computation)
|
||||
|
||||
return graph
|
||||
|
||||
graph = create_graph_from_output_tracers(arguments, output_tracers)
|
||||
input_nodes = {
|
||||
input_indices[node]: node
|
||||
for node in graph.nodes()
|
||||
if len(graph.pred[node]) == 0 and node.operation == Operation.Input
|
||||
}
|
||||
output_nodes = {
|
||||
output_idx: tracer.computation for output_idx, tracer in enumerate(output_tracers)
|
||||
}
|
||||
|
||||
return Graph(graph, input_nodes, output_nodes, is_direct)
|
||||
|
||||
# pylint: enable=too-many-statements
|
||||
|
||||
def __init__(self, computation: Node, input_tracers: List["Tracer"]):
|
||||
self.computation = computation
|
||||
self.input_tracers = input_tracers
|
||||
self.output = computation.output
|
||||
|
||||
for i, tracer in enumerate(self.input_tracers):
|
||||
self.input_tracers[i] = tracer if tracer.last_version is None else tracer.last_version
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
# pylint: disable=invalid-bool-returned
|
||||
|
||||
message = "Branching within circuits is not possible"
|
||||
raise RuntimeError(message)
|
||||
|
||||
@staticmethod
|
||||
def sanitize(value: Any) -> Any:
|
||||
"""
|
||||
Try to create a tracer from a value.
|
||||
|
||||
Args:
|
||||
value (Any):
|
||||
value to use
|
||||
|
||||
Returns:
|
||||
Any:
|
||||
resulting tracer
|
||||
"""
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(Tracer.sanitize(item) for item in value)
|
||||
|
||||
if isinstance(value, Tracer):
|
||||
return value
|
||||
|
||||
computation = Node.constant(value)
|
||||
return Tracer(computation, [])
|
||||
|
||||
SUPPORTED_NUMPY_OPERATORS: Set[Any] = {
|
||||
np.abs,
|
||||
np.absolute,
|
||||
np.add,
|
||||
np.arccos,
|
||||
np.arccosh,
|
||||
np.arcsin,
|
||||
np.arcsinh,
|
||||
np.arctan,
|
||||
np.arctan2,
|
||||
np.arctanh,
|
||||
np.around,
|
||||
np.bitwise_and,
|
||||
np.bitwise_or,
|
||||
np.bitwise_xor,
|
||||
np.broadcast_to,
|
||||
np.cbrt,
|
||||
np.ceil,
|
||||
np.clip,
|
||||
np.concatenate,
|
||||
np.copysign,
|
||||
np.cos,
|
||||
np.cosh,
|
||||
np.deg2rad,
|
||||
np.degrees,
|
||||
np.divide,
|
||||
np.dot,
|
||||
np.equal,
|
||||
np.exp,
|
||||
np.exp2,
|
||||
np.expand_dims,
|
||||
np.expm1,
|
||||
np.fabs,
|
||||
np.float_power,
|
||||
np.floor,
|
||||
np.floor_divide,
|
||||
np.fmax,
|
||||
np.fmin,
|
||||
np.fmod,
|
||||
np.gcd,
|
||||
np.greater,
|
||||
np.greater_equal,
|
||||
np.heaviside,
|
||||
np.hypot,
|
||||
np.invert,
|
||||
np.isfinite,
|
||||
np.isinf,
|
||||
np.isnan,
|
||||
np.lcm,
|
||||
np.ldexp,
|
||||
np.left_shift,
|
||||
np.less,
|
||||
np.less_equal,
|
||||
np.log,
|
||||
np.log10,
|
||||
np.log1p,
|
||||
np.log2,
|
||||
np.logaddexp,
|
||||
np.logaddexp2,
|
||||
np.logical_and,
|
||||
np.logical_not,
|
||||
np.logical_or,
|
||||
np.logical_xor,
|
||||
np.matmul,
|
||||
np.maximum,
|
||||
np.minimum,
|
||||
np.mod,
|
||||
np.multiply,
|
||||
np.negative,
|
||||
np.nextafter,
|
||||
np.not_equal,
|
||||
np.ones_like,
|
||||
np.positive,
|
||||
np.power,
|
||||
np.rad2deg,
|
||||
np.radians,
|
||||
np.reciprocal,
|
||||
np.remainder,
|
||||
np.reshape,
|
||||
np.right_shift,
|
||||
np.rint,
|
||||
np.round_,
|
||||
np.sign,
|
||||
np.signbit,
|
||||
np.sin,
|
||||
np.sinh,
|
||||
np.spacing,
|
||||
np.sqrt,
|
||||
np.square,
|
||||
np.squeeze,
|
||||
np.subtract,
|
||||
np.sum,
|
||||
np.tan,
|
||||
np.tanh,
|
||||
np.transpose,
|
||||
np.true_divide,
|
||||
np.trunc,
|
||||
np.where,
|
||||
np.zeros_like,
|
||||
}
|
||||
|
||||
SUPPORTED_KWARGS: Dict[Any, Set[str]] = {
|
||||
np.around: {
|
||||
"decimals",
|
||||
},
|
||||
np.broadcast_to: {
|
||||
"shape",
|
||||
},
|
||||
np.concatenate: {
|
||||
"axis",
|
||||
},
|
||||
np.expand_dims: {
|
||||
"axis",
|
||||
},
|
||||
np.ones_like: {
|
||||
"dtype",
|
||||
},
|
||||
np.reshape: {
|
||||
"newshape",
|
||||
},
|
||||
np.round_: {
|
||||
"decimals",
|
||||
},
|
||||
np.squeeze: {
|
||||
"axis",
|
||||
},
|
||||
np.sum: {
|
||||
"axis",
|
||||
"keepdims",
|
||||
},
|
||||
np.transpose: {
|
||||
"axes",
|
||||
},
|
||||
np.zeros_like: {
|
||||
"dtype",
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _trace_numpy_operation(operation: Callable, *args, **kwargs) -> "Tracer":
|
||||
"""
|
||||
Trace an arbitrary numpy operation into an Operation.Generic node.
|
||||
|
||||
Args:
|
||||
operation (Callable):
|
||||
operation to trace
|
||||
|
||||
args (List[Any]):
|
||||
args of the arbitrary computation
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
kwargs of the arbitrary computation
|
||||
|
||||
Returns:
|
||||
Tracer:
|
||||
tracer representing the arbitrary computation
|
||||
"""
|
||||
|
||||
if operation not in Tracer.SUPPORTED_NUMPY_OPERATORS:
|
||||
message = f"Function 'np.{operation.__name__}' is not supported"
|
||||
raise RuntimeError(message)
|
||||
|
||||
supported_kwargs = Tracer.SUPPORTED_KWARGS.get(operation, set())
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in supported_kwargs:
|
||||
message = (
|
||||
f"Function 'np.{operation.__name__}' is not supported with kwarg '{kwarg}'"
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
|
||||
if operation == np.ones_like: # pylint: disable=comparison-with-callable
|
||||
dtype = kwargs.get("dtype", np.int64)
|
||||
return Tracer(Node.constant(np.ones(args[0].shape, dtype=dtype)), [])
|
||||
|
||||
if operation == np.zeros_like: # pylint: disable=comparison-with-callable
|
||||
dtype = kwargs.get("dtype", np.int64)
|
||||
return Tracer(Node.constant(np.zeros(args[0].shape, dtype=dtype)), [])
|
||||
|
||||
def sampler(arg: Any) -> Any:
|
||||
if isinstance(arg, tuple):
|
||||
return tuple(sampler(item) for item in arg)
|
||||
|
||||
output = arg.output
|
||||
assert_that(isinstance(output.dtype, (Float, Integer)))
|
||||
|
||||
dtype: Any = np.int64
|
||||
if isinstance(output.dtype, Float):
|
||||
assert_that(output.dtype.bit_width in [16, 32, 64])
|
||||
dtype = {64: np.float64, 32: np.float32, 16: np.float16}[output.dtype.bit_width]
|
||||
|
||||
if output.shape == ():
|
||||
return dtype(1)
|
||||
|
||||
return np.ones(output.shape, dtype=dtype)
|
||||
|
||||
sample = [sampler(arg) for arg in args]
|
||||
evaluation = operation(*sample, **kwargs)
|
||||
|
||||
def extract_tracers(arg: Any, tracers: List[Tracer]):
|
||||
if isinstance(arg, tuple):
|
||||
for item in arg:
|
||||
extract_tracers(item, tracers)
|
||||
|
||||
if isinstance(arg, Tracer):
|
||||
tracers.append(arg)
|
||||
|
||||
tracers: List[Tracer] = []
|
||||
for arg in args:
|
||||
extract_tracers(arg, tracers)
|
||||
|
||||
output_value = Value.of(evaluation)
|
||||
output_value.is_encrypted = any(tracer.output.is_encrypted for tracer in tracers)
|
||||
|
||||
if Tracer._is_direct and isinstance(output_value.dtype, Integer):
|
||||
|
||||
assert all(isinstance(tracer.output.dtype, Integer) for tracer in tracers)
|
||||
dtypes = cast(List[Integer], [tracer.output.dtype for tracer in tracers])
|
||||
|
||||
output_value.dtype.bit_width = max(dtype.bit_width for dtype in dtypes)
|
||||
output_value.dtype.is_signed = any(dtype.is_signed for dtype in dtypes)
|
||||
|
||||
computation = Node.generic(
|
||||
operation.__name__,
|
||||
[tracer.output for tracer in tracers],
|
||||
output_value,
|
||||
operation,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
return Tracer(computation, tracers)
|
||||
|
||||
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
|
||||
"""
|
||||
Numpy ufunc hook.
|
||||
|
||||
(https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch)
|
||||
"""
|
||||
|
||||
if method == "__call__":
|
||||
sanitized_args = [self.sanitize(arg) for arg in args]
|
||||
return Tracer._trace_numpy_operation(ufunc, *sanitized_args, **kwargs)
|
||||
|
||||
message = "Only __call__ hook is supported for numpy ufuncs"
|
||||
raise RuntimeError(message)
|
||||
|
||||
def __array_function__(self, func, _types, args, kwargs):
|
||||
"""
|
||||
Numpy function hook.
|
||||
|
||||
(https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch)
|
||||
"""
|
||||
|
||||
if func is np.broadcast_to:
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["shape"] = args[1]
|
||||
elif func is np.reshape:
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["newshape"] = args[1]
|
||||
elif func is np.transpose:
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["axes"] = args[1]
|
||||
else:
|
||||
sanitized_args = [self.sanitize(arg) for arg in args]
|
||||
|
||||
return Tracer._trace_numpy_operation(func, *sanitized_args, **kwargs)
|
||||
|
||||
def __add__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.add, self, self.sanitize(other))
|
||||
|
||||
def __radd__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.add, self.sanitize(other), self)
|
||||
|
||||
def __sub__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.subtract, self, self.sanitize(other))
|
||||
|
||||
def __rsub__(self, other) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.subtract, self.sanitize(other), self)
|
||||
|
||||
def __mul__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.multiply, self, self.sanitize(other))
|
||||
|
||||
def __rmul__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.multiply, self.sanitize(other), self)
|
||||
|
||||
def __truediv__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.true_divide, self, self.sanitize(other))
|
||||
|
||||
def __rtruediv__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.true_divide, self.sanitize(other), self)
|
||||
|
||||
def __floordiv__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.floor_divide, self, self.sanitize(other))
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.floor_divide, self.sanitize(other), self)
|
||||
|
||||
def __pow__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.power, self, self.sanitize(other))
|
||||
|
||||
def __rpow__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.power, self.sanitize(other), self)
|
||||
|
||||
def __mod__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.mod, self, self.sanitize(other))
|
||||
|
||||
def __rmod__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.mod, self.sanitize(other), self)
|
||||
|
||||
def __matmul__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.matmul, self, self.sanitize(other))
|
||||
|
||||
def __rmatmul__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.matmul, self.sanitize(other), self)
|
||||
|
||||
def __neg__(self) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.negative, self)
|
||||
|
||||
def __pos__(self) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.positive, self)
|
||||
|
||||
def __abs__(self):
|
||||
return Tracer._trace_numpy_operation(np.absolute, self)
|
||||
|
||||
def __round__(self, ndigits=None):
|
||||
if ndigits is None:
|
||||
result = Tracer._trace_numpy_operation(np.around, self)
|
||||
if self._is_direct:
|
||||
message = (
|
||||
"'round(x)' cannot be used in direct definition (you may use np.around instead)"
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
return result.astype(np.int64)
|
||||
|
||||
return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits)
|
||||
|
||||
def __invert__(self):
|
||||
return Tracer._trace_numpy_operation(np.invert, self)
|
||||
|
||||
def __and__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_and, self, self.sanitize(other))
|
||||
|
||||
def __rand__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_and, self.sanitize(other), self)
|
||||
|
||||
def __or__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_or, self, self.sanitize(other))
|
||||
|
||||
def __ror__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_or, self.sanitize(other), self)
|
||||
|
||||
def __xor__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_xor, self, self.sanitize(other))
|
||||
|
||||
def __rxor__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_xor, self.sanitize(other), self)
|
||||
|
||||
def __lshift__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.left_shift, self, self.sanitize(other))
|
||||
|
||||
def __rlshift__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.left_shift, self.sanitize(other), self)
|
||||
|
||||
def __rshift__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.right_shift, self, self.sanitize(other))
|
||||
|
||||
def __rrshift__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.right_shift, self.sanitize(other), self)
|
||||
|
||||
def __gt__(self, other: Any) -> "Tracer": # type: ignore
|
||||
return Tracer._trace_numpy_operation(np.greater, self, self.sanitize(other))
|
||||
|
||||
def __ge__(self, other: Any) -> "Tracer": # type: ignore
|
||||
return Tracer._trace_numpy_operation(np.greater_equal, self, self.sanitize(other))
|
||||
|
||||
def __lt__(self, other: Any) -> "Tracer": # type: ignore
|
||||
return Tracer._trace_numpy_operation(np.less, self, self.sanitize(other))
|
||||
|
||||
def __le__(self, other: Any) -> "Tracer": # type: ignore
|
||||
return Tracer._trace_numpy_operation(np.less_equal, self, self.sanitize(other))
|
||||
|
||||
def __eq__(self, other: Any) -> Union[bool, "Tracer"]: # type: ignore
|
||||
return (
|
||||
self is other
|
||||
if not self._is_tracing
|
||||
else Tracer._trace_numpy_operation(np.equal, self, self.sanitize(other))
|
||||
)
|
||||
|
||||
def __ne__(self, other: Any) -> Union[bool, "Tracer"]: # type: ignore
|
||||
return (
|
||||
self is not other
|
||||
if not self._is_tracing
|
||||
else Tracer._trace_numpy_operation(np.not_equal, self, self.sanitize(other))
|
||||
)
|
||||
|
||||
def astype(self, dtype: Union[DTypeLike, Type["ScalarAnnotation"]]) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.astype(dtype).
|
||||
"""
|
||||
|
||||
if Tracer._is_direct:
|
||||
output_value = deepcopy(self.output)
|
||||
|
||||
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
|
||||
output_value.dtype = dtype.dtype
|
||||
else:
|
||||
message = (
|
||||
"`astype` method must be called with a concrete.numpy type "
|
||||
"for direct circuit definition (e.g., value.astype(cnp.uint4))"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
computation = Node.generic(
|
||||
"astype",
|
||||
[self.output],
|
||||
output_value,
|
||||
lambda x: x, # unused for direct definition
|
||||
)
|
||||
return Tracer(computation, [self])
|
||||
|
||||
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
|
||||
message = (
|
||||
"`astype` method must be called with a "
|
||||
"numpy type for compilation (e.g., value.astype(np.int64))"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
dtype = np.dtype(dtype).type
|
||||
if np.issubdtype(dtype, np.integer) and dtype != np.int64:
|
||||
print(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
"only use `np.int64` as the newtype "
|
||||
"to avoid unexpected overflows "
|
||||
"during inputset evaluation"
|
||||
)
|
||||
|
||||
output_value = deepcopy(self.output)
|
||||
output_value.dtype = Value.of(dtype(0)).dtype # type: ignore
|
||||
|
||||
if np.issubdtype(dtype, np.integer):
|
||||
|
||||
def evaluator(x, dtype):
|
||||
if np.any(np.isnan(x)):
|
||||
message = "A `NaN` value is tried to be converted to integer"
|
||||
raise ValueError(message)
|
||||
if np.any(np.isinf(x)):
|
||||
message = "An `Inf` value is tried to be converted to integer"
|
||||
raise ValueError(message)
|
||||
return x.astype(dtype)
|
||||
|
||||
else:
|
||||
|
||||
def evaluator(x, dtype):
|
||||
return x.astype(dtype)
|
||||
|
||||
computation = Node.generic(
|
||||
"astype",
|
||||
[self.output],
|
||||
output_value,
|
||||
evaluator,
|
||||
kwargs={"dtype": dtype},
|
||||
)
|
||||
return Tracer(computation, [self])
|
||||
|
||||
def clip(self, minimum: Any, maximum: Any) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.clip().
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(
|
||||
np.clip, self, self.sanitize(minimum), self.sanitize(maximum)
|
||||
)
|
||||
|
||||
def dot(self, other: Any) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.dot().
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.dot, self, self.sanitize(other))
|
||||
|
||||
def flatten(self) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.flatten().
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.reshape, self, newshape=(self.output.size,))
|
||||
|
||||
def reshape(self, newshape: Tuple[Any, ...]) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.reshape(newshape).
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.reshape, self, newshape=newshape)
|
||||
|
||||
def round(self, decimals: int = 0) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.round().
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.around, self, decimals=decimals)
|
||||
|
||||
def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.transpose().
|
||||
"""
|
||||
|
||||
if axes is None:
|
||||
return Tracer._trace_numpy_operation(np.transpose, self)
|
||||
|
||||
return Tracer._trace_numpy_operation(np.transpose, self, axes=axes)
|
||||
|
||||
def __getitem__(
|
||||
self,
|
||||
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
|
||||
) -> "Tracer":
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
|
||||
for indexing_element in index:
|
||||
valid = isinstance(indexing_element, (int, np.integer, slice))
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
if (
|
||||
not (
|
||||
indexing_element.start is None
|
||||
or isinstance(indexing_element.start, (int, np.integer))
|
||||
)
|
||||
or not (
|
||||
indexing_element.stop is None
|
||||
or isinstance(indexing_element.stop, (int, np.integer))
|
||||
)
|
||||
or not (
|
||||
indexing_element.step is None
|
||||
or isinstance(indexing_element.step, (int, np.integer))
|
||||
)
|
||||
):
|
||||
valid = False
|
||||
|
||||
if not valid:
|
||||
message = (
|
||||
f"Indexing with '{format_indexing_element(indexing_element)}' is not supported"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
output_value = deepcopy(self.output)
|
||||
output_value.shape = np.zeros(output_value.shape)[index].shape
|
||||
|
||||
computation = Node.generic(
|
||||
"index.static",
|
||||
[self.output],
|
||||
output_value,
|
||||
lambda x, index: x[index],
|
||||
kwargs={"index": index},
|
||||
)
|
||||
return Tracer(computation, [self])
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
|
||||
value: Any,
|
||||
):
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
|
||||
for indexing_element in index:
|
||||
valid = isinstance(indexing_element, (int, np.integer, slice))
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
if (
|
||||
not (
|
||||
indexing_element.start is None
|
||||
or isinstance(indexing_element.start, (int, np.integer))
|
||||
)
|
||||
or not (
|
||||
indexing_element.stop is None
|
||||
or isinstance(indexing_element.stop, (int, np.integer))
|
||||
)
|
||||
or not (
|
||||
indexing_element.step is None
|
||||
or isinstance(indexing_element.step, (int, np.integer))
|
||||
)
|
||||
):
|
||||
valid = False
|
||||
|
||||
if not valid:
|
||||
message = (
|
||||
f"Assigning to '{format_indexing_element(indexing_element)}' is not supported"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
np.zeros(self.output.shape)[index] = 1
|
||||
|
||||
def assign(x, value, index):
|
||||
x[index] = value
|
||||
return x
|
||||
|
||||
sanitized_value = self.sanitize(value)
|
||||
computation = Node.generic(
|
||||
"assign.static",
|
||||
[self.output, sanitized_value.output],
|
||||
self.output,
|
||||
assign,
|
||||
kwargs={"index": index},
|
||||
)
|
||||
new_version = Tracer(computation, [self, sanitized_value])
|
||||
|
||||
self.last_version = new_version
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""
|
||||
Trace numpy.ndarray.shape.
|
||||
"""
|
||||
|
||||
return self.output.shape
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
"""
|
||||
Trace numpy.ndarray.ndim.
|
||||
"""
|
||||
|
||||
return self.output.ndim
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Trace numpy.ndarray.size.
|
||||
"""
|
||||
|
||||
return self.output.size
|
||||
|
||||
@property
|
||||
def T(self) -> "Tracer": # pylint: disable=invalid-name # noqa: N802
|
||||
"""
|
||||
Trace numpy.ndarray.T.
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.transpose, self)
|
||||
|
||||
|
||||
class Annotation(Tracer):
|
||||
"""
|
||||
Base annotation for direct definition.
|
||||
"""
|
||||
|
||||
|
||||
class ScalarAnnotation(Annotation):
|
||||
"""
|
||||
Base scalar annotation for direct definition.
|
||||
"""
|
||||
|
||||
dtype: BaseDataType
|
||||
|
||||
|
||||
class TensorAnnotation(Annotation):
|
||||
"""
|
||||
Base tensor annotation for direct definition.
|
||||
"""
|
||||
1223
frontends/concrete-python/concrete/numpy/tracing/typing.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Define the available values and their semantics.
|
||||
"""
|
||||
|
||||
from .scalar import ClearScalar, EncryptedScalar
|
||||
from .tensor import ClearTensor, EncryptedTensor
|
||||
from .value import Value
|
||||
44
frontends/concrete-python/concrete/numpy/values/scalar.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Declaration of `ClearScalar` and `EncryptedScalar` wrappers.
|
||||
"""
|
||||
|
||||
from ..dtypes import BaseDataType
|
||||
from .value import Value
|
||||
|
||||
|
||||
def clear_scalar_builder(dtype: BaseDataType) -> Value:
|
||||
"""
|
||||
Build a clear scalar value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType):
|
||||
dtype of the value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
clear scalar value with given dtype
|
||||
"""
|
||||
|
||||
return Value(dtype=dtype, shape=(), is_encrypted=False)
|
||||
|
||||
|
||||
ClearScalar = clear_scalar_builder
|
||||
|
||||
|
||||
def encrypted_scalar_builder(dtype: BaseDataType) -> Value:
|
||||
"""
|
||||
Build an encrypted scalar value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType):
|
||||
dtype of the value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
encrypted scalar value with given dtype
|
||||
"""
|
||||
|
||||
return Value(dtype=dtype, shape=(), is_encrypted=True)
|
||||
|
||||
|
||||
EncryptedScalar = encrypted_scalar_builder
|
||||
52
frontends/concrete-python/concrete/numpy/values/tensor.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Declaration of `ClearTensor` and `EncryptedTensor` wrappers.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from ..dtypes import BaseDataType
|
||||
from .value import Value
|
||||
|
||||
|
||||
def clear_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
|
||||
"""
|
||||
Build a clear tensor value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType):
|
||||
dtype of the value
|
||||
|
||||
shape (Tuple[int, ...]):
|
||||
shape of the value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
clear tensor value with given dtype and shape
|
||||
"""
|
||||
|
||||
return Value(dtype=dtype, shape=shape, is_encrypted=False)
|
||||
|
||||
|
||||
ClearTensor = clear_tensor_builder
|
||||
|
||||
|
||||
def encrypted_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
|
||||
"""
|
||||
Build an encrypted tensor value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType):
|
||||
dtype of the value
|
||||
|
||||
shape (Tuple[int, ...]):
|
||||
shape of the value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
encrypted tensor value with given dtype and shape
|
||||
"""
|
||||
|
||||
return Value(dtype=dtype, shape=shape, is_encrypted=True)
|
||||
|
||||
|
||||
EncryptedTensor = encrypted_tensor_builder
|
||||
162
frontends/concrete-python/concrete/numpy/values/value.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Declaration of `Value` class.
|
||||
"""
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import BaseDataType, Float, Integer, UnsignedInteger
|
||||
|
||||
|
||||
class Value:
|
||||
"""
|
||||
Value class, to combine data type, shape, and encryption status into a single object.
|
||||
"""
|
||||
|
||||
dtype: BaseDataType
|
||||
shape: Tuple[int, ...]
|
||||
is_encrypted: bool
|
||||
|
||||
@staticmethod
|
||||
def of(value: Any, is_encrypted: bool = False) -> "Value": # pylint: disable=invalid-name
|
||||
"""
|
||||
Get the `Value` that can represent `value`.
|
||||
|
||||
Args:
|
||||
value (Any):
|
||||
value that needs to be represented
|
||||
|
||||
is_encrypted (bool, default = False):
|
||||
whether the resulting `Value` is encrypted or not
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
`Value` that can represent `value`
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if `value` cannot be represented by `Value`
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-return-statements
|
||||
|
||||
if isinstance(value, (bool, np.bool_)):
|
||||
return Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=is_encrypted)
|
||||
|
||||
if isinstance(value, (int, np.integer)):
|
||||
return Value(
|
||||
dtype=Integer.that_can_represent(value),
|
||||
shape=(),
|
||||
is_encrypted=is_encrypted,
|
||||
)
|
||||
|
||||
if isinstance(value, (float, np.float64)):
|
||||
return Value(dtype=Float(64), shape=(), is_encrypted=is_encrypted)
|
||||
|
||||
if isinstance(value, np.float32):
|
||||
return Value(dtype=Float(32), shape=(), is_encrypted=is_encrypted)
|
||||
|
||||
if isinstance(value, np.float16):
|
||||
return Value(dtype=Float(16), shape=(), is_encrypted=is_encrypted)
|
||||
|
||||
if isinstance(value, list):
|
||||
try:
|
||||
value = np.array(value)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to convert the list to np.ndarray
|
||||
# if it fails we raise the exception at the end of the function
|
||||
pass
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
|
||||
if np.issubdtype(value.dtype, np.bool_):
|
||||
return Value(dtype=UnsignedInteger(1), shape=value.shape, is_encrypted=is_encrypted)
|
||||
|
||||
if np.issubdtype(value.dtype, np.integer):
|
||||
return Value(
|
||||
dtype=Integer.that_can_represent(value),
|
||||
shape=value.shape,
|
||||
is_encrypted=is_encrypted,
|
||||
)
|
||||
|
||||
if np.issubdtype(value.dtype, np.float64):
|
||||
return Value(dtype=Float(64), shape=value.shape, is_encrypted=is_encrypted)
|
||||
|
||||
if np.issubdtype(value.dtype, np.float32):
|
||||
return Value(dtype=Float(32), shape=value.shape, is_encrypted=is_encrypted)
|
||||
|
||||
if np.issubdtype(value.dtype, np.float16):
|
||||
return Value(dtype=Float(16), shape=value.shape, is_encrypted=is_encrypted)
|
||||
|
||||
message = f"Value cannot represent {repr(value)}"
|
||||
raise ValueError(message)
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements
|
||||
|
||||
def __init__(self, dtype: BaseDataType, shape: Tuple[int, ...], is_encrypted: bool):
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
self.is_encrypted = is_encrypted
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, Value)
|
||||
and self.dtype == other.dtype
|
||||
and self.shape == other.shape
|
||||
and self.is_encrypted == other.is_encrypted
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
encrypted_or_clear_str = "Encrypted" if self.is_encrypted else "Clear"
|
||||
scalar_or_tensor_str = "Scalar" if self.is_scalar else "Tensor"
|
||||
shape_str = f", shape={self.shape}" if not self.is_scalar else ""
|
||||
return f"{encrypted_or_clear_str}{scalar_or_tensor_str}<{str(self.dtype)}{shape_str}>"
|
||||
|
||||
@property
|
||||
def is_clear(self) -> bool:
|
||||
"""
|
||||
Get whether the value is clear or not.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if value is not encrypted, False otherwise
|
||||
"""
|
||||
|
||||
return not self.is_encrypted
|
||||
|
||||
@property
|
||||
def is_scalar(self) -> bool:
|
||||
"""
|
||||
Get whether the value is scalar or not.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if shape of the value is (), False otherwise
|
||||
"""
|
||||
|
||||
return self.shape == ()
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
"""
|
||||
Get number of dimensions of the value.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
number of dimensions of the value
|
||||
"""
|
||||
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get number of elements in the value.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
number of elements in the value
|
||||
"""
|
||||
|
||||
return int(np.prod(self.shape))
|
||||
6
frontends/concrete-python/concrete/onnx/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Implement machine learning operations as specified by ONNX.
|
||||
"""
|
||||
|
||||
from .convolution import conv
|
||||
from .maxpool import maxpool
|
||||
683
frontends/concrete-python/concrete/onnx/convolution.py
Normal file
@@ -0,0 +1,683 @@
|
||||
"""
|
||||
Convolution operations' tracing and evaluation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..numpy.internal.utils import assert_that
|
||||
from ..numpy.representation import Node
|
||||
from ..numpy.tracing import Tracer
|
||||
from ..numpy.values import EncryptedTensor
|
||||
|
||||
SUPPORTED_AUTO_PAD = {
|
||||
"NOTSET",
|
||||
}
|
||||
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
|
||||
|
||||
def conv(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]] = None,
|
||||
pads: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
strides: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
dilations: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
kernel_shape: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
group: int = 1,
|
||||
auto_pad: str = "NOTSET",
|
||||
) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Trace and evaluate convolution operations.
|
||||
|
||||
Refer to https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv for more info.
|
||||
|
||||
Args:
|
||||
x (Union[np.ndarray, Tracer]): input of shape (N, C, D1, ..., DN)
|
||||
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C / group, K1, ..., KN)
|
||||
bias (Optional[Union[np.ndarray, Tracer]], optional): bias of shape (F,). Defaults to None.
|
||||
pads (Optional[Union[Tuple[int, ...], List[int]]], optional):
|
||||
padding for the beginning and ending along each spatial axis
|
||||
(D1_begin, D2_begin, ..., D1_end, D2_end, ...).
|
||||
Will be set to 0 along each spatial axis if not set.
|
||||
strides (Optional[Union[Tuple[int, ...], List[int]]], optional):
|
||||
stride along each spatial axis. Will be set to 1 along each spatial axis if not set.
|
||||
dilations (Optional[Union[Tuple[int, ...], List[int]]], optional):
|
||||
dilation along each spatial axis. Will be set to 1 along each spatial axis if not set.
|
||||
kernel_shape (Optional[Union[Tuple[int, ...], List[int]]], optional):
|
||||
shape of the convolution kernel. Inferred from input weight if not present
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into. Defaults to 1.
|
||||
auto_pad (str, optional): padding strategy. Defaults to "NOTSET".
|
||||
|
||||
Raises:
|
||||
ValueError: if arguments are not appropriate
|
||||
TypeError: unexpected types
|
||||
NotImplementedError: a convolution that we don't support
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]: evaluation result or traced computation
|
||||
"""
|
||||
if kernel_shape is not None and (
|
||||
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
|
||||
):
|
||||
message = f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}"
|
||||
raise ValueError(message)
|
||||
|
||||
if isinstance(x, np.ndarray):
|
||||
if not isinstance(weight, np.ndarray):
|
||||
message = "expected weight to be of same type as x"
|
||||
raise TypeError(message)
|
||||
if bias is not None and not isinstance(bias, np.ndarray):
|
||||
message = "expected bias to be of same type as x"
|
||||
raise TypeError(message)
|
||||
elif isinstance(x, Tracer):
|
||||
if not isinstance(weight, (Tracer, np.ndarray)):
|
||||
message = "expected weight to be of type Tracer or ndarray"
|
||||
raise TypeError(message)
|
||||
if bias is not None and not isinstance(bias, (Tracer, np.ndarray)):
|
||||
message = "expected bias to be of type Tracer or ndarray"
|
||||
raise TypeError(message)
|
||||
|
||||
if x.ndim <= 2:
|
||||
message = (
|
||||
f"expected input x to have at least 3 dimensions (N, C, D1, ...), but got {x.ndim}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if weight.ndim <= 2:
|
||||
message = (
|
||||
f"expected weight to have at least 3 dimensions (F, C / group, K1, ...), but got "
|
||||
f"{weight.ndim}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if bias is not None and bias.ndim != 1:
|
||||
message = f"expected bias to have a single dimension (F,), but got {bias.ndim}"
|
||||
raise ValueError(message)
|
||||
|
||||
if not isinstance(group, int) or group <= 0:
|
||||
message = f"expected group to be an integer > 0, but got {group}"
|
||||
raise ValueError(message)
|
||||
|
||||
if auto_pad not in SUPPORTED_AUTO_PAD:
|
||||
message = f"auto_pad should be in {SUPPORTED_AUTO_PAD}, but got {repr(auto_pad)}"
|
||||
raise ValueError(message)
|
||||
|
||||
n_channels = x.shape[1]
|
||||
if weight.shape[1] != n_channels / group:
|
||||
message = (
|
||||
f"expected number of channel in weight to be {n_channels / group} (C / group), but got "
|
||||
f"{weight.shape[1]}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if weight.shape[0] % group != 0:
|
||||
message = (
|
||||
f"expected number of feature maps ({weight.shape[0]}) to be a multiple of group "
|
||||
f"({group})"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
dims = x.ndim - 2
|
||||
if dims == 1:
|
||||
pads = (0, 0) if pads is None else pads
|
||||
strides = (1,) if strides is None else strides
|
||||
dilations = (1,) if dilations is None else dilations
|
||||
return _conv1d(
|
||||
x,
|
||||
weight,
|
||||
bias=bias,
|
||||
pads=pads,
|
||||
strides=strides,
|
||||
dilations=dilations,
|
||||
group=group,
|
||||
auto_pad=auto_pad,
|
||||
)
|
||||
if dims == 2:
|
||||
pads = (0, 0, 0, 0) if pads is None else pads
|
||||
strides = (1, 1) if strides is None else strides
|
||||
dilations = (1, 1) if dilations is None else dilations
|
||||
return _conv2d(
|
||||
x,
|
||||
weight,
|
||||
bias=bias,
|
||||
pads=pads,
|
||||
strides=strides,
|
||||
dilations=dilations,
|
||||
group=group,
|
||||
auto_pad=auto_pad,
|
||||
)
|
||||
if dims == 3:
|
||||
pads = (0, 0, 0, 0, 0, 0) if pads is None else pads
|
||||
strides = (1, 1, 1) if strides is None else strides
|
||||
dilations = (1, 1, 1) if dilations is None else dilations
|
||||
return _conv3d(
|
||||
x,
|
||||
weight,
|
||||
bias=bias,
|
||||
pads=pads,
|
||||
strides=strides,
|
||||
dilations=dilations,
|
||||
group=group,
|
||||
auto_pad=auto_pad,
|
||||
)
|
||||
|
||||
message = "only 1D, 2D, and 3D convolutions are supported"
|
||||
raise NotImplementedError(message)
|
||||
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
|
||||
|
||||
def _conv1d(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]],
|
||||
pads: Union[Tuple[int, ...], List[int]],
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
auto_pad: str, # pylint: disable=unused-argument
|
||||
) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Trace or evaluate 1D convolution.
|
||||
|
||||
Args:
|
||||
x (Union[np.ndarray, Tracer]): input of shape (N, C, D)
|
||||
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C, D)
|
||||
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
|
||||
pads (Union[Tuple[int, ...], List[int]]):
|
||||
padding over dimension D (D_beg, D_end)
|
||||
strides (Union[Tuple[int, ...], List[int]]): stride over dimension D
|
||||
dilations (Union[Tuple[int, ...], List[int]]): dilation over dimension D
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
auto_pad (str, optional): padding strategy.
|
||||
|
||||
Raises:
|
||||
ValueError: if arguments are not appropriate
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]: evaluation result or traced computation
|
||||
"""
|
||||
|
||||
assert_that(
|
||||
x.ndim == 3,
|
||||
f"expected input x to be of shape (N, C, D) when performing 1D convolution, but "
|
||||
f"got {x.shape}",
|
||||
)
|
||||
|
||||
assert_that(
|
||||
weight.ndim == 3,
|
||||
f"expected weight to be of shape (F, C, D) when performing 1D convolution, but "
|
||||
f"got {weight.shape}",
|
||||
)
|
||||
|
||||
if len(pads) != 2:
|
||||
message = (
|
||||
f"pads should be of form "
|
||||
f"(D_begin_pad, D_end_pad) when performing "
|
||||
f"1D convolution, but it's {pads}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
if len(strides) != 1:
|
||||
message = (
|
||||
f"strides should be of form (D_stride,) when performing 1D "
|
||||
f"convolution, but it's {strides}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
if len(dilations) != 1:
|
||||
message = (
|
||||
f"dilations should be of form (D_dilation,) when performing 1D "
|
||||
f"convolution, but it's {dilations}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
|
||||
|
||||
|
||||
def _conv2d(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]],
|
||||
pads: Union[Tuple[int, ...], List[int]],
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
auto_pad: str, # pylint: disable=unused-argument
|
||||
) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Trace or evaluate 2D convolution.
|
||||
|
||||
Args:
|
||||
x (Union[np.ndarray, Tracer]): input of shape (N, C, H, W)
|
||||
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C, H, W)
|
||||
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
|
||||
pads (Union[Tuple[int, ...], List[int]]):
|
||||
padding over each height and width (H_beg, W_beg, H_end, W_end)
|
||||
strides (Union[Tuple[int, ...], List[int]]): stride over height and width
|
||||
dilations (Union[Tuple[int, ...], List[int]]): dilation over height and width
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
auto_pad (str, optional): padding strategy.
|
||||
|
||||
Raises:
|
||||
ValueError: if arguments are not appropriate
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]: evaluation result or traced computation
|
||||
"""
|
||||
|
||||
assert_that(
|
||||
x.ndim == 4,
|
||||
f"expected input x to be of shape (N, C, H, W) when performing 2D convolution, but "
|
||||
f"got {x.shape}",
|
||||
)
|
||||
|
||||
assert_that(
|
||||
weight.ndim == 4,
|
||||
f"expected weight to be of shape (F, C, H, W) when performing 2D convolution, but "
|
||||
f"got {weight.shape}",
|
||||
)
|
||||
|
||||
if len(pads) != 4:
|
||||
message = (
|
||||
f"pads should be of form "
|
||||
f"(height_begin_pad, width_begin_pad, height_end_pad, width_end_pad) when performing "
|
||||
f"2D convolution, but it's {pads}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
if len(strides) != 2:
|
||||
message = (
|
||||
f"strides should be of form (height_stride, width_stride) when performing 2D "
|
||||
f"convolution, but it's {strides}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
if len(dilations) != 2:
|
||||
message = (
|
||||
f"dilations should be of form (height_dilation, width_dilation) when performing 2D "
|
||||
f"convolution, but it's {dilations}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
|
||||
|
||||
|
||||
def _conv3d(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]],
|
||||
pads: Union[Tuple[int, ...], List[int]],
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
auto_pad: str, # pylint: disable=unused-argument
|
||||
) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Trace or evaluate 3D convolution.
|
||||
|
||||
Args:
|
||||
x (Union[np.ndarray, Tracer]): input of shape (N, C, D, H, W)
|
||||
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C, D, H, W)
|
||||
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
|
||||
pads (Union[Tuple[int, ...], List[int]]):
|
||||
padding over each spatial axis (D_beg, H_beg, W_beg, D_end, H_end, W_end)
|
||||
strides (Union[Tuple[int, ...], List[int]]): stride over each spatial axis
|
||||
dilations (Union[Tuple[int, ...], List[int]]): dilation over each spatial axis
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
auto_pad (str, optional): padding strategy.
|
||||
|
||||
Raises:
|
||||
ValueError: if arguments are not appropriate
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]: evaluation result or traced computation
|
||||
"""
|
||||
|
||||
assert_that(
|
||||
x.ndim == 5,
|
||||
f"expected input x to be of shape (N, C, D, H, W) when performing 3D convolution, but "
|
||||
f"got {x.shape}",
|
||||
)
|
||||
|
||||
assert_that(
|
||||
weight.ndim == 5,
|
||||
f"expected weight to be of shape (F, C, D, H, W) when performing 3D convolution, but "
|
||||
f"got {weight.shape}",
|
||||
)
|
||||
|
||||
if len(pads) != 6:
|
||||
message = (
|
||||
f"pads should be of form "
|
||||
f"(D_begin_pad, height_begin_pad, width_begin_pad, "
|
||||
f"D_end_pad, height_end_pad, width_end_pad) when performing "
|
||||
f"3D convolution, but it's {pads}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
if len(strides) != 3:
|
||||
message = (
|
||||
f"strides should be of form (D_stride, height_stride, width_stride) when performing "
|
||||
f"3D convolution, but it's {strides}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
if len(dilations) != 3:
|
||||
message = (
|
||||
f"dilations should be of form (D_dilation, height_dilation, width_dilation) when "
|
||||
f"performing 3D convolution, but it's {dilations}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
|
||||
|
||||
|
||||
def _trace_or_eval(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]],
|
||||
pads: Union[Tuple[int, ...], List[int]],
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Trace or evaluate convolution.
|
||||
|
||||
Args:
|
||||
x (Union[np.ndarray, Tracer]): input of shape (N, C, D1, ..., DN)
|
||||
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C / group, K1, ..., KN)
|
||||
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
|
||||
pads (Union[Tuple[int, ...], List[int]]):
|
||||
padding for the beginning and ending along each spatial axis
|
||||
(D1_begin, D2_begin, ..., D1_end, D2_end, ...).
|
||||
strides (Union[Tuple[int, ...], List[int]]): stride along each spatial axis.
|
||||
dilations (Union[Tuple[int, ...], List[int]]): dilation along each spatial axis.
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]: evaluation result or traced computation
|
||||
"""
|
||||
assert_that(x.ndim in [3, 4, 5], "only support 1D, 2D, and 3D conv")
|
||||
if x.ndim == 3:
|
||||
conv_func = "conv1d"
|
||||
elif x.ndim == 4:
|
||||
conv_func = "conv2d"
|
||||
else: # x.ndim == 5
|
||||
conv_func = "conv3d"
|
||||
|
||||
if isinstance(x, Tracer):
|
||||
return _trace_conv(x, weight, bias, pads, strides, dilations, group, conv_func)
|
||||
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert isinstance(weight, np.ndarray)
|
||||
|
||||
dtype = (
|
||||
np.float64
|
||||
if np.issubdtype(x.dtype, np.floating) or np.issubdtype(weight.dtype, np.floating)
|
||||
else np.int64
|
||||
)
|
||||
bias = np.zeros(weight.shape[0], dtype=dtype) if bias is None else bias
|
||||
|
||||
assert isinstance(bias, np.ndarray)
|
||||
|
||||
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, conv_func)
|
||||
|
||||
|
||||
def _trace_conv(
|
||||
x: Tracer,
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]],
|
||||
pads: Union[Tuple[int, ...], List[int]],
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
conv_func: str,
|
||||
) -> Tracer:
|
||||
"""
|
||||
Trace convolution.
|
||||
|
||||
Args:
|
||||
x (Tracer): input of shape (N, C, D1, ..., DN)
|
||||
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C / group, K1, ..., KN)
|
||||
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
|
||||
pads (Union[Tuple[int, int, int, int], List[int]]):
|
||||
padding for the beginning and ending along each spatial axis
|
||||
(D1_begin, D2_begin, ..., D1_end, D2_end, ...).
|
||||
strides (Union[Tuple[int, int], List[int]]): stride along each spatial axis.
|
||||
dilations (Union[Tuple[int, int], List[int]]): dilation along each spatial axis.
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
conv_func (str): convolution to apply, should be one of {conv1d,conv2d,conv3d}
|
||||
|
||||
Returns:
|
||||
Tracer:
|
||||
traced computation
|
||||
"""
|
||||
|
||||
conv_eval_funcs = {
|
||||
"conv1d": _evaluate_conv1d,
|
||||
"conv2d": _evaluate_conv2d,
|
||||
"conv3d": _evaluate_conv3d,
|
||||
}
|
||||
eval_func = conv_eval_funcs.get(conv_func, None)
|
||||
assert_that(
|
||||
eval_func is not None,
|
||||
f"expected conv_func to be one of {list(conv_eval_funcs.keys())}, but got {conv_func}",
|
||||
)
|
||||
eval_func = cast(Callable, eval_func)
|
||||
|
||||
weight = weight if isinstance(weight, Tracer) else Tracer(Node.constant(weight), [])
|
||||
|
||||
input_values = [x.output, weight.output]
|
||||
inputs = [x, weight]
|
||||
|
||||
if bias is not None:
|
||||
bias = bias if isinstance(bias, Tracer) else Tracer(Node.constant(bias), [])
|
||||
input_values.append(bias.output)
|
||||
inputs.append(bias)
|
||||
|
||||
batch_size = x.output.shape[0]
|
||||
n_filters = weight.output.shape[0]
|
||||
|
||||
n_dim = x.ndim - 2 # remove batch_size and channel dims
|
||||
total_pads_per_dim = []
|
||||
for dim in range(n_dim):
|
||||
total_pads_per_dim.append(pads[dim] + pads[n_dim + dim])
|
||||
|
||||
output_shape = [batch_size, n_filters]
|
||||
for dim in range(n_dim):
|
||||
input_dim_at_dim = x.output.shape[dim + 2]
|
||||
weight_dim_at_dim = weight.output.shape[dim + 2]
|
||||
output_shape.append(
|
||||
math.floor(
|
||||
(
|
||||
input_dim_at_dim
|
||||
+ total_pads_per_dim[dim]
|
||||
- dilations[dim] * (weight_dim_at_dim - 1)
|
||||
- 1
|
||||
)
|
||||
/ strides[dim]
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
output_value = EncryptedTensor(dtype=x.output.dtype, shape=tuple(output_shape))
|
||||
|
||||
computation = Node.generic(
|
||||
conv_func, # "conv1d" or "conv2d" or "conv3d"
|
||||
input_values,
|
||||
output_value,
|
||||
eval_func,
|
||||
args=() if bias is not None else (np.zeros(n_filters, dtype=np.int64),),
|
||||
kwargs={"pads": pads, "strides": strides, "dilations": dilations, "group": group},
|
||||
)
|
||||
return Tracer(computation, inputs)
|
||||
|
||||
|
||||
def _evaluate_conv1d(
|
||||
x: np.ndarray,
|
||||
weight: np.ndarray,
|
||||
bias: np.ndarray,
|
||||
pads: Union[Tuple[int, ...], List[int]],
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Evaluate 1D convolution.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): input of shape (N, C, D)
|
||||
weight (np.ndarray): kernel of shape (F, C / group, D)
|
||||
bias (np.ndarray): bias of shape (F,)
|
||||
pads (Union[Tuple[int, ...], List[int]]):
|
||||
padding over each axis (D_beg, D_end)
|
||||
strides (Union[Tuple[int, ...], List[int]]): stride over dimension D
|
||||
dilations (Union[Tuple[int, ...], List[int]]): dilation over dimension D
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
|
||||
Returns:
|
||||
np.ndarray: result of the convolution
|
||||
"""
|
||||
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, "conv1d")
|
||||
|
||||
|
||||
def _evaluate_conv2d(
|
||||
x: np.ndarray,
|
||||
weight: np.ndarray,
|
||||
bias: np.ndarray,
|
||||
pads: Union[Tuple[int, ...], List[int]],
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Evaluate 2D convolution.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): input of shape (N, C, H, W)
|
||||
weight (np.ndarray): kernel of shape (F, C / group, H, W)
|
||||
bias (np.ndarray): bias of shape (F,)
|
||||
pads (Union[Tuple[int, ...], List[int]]):
|
||||
padding over each axis (H_beg, W_beg, H_end, W_end)
|
||||
strides (Union[Tuple[int, ...], List[int]]): stride over height and width
|
||||
dilations (Union[Tuple[int, ...], List[int]]): dilation over height and width
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
|
||||
Returns:
|
||||
np.ndarray: result of the convolution
|
||||
"""
|
||||
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, "conv2d")
|
||||
|
||||
|
||||
def _evaluate_conv3d(
|
||||
x: np.ndarray,
|
||||
weight: np.ndarray,
|
||||
bias: np.ndarray,
|
||||
pads: Union[Tuple[int, ...], List[int]],
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Evaluate 3D convolution.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): input of shape (N, C, D, H, W)
|
||||
weight (np.ndarray): kernel of shape (F, C / group, D, H, W)
|
||||
bias (np.ndarray): bias of shape (F,)
|
||||
pads (Union[Tuple[int, ...], List[int]]):
|
||||
padding over each axis (D_beg, H_beg, W_beg, D_end, H_end, W_end)
|
||||
strides (Union[Tuple[int, ...], List[int]]): stride over D, height, and width
|
||||
dilations (Union[Tuple[int, ...], List[int]]): dilation over D, height, and width
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
|
||||
Returns:
|
||||
np.ndarray: result of the convolution
|
||||
"""
|
||||
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, "conv3d")
|
||||
|
||||
|
||||
def _evaluate_conv(
|
||||
x: np.ndarray,
|
||||
weight: np.ndarray,
|
||||
bias: np.ndarray,
|
||||
pads: Union[Tuple[int, ...], List[int]], # pylint: disable=unused-argument
|
||||
strides: Union[Tuple[int, ...], List[int]],
|
||||
dilations: Union[Tuple[int, ...], List[int]],
|
||||
group: int,
|
||||
conv_func: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Evaluate 2D convolution.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): input of shape (N, C, D1, ..., DN)
|
||||
weight (np.ndarray): kernel of shape (F, C / group, K1, ..., KN)
|
||||
bias (np.ndarray): bias of shape (F,)
|
||||
pads (Union[Tuple[int, ...], List[int]]):
|
||||
padding for the beginning and ending along each spatial axis
|
||||
(D1_begin, D2_begin, ..., D1_end, D2_end, ...).
|
||||
strides (Union[Tuple[int, ...], List[int]]): stride along each spatial axis.
|
||||
dilations (Union[Tuple[int, ...], List[int]]): dilation along each spatial axis.
|
||||
group (int, optional):
|
||||
number of groups input channels and output channels are divided into.
|
||||
conv_func (str): convolution to apply, should be one of {conv1d,conv2d,conv3d}
|
||||
|
||||
Returns:
|
||||
np.ndarray: result of the convolution
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member
|
||||
conv_funcs = {
|
||||
"conv1d": torch.conv1d,
|
||||
"conv2d": torch.conv2d,
|
||||
"conv3d": torch.conv3d,
|
||||
}
|
||||
|
||||
torch_conv_func = conv_funcs.get(conv_func, None)
|
||||
assert_that(
|
||||
torch_conv_func is not None,
|
||||
f"expected conv_func to be one of {list(conv_funcs.keys())}, but got {conv_func}",
|
||||
)
|
||||
torch_conv_func = cast(Callable, torch_conv_func)
|
||||
|
||||
n_dim = x.ndim - 2 # remove batch_size and channel dims
|
||||
torch_padding = []
|
||||
for dim in range(n_dim):
|
||||
if pads[dim] != pads[n_dim + dim]:
|
||||
message = (
|
||||
f"padding should be the same for the beginning of the dimension and its end, but "
|
||||
f"got {pads[dim]} in the beginning, and {pads[n_dim + dim]} at the end for "
|
||||
f"dimension {dim}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
torch_padding.append(pads[dim])
|
||||
|
||||
dtype = (
|
||||
torch.float64
|
||||
if np.issubdtype(x.dtype, np.floating)
|
||||
or np.issubdtype(weight.dtype, np.floating)
|
||||
or np.issubdtype(bias.dtype, np.floating)
|
||||
else torch.long
|
||||
)
|
||||
return torch_conv_func(
|
||||
torch.tensor(x, dtype=dtype),
|
||||
torch.tensor(weight, dtype=dtype),
|
||||
torch.tensor(bias, dtype=dtype),
|
||||
stride=strides,
|
||||
padding=torch_padding,
|
||||
dilation=dilations,
|
||||
groups=group,
|
||||
).numpy()
|
||||
|
||||
# pylint: enable=no-member
|
||||
336
frontends/concrete-python/concrete/onnx/maxpool.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Tracing and evaluation of maxpool function.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..numpy.internal.utils import assert_that
|
||||
from ..numpy.representation import Node
|
||||
from ..numpy.tracing import Tracer
|
||||
from ..numpy.values import Value
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
|
||||
|
||||
AVAILABLE_AUTO_PAD = {
|
||||
"NOTSET",
|
||||
"SAME_UPPER",
|
||||
"SAME_LOWER",
|
||||
"VALID",
|
||||
}
|
||||
|
||||
AVAILABLE_CEIL_MODE = {
|
||||
0,
|
||||
1,
|
||||
}
|
||||
|
||||
AVAILABLE_STORAGE_ORDER = {
|
||||
0,
|
||||
1,
|
||||
}
|
||||
|
||||
|
||||
SUPPORTED_AUTO_PAD = {
|
||||
"NOTSET",
|
||||
}
|
||||
|
||||
SUPPORTED_CEIL_MODE = {
|
||||
0,
|
||||
}
|
||||
|
||||
SUPPORTED_STORAGE_ORDER = {
|
||||
0,
|
||||
}
|
||||
|
||||
|
||||
# pylint: disable=no-member
|
||||
|
||||
_EVALUATORS = {
|
||||
1: torch.max_pool1d,
|
||||
2: torch.max_pool2d,
|
||||
3: torch.max_pool3d,
|
||||
}
|
||||
|
||||
# pylint: enable=no-member
|
||||
|
||||
|
||||
def maxpool(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
kernel_shape: Union[Tuple[int, ...], List[int]],
|
||||
strides: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
auto_pad: str = "NOTSET",
|
||||
pads: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
dilations: Optional[Union[Tuple[int, ...], List[int]]] = None,
|
||||
ceil_mode: int = 0,
|
||||
storage_order: int = 0,
|
||||
) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Evaluate or trace MaxPool operation.
|
||||
|
||||
Refer to https://github.com/onnx/onnx/blob/main/docs/Operators.md#maxpool for more info.
|
||||
|
||||
Args:
|
||||
x (Union[np.ndarray, Tracer]):
|
||||
input of shape (N, C, D1, ..., DN)
|
||||
|
||||
kernel_shape (Union[Tuple[int, ...], List[int]]):
|
||||
shape of the kernel
|
||||
|
||||
strides (Optional[Union[Tuple[int, ...], List[int]]]):
|
||||
stride along each spatial axis
|
||||
set to 1 along each spatial axis if not set
|
||||
|
||||
auto_pad (str, default = "NOTSET"):
|
||||
padding strategy
|
||||
|
||||
pads (Optional[Union[Tuple[int, ...], List[int]]]):
|
||||
padding for the beginning and ending along each spatial axis
|
||||
(D1_begin, D2_begin, ..., D1_end, D2_end, ...)
|
||||
set to 0 along each spatial axis if not set
|
||||
|
||||
dilations (Optional[Union[Tuple[int, ...], List[int]]]):
|
||||
dilation along each spatial axis
|
||||
set to 1 along each spatial axis if not set
|
||||
|
||||
ceil_mode (int, default = 1):
|
||||
ceiling mode
|
||||
|
||||
storage_order (int, default = 0):
|
||||
storage order, 0 for row major, 1 for column major
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
if arguments are inappropriately typed
|
||||
|
||||
ValueError:
|
||||
if arguments are inappropriate
|
||||
|
||||
NotImplementedError:
|
||||
if desired operation is not supported yet
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
maxpool over the input or traced computation
|
||||
"""
|
||||
|
||||
def check_value_is_a_tuple_or_list_of_ints_of_size(value_name, value, size) -> Tuple[int, ...]:
|
||||
if isinstance(value, list):
|
||||
value = tuple(value)
|
||||
|
||||
if not isinstance(value, tuple):
|
||||
message = (
|
||||
f"Expected {value_name} to be a tuple or a list but it's {type(value).__name__}"
|
||||
)
|
||||
raise TypeError(message)
|
||||
|
||||
for element in value:
|
||||
if not isinstance(element, int):
|
||||
message = (
|
||||
f"Expected {value_name} to consist of integers "
|
||||
f"but it has an element of type {type(element).__name__}"
|
||||
)
|
||||
raise TypeError(message)
|
||||
|
||||
if len(value) != size:
|
||||
message = f"Expected {value_name} to have {size} elements but it has {len(value)}"
|
||||
raise ValueError(message)
|
||||
|
||||
return value
|
||||
|
||||
# check x
|
||||
|
||||
if isinstance(x, list): # pragma: no cover
|
||||
try:
|
||||
x = np.array(x)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if isinstance(x, np.ndarray):
|
||||
if not (
|
||||
np.issubdtype(x.dtype, np.integer)
|
||||
or np.issubdtype(x.dtype, np.floating)
|
||||
or np.issubdtype(x.dtype, np.bool_)
|
||||
):
|
||||
message = (
|
||||
f"Expected input elements to be of type np.integer, np.floating, or np.bool_ "
|
||||
f"but it's {type(x.dtype).__name__}"
|
||||
)
|
||||
raise TypeError(message)
|
||||
elif not isinstance(x, Tracer):
|
||||
message = (
|
||||
f"Expected input to be of type np.ndarray or Tracer "
|
||||
f"but it's {type(auto_pad).__name__}"
|
||||
)
|
||||
raise TypeError(message)
|
||||
|
||||
if x.ndim < 3:
|
||||
message = (
|
||||
f"Expected input to have at least 3 dimensions (N, C, D1, ...) "
|
||||
f"but it only has {x.ndim}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if x.ndim > 5:
|
||||
message = f"{x.ndim - 2}D maximum pooling is not supported yet"
|
||||
raise NotImplementedError(message)
|
||||
|
||||
# check kernel_shape
|
||||
|
||||
kernel_shape = check_value_is_a_tuple_or_list_of_ints_of_size(
|
||||
"kernel_shape", kernel_shape, x.ndim - 2
|
||||
)
|
||||
|
||||
# check strides
|
||||
|
||||
if strides is None:
|
||||
strides = (1,) * (x.ndim - 2)
|
||||
|
||||
strides = check_value_is_a_tuple_or_list_of_ints_of_size("strides", strides, x.ndim - 2)
|
||||
|
||||
# check auto_pad
|
||||
|
||||
if not isinstance(auto_pad, str):
|
||||
message = f"Expected auto_pad to be of type str but it's {type(auto_pad).__name__}"
|
||||
raise TypeError(message)
|
||||
|
||||
if auto_pad not in AVAILABLE_AUTO_PAD:
|
||||
message = (
|
||||
f"Expected auto_pad to be one of "
|
||||
f"{', '.join(sorted(AVAILABLE_AUTO_PAD))} "
|
||||
f"but it's {auto_pad}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if auto_pad not in SUPPORTED_AUTO_PAD:
|
||||
message = f"Desired auto_pad of {auto_pad} is not supported yet"
|
||||
raise NotImplementedError(message)
|
||||
|
||||
# check pads
|
||||
|
||||
if pads is None:
|
||||
pads = (0,) * (2 * (x.ndim - 2))
|
||||
|
||||
pads = check_value_is_a_tuple_or_list_of_ints_of_size("pads", pads, 2 * (x.ndim - 2))
|
||||
|
||||
for i in range(len(pads) // 2):
|
||||
pad_begin = pads[i]
|
||||
pad_end = pads[i + len(pads) // 2]
|
||||
if pad_begin != pad_end:
|
||||
message = f"Desired pads of {pads} is not supported yet because of uneven padding"
|
||||
raise NotImplementedError(message)
|
||||
|
||||
# check dilations
|
||||
|
||||
if dilations is None:
|
||||
dilations = (1,) * (x.ndim - 2)
|
||||
|
||||
dilations = check_value_is_a_tuple_or_list_of_ints_of_size("dilations", dilations, x.ndim - 2)
|
||||
|
||||
# check ceil_mode
|
||||
|
||||
if not isinstance(ceil_mode, int):
|
||||
message = f"Expected ceil_mode to be of type int but it's {type(ceil_mode).__name__}"
|
||||
raise TypeError(message)
|
||||
|
||||
if ceil_mode not in AVAILABLE_CEIL_MODE:
|
||||
message = (
|
||||
f"Expected ceil_mode to be one of "
|
||||
f"{', '.join(sorted(str(x) for x in AVAILABLE_CEIL_MODE))} "
|
||||
f"but it's {ceil_mode}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if ceil_mode not in SUPPORTED_CEIL_MODE:
|
||||
message = f"Desired ceil_mode of {ceil_mode} is not supported yet"
|
||||
raise NotImplementedError(message)
|
||||
|
||||
# check storage_order
|
||||
|
||||
if not isinstance(storage_order, int):
|
||||
message = (
|
||||
f"Expected storage_order to be of type int but it's {type(storage_order).__name__}"
|
||||
)
|
||||
raise TypeError(message)
|
||||
|
||||
if storage_order not in AVAILABLE_STORAGE_ORDER:
|
||||
message = (
|
||||
f"Expected storage_order to be one of "
|
||||
f"{', '.join(sorted(str(x) for x in AVAILABLE_STORAGE_ORDER))} "
|
||||
f"but it's {storage_order}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if storage_order not in SUPPORTED_STORAGE_ORDER:
|
||||
message = f"Desired storage_order of {storage_order} is not supported yet"
|
||||
raise NotImplementedError(message)
|
||||
|
||||
# trace or evaluate
|
||||
return _trace_or_evaluate(x, kernel_shape, strides, pads, dilations, ceil_mode == 1)
|
||||
|
||||
|
||||
def _trace_or_evaluate(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
kernel_shape: Tuple[int, ...],
|
||||
strides: Tuple[int, ...],
|
||||
pads: Tuple[int, ...],
|
||||
dilations: Tuple[int, ...],
|
||||
ceil_mode: bool,
|
||||
):
|
||||
if not isinstance(x, Tracer):
|
||||
return _evaluate(x, kernel_shape, strides, pads, dilations, ceil_mode == 1)
|
||||
|
||||
result = _evaluate(np.zeros(x.shape), kernel_shape, strides, pads, dilations, ceil_mode == 1)
|
||||
resulting_value = Value.of(result)
|
||||
|
||||
resulting_value.is_encrypted = x.output.is_encrypted
|
||||
resulting_value.dtype = x.output.dtype
|
||||
|
||||
computation = Node.generic(
|
||||
"maxpool",
|
||||
[x.output],
|
||||
resulting_value,
|
||||
_evaluate,
|
||||
kwargs={
|
||||
"kernel_shape": kernel_shape,
|
||||
"strides": strides,
|
||||
"pads": pads,
|
||||
"dilations": dilations,
|
||||
"ceil_mode": ceil_mode,
|
||||
},
|
||||
)
|
||||
return Tracer(computation, [x])
|
||||
|
||||
|
||||
def _evaluate(
|
||||
x: np.ndarray,
|
||||
kernel_shape: Tuple[int, ...],
|
||||
strides: Tuple[int, ...],
|
||||
pads: Tuple[int, ...],
|
||||
dilations: Tuple[int, ...],
|
||||
ceil_mode: bool,
|
||||
) -> np.ndarray:
|
||||
# pylint: disable=no-member
|
||||
|
||||
dims = x.ndim - 2
|
||||
assert_that(dims in {1, 2, 3})
|
||||
|
||||
evaluator = _EVALUATORS[dims]
|
||||
result = (
|
||||
evaluator(
|
||||
torch.from_numpy(x.astype(np.float64)), # torch only supports float maxpools
|
||||
kernel_shape,
|
||||
strides,
|
||||
pads[: len(pads) // 2],
|
||||
dilations,
|
||||
ceil_mode,
|
||||
)
|
||||
.numpy()
|
||||
.astype(x.dtype)
|
||||
)
|
||||
|
||||
# pylint: enable=no-member
|
||||
|
||||
return result
|
||||
53
frontends/concrete-python/docker/Dockerfile.dev
Normal file
@@ -0,0 +1,53 @@
|
||||
FROM ubuntu:22.04
|
||||
|
||||
ENV TZ=Europe/Paris
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
|
||||
# Replace default archive.ubuntu.com with fr mirror
|
||||
# original archive showed performance issues and is farther away
|
||||
RUN sed -i 's|^deb http://archive|deb http://fr.archive|g' /etc/apt/sources.list
|
||||
|
||||
COPY ./script/make_utils/setup_os_deps.sh ./setup_os_deps.sh
|
||||
RUN ./setup_os_deps.sh --linux-install-python && rm ./setup_os_deps.sh
|
||||
|
||||
ENV SRC_DIR=/src
|
||||
|
||||
# Default to Ubuntu default uid for first user
|
||||
ARG BUILD_GID=1000
|
||||
ARG BUILD_UID=1000
|
||||
|
||||
# Get sudo for our future user
|
||||
RUN apt-get update && \
|
||||
apt-get install --no-install-recommends -y sudo && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# From https://dev.to/emmanuelnk/using-sudo-without-password-prompt-as-non-root-docker-user-52bg
|
||||
# Create dev_user and add it to relevant groups
|
||||
# Create /src and make the dev user own it
|
||||
# Ensure sudo group users are not asked for a password when using
|
||||
# sudo command by ammending sudoers file
|
||||
RUN groupadd -g "${BUILD_GID}" dev_user && \
|
||||
adduser --disabled-password \
|
||||
--uid "${BUILD_UID}" --gid "${BUILD_GID}" --shell /bin/bash --gecos "" dev_user && \
|
||||
usermod -aG sudo dev_user && \
|
||||
mkdir -p "${SRC_DIR}" && \
|
||||
chown dev_user "${SRC_DIR}" && \
|
||||
echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers
|
||||
|
||||
# Now switch to the newly created user
|
||||
USER dev_user
|
||||
|
||||
RUN echo "source ~/dev_venv/bin/activate" >> ~/.bashrc && \
|
||||
echo "if [[ \"\$?\" != \"0\" ]]; then" >> ~/.bashrc && \
|
||||
echo " python3 -m venv ~/dev_venv" >> ~/.bashrc && \
|
||||
echo " source ~/dev_venv/bin/activate" >> ~/.bashrc && \
|
||||
echo " cd ${SRC_DIR}/ && make setup_env" >> ~/.bashrc && \
|
||||
echo "fi" >> ~/.bashrc && \
|
||||
echo "export MPLBACKEND=TkAgg" >> ~/.bashrc && \
|
||||
touch ~/.sudo_as_admin_successful && \
|
||||
mkdir -p ~/dev_venv && \
|
||||
mkdir -p ~/.cache
|
||||
|
||||
WORKDIR ${SRC_DIR}
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -0,0 +1 @@
|
||||
!script/make_utils/setup_os_deps.sh
|
||||
36
frontends/concrete-python/docker/Dockerfile.release
Normal file
@@ -0,0 +1,36 @@
|
||||
FROM ubuntu:22.04
|
||||
|
||||
ENV TZ=Europe/Paris
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
|
||||
# Replace default archive.ubuntu.com with fr mirror
|
||||
# original archive showed performance issues and is farther away
|
||||
RUN sed -i 's|^deb http://archive|deb http://fr.archive|g' /etc/apt/sources.list
|
||||
|
||||
RUN mkdir /pkg && mkdir /app
|
||||
WORKDIR /pkg
|
||||
COPY docker/release_resources/release_requirements.txt .
|
||||
COPY ./pkg/*.whl .
|
||||
|
||||
RUN apt-get update && apt-get upgrade --no-install-recommends -y && \
|
||||
apt-get install --no-install-recommends -y \
|
||||
build-essential \
|
||||
python3-pip \
|
||||
python3 \
|
||||
python3-dev \
|
||||
python3-tk \
|
||||
python-is-python3 && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
python3 -m pip install --no-cache-dir --upgrade pip wheel setuptools && \
|
||||
echo "export MPLBACKEND=TkAgg" >> /root/.bashrc && \
|
||||
python3 -m pip install --no-cache-dir "$(ls ./*.whl)" && \
|
||||
python3 -m pip install --no-cache-dir -r release_requirements.txt
|
||||
|
||||
WORKDIR /app
|
||||
COPY docker/release_resources/entry_point.sh ./entry_point.sh
|
||||
RUN mkdir /data
|
||||
|
||||
WORKDIR /data
|
||||
VOLUME [ "/data" ]
|
||||
|
||||
CMD ["/bin/bash", "-i", "/app/entry_point.sh"]
|
||||
@@ -0,0 +1,6 @@
|
||||
# Not our sources
|
||||
!docker/release_resources/entry_point.sh
|
||||
!docker/release_resources/release_requirements.txt
|
||||
|
||||
!pkg/
|
||||
!pkg/**
|
||||
5
frontends/concrete-python/docker/build_release_image.sh
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CURR_DIR=$(dirname "$0")
|
||||
DOCKER_BUILDKIT=1 docker build --pull --no-cache -f "$CURR_DIR/Dockerfile.release" \
|
||||
-t concrete-numpy-release "$CURR_DIR/.."
|
||||
3
frontends/concrete-python/docker/release_resources/entry_point.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
python3 -m jupyter notebook --ip=0.0.0.0 --allow-root --no-browser
|
||||
@@ -0,0 +1 @@
|
||||
jupyter~=1.0.0
|
||||
@@ -0,0 +1,41 @@
|
||||
import random
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
def main():
|
||||
def function_to_compile(x):
|
||||
return x + 42
|
||||
|
||||
n_bits = 3
|
||||
|
||||
compiler = cnp.Compiler(
|
||||
function_to_compile,
|
||||
{"x": "encrypted"},
|
||||
)
|
||||
|
||||
print("Compiling...")
|
||||
|
||||
engine = compiler.compile(range(2 ** n_bits))
|
||||
|
||||
inputs = []
|
||||
labels = []
|
||||
for _ in range(4):
|
||||
sample_x = random.randint(0, 2 ** n_bits - 1)
|
||||
|
||||
inputs.append([sample_x])
|
||||
labels.append(function_to_compile(*inputs[-1]))
|
||||
|
||||
correct = 0
|
||||
for idx, (input_i, label_i) in enumerate(zip(inputs, labels), 1):
|
||||
print(f"Inference #{idx}")
|
||||
result_i = engine.encrypt_run_decrypt(*input_i)
|
||||
|
||||
if result_i == label_i:
|
||||
correct += 1
|
||||
|
||||
print(f"{correct}/{len(inputs)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
24
frontends/concrete-python/docs/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# What is Concrete Numpy?
|
||||
|
||||
[<mark style="background-color:yellow;">⭐️ Star the repo on Github</mark>](https://github.com/zama-ai/concrete-numpy) <mark style="background-color:yellow;">| 🗣</mark> [<mark style="background-color:yellow;">Community support forum</mark>](https://community.zama.ai/c/concrete-numpy) <mark style="background-color:yellow;">| 📁</mark> [<mark style="background-color:yellow;">Contribute to the project</mark>](dev/contributing.md)
|
||||
|
||||
<figure><img src="_static/zama_home_docs.png" alt=""><figcaption></figcaption></figure>
|
||||
|
||||
**Concrete-Numpy** is an open-source library which simplifies the use of fully homomorphic encryption (FHE).
|
||||
|
||||
FHE is a powerful cryptographic tool, which allows computation to be performed directly on encrypted data without needing to decrypt it first. With FHE, you can build services that preserve privacy for all users. FHE is also great against data breaches as everything is done on encrypted data. Even if the server is compromised, in the end no sensitive data is leaked.
|
||||
|
||||
## Organization of this documentation
|
||||
|
||||
This documentation is split into several sections:
|
||||
|
||||
* **Getting Started** gives you the basics,
|
||||
* **Tutorials** gives you some essential examples on various features of the library,
|
||||
* **How to** helps you perform specific tasks,
|
||||
* and **Developer** explains the inner workings of the library and everything related to contributing to the project.
|
||||
|
||||
## Looking for support? Ask our team!
|
||||
|
||||
* Support forum: [https://community.zama.ai](https://community.zama.ai) (we answer in less than 24 hours).
|
||||
* Live discussion on the FHE.org discord server: [https://discord.fhe.org](https://discord.fhe.org) (inside the #**concrete** channel).
|
||||
* Do you have a question about Zama? You can write us on [Twitter](https://twitter.com/zama\_fhe) or send us an email at: **hello@zama.ai**
|
||||
40
frontends/concrete-python/docs/SUMMARY.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Table of contents
|
||||
|
||||
* [What is Concrete Numpy?](README.md)
|
||||
|
||||
## Getting Started
|
||||
|
||||
* [Installation](getting-started/installing.md)
|
||||
* [Quick Start](getting-started/quick\_start.md)
|
||||
* [Compatibility](getting-started/compatibility.md)
|
||||
* [Exactness](getting-started/exactness.md)
|
||||
* [Performance](getting-started/performance.md)
|
||||
|
||||
## Tutorials
|
||||
|
||||
* [Decorator](tutorial/decorator.md)
|
||||
* [Formatting](tutorial/formatting.md)
|
||||
* [Tagging](tutorial/tagging.md)
|
||||
* [Extensions](tutorial/extensions.md)
|
||||
* [Table Lookups](tutorial/table\_lookups.md)
|
||||
* [Rounded Table Lookups](tutorial/rounded\_table\_lookups.md)
|
||||
* [Floating Points](tutorial/floating\_points.md)
|
||||
* [Simulation](tutorial/simulation.md)
|
||||
* [Direct Circuits](tutorial/direct\_circuits.md)
|
||||
* [Key Value Database](tutorial/key\_value\_database.md)
|
||||
|
||||
## How To
|
||||
|
||||
* [Configure](howto/configure.md)
|
||||
* [Debug](howto/debug.md)
|
||||
* [Deploy](howto/deploy.md)
|
||||
|
||||
## Developer
|
||||
|
||||
* [Project Setup](dev/project\_setup.md)
|
||||
* [Docker Setup](dev/docker.md)
|
||||
* [Contribute](dev/contributing.md)
|
||||
* [Terminology and Structure](dev/terminology\_and\_structure.md)
|
||||
* [Compilation](dev/compilation.md)
|
||||
* [Fusing](dev/fusing.md)
|
||||
* [MLIR](dev/mlir.md)
|
||||
BIN
frontends/concrete-python/docs/_static/basics/compiling_and_executing_example_graph.png
vendored
Normal file
|
After Width: | Height: | Size: 42 KiB |
BIN
frontends/concrete-python/docs/_static/compilation-pipeline/two_x_plus_three.png
vendored
Normal file
|
After Width: | Height: | Size: 22 KiB |
BIN
frontends/concrete-python/docs/_static/float_fusing_example/after.png
vendored
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
frontends/concrete-python/docs/_static/float_fusing_example/after_bigger_search.png
vendored
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
frontends/concrete-python/docs/_static/float_fusing_example/before.png
vendored
Normal file
|
After Width: | Height: | Size: 59 KiB |
BIN
frontends/concrete-python/docs/_static/float_fusing_example/before_bigger_search.png
vendored
Normal file
|
After Width: | Height: | Size: 28 KiB |
BIN
frontends/concrete-python/docs/_static/float_fusing_example/subgraph.png
vendored
Normal file
|
After Width: | Height: | Size: 61 KiB |
BIN
frontends/concrete-python/docs/_static/float_fusing_example/subgraph_bigger_search.png
vendored
Normal file
|
After Width: | Height: | Size: 38 KiB |
BIN
frontends/concrete-python/docs/_static/mlir/MLIR_conversion.png
vendored
Normal file
|
After Width: | Height: | Size: 22 KiB |
BIN
frontends/concrete-python/docs/_static/p_error_simulation.pdf
vendored
Normal file
BIN
frontends/concrete-python/docs/_static/rounded-tlu/10-bits-removed.png
vendored
Normal file
|
After Width: | Height: | Size: 32 KiB |
BIN
frontends/concrete-python/docs/_static/rounded-tlu/12-bits-removed.png
vendored
Normal file
|
After Width: | Height: | Size: 27 KiB |
BIN
frontends/concrete-python/docs/_static/rounded-tlu/4-bits-kept.png
vendored
Normal file
|
After Width: | Height: | Size: 23 KiB |
BIN
frontends/concrete-python/docs/_static/rounded-tlu/6-bits-kept.png
vendored
Normal file
|
After Width: | Height: | Size: 27 KiB |
BIN
frontends/concrete-python/docs/_static/rounded-tlu/relu.png
vendored
Normal file
|
After Width: | Height: | Size: 39 KiB |
BIN
frontends/concrete-python/docs/_static/tutorials/artifacts/auto/1.initial.graph.png
vendored
Normal file
|
After Width: | Height: | Size: 8.4 KiB |
BIN
frontends/concrete-python/docs/_static/tutorials/artifacts/auto/2.final.graph.png
vendored
Normal file
|
After Width: | Height: | Size: 8.4 KiB |
BIN
frontends/concrete-python/docs/_static/tutorials/artifacts/manual/1.initial.graph.png
vendored
Normal file
|
After Width: | Height: | Size: 42 KiB |
BIN
frontends/concrete-python/docs/_static/tutorials/artifacts/manual/2.after-fusing.graph.png
vendored
Normal file
|
After Width: | Height: | Size: 21 KiB |
BIN
frontends/concrete-python/docs/_static/tutorials/artifacts/manual/3.final.graph.png
vendored
Normal file
|
After Width: | Height: | Size: 21 KiB |
BIN
frontends/concrete-python/docs/_static/tutorials/table-lookup/1.initial.graph.png
vendored
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
frontends/concrete-python/docs/_static/tutorials/table-lookup/3.final.graph.png
vendored
Normal file
|
After Width: | Height: | Size: 22 KiB |
BIN
frontends/concrete-python/docs/_static/zama_home_docs.png
vendored
Normal file
|
After Width: | Height: | Size: 377 KiB |
1
frontends/concrete-python/docs/conftest.py
Symbolic link
@@ -0,0 +1 @@
|
||||
../tests/conftest.py
|
||||
148
frontends/concrete-python/docs/dev/compilation.md
Normal file
@@ -0,0 +1,148 @@
|
||||
# Compilation
|
||||
|
||||
The compilation journey begins with tracing to get an easy-to-manipulate representation of the function. We call this representation a `Computation Graph`, which is basically a Directed Acyclic Graph (DAG) containing nodes representing the computations done in the function. Working with graphs is good because they have been studied extensively over the years and there are a lot of algorithms to manipulate them. Internally, we use [networkx](https://networkx.org), which is an excellent graph library for Python.
|
||||
|
||||
The next step in the compilation is transforming the computation graph. There are many transformations we perform, and they will be discussed in their own sections. In any case, the result of transformations is just another computation graph.
|
||||
|
||||
After transformations are applied, we need to determine the bounds (i.e., the minimum and the maximum values) of each intermediate node. This is required because FHE currently allows a limited precision for computations. Bound measurement is our way to know what is the required precision for the function.
|
||||
|
||||
The final step is to transform the computation graph to equivalent `MLIR` code. How this is done will be explained in detail in its own chapter.
|
||||
|
||||
Once the MLIR is generated, we send it to the **Concrete-Compiler**, and it completes the compilation process.
|
||||
|
||||
## Tracing
|
||||
|
||||
Given a Python function `f` such as this one:
|
||||
|
||||
```
|
||||
def f(x):
|
||||
return (2 * x) + 3
|
||||
```
|
||||
|
||||
...the goal of tracing is to create the following computation graph without needing any change from the user.
|
||||
|
||||

|
||||
|
||||
(Note that the edge labels are for non-commutative operations. To give an example, a subtraction node represents `(predecessor with edge label 0) - (predecessor with edge label 1)`)
|
||||
|
||||
To do this, we make use of `Tracer`s, which are objects that record the operation performed during their creation. We create a `Tracer` for each argument of the function and call the function with those tracers. `Tracer`s make use of the operator overloading feature of Python to achieve their goal:
|
||||
|
||||
```
|
||||
def f(x, y):
|
||||
return x + 2 * y
|
||||
|
||||
x = Tracer(computation=Input("x"))
|
||||
y = Tracer(computation=Input("y"))
|
||||
|
||||
resulting_tracer = f(x, y)
|
||||
```
|
||||
|
||||
`2 * y` will be performed first, and `*` is overloaded for `Tracer` to return another tracer: `Tracer(computation=Multiply(Constant(2), self.computation))`, which is equal to `Tracer(computation=Multiply(Constant(2), Input("y")))`
|
||||
|
||||
`x + (2 * y)` will be performed next, and `+` is overloaded for `Tracer` to return another tracer: `Tracer(computation=Add(self.computation, (2 * y).computation))`, which is equal to `Tracer(computation=Add(Input("x"), Multiply(Constant(2), Input("y")))`
|
||||
|
||||
In the end, we will have output tracers that can be used to create the computation graph. The implementation is a bit more complex than this, but the idea is the same.
|
||||
|
||||
Tracing is also responsible for indicating whether the values in the node would be encrypted or not, and the rule for that is if a node has an encrypted predecessor, it is encrypted as well.
|
||||
|
||||
## Topological transforms
|
||||
|
||||
The goal of topological transforms is to make more functions compilable.
|
||||
|
||||
With the current version of **Concrete-Numpy**, floating-point inputs and floating-point outputs are not supported. However, if the floating-point operations are intermediate operations, they can sometimes be fused into a single table lookup from integer to integer, thanks to some specific transforms.
|
||||
|
||||
Let's take a closer look at the transforms we can currently perform.
|
||||
|
||||
### Fusing.
|
||||
|
||||
We have allocated a whole new chapter to explaining fusing. You can find it after this chapter.
|
||||
|
||||
## Bounds measurement
|
||||
|
||||
Given a computation graph, the goal of the bound measurement step is to assign the minimal data type to each node in the graph.
|
||||
|
||||
Let's say we have an encrypted input that is always between `0` and `10`. We should assign the type `Encrypted<uint4>` to the node of this input as `Encrypted<uint4>` is the minimal encrypted integer that supports all values between `0` and `10`.
|
||||
|
||||
If there were negative values in the range, we could have used `intX` instead of `uintX`.
|
||||
|
||||
Bounds measurement is necessary because FHE supports limited precision, and we don't want unexpected behaviour while evaluating the compiled functions.
|
||||
|
||||
Let's take a closer look at how we perform bounds measurement.
|
||||
|
||||
### Inputset evaluation.
|
||||
|
||||
This is a simple approach that requires an inputset to be provided by the user.
|
||||
|
||||
The inputset is not to be confused with the dataset, which is classical in ML, as it doesn't require labels. Rather, it is a set of values which are typical inputs of the function.
|
||||
|
||||
The idea is to evaluate each input in the inputset and record the result of each operation in the computation graph. Then we compare the evaluation results with the current minimum/maximum values of each node and update the minimum/maximum accordingly. After the entire inputset is evaluated, we assign a data type to each node using the minimum and the maximum values it contains.
|
||||
|
||||
Here is an example, given this computation graph where `x` is encrypted:
|
||||
|
||||

|
||||
|
||||
and this inputset:
|
||||
|
||||
```
|
||||
[2, 3, 1]
|
||||
```
|
||||
|
||||
Evaluation Result of `2`:
|
||||
|
||||
* `x`: 2
|
||||
* `2`: 2
|
||||
* `*`: 4
|
||||
* `3`: 3
|
||||
* `+`: 7
|
||||
|
||||
New Bounds:
|
||||
|
||||
* `x`: \[**2**, **2**]
|
||||
* `2`: \[**2**, **2**]
|
||||
* `*`: \[**4**, **4**]
|
||||
* `3`: \[**3**, **3**]
|
||||
* `+`: \[**7**, **7**]
|
||||
|
||||
Evaluation Result of `3`:
|
||||
|
||||
* `x`: 3
|
||||
* `2`: 2
|
||||
* `*`: 6
|
||||
* `3`: 3
|
||||
* `+`: 9
|
||||
|
||||
New Bounds:
|
||||
|
||||
* `x`: \[2, **3**]
|
||||
* `2`: \[2, 2]
|
||||
* `*`: \[4, **6**]
|
||||
* `3`: \[3, 3]
|
||||
* `+`: \[7, **9**]
|
||||
|
||||
Evaluation Result of `1`:
|
||||
|
||||
* `x`: 1
|
||||
* `2`: 2
|
||||
* `*`: 2
|
||||
* `3`: 3
|
||||
* `+`: 5
|
||||
|
||||
New Bounds:
|
||||
|
||||
* `x`: \[**1**, 3]
|
||||
* `2`: \[2, 2]
|
||||
* `*`: \[**2**, 6]
|
||||
* `3`: \[3, 3]
|
||||
* `+`: \[**5**, 9]
|
||||
|
||||
Assigned Data Types:
|
||||
|
||||
* `x`: Encrypted<**uint2**>
|
||||
* `2`: Clear<**uint2**>
|
||||
* `*`: Encrypted<**uint3**>
|
||||
* `3`: Clear<**uint2**>
|
||||
* `+`: Encrypted<**uint4**>
|
||||
|
||||
## MLIR conversion
|
||||
|
||||
The actual compilation will be done by the **Concrete-Compiler**, which is expecting an MLIR input. The MLIR conversion goes from a computation graph to its MLIR equivalent. You can read more about it [here](mlir.md).
|
||||
99
frontends/concrete-python/docs/dev/contributing.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# Contribute
|
||||
|
||||
{% hint style="info" %}
|
||||
There are two ways to contribute to **Concrete-Numpy** or to **Concrete** tools in general:
|
||||
|
||||
* You can open issues to report bugs and typos and to suggest ideas.
|
||||
* You can ask to become an official contributor by emailing hello@zama.ai. Only approved contributors can send pull requests (PRs), so please make sure to get in touch before you do!
|
||||
{% endhint %}
|
||||
|
||||
Now, let's go over some other important items that you need to know.
|
||||
|
||||
## Creating a new branch
|
||||
|
||||
We are using a consistent branch naming scheme, and you are expected to follow it as well. Here is the format:
|
||||
|
||||
```shell
|
||||
git checkout -b {feat|fix|refactor|test|benchmark|doc|style|chore}/short-description
|
||||
```
|
||||
|
||||
...and here are some examples:
|
||||
|
||||
```shell
|
||||
git checkout -b feat/direct-tlu
|
||||
git checkout -b fix/tracing-indexing
|
||||
```
|
||||
|
||||
## Before committing
|
||||
|
||||
### Conformance.
|
||||
|
||||
Each commit to **Concrete-Numpy** should conform to the standards decided by the team. Conformance can be checked using the following command:
|
||||
|
||||
```shell
|
||||
make pcc
|
||||
```
|
||||
|
||||
### Testing.
|
||||
|
||||
On top of conformance, all tests must pass with 100% code coverage across the codebase:
|
||||
|
||||
```shell
|
||||
make pytest
|
||||
```
|
||||
|
||||
{% hint style="info" %}
|
||||
There may be cases where covering 100% of the code is not possible (e.g., exceptions that cannot be triggered in normal execution circumstances). In those cases, you may be allowed to disable coverage for some specific lines. This should be the exception rather than the rule. Reviewers may ask why some lines are not covered and, if it appears they can be covered, then the PR won't be accepted in that state.
|
||||
{% endhint %}
|
||||
|
||||
## Committing
|
||||
|
||||
We are using a consistent commit naming scheme, and you are expected to follow it as well. Again, here is the accepted format:
|
||||
|
||||
```shell
|
||||
make show_scope
|
||||
```
|
||||
|
||||
...and some examples:
|
||||
|
||||
```shell
|
||||
git commit -m "feat: implement bounds checking"
|
||||
git commit -m "feat(debugging): add an helper function to print intermediate representation"
|
||||
git commit -m "fix(tracing): fix a bug that crashed pytorch tracer"
|
||||
```
|
||||
|
||||
To learn more about conventional commits, check [this](https://www.conventionalcommits.org/en/v1.0.0/) page.
|
||||
|
||||
## Before creating a pull request
|
||||
|
||||
{% hint style="info" %}
|
||||
We remind you that only official contributors can send pull requests. To become an official contributor, please email hello@zama.ai.
|
||||
{% endhint %}
|
||||
|
||||
You should rebase on top of the `main` branch before you create your pull request. We don't allow merge commits, so rebasing on `main` before pushing gives you the best chance of avoiding rewriting parts of your PR later if conflicts arise with other PRs being merged. After you commit your changes to your new branch, you can use the following commands to rebase:
|
||||
|
||||
```shell
|
||||
# fetch the list of active remote branches
|
||||
git fetch --all --prune
|
||||
|
||||
# checkout to main
|
||||
git checkout main
|
||||
|
||||
# pull the latest changes to main (--ff-only is there to prevent accidental commits to main)
|
||||
git pull --ff-only
|
||||
|
||||
# checkout back to your branch
|
||||
git checkout $YOUR_BRANCH
|
||||
|
||||
# rebase on top of main branch
|
||||
git rebase main
|
||||
|
||||
# If there are conflicts during the rebase, resolve them
|
||||
# and continue the rebase with the following command
|
||||
git rebase --continue
|
||||
|
||||
# push the latest version of the local branch to remote
|
||||
git push --force
|
||||
```
|
||||
|
||||
You can learn more about rebasing [here](https://git-scm.com/docs/git-rebase).
|
||||
45
frontends/concrete-python/docs/dev/docker.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Docker Setup
|
||||
|
||||
## Installation
|
||||
|
||||
Before you start this section, go ahead and install Docker. You can follow [this](https://docs.docker.com/engine/install/) official guide if you need help.
|
||||
|
||||
## X forwarding
|
||||
|
||||
### Linux.
|
||||
|
||||
You can use this xhost command:
|
||||
|
||||
```shell
|
||||
xhost +localhost
|
||||
```
|
||||
|
||||
### macOS.
|
||||
|
||||
To use X forwarding on macOS:
|
||||
|
||||
* Install XQuartz
|
||||
* Open XQuartz.app application, make sure in the application parameters that `authorize network connections` are set (currently in the Security settings)
|
||||
* Open a new terminal within XQuartz.app and type:
|
||||
|
||||
```shell
|
||||
xhost +127.0.0.1
|
||||
```
|
||||
|
||||
X server should be all set for Docker in the regular terminal.
|
||||
|
||||
## Building
|
||||
|
||||
You can use the dedicated target in the makefile to build the docker image:
|
||||
|
||||
```shell
|
||||
make docker_build
|
||||
```
|
||||
|
||||
## Starting
|
||||
|
||||
You can use the dedicated target in the makefile to start the docker session:
|
||||
|
||||
```shell
|
||||
make docker_start
|
||||
```
|
||||
37
frontends/concrete-python/docs/dev/fusing.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Fusing
|
||||
|
||||
Fusing is the act of combining multiple nodes into a single node, which is converted to a table lookup.
|
||||
|
||||
## How is it done?
|
||||
|
||||
Code related to fusing is in the `concrete/numpy/compilation/utils.py` file. Fusing can be performed using the `fuse` function.
|
||||
|
||||
Within `fuse`:
|
||||
|
||||
1. We loop until there are no more subgraphs to fuse.
|
||||
2. <mark style="background-color:yellow;">Within each iteration:</mark>
|
||||
2.1. We find a subgraph to fuse.
|
||||
|
||||
2.2. We search for a terminal node that is appropriate for fusing.
|
||||
|
||||
2.3. We crawl backwards to find the closest integer nodes to this node.
|
||||
|
||||
2.4. If there is a single node as such, we return the subgraph from this node to the terminal node.
|
||||
|
||||
2.5. Otherwise, we try to find the lowest common ancestor (lca) of this list of nodes.
|
||||
|
||||
2.6. If an lca doesn't exist, we say this particular terminal node is not fusable, and we go back to search for another subgraph.
|
||||
|
||||
2.7. Otherwise, we use this lca as the input of the subgraph and continue with `subgraph` node creation below.
|
||||
|
||||
2.8. We convert the subgraph into a `subgraph` node by checking fusability status of the nodes of the subgraph in this step.
|
||||
|
||||
2.10. We substitute the `subgraph` node to the original graph.
|
||||
|
||||
## Limitations
|
||||
|
||||
With the current implementation, we cannot fuse subgraphs that depend on multiple encrypted values where those values doesn't have a common lca (e.g., `np.round(np.sin(x) + np.cos(y))`).
|
||||
|
||||
{% hint style="info" %}
|
||||
[Kolmogorov–Arnold representation theorem](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Arnold\_representation\_theorem) states that every multivariate continuous function can be represented as a superposition of continuous functions of one variable. Therefore, the case above could be handled in future versions of **Concrete-Numpy**.
|
||||
{% endhint %}
|
||||
17
frontends/concrete-python/docs/dev/mlir.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# MLIR
|
||||
|
||||
The MLIR project is a sub-project of the LLVM project. It's designed to simplify building domain-specific compilers such as our **Concrete-Compiler**.
|
||||
|
||||
**Concrete-Compiler** accepts MLIR as an input and emits compiled assembly code for a target architecture.
|
||||
|
||||
**Concrete-Numpy** performs the MLIR generation from the computation graph. Code related to this conversion is in the `concrete/numpy/mlir` folder.
|
||||
|
||||
The conversion can be performed using the `convert` method of the `GraphConverter` class.
|
||||
|
||||
Within the `convert` method of `GraphConverter`:
|
||||
|
||||
* MLIR compatibility of the graph is checked;
|
||||
* bit width constraints are checked;
|
||||
* negative lookup tables are offset;
|
||||
* the computation graph is traversed and each node is converted to their corresponding MLIR representation using the `NodeConverter` class;
|
||||
* and string representation of the resulting MLIR is returned.
|
||||
91
frontends/concrete-python/docs/dev/project_setup.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# Project Setup
|
||||
|
||||
{% hint style="info" %}
|
||||
It is **strongly** recommended to use the development tool Docker. However, you are able to set the project up on a bare Linux or macOS as long as you have the required dependencies. You can see the required dependencies in `Dockerfile.dev` under `docker` directory.
|
||||
{% endhint %}
|
||||
|
||||
## Installing `Python`
|
||||
|
||||
**Concrete-Numpy** is a `Python` library, so `Python` should be installed to develop it. `v3.8` and `v3.9` are, currently, the only supported versions.
|
||||
|
||||
You probably have Python already, but in case you don't, or in case you have an unsupported version, you can google `how to install python 3.8` and follow one of the results.
|
||||
|
||||
## Installing `Poetry`
|
||||
|
||||
`Poetry` is our package manager. It drastically simplifies dependency and environment management.
|
||||
|
||||
You can follow [this](https://python-poetry.org/docs/#installation) official guide to install it.
|
||||
|
||||
## Installing `make`
|
||||
|
||||
`make` is used to launch various commands such as formatting and testing.
|
||||
|
||||
On Linux, you can install `make` using the package manager of your distribution.
|
||||
|
||||
On macOS, you can install `gmake` via brew:
|
||||
|
||||
```shell
|
||||
brew install make
|
||||
```
|
||||
|
||||
{% hint style="info" %}
|
||||
In the following sections, be sure to use the proper `make` tool for your system (i.e., `make`, `gmake`, etc).
|
||||
{% endhint %}
|
||||
|
||||
## Cloning the repository
|
||||
|
||||
Now, it's time to get the source code of **Concrete-Numpy**.
|
||||
|
||||
Clone the git repository from GitHub using the protocol of your choice (ssh or https).
|
||||
|
||||
## Setting up the environment
|
||||
|
||||
Virtual environments are utilized to keep the project isolated from other `Python` projects in the system.
|
||||
|
||||
To create a new virtual environment and install dependencies, use the command:
|
||||
|
||||
```shell
|
||||
make setup_env
|
||||
```
|
||||
|
||||
## Activating the environment
|
||||
|
||||
To activate the newly created environment, use:
|
||||
|
||||
```shell
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
## Syncing the environment
|
||||
|
||||
From time to time, new dependencies will be added to the project and old ones will be removed.mThe command below will make sure the project has the proper environment, so run it regularly.
|
||||
|
||||
```shell
|
||||
make sync_env
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### In native setups.
|
||||
|
||||
If you are having issues in a native setup, you can try to re-create your environment like this:
|
||||
|
||||
```shell
|
||||
deactivate
|
||||
rm -rf .venv
|
||||
make setup_env
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
If the problem persists, you should consider using Docker. If you are working on a platform specific feature and Docker is not an option, you should create an issue so that we can take a look at your problem.
|
||||
|
||||
### In docker setups.
|
||||
|
||||
If you are having issues in a docker setup, you can try to re-build the docker image:
|
||||
|
||||
```shell
|
||||
make docker_rebuild
|
||||
make docker_start
|
||||
```
|
||||
|
||||
If the problem persists, you should contact us for help.
|
||||
17
frontends/concrete-python/docs/dev/releasing.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Release process
|
||||
|
||||
## Release candidate cycle
|
||||
|
||||
Throughout the quarter, many release candidatess are relesed. Those candidates are released in a private package repository. At the end of the quarter, we take the latest release candidate, and release it in PyPI without `rcX` tag.
|
||||
|
||||
## Release flow
|
||||
|
||||
* Checkout to the commit that you want to include in the release (everything before this commit and this commit will be in the release)
|
||||
* Run `make release`
|
||||
* Wait for CI to complete
|
||||
* Checkout to `chore/version` branch
|
||||
* Run `VERSION=a.b.c-rcX make set_version` with appropriate version
|
||||
* Push the branch to origin
|
||||
* Create a PR to merge it to main
|
||||
* Wait for CI to finish and get approval in the meantime
|
||||
* Merge the version update to main
|
||||
@@ -0,0 +1,26 @@
|
||||
# Terminology and Structure
|
||||
|
||||
## Terminology
|
||||
|
||||
Some terms used throughout the project include:
|
||||
|
||||
* computation graph - a data structure to represent a computation. This is basically a directed acyclic graph in which nodes are either inputs, constants or operations on other nodes.
|
||||
* tracing - the technique that takes a Python function from the user and generates the corresponding computation graph in an easy-to-read format.
|
||||
* bounds - before a computation graph is converted to MLIR, we need to know which node will output which type (e.g., uint3 vs euint5). Computation graphs with different inputs must remember the minimum and maximum values for each node, which is what we call bounds, and use bounds to determine the appropriate type for each node.
|
||||
* circuit - the result of compilation. A circuit is made of the client and server components and has methods, everything from printing to evaluation.
|
||||
|
||||
## Module structure
|
||||
|
||||
In this section, we will briefly discuss the module structure of **Concrete-Numpy**. You are encouraged to check individual `.py` files to learn more.
|
||||
|
||||
* Concrete
|
||||
* Numpy
|
||||
* dtypes - data type specifications
|
||||
* values - value specifications (i.e., data type + shape + encryption status)
|
||||
* representation - representation of computation
|
||||
* tracing - tracing of Python functions
|
||||
* extensions - custom functionality which is not available in NumPy (e.g., direct table lookups)
|
||||
* MLIR - MLIR conversion
|
||||
* compilation - compilation from a Python function to a circuit, client/server architecture
|
||||
* ONNX
|
||||
* convolution - custom convolution operations that follow the behavior of ONNX
|
||||
177
frontends/concrete-python/docs/getting-started/compatibility.md
Normal file
@@ -0,0 +1,177 @@
|
||||
# Compatibility
|
||||
|
||||
## Supported operations
|
||||
|
||||
Here are the operations you can use inside the function you are compiling.
|
||||
|
||||
{% hint style="info" %}
|
||||
Some of these operations are not supported between two encrypted values. A detailed error will be raised if you try to do something that is not supported.
|
||||
{% endhint %}
|
||||
|
||||
### Supported Python operators.
|
||||
|
||||
* [\_\_abs\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_abs\_\_)
|
||||
* [\_\_add\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_add\_\_)
|
||||
* [\_\_and\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_and\_\_)
|
||||
* [\_\_eq\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_eq\_\_)
|
||||
* [\_\_floordiv\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_floordiv\_\_)
|
||||
* [\_\_ge\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_ge\_\_)
|
||||
* [\_\_getitem\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_getitem\_\_)
|
||||
* [\_\_gt\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_gt\_\_)
|
||||
* [\_\_invert\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_invert\_\_)
|
||||
* [\_\_le\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_le\_\_)
|
||||
* [\_\_lshift\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_lshift\_\_)
|
||||
* [\_\_lt\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_lt\_\_)
|
||||
* [\_\_matmul\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_matmul\_\_)
|
||||
* [\_\_mod\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_mod\_\_)
|
||||
* [\_\_mul\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_mul\_\_)
|
||||
* [\_\_ne\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_ne\_\_)
|
||||
* [\_\_neg\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_neg\_\_)
|
||||
* [\_\_or\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_or\_\_)
|
||||
* [\_\_pos\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_pos\_\_)
|
||||
* [\_\_pow\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_pow\_\_)
|
||||
* [\_\_radd\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_radd\_\_)
|
||||
* [\_\_rand\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rand\_\_)
|
||||
* [\_\_rfloordiv\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rfloordiv\_\_)
|
||||
* [\_\_rlshift\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rlshift\_\_)
|
||||
* [\_\_rmatmul\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rmatmul\_\_)
|
||||
* [\_\_rmod\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rmod\_\_)
|
||||
* [\_\_rmul\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rmul\_\_)
|
||||
* [\_\_ror\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_ror\_\_)
|
||||
* [\_\_round\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_round\_\_)
|
||||
* [\_\_rpow\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rpow\_\_)
|
||||
* [\_\_rrshift\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rrshift\_\_)
|
||||
* [\_\_rshift\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rshift\_\_)
|
||||
* [\_\_rsub\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rsub\_\_)
|
||||
* [\_\_rtruediv\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rtruediv\_\_)
|
||||
* [\_\_rxor\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rxor\_\_)
|
||||
* [\_\_sub\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_sub\_\_)
|
||||
* [\_\_truediv\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_truediv\_\_)
|
||||
* [\_\_xor\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_xor\_\_)
|
||||
|
||||
### Supported NumPy functions.
|
||||
|
||||
* [np.absolute](https://numpy.org/doc/stable/reference/generated/numpy.absolute.html)
|
||||
* [np.add](https://numpy.org/doc/stable/reference/generated/numpy.add.html)
|
||||
* [np.arccos](https://numpy.org/doc/stable/reference/generated/numpy.arccos.html)
|
||||
* [np.arccosh](https://numpy.org/doc/stable/reference/generated/numpy.arccosh.html)
|
||||
* [np.arcsin](https://numpy.org/doc/stable/reference/generated/numpy.arcsin.html)
|
||||
* [np.arcsinh](https://numpy.org/doc/stable/reference/generated/numpy.arcsinh.html)
|
||||
* [np.arctan](https://numpy.org/doc/stable/reference/generated/numpy.arctan.html)
|
||||
* [np.arctan2](https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html)
|
||||
* [np.arctanh](https://numpy.org/doc/stable/reference/generated/numpy.arctanh.html)
|
||||
* [np.around](https://numpy.org/doc/stable/reference/generated/numpy.around.html)
|
||||
* [np.bitwise\_and](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_and.html)
|
||||
* [np.bitwise\_or](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_or.html)
|
||||
* [np.bitwise\_xor](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_xor.html)
|
||||
* [np.broadcast\_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast\_to.html)
|
||||
* [np.cbrt](https://numpy.org/doc/stable/reference/generated/numpy.cbrt.html)
|
||||
* [np.ceil](https://numpy.org/doc/stable/reference/generated/numpy.ceil.html)
|
||||
* [np.clip](https://numpy.org/doc/stable/reference/generated/numpy.clip.html)
|
||||
* [np.concatenate](https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html)
|
||||
* [np.copysign](https://numpy.org/doc/stable/reference/generated/numpy.copysign.html)
|
||||
* [np.cos](https://numpy.org/doc/stable/reference/generated/numpy.cos.html)
|
||||
* [np.cosh](https://numpy.org/doc/stable/reference/generated/numpy.cosh.html)
|
||||
* [np.deg2rad](https://numpy.org/doc/stable/reference/generated/numpy.deg2rad.html)
|
||||
* [np.degrees](https://numpy.org/doc/stable/reference/generated/numpy.degrees.html)
|
||||
* [np.dot](https://numpy.org/doc/stable/reference/generated/numpy.dot.html)
|
||||
* [np.equal](https://numpy.org/doc/stable/reference/generated/numpy.equal.html)
|
||||
* [np.exp](https://numpy.org/doc/stable/reference/generated/numpy.exp.html)
|
||||
* [np.exp2](https://numpy.org/doc/stable/reference/generated/numpy.exp2.html)
|
||||
* [np.expand\_dims](https://numpy.org/doc/stable/reference/generated/numpy.expand\_dims.html)
|
||||
* [np.expm1](https://numpy.org/doc/stable/reference/generated/numpy.expm1.html)
|
||||
* [np.fabs](https://numpy.org/doc/stable/reference/generated/numpy.fabs.html)
|
||||
* [np.float\_power](https://numpy.org/doc/stable/reference/generated/numpy.float\_power.html)
|
||||
* [np.floor](https://numpy.org/doc/stable/reference/generated/numpy.floor.html)
|
||||
* [np.floor\_divide](https://numpy.org/doc/stable/reference/generated/numpy.floor\_divide.html)
|
||||
* [np.fmax](https://numpy.org/doc/stable/reference/generated/numpy.fmax.html)
|
||||
* [np.fmin](https://numpy.org/doc/stable/reference/generated/numpy.fmin.html)
|
||||
* [np.fmod](https://numpy.org/doc/stable/reference/generated/numpy.fmod.html)
|
||||
* [np.gcd](https://numpy.org/doc/stable/reference/generated/numpy.gcd.html)
|
||||
* [np.greater](https://numpy.org/doc/stable/reference/generated/numpy.greater.html)
|
||||
* [np.greater\_equal](https://numpy.org/doc/stable/reference/generated/numpy.greater\_equal.html)
|
||||
* [np.heaviside](https://numpy.org/doc/stable/reference/generated/numpy.heaviside.html)
|
||||
* [np.hypot](https://numpy.org/doc/stable/reference/generated/numpy.hypot.html)
|
||||
* [np.invert](https://numpy.org/doc/stable/reference/generated/numpy.invert.html)
|
||||
* [np.isfinite](https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html)
|
||||
* [np.isinf](https://numpy.org/doc/stable/reference/generated/numpy.isinf.html)
|
||||
* [np.isnan](https://numpy.org/doc/stable/reference/generated/numpy.isnan.html)
|
||||
* [np.lcm](https://numpy.org/doc/stable/reference/generated/numpy.lcm.html)
|
||||
* [np.ldexp](https://numpy.org/doc/stable/reference/generated/numpy.ldexp.html)
|
||||
* [np.left\_shift](https://numpy.org/doc/stable/reference/generated/numpy.left\_shift.html)
|
||||
* [np.less](https://numpy.org/doc/stable/reference/generated/numpy.less.html)
|
||||
* [np.less\_equal](https://numpy.org/doc/stable/reference/generated/numpy.less\_equal.html)
|
||||
* [np.log](https://numpy.org/doc/stable/reference/generated/numpy.log.html)
|
||||
* [np.log10](https://numpy.org/doc/stable/reference/generated/numpy.log10.html)
|
||||
* [np.log1p](https://numpy.org/doc/stable/reference/generated/numpy.log1p.html)
|
||||
* [np.log2](https://numpy.org/doc/stable/reference/generated/numpy.log2.html)
|
||||
* [np.logaddexp](https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html)
|
||||
* [np.logaddexp2](https://numpy.org/doc/stable/reference/generated/numpy.logaddexp2.html)
|
||||
* [np.logical\_and](https://numpy.org/doc/stable/reference/generated/numpy.logical\_and.html)
|
||||
* [np.logical\_not](https://numpy.org/doc/stable/reference/generated/numpy.logical\_not.html)
|
||||
* [np.logical\_or](https://numpy.org/doc/stable/reference/generated/numpy.logical\_or.html)
|
||||
* [np.logical\_xor](https://numpy.org/doc/stable/reference/generated/numpy.logical\_xor.html)
|
||||
* [np.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html)
|
||||
* [np.maximum](https://numpy.org/doc/stable/reference/generated/numpy.maximum.html)
|
||||
* [np.minimum](https://numpy.org/doc/stable/reference/generated/numpy.minimum.html)
|
||||
* [np.multiply](https://numpy.org/doc/stable/reference/generated/numpy.multiply.html)
|
||||
* [np.negative](https://numpy.org/doc/stable/reference/generated/numpy.negative.html)
|
||||
* [np.nextafter](https://numpy.org/doc/stable/reference/generated/numpy.nextafter.html)
|
||||
* [np.not\_equal](https://numpy.org/doc/stable/reference/generated/numpy.not\_equal.html)
|
||||
* [np.ones\_like](https://numpy.org/doc/stable/reference/generated/numpy.ones\_like.html)
|
||||
* [np.positive](https://numpy.org/doc/stable/reference/generated/numpy.positive.html)
|
||||
* [np.power](https://numpy.org/doc/stable/reference/generated/numpy.power.html)
|
||||
* [np.rad2deg](https://numpy.org/doc/stable/reference/generated/numpy.rad2deg.html)
|
||||
* [np.radians](https://numpy.org/doc/stable/reference/generated/numpy.radians.html)
|
||||
* [np.reciprocal](https://numpy.org/doc/stable/reference/generated/numpy.reciprocal.html)
|
||||
* [np.remainder](https://numpy.org/doc/stable/reference/generated/numpy.remainder.html)
|
||||
* [np.reshape](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html)
|
||||
* [np.right\_shift](https://numpy.org/doc/stable/reference/generated/numpy.right\_shift.html)
|
||||
* [np.rint](https://numpy.org/doc/stable/reference/generated/numpy.rint.html)
|
||||
* [np.round\_](https://numpy.org/doc/stable/reference/generated/numpy.round\_.html)
|
||||
* [np.sign](https://numpy.org/doc/stable/reference/generated/numpy.sign.html)
|
||||
* [np.signbit](https://numpy.org/doc/stable/reference/generated/numpy.signbit.html)
|
||||
* [np.sin](https://numpy.org/doc/stable/reference/generated/numpy.sin.html)
|
||||
* [np.sinh](https://numpy.org/doc/stable/reference/generated/numpy.sinh.html)
|
||||
* [np.spacing](https://numpy.org/doc/stable/reference/generated/numpy.spacing.html)
|
||||
* [np.sqrt](https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html)
|
||||
* [np.square](https://numpy.org/doc/stable/reference/generated/numpy.square.html)
|
||||
* [np.subtract](https://numpy.org/doc/stable/reference/generated/numpy.subtract.html)
|
||||
* [np.sum](https://numpy.org/doc/stable/reference/generated/numpy.sum.html)
|
||||
* [np.tan](https://numpy.org/doc/stable/reference/generated/numpy.tan.html)
|
||||
* [np.tanh](https://numpy.org/doc/stable/reference/generated/numpy.tanh.html)
|
||||
* [np.transpose](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html)
|
||||
* [np.true\_divide](https://numpy.org/doc/stable/reference/generated/numpy.true\_divide.html)
|
||||
* [np.trunc](https://numpy.org/doc/stable/reference/generated/numpy.trunc.html)
|
||||
* [np.where](https://numpy.org/doc/stable/reference/generated/numpy.where.html)
|
||||
* [np.zeros\_like](https://numpy.org/doc/stable/reference/generated/numpy.zeros\_like.html)
|
||||
|
||||
### Supported `ndarray` methods.
|
||||
|
||||
* [np.ndarray.astype](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.astype.html)
|
||||
* [np.ndarray.clip](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.clip.html)
|
||||
* [np.ndarray.dot](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dot.html)
|
||||
* [np.ndarray.flatten](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flatten.html)
|
||||
* [np.ndarray.reshape](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.reshape.html)
|
||||
* [np.ndarray.transpose](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.transpose.html)
|
||||
|
||||
### Supported `ndarray` properties.
|
||||
|
||||
* [np.ndarray.shape](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html)
|
||||
* [np.ndarray.ndim](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html)
|
||||
* [np.ndarray.size](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.size.html)
|
||||
* [np.ndarray.T](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.T.html)
|
||||
|
||||
## Limitations
|
||||
|
||||
### Control flow constraints.
|
||||
|
||||
Some Python control flow statements are not supported. For example, you cannot have an `if` statement or a `while` statement for which the condition depends on an encrypted value. However, such statements are supported with constant values (e.g., `for i in range(SOME_CONSTANT)`, `if os.environ.get("SOME_FEATURE") == "ON":`).
|
||||
|
||||
### Type constraints.
|
||||
|
||||
Another constraint is that you cannot have floating-point inputs or floating-point outputs. You can have floating-point intermediate values as long as they can be converted to an integer Table Lookup (e.g., `(60 * np.sin(x)).astype(np.int64)`).
|
||||
|
||||
### Bit width constraints.
|
||||
|
||||
There is a limit on the bit width of encrypted values. We are constantly working on increasing this bit width. If you go above the limit, you will get an error.
|
||||
27
frontends/concrete-python/docs/getting-started/exactness.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Exactness
|
||||
|
||||
One of the most common operations in **Concrete-Numpy** is `Table Lookups` (TLUs). TLUs are performed with an FHE operation called `Programmable Bootstrapping` (PBS). PBSes have a certain probability of error, which, when triggered, result in inaccurate results.
|
||||
|
||||
Let's say you have the table:
|
||||
|
||||
```python
|
||||
[0, 1, 4, 9, 16, 25, 36, 49, 64]
|
||||
```
|
||||
|
||||
And you performed a table lookup using `4`. The result you should get is `16`, but because of the possibility of error, you can sometimes get `9` or `25`. Sometimes even `4` or `36` if you have a high probability of error.
|
||||
|
||||
The probability of this error can be configured through the `p_error` and `global_p_error` configuration options. The difference between these two options is that, `p_error` is for individual TLUs but `global_p_error` is for the whole circuit.
|
||||
|
||||
Here is an example, if you set `p_error` to `0.01`, it means every TLU in the circuit will have a 1% chance of not being exact and 99% chance of being exact. If you have a single TLU in the circuit, `global_p_error` would be 1% as well. But if you have 2 TLUs for example, `global_p_error` would be almost 2% (`1 - (0.99 * 0.99)`).
|
||||
|
||||
However, if you set `global_p_error` to `0.01`, the whole circuit will have 1% probability of being not exact, no matter how many table lookups are there.
|
||||
|
||||
If you set both of them, both will be satisfied. Essentially, the stricter one will be used.
|
||||
|
||||
By default, both `p_error` and `global_p_error` is set to `None`, which results in `global_p_error` of `1 / 100_000` being used. Feel free to play with these configuration options to pick the one best suited for your needs! For example, in some machine learning use cases, off-by-one or off-by-two errors doesn't affect the result much, in such cases `p_error` could be set to increase performance without losing accuracy.
|
||||
|
||||
See [How to Configure](../howto/configure.md) to learn how you can set a custom `p_error` and/or `global_p_error`.
|
||||
|
||||
{% hint style="info" %}
|
||||
Configuring either of those variables would affect computation time (compilation, keys generation, circuit execution) and space requirements (size of the keys on disk and in memory). Lower error probability would result in longer computation time and larger space requirements.
|
||||
{% endhint %}
|
||||
46
frontends/concrete-python/docs/getting-started/installing.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Installation
|
||||
|
||||
**Concrete-Numpy** is natively supported on Linux and macOS for Python 3.8 and 3.9, but if you have Docker support in your platform, you can use the docker image to use **Concrete-Numpy**.
|
||||
|
||||
## Using PyPI
|
||||
|
||||
You can install **Concrete-Numpy** from PyPI:
|
||||
|
||||
```shell
|
||||
pip install -U pip wheel setuptools
|
||||
pip install concrete-numpy
|
||||
```
|
||||
|
||||
{% hint style="warning" %}
|
||||
Apple silicon users must use docker installation (explained below) as there is no ARM version of some of our dependencies for the time being.
|
||||
{% endhint %}
|
||||
|
||||
## Using Docker
|
||||
|
||||
You can also get the **Concrete-Numpy** docker image:
|
||||
|
||||
```shell
|
||||
docker pull zamafhe/concrete-numpy:v1.0.0
|
||||
```
|
||||
|
||||
### Starting a Jupyter server.
|
||||
|
||||
By default, the entry point of the **Concrete-Numpy** docker image is a jupyter server that you can access from your browser:
|
||||
|
||||
```shell
|
||||
docker run --rm -it -p 8888:8888 zamafhe/concrete-numpy:v1.0.0
|
||||
```
|
||||
|
||||
To save notebooks on host, you can use a local volume:
|
||||
|
||||
```shell
|
||||
docker run --rm -it -p 8888:8888 -v /path/to/notebooks:/data zamafhe/concrete-numpy:v1.0.0
|
||||
```
|
||||
|
||||
### Starting a Bash session.
|
||||
|
||||
Alternatively, you can launch a Bash session:
|
||||
|
||||
```shell
|
||||
docker run --rm -it zamafhe/concrete-numpy:latest /bin/bash
|
||||
```
|
||||
104
frontends/concrete-python/docs/getting-started/performance.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Performance
|
||||
|
||||
The most important operation in Concrete-Numpy is the table lookup operation. All operations except addition, subtraction, multiplication with non-encrypted values, and a few operations built with those primitive operations (e.g. matmul, conv) are converted to table lookups under the hood:
|
||||
|
||||
```python
|
||||
import concrete.numpy as cnp
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return x ** 2
|
||||
|
||||
inputset = range(2 ** 4)
|
||||
circuit = f.compile(inputset)
|
||||
```
|
||||
|
||||
is exactly the same as
|
||||
|
||||
```python
|
||||
import concrete.numpy as cnp
|
||||
|
||||
table = cnp.LookupTable([x ** 2 for x in range(2 ** 4)])
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return table[x]
|
||||
|
||||
inputset = range(2 ** 4)
|
||||
circuit = f.compile(inputset)
|
||||
```
|
||||
|
||||
Table lookups are very flexible, and they allow Concrete Numpy to support many operations, but they are expensive! Therefore, you should try to avoid them as much as possible. In most cases, it's not possible to avoid them completely, but you might remove the number of TLUs or replace some of them with other primitive operations.
|
||||
|
||||
The exact cost depend on many variables (machine configuration, error probability, etc.), but you can develop some intuition for single threaded CPU execution performance using:
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
import concrete.numpy as cnp
|
||||
import numpy as np
|
||||
|
||||
WARMUP = 3
|
||||
SAMPLES = 8
|
||||
BITWIDTHS = range(1, 15)
|
||||
CONFIGURATION = cnp.Configuration(
|
||||
enable_unsafe_features=True,
|
||||
use_insecure_key_cache=True,
|
||||
insecure_key_cache_location=".keys",
|
||||
)
|
||||
|
||||
timings = {}
|
||||
for n in BITWIDTHS:
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def base(x):
|
||||
return x
|
||||
|
||||
table = cnp.LookupTable([np.sqrt(x).round().astype(np.int64) for x in range(2 ** n)])
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def tlu(x):
|
||||
return table[x]
|
||||
|
||||
inputset = [0, 2**n - 1]
|
||||
|
||||
base_circuit = base.compile(inputset, CONFIGURATION)
|
||||
tlu_circuit = tlu.compile(inputset, CONFIGURATION)
|
||||
|
||||
print()
|
||||
print(f"Generating keys for n={n}...")
|
||||
|
||||
base_circuit.keygen()
|
||||
tlu_circuit.keygen()
|
||||
|
||||
timings[n] = []
|
||||
for i in range(SAMPLES + WARMUP):
|
||||
sample = np.random.randint(0, 2 ** n)
|
||||
|
||||
encrypted_sample = base_circuit.encrypt(sample)
|
||||
start = time.time()
|
||||
encrypted_result = base_circuit.run(encrypted_sample)
|
||||
end = time.time()
|
||||
assert base_circuit.decrypt(encrypted_result) == sample
|
||||
|
||||
base_time = end - start
|
||||
|
||||
encrypted_sample = tlu_circuit.encrypt(sample)
|
||||
start = time.time()
|
||||
encrypted_result = tlu_circuit.run(encrypted_sample)
|
||||
end = time.time()
|
||||
assert tlu_circuit.decrypt(encrypted_result) == np.sqrt(sample).round().astype(np.int64)
|
||||
|
||||
tlu_time = end - start
|
||||
|
||||
if i >= WARMUP:
|
||||
timings[n].append(tlu_time - base_time)
|
||||
print(f"Sample #{i - WARMUP + 1} took {timings[n][-1] * 1000:.3f}ms")
|
||||
|
||||
print()
|
||||
for n, times in timings.items():
|
||||
print(f"{n}-bits -> {np.mean(times) * 1000:.3f}ms")
|
||||
```
|
||||
|
||||
{% hint style="info" %}
|
||||
Concrete Numpy automatically parallelize execution if TLUs are applied to tensors.
|
||||
{% endhint %}
|
||||
@@ -0,0 +1,90 @@
|
||||
# Quick Start
|
||||
|
||||
To compute on encrypted data, you first need to define the function that you want to compute, then compile it into a Concrete-Numpy `Circuit`, which you can use to perform homomorphic evaluation.
|
||||
|
||||
Here is the full example that we will walk through:
|
||||
|
||||
```python
|
||||
import concrete.numpy as cnp
|
||||
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
compiler = cnp.Compiler(add, {"x": "encrypted", "y": "clear"})
|
||||
|
||||
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1)]
|
||||
circuit = compiler.compile(inputset)
|
||||
|
||||
x = 4
|
||||
y = 4
|
||||
|
||||
clear_evaluation = add(x, y)
|
||||
homomorphic_evaluation = circuit.encrypt_run_decrypt(x, y)
|
||||
|
||||
print(x, "+", y, "=", clear_evaluation, "=", homomorphic_evaluation)
|
||||
```
|
||||
|
||||
## Importing the library
|
||||
|
||||
Everything you need to perform homomorphic evaluation is included in a single module:
|
||||
|
||||
<!--pytest-codeblocks:skip-->
|
||||
```python
|
||||
import concrete.numpy as cnp
|
||||
```
|
||||
|
||||
## Defining the function to compile
|
||||
|
||||
In this example, we will compile a simple addition function:
|
||||
|
||||
<!--pytest-codeblocks:skip-->
|
||||
```python
|
||||
def add(x, y):
|
||||
return x + y
|
||||
```
|
||||
|
||||
## Creating a compiler
|
||||
|
||||
To compile the function, you need to create a `Compiler` by specifying the function to compile and encryption status of its inputs:
|
||||
|
||||
<!--pytest-codeblocks:skip-->
|
||||
```python
|
||||
compiler = cnp.Compiler(add, {"x": "encrypted", "y": "clear"})
|
||||
```
|
||||
|
||||
## Defining an inputset
|
||||
|
||||
An inputset is a collection representing the typical inputs to the function. It is used to determine the bit widths and shapes of the variables within the function.
|
||||
|
||||
It should be an iterable, yielding tuples of the same length as the number of arguments of the function being compiled:
|
||||
|
||||
<!--pytest-codeblocks:skip-->
|
||||
```python
|
||||
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1)]
|
||||
```
|
||||
|
||||
{% hint style="warning" %}
|
||||
All inputs in the inputset will be evaluated in the graph, which takes time. If you're experiencing long compilation times, consider providing a smaller inputset.
|
||||
{% endhint %}
|
||||
|
||||
## Compiling the function
|
||||
|
||||
You can use the `compile` method of a `Compiler` class with an inputset to perform the compilation and get the resulting circuit back:
|
||||
|
||||
<!--pytest-codeblocks:skip-->
|
||||
```python
|
||||
circuit = compiler.compile(inputset)
|
||||
```
|
||||
|
||||
## Performing homomorphic evaluation
|
||||
|
||||
You can use the `encrypt_run_decrypt` method of a `Circuit` class to perform homomorphic evaluation:
|
||||
|
||||
<!--pytest-codeblocks:skip-->
|
||||
```python
|
||||
homomorphic_evaluation = circuit.encrypt_run_decrypt(4, 4)
|
||||
```
|
||||
|
||||
{% hint style="info" %}
|
||||
`circuit.encrypt_run_decrypt(*args)` is just a convenient way to do everything at once. It is implemented as `circuit.decrypt(circuit.run(circuit.encrypt(*args)))`.
|
||||
{% endhint %}
|
||||