chore: Move to the mono repo layout

This commit is contained in:
Quentin Bourgerie
2023-03-08 11:23:21 +01:00
parent 4fb476aaec
commit ce7eddc22d
201 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1 @@
**

View File

@@ -0,0 +1,15 @@
# EditorConfig is awesome: https://EditorConfig.org
# top-most EditorConfig file
root = true
# Unix-style newlines with a newline ending every file
[*]
end_of_line = lf
insert_final_newline = true
# 4 space indentation
[*.py]
charset = utf-8
indent_style = space
indent_size = 4

View File

@@ -0,0 +1 @@
root: ./docs

View File

@@ -0,0 +1,28 @@
BSD 3-Clause Clear License
Copyright © 2022 ZAMA.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation and/or other
materials provided with the distribution.
3. Neither the name of ZAMA nor the names of its contributors may be used to endorse
or promote products derived from this software without specific prior written permission.
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE.
THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,313 @@
SHELL:=/bin/bash
DEV_DOCKER_IMG:=concrete-numpy-dev
DEV_DOCKERFILE:=docker/Dockerfile.dev
DEV_CONTAINER_VENV_VOLUME:=concrete-numpy-internal-venv
DEV_CONTAINER_CACHE_VOLUME:=concrete-numpy-internal-cache
SRC_DIR:=concrete
.PHONY: setup_env # Set up the environment
setup_env:
poetry run python -m pip install -U pip wheel
poetry run python -m pip install -U --force-reinstall setuptools
if [[ $$(uname) != "Linux" ]] && [[ $$(uname) != "Darwin" ]]; then \
poetry install --only dev; \
else \
poetry install; \
fi
.PHONY: sync_env # Synchronise the environment
sync_env:
if [[ $$(uname) != "Linux" ]] && [[ $$(uname) != "Darwin" ]]; then \
poetry install --remove-untracked --only dev; \
else \
poetry install --remove-untracked; \
fi
$(MAKE) setup_env
.PHONY: python_format # Apply python formatting
python_format:
poetry run env bash ./script/source_format/format_python.sh \
--dir $(SRC_DIR) --dir tests --dir script
.PHONY: check_python_format # Check python format
check_python_format:
poetry run env bash ./script/source_format/format_python.sh \
--dir $(SRC_DIR) --dir tests --dir script --check
.PHONY: check_finalize_nb # Sanitize notebooks
check_finalize_nb:
poetry run python ./script/nbmake_utils/notebook_finalize.py docs --check
.PHONY: pylint # Run pylint
pylint:
$(MAKE) --keep-going pylint_src pylint_tests pylint_script
.PHONY: pylint_src # Run pylint on sources
pylint_src:
poetry run pylint --rcfile=pylintrc $(SRC_DIR)
.PHONY: pylint_tests # Run pylint on tests
pylint_tests:
@# Disable duplicate code detection (R0801) in tests
@# Disable unnecessary lambda (W0108) for tests
find ./tests/ -type f -name "*.py" | xargs poetry run pylint --disable=R0801,W0108 --rcfile=pylintrc
.PHONY: pylint_script # Run pylint on scripts
pylint_script:
find ./script/ -type f -name "*.py" | xargs poetry run pylint --rcfile=pylintrc
.PHONY: flake8 # Run flake8
flake8:
poetry run flake8 --max-line-length 100 --per-file-ignores="__init__.py:F401" \
$(SRC_DIR)/ tests/ script/
.PHONY: ruff
ruff:
poetry run ruff $(SRC_DIR)/ tests/ script/
.PHONY: python_linting # Run python linters
python_linting: pylint flake8 ruff
.PHONY: conformance # Run command to fix some conformance issues automatically
conformance: finalize_nb python_format supported_functions licenses
.PHONY: pcc # Run pre-commit checks
pcc:
@$(MAKE) --keep-going --jobs $$(./script/make_utils/ncpus.sh) --output-sync=recurse \
--no-print-directory pcc_internal
PCC_DEPS := check_python_format check_finalize_nb python_linting mypy_ci pydocstyle shell_lint
PCC_DEPS += check_supported_functions # check_licenses
# Not commented on purpose for make help, since internal
.PHONY: pcc_internal
pcc_internal: $(PCC_DEPS)
# One can reproduce pytest thanks to the --randomly-seed which is given by
# pytest-randomly
.PHONY: pytest # Run pytest
pytest:
poetry run pytest -svv \
--global-coverage=.global-coverage.json \
-n $$(./script/make_utils/ncpus.sh) \
--cov=$(SRC_DIR) --cov-fail-under=100 \
--randomly-dont-reorganize \
--cov-report=term-missing:skip-covered tests/
# Not a huge fan of ignoring missing imports, but some packages do not have typing stubs
.PHONY: mypy # Run mypy
mypy:
poetry run mypy -p $(SRC_DIR) --ignore-missing-imports
# Friendly target to run mypy without ignoring missing stubs and still have errors messages
# Allows to see which stubs we are missing
.PHONY: mypy_ns # Run mypy (without ignoring missing stubs)
mypy_ns:
poetry run mypy -p $(SRC_DIR)
.PHONY: mypy_test # Run mypy on test files
mypy_test:
find ./tests/ -name "*.py" | xargs poetry run mypy --ignore-missing-imports
.PHONY: mypy_script # Run mypy on scripts
mypy_script:
find ./script/ -name "*.py" | xargs poetry run mypy --ignore-missing-imports
# The plus indicates that make will be called by the command and allows to share the context with
# the parent make execution. We serialize calls to these targets as they may overwrite each others
# cache which can cause issues.
.PHONY: mypy_ci # Run all mypy checks for CI
mypy_ci:
$(MAKE) --keep-going mypy mypy_test mypy_script
.PHONY: docker_build # Build dev docker
docker_build:
BUILD_ARGS=; \
if [[ $$(uname) == "Linux" ]]; then \
BUILD_ARGS="--build-arg BUILD_UID=$$(id -u) --build-arg BUILD_GID=$$(id -g)"; \
fi; \
DOCKER_BUILDKIT=1 docker build $${BUILD_ARGS:+$$BUILD_ARGS} \
--pull -t $(DEV_DOCKER_IMG) -f $(DEV_DOCKERFILE) .
.PHONY: docker_rebuild # Rebuild docker
docker_rebuild: docker_clean_volumes
BUILD_ARGS=; \
if [[ $$(uname) == "Linux" ]]; then \
BUILD_ARGS="--build-arg BUILD_UID=$$(id -u) --build-arg BUILD_GID=$$(id -g)"; \
fi; \
DOCKER_BUILDKIT=1 docker build $${BUILD_ARGS:+$$BUILD_ARGS} \
--pull --no-cache -t $(DEV_DOCKER_IMG) -f $(DEV_DOCKERFILE) .
.PHONY: docker_start # Launch docker
docker_start:
@# the slash before pwd is for Windows
docker run --rm -it \
-p 8888:8888 \
--env DISPLAY=host.docker.internal:0 \
--volume /"$$(pwd)":/src \
--volume $(DEV_CONTAINER_VENV_VOLUME):/home/dev_user/dev_venv \
--volume $(DEV_CONTAINER_CACHE_VOLUME):/home/dev_user/.cache \
$(DEV_DOCKER_IMG)
.PHONY: docker_build_and_start # Docker build and start
docker_build_and_start: docker_build docker_start
.PHONY: docker_bas # Docker build and start
docker_bas: docker_build_and_start
.PHONY: docker_clean_volumes # Docker clean volumes
docker_clean_volumes:
docker volume rm -f $(DEV_CONTAINER_VENV_VOLUME)
docker volume rm -f $(DEV_CONTAINER_CACHE_VOLUME)
.PHONY: docker_cv # Docker clean volumes
docker_cv: docker_clean_volumes
.PHONY: pydocstyle # Launch syntax checker on source code documentation
pydocstyle:
@# From http://www.pydocstyle.org/en/stable/error_codes.html
poetry run pydocstyle $(SRC_DIR) --convention google --add-ignore=D1,D200,D202,D212,D402,D417 --add-select=D401
.PHONY: finalize_nb # Sanitize notebooks
finalize_nb:
poetry run python ./script/nbmake_utils/notebook_finalize.py docs
# A warning in a package unrelated to the project made pytest fail with notebooks
# Run notebook tests without warnings as sources are already tested with warnings treated as errors
.PHONY: pytest_nb # Launch notebook tests
pytest_nb:
find docs -name "*.ipynb" | grep -v _build | grep -v .ipynb_checkpoints | xargs poetry run pytest -Wignore --nbmake
.PHONY: jupyter # Launch jupyter notebook
jupyter:
poetry run jupyter notebook --allow-root --no-browser --ip=0.0.0.0
.PHONY: release_docker # Build a docker release image
release_docker:
./docker/build_release_image.sh
.PHONY: upgrade_py_deps # Upgrade python dependencies
upgrade_py_deps:
./script/make_utils/upgrade_deps.sh
# Keeping this target as it proved useful before we had a proper package, allowed to run code that
# pytest-codeblocks was failing to execute if not installed as a pip package.
# This is done by hand as pytest-codeblocks was failing with our native extensions.
# See refused PR on the project here: https://github.com/nschloe/pytest-codeblocks/pull/58
# Test code blocks using a custom python script in the documentation
.PHONY: test_codeblocks
test_codeblocks:
poetry run python ./script/make_utils/test_md_python_code.py --md_dir docs/
.PHONY: pytest_codeblocks # Test code blocks using pytest in the documentation
pytest_codeblocks:
poetry run pytest --codeblocks -svv -n $$(./script/make_utils/ncpus.sh) \
--randomly-dont-reorganize docs/
# From https://stackoverflow.com/a/63523300 for the find command
.PHONY: shell_lint # Lint all bash scripts
shell_lint:
find \( -path "./.venv" -o -path "./.docker_venv" \) -prune -o -type f -name "*.sh" -print | \
xargs shellcheck
.PHONY: set_version_no_commit # Dry run for set_version
set_version_no_commit:
@if [[ "$$VERSION" == "" ]]; then \
echo "VERSION env variable is empty. Please set to desired version."; \
exit 1; \
fi && \
poetry run python ./script/make_utils/version_utils.py set-version --version "$${VERSION}"
.PHONY: set_version # Generate a new version number and update all files with it accordingly
set_version:
@if [[ "$$VERSION" == "" ]]; then \
echo "VERSION env variable is empty. Please set to desired version."; \
exit 1; \
fi && \
STASH_COUNT="$$(git stash list | wc -l)" && \
git stash && \
poetry run python ./script/make_utils/version_utils.py set-version --version "$${VERSION}" && \
git add -u && \
git commit -m "chore: bump version to $${VERSION}" && \
NEW_STASH_COUNT="$$(git stash list | wc -l)" && \
if [[ "$$NEW_STASH_COUNT" != "$$STASH_COUNT" ]]; then \
git stash pop; \
fi
.PHONY: changelog # Generate a changelog
changelog:
PROJECT_VER=($$(poetry version)) && \
PROJECT_VER="$${PROJECT_VER[1]}" && \
poetry run python ./script/make_utils/changelog_helper.py > "CHANGELOG_$${PROJECT_VER}.md"
.PHONY: release # Create a new release
release:
@PROJECT_VER=($$(poetry version)) && \
PROJECT_VER="$${PROJECT_VER[1]}" && \
TAG_NAME="v$${PROJECT_VER}" && \
git fetch --tags --force && \
git tag -s -a -m "$${TAG_NAME} release" "$${TAG_NAME}" && \
git push origin "refs/tags/$${TAG_NAME}"
.PHONY: show_scope # Show the accepted types and optional scopes (for git conventional commits)
show_scope:
@echo "Accepted types and optional scopes:"
@cat .github/workflows/continuous-integration.yaml | grep feat | grep pattern | cut -f 2- -d ":" | cut -f 2- -d " "
.PHONY: show_type # Show the accepted types and optional scopes (for git conventional commits)
show_type:show_scope
# grep recursively, ignore binary files, print file line, print file name
# exclude dot dirs, exclude pylintrc (would match the notes)
# exclude notebooks (sometimes matches in svg text), match the notes in this directory
.PHONY: todo # List all todo left in the code
todo:
@NOTES_ARGS=$$(poetry run python ./script/make_utils/get_pylintrc_notes.py \
--pylintrc-path pylintrc) && \
grep -rInH --exclude-dir='.[^.]*' --exclude=pylintrc --exclude='*.ipynb' "$${NOTES_ARGS}" .
.PHONY: supported_functions # Update docs with supported functions
supported_functions:
poetry run python script/doc_utils/gen_supported_ufuncs.py docs/getting-started/compatibility.md
.PHONY: check_supported_functions # Check supported functions (for the doc)
check_supported_functions:
poetry run python script/doc_utils/gen_supported_ufuncs.py docs/getting-started/compatibility.md --check
.PHONY: licenses # Generate the list of licenses of dependencies
licenses:
@./script/make_utils/licenses.sh
.PHONY: check_licenses # Check if the licenses of dependencies have changed
check_licenses:
@TMP_OUT="$$(mktemp)" && \
if ! poetry run env bash ./script/make_utils/licenses.sh --check > "$${TMP_OUT}"; then \
cat "$${TMP_OUT}"; \
rm -f "$${TMP_OUT}"; \
echo "Error while checking licenses, see log above."; \
echo "Consider re-running 'make licenses'"; \
exit 1; \
else \
echo "Licenses check OK"; \
fi
.PHONY: check_licenses
.PHONY: help # Generate list of targets with descriptions
help:
@grep '^.PHONY: .* #' Makefile | sed 's/\.PHONY: \(.*\) # \(.*\)/\1\t\2/' | expand -t30 | sort
.PHONY: pip_audit # Run pip-audit and check if there are known vulnerabilities in our dependencies
pip_audit:
poetry run pip-audit
.PHONY: clean_local_git # Tell the user how to delete local git branches, except main
clean_local_git:
@git fetch --all --prune
@echo "Consider doing: "
@echo
@# Don't consider deleting `main` or current branches
@git branch | grep -v "^*" | grep -v main | xargs echo "git branch -D "
@echo

View File

@@ -0,0 +1,145 @@
<p align="center">
<!-- product name logo -->
<img width=600 src="https://user-images.githubusercontent.com/5758427/193612313-6b1124c7-8e3e-4e23-8b8c-57fd43b17d4f.png">
</p>
<p align="center">
<!-- Version badge using shields.io -->
<a href="https://github.com/zama-ai/concrete-numpy/releases">
<img src="https://img.shields.io/github/v/release/zama-ai/concrete-numpy?style=flat-square">
</a>
<!-- Link to docs badge using shields.io -->
<a href="https://docs.zama.ai/concrete-numpy/">
<img src="https://img.shields.io/badge/read-documentation-yellow?style=flat-square">
</a>
<!-- Community forum badge using shields.io -->
<a href="https://community.zama.ai/c/concrete-numpy">
<img src="https://img.shields.io/badge/community%20forum-online-brightgreen?style=flat-square">
</a>
<!-- Open source badge using shields.io -->
<a href="https://docs.zama.ai/concrete-numpy/developer/contributing">
<img src="https://img.shields.io/badge/we're%20open%20source-contributing.md-blue?style=flat-square">
</a>
<!-- Follow on twitter badge using shields.io -->
<a href="https://twitter.com/zama_fhe">
<img src="https://img.shields.io/badge/follow-zama_fhe-blue?logo=twitter&style=flat-square">
</a>
</p>
**Concrete Numpy** is an open-source library which simplifies the use of fully homomorphic encryption (FHE) in Python.
FHE is a powerful cryptographic tool, which allows computation to be performed directly on encrypted data without needing to decrypt it first.
With FHE, you can build services that preserve the privacy of the users. FHE is also great against data breaches as everything is done on encrypted data. Even if the server is compromised, in the end no sensitive data is leaked.
## Main features
- Ability to compile Python functions (that may use NumPy within) to their FHE equivalents, to operate on encrypted data
- Support for [large collection of operators](https://docs.zama.ai/concrete-numpy/getting-started/compatibility)
- Partial support for floating points
- Support for table lookups on integers
- Support for integration with Client / Server architectures
## Installation
| OS / HW | Available on Docker | Available on PyPI |
| :----------------------------------: | :-----------------: | :--------------: |
| Linux | Yes | Yes |
| Windows | Yes | Coming soon |
| Windows Subsystem for Linux | Yes | Yes |
| macOS (Intel) | Yes | Yes |
| macOS (Apple Silicon, ie M1, M2 etc) | Yes (Rosetta) | Coming soon |
The preferred way to install Concrete Numpy is through PyPI:
```shell
pip install concrete-numpy
```
You can get the concrete-numpy docker image by pulling the latest docker image:
```shell
docker pull zamafhe/concrete-numpy:v0.10.0
```
You can find more detailed installation instructions in [installing.md](docs/getting-started/installing.md)
## Getting started
```python
import concrete.numpy as cnp
def add(x, y):
return x + y
compiler = cnp.Compiler(add, {"x": "encrypted", "y": "encrypted"})
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1), (3, 2), (6, 1), (1, 7), (4, 5), (5, 4)]
print(f"Compiling...")
circuit = compiler.compile(inputset)
print(f"Generating keys...")
circuit.keygen()
examples = [(3, 4), (1, 2), (7, 7), (0, 0)]
for example in examples:
encrypted_example = circuit.encrypt(*example)
encrypted_result = circuit.run(encrypted_example)
result = circuit.decrypt(encrypted_result)
print(f"Evaluation of {' + '.join(map(str, example))} homomorphically = {result}")
```
or if you have a simple function that you can decorate, and you don't care about explicit steps of key generation, encryption, evaluation and decryption:
```python
import concrete.numpy as cnp
@cnp.compiler({"x": "encrypted", "y": "encrypted"})
def add(x, y):
return x + y
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1), (3, 2), (6, 1), (1, 7), (4, 5), (5, 4)]
print(f"Compiling...")
circuit = add.compile(inputset)
examples = [(3, 4), (1, 2), (7, 7), (0, 0)]
for example in examples:
result = circuit.encrypt_run_decrypt(*example)
print(f"Evaluation of {' + '.join(map(str, example))} homomorphically = {result}")
```
## Documentation
Full, comprehensive documentation is available at [https://docs.zama.ai/concrete-numpy](https://docs.zama.ai/concrete-numpy).
## Target users
Concrete Numpy is a generic library that supports a variety of use cases. Because of this flexibility,
it doesn't provide primitives for specific use cases.
If you have a specific use case, or a specific field of computation, you may want to build abstractions on top of Concrete Numpy.
One such example is [Concrete ML](https://github.com/zama-ai/concrete-ml), which is built on top of Concrete Numpy to simplify Machine Learning oriented use cases.
## Tutorials
Various tutorials are proposed in the documentation to help you start writing homomorphic programs:
- How to use Concrete Numpy with [Decorators](https://docs.zama.ai/concrete-numpy/tutorials/decorator)
- Partial support of [Floating Points](https://docs.zama.ai/concrete-numpy/tutorials/floating_points)
- How to perform [Table Lookup](https://docs.zama.ai/concrete-numpy/tutorials/table_lookup)
More generally, if you have built awesome projects using Concrete Numpy, feel free to let us know and we'll link to it!
## Need support?
<a target="_blank" href="https://community.zama.ai">
<img src="https://user-images.githubusercontent.com/5758427/191792238-b132e413-05f9-4fee-bee3-1371f3d81c28.png">
</a>
## License
This software is distributed under the BSD-3-Clause-Clear license. If you have any questions, please contact us at hello@zama.ai.

View File

@@ -0,0 +1,7 @@
"""
Setup concrete module to be enlarged with numpy module.
"""
# Do not modify, this is to have a compatible namespace package
# https://packaging.python.org/en/latest/guides/packaging-namespace-packages/#pkg-resources-style-namespace-packages
__import__("pkg_resources").declare_namespace(__name__) # pragma: no cover

View File

@@ -0,0 +1,166 @@
"""
Export everything that users might need.
"""
from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult
from .compilation import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
Circuit,
Client,
ClientSpecs,
Compiler,
Configuration,
DebugArtifacts,
EncryptionStatus,
Server,
)
from .compilation.decorators import circuit, compiler
from .extensions import (
AutoRounder,
LookupTable,
array,
one,
ones,
round_bit_pattern,
tag,
univariate,
zero,
zeros,
)
from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from .representation import Graph
from .tracing.typing import (
f32,
f64,
int1,
int2,
int3,
int4,
int5,
int6,
int7,
int8,
int9,
int10,
int11,
int12,
int13,
int14,
int15,
int16,
int17,
int18,
int19,
int20,
int21,
int22,
int23,
int24,
int25,
int26,
int27,
int28,
int29,
int30,
int31,
int32,
int33,
int34,
int35,
int36,
int37,
int38,
int39,
int40,
int41,
int42,
int43,
int44,
int45,
int46,
int47,
int48,
int49,
int50,
int51,
int52,
int53,
int54,
int55,
int56,
int57,
int58,
int59,
int60,
int61,
int62,
int63,
int64,
tensor,
uint1,
uint2,
uint3,
uint4,
uint5,
uint6,
uint7,
uint8,
uint9,
uint10,
uint11,
uint12,
uint13,
uint14,
uint15,
uint16,
uint17,
uint18,
uint19,
uint20,
uint21,
uint22,
uint23,
uint24,
uint25,
uint26,
uint27,
uint28,
uint29,
uint30,
uint31,
uint32,
uint33,
uint34,
uint35,
uint36,
uint37,
uint38,
uint39,
uint40,
uint41,
uint42,
uint43,
uint44,
uint45,
uint46,
uint47,
uint48,
uint49,
uint50,
uint51,
uint52,
uint53,
uint54,
uint55,
uint56,
uint57,
uint58,
uint59,
uint60,
uint61,
uint62,
uint63,
uint64,
)

View File

@@ -0,0 +1,11 @@
"""
Glue the compilation process together.
"""
from .artifacts import DebugArtifacts
from .circuit import Circuit
from .client import Client
from .compiler import Compiler, EncryptionStatus
from .configuration import DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, Configuration
from .server import Server
from .specs import ClientSpecs

View File

@@ -0,0 +1,189 @@
"""
Declaration of `DebugArtifacts` class.
"""
import inspect
import platform
import shutil
import subprocess
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from ..representation import Graph
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
class DebugArtifacts:
"""
DebugArtifacts class, to export information about the compilation process.
"""
output_directory: Path
source_code: Optional[str]
parameter_encryption_statuses: Dict[str, str]
textual_representations_of_graphs: Dict[str, List[str]]
final_graph: Optional[Graph]
mlir_to_compile: Optional[str]
client_parameters: Optional[bytes]
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
self.output_directory = Path(output_directory)
self.source_code = None
self.parameter_encryption_statuses = {}
self.textual_representations_of_graphs = {}
self.final_graph = None
self.mlir_to_compile = None
self.client_parameters = None
def add_source_code(self, function: Union[str, Callable]):
"""
Add source code of the function being compiled.
Args:
function (Union[str, Callable]):
either the source code of the function or the function itself
"""
try:
self.source_code = (
function if isinstance(function, str) else inspect.getsource(function)
)
except OSError: # pragma: no cover
self.source_code = "unavailable"
def add_parameter_encryption_status(self, name: str, encryption_status: str):
"""
Add parameter encryption status of a parameter of the function being compiled.
Args:
name (str):
name of the parameter
encryption_status (str):
encryption status of the parameter
"""
self.parameter_encryption_statuses[name] = encryption_status
def add_graph(self, name: str, graph: Graph):
"""
Add a representation of the function being compiled.
Args:
name (str):
name of the graph (e.g., initial, optimized, final)
graph (Graph):
a representation of the function being compiled
"""
if name not in self.textual_representations_of_graphs:
self.textual_representations_of_graphs[name] = []
textual_representation = graph.format()
self.textual_representations_of_graphs[name].append(textual_representation)
self.final_graph = graph
def add_mlir_to_compile(self, mlir: str):
"""
Add textual representation of the resulting MLIR.
Args:
mlir (str):
textual representation of the resulting MLIR
"""
self.mlir_to_compile = mlir
def add_client_parameters(self, client_parameters: bytes):
"""
Add client parameters used.
Args:
client_parameters (bytes): client parameters
"""
self.client_parameters = client_parameters
def export(self):
"""
Export the collected information to `self.output_directory`.
"""
# pylint: disable=too-many-branches
output_directory = self.output_directory
if output_directory.exists():
shutil.rmtree(output_directory)
output_directory.mkdir(parents=True)
with open(output_directory.joinpath("environment.txt"), "w", encoding="utf-8") as f:
f.write(f"{platform.platform()} {platform.version()}\n")
f.write(f"Python {platform.python_version()}\n")
with open(output_directory.joinpath("requirements.txt"), "w", encoding="utf-8") as f:
# example `pip list` output
# Package Version
# ----------------------------- ---------
# alabaster 0.7.12
# appdirs 1.4.4
# ... ...
# ... ...
# wrapt 1.12.1
# zipp 3.5.0
pip_process = subprocess.run(
["pip", "--disable-pip-version-check", "list"], stdout=subprocess.PIPE, check=True
)
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
# skip 'Package ... Version' line
next(dependencies)
# skip '------- ... -------' line
next(dependencies)
for dependency in dependencies:
tokens = [token for token in dependency.split(" ") if token != ""] # noqa: S105
if len(tokens) == 0:
continue
name = tokens[0]
version = tokens[1]
f.write(f"{name}=={version}\n")
if self.source_code is not None:
with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f:
f.write(self.source_code)
if len(self.parameter_encryption_statuses) > 0:
with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f:
for name, parameter in self.parameter_encryption_statuses.items():
f.write(f"{name} :: {parameter}\n")
identifier = 0
textual_representations = self.textual_representations_of_graphs.items()
for name, representations in textual_representations:
for representation in representations:
identifier += 1
output_path = output_directory.joinpath(f"{identifier}.{name}.graph.txt")
with open(output_path, "w", encoding="utf-8") as f:
f.write(f"{representation}\n")
if self.mlir_to_compile is not None:
assert self.final_graph is not None
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
f.write(f"{self.mlir_to_compile}\n")
if self.client_parameters is not None:
with open(output_directory.joinpath("client_parameters.json"), "wb") as f:
f.write(self.client_parameters)
# pylint: enable=too-many-branches

View File

@@ -0,0 +1,218 @@
"""
Declaration of `Circuit` class.
"""
from typing import Any, Optional, Tuple, Union, cast
import numpy as np
from concrete.compiler import PublicArguments, PublicResult
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Graph
from .client import Client
from .configuration import Configuration
from .server import Server
class Circuit:
"""
Circuit class, to combine computation graph, mlir, client and server into a single object.
"""
configuration: Configuration
graph: Graph
mlir: str
client: Client
server: Server
def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None):
self.configuration = configuration if configuration is not None else Configuration()
self.graph = graph
self.mlir = mlir
self._initialize_client_and_server()
def _initialize_client_and_server(self):
input_signs = []
for i in range(len(self.graph.input_nodes)): # pylint: disable=consider-using-enumerate
input_value = self.graph.input_nodes[i].output
assert_that(isinstance(input_value.dtype, Integer))
input_dtype = cast(Integer, input_value.dtype)
input_signs.append(input_dtype.is_signed)
output_signs = []
for i in range(len(self.graph.output_nodes)): # pylint: disable=consider-using-enumerate
output_value = self.graph.output_nodes[i].output
assert_that(isinstance(output_value.dtype, Integer))
output_dtype = cast(Integer, output_value.dtype)
output_signs.append(output_dtype.is_signed)
self.server = Server.create(self.mlir, input_signs, output_signs, self.configuration)
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
assert_that(self.configuration.enable_unsafe_features)
assert_that(self.configuration.insecure_key_cache_location is not None)
keyset_cache_directory = self.configuration.insecure_key_cache_location
self.client = Client(self.server.client_specs, keyset_cache_directory)
def __str__(self):
return self.graph.format()
def simulate(self, *args: Any) -> Any:
"""
Simulate execution of the circuit.
Args:
*args (Any):
inputs to the circuit
Returns:
Any:
result of the simulation
"""
return self.graph(*args, p_error=self.p_error)
def keygen(self, force: bool = False):
"""
Generate keys required for homomorphic evaluation.
Args:
force (bool, default = False):
whether to generate new keys even if keys are already generated
"""
self.client.keygen(force)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
"""
Prepare inputs to be run on the circuit.
Args:
*args (Union[int, numpy.ndarray]):
inputs to the circuit
Returns:
PublicArguments:
encrypted and plain arguments as well as public keys
"""
return self.client.encrypt(*args)
def run(self, args: PublicArguments) -> PublicResult:
"""
Evaluate circuit using encrypted arguments.
Args:
args (PublicArguments):
arguments to the circuit (can be obtained with `encrypt` method of `Circuit`)
Returns:
PublicResult:
encrypted result of homomorphic evaluaton
"""
self.keygen(force=False)
return self.server.run(args, self.client.evaluation_keys)
def decrypt(
self,
result: PublicResult,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
"""
Decrypt result of homomorphic evaluaton.
Args:
result (PublicResult):
encrypted result of homomorphic evaluaton
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluaton
"""
return self.client.decrypt(result)
def encrypt_run_decrypt(self, *args: Any) -> Any:
"""
Encrypt inputs, run the circuit, and decrypt the outputs in one go.
Args:
*args (Union[int, numpy.ndarray]):
inputs to the circuit
Returns:
Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
clear result of homomorphic evaluation
"""
return self.decrypt(self.run(self.encrypt(*args)))
def cleanup(self):
"""
Cleanup the temporary library output directory.
"""
self.server.cleanup()
@property
def complexity(self) -> float:
"""
Get complexity of the circuit.
"""
return self.server.complexity
@property
def size_of_secret_keys(self) -> int:
"""
Get size of the secret keys of the circuit.
"""
return self.server.size_of_secret_keys
@property
def size_of_bootstrap_keys(self) -> int:
"""
Get size of the bootstrap keys of the circuit.
"""
return self.server.size_of_bootstrap_keys
@property
def size_of_keyswitch_keys(self) -> int:
"""
Get size of the key switch keys of the circuit.
"""
return self.server.size_of_keyswitch_keys
@property
def size_of_inputs(self) -> int:
"""
Get size of the inputs of the circuit.
"""
return self.server.size_of_inputs
@property
def size_of_outputs(self) -> int:
"""
Get size of the outputs of the circuit.
"""
return self.server.size_of_outputs
@property
def p_error(self) -> int:
"""
Get probability of error for each simple TLU (on a scalar).
"""
return self.server.p_error
@property
def global_p_error(self) -> int:
"""
Get the probability of having at least one simple TLU error during the entire execution.
"""
return self.server.global_p_error

View File

@@ -0,0 +1,271 @@
"""
Declaration of `Client` class.
"""
import json
import shutil
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import (
ClientSupport,
EvaluationKeys,
KeySet,
KeySetCache,
PublicArguments,
PublicResult,
)
from ..dtypes.integer import SignedInteger, UnsignedInteger
from ..internal.utils import assert_that
from ..values.value import Value
from .specs import ClientSpecs
class Client:
"""
Client class, which can be used to manage keys, encrypt arguments and decrypt results.
"""
specs: ClientSpecs
_keyset: Optional[KeySet]
_keyset_cache: Optional[KeySetCache]
def __init__(
self,
client_specs: ClientSpecs,
keyset_cache_directory: Optional[Union[str, Path]] = None,
):
self.specs = client_specs
self._keyset = None
self._keyset_cache = None
if keyset_cache_directory is not None:
self._keyset_cache = KeySetCache.new(str(keyset_cache_directory))
def save(self, path: Union[str, Path]):
"""
Save the client into the given path in zip format.
Args:
path (Union[str, Path]):
path to save the client
"""
with tempfile.TemporaryDirectory() as tmp_dir:
with open(Path(tmp_dir) / "client.specs.json", "w", encoding="utf-8") as f:
f.write(self.specs.serialize())
path = str(path)
if path.endswith(".zip"):
path = path[: len(path) - 4]
shutil.make_archive(path, "zip", tmp_dir)
@staticmethod
def load(
path: Union[str, Path],
keyset_cache_directory: Optional[Union[str, Path]] = None,
) -> "Client":
"""
Load the client from the given path in zip format.
Args:
path (Union[str, Path]):
path to load the client from
keyset_cache_directory (Optional[Union[str, Path]], default = None):
keyset cache directory to use
Returns:
Client:
client loaded from the filesystem
"""
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.unpack_archive(path, tmp_dir, "zip")
with open(Path(tmp_dir) / "client.specs.json", "r", encoding="utf-8") as f:
client_specs = ClientSpecs.unserialize(f.read())
return Client(client_specs, keyset_cache_directory)
def keygen(self, force: bool = False):
"""
Generate keys required for homomorphic evaluation.
Args:
force (bool, default = False):
whether to generate new keys even if keys are already generated
"""
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(self.specs.client_parameters, self._keyset_cache)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
"""
Prepare inputs to be run on the circuit.
Args:
*args (Union[int, numpy.ndarray]):
inputs to the circuit
Returns:
PublicArguments:
encrypted and plain arguments as well as public keys
"""
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
assert_that("inputs" in client_parameters_json)
input_specs = client_parameters_json["inputs"]
if len(args) != len(input_specs):
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
raise ValueError(message)
sanitized_args: Dict[int, Union[int, np.ndarray]] = {}
for index, spec in enumerate(input_specs):
arg = args[index]
if isinstance(arg, list):
arg = np.array(arg)
is_valid = isinstance(arg, (int, np.integer)) or (
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
)
width = spec["shape"]["width"]
shape = tuple(spec["shape"]["dimensions"])
is_encrypted = spec["encryption"] is not None
expected_dtype = (
SignedInteger(width) if self.specs.input_signs[index] else UnsignedInteger(width)
)
expected_value = Value(expected_dtype, shape, is_encrypted)
if is_valid:
expected_min = expected_dtype.min()
expected_max = expected_dtype.max()
actual_min = arg if isinstance(arg, int) else arg.min()
actual_max = arg if isinstance(arg, int) else arg.max()
actual_shape = () if isinstance(arg, int) else arg.shape
is_valid = (
actual_min >= expected_min
and actual_max <= expected_max
and actual_shape == expected_value.shape
)
if is_valid:
is_signed = self.specs.input_signs[index]
sanitizer = 0 if not is_signed else 2 ** (width - 1)
if isinstance(arg, int):
sanitized_args[index] = arg + sanitizer
else:
sanitized_args[index] = (arg + sanitizer).astype(np.uint64)
if not is_valid:
actual_value = Value.of(arg, is_encrypted=is_encrypted)
message = (
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)
raise ValueError(message)
self.keygen(force=False)
return ClientSupport.encrypt_arguments(
self.specs.client_parameters,
self._keyset,
[sanitized_args[i] for i in range(len(sanitized_args))],
)
def decrypt(
self,
result: PublicResult,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
"""
Decrypt result of homomorphic evaluaton.
Args:
result (PublicResult):
encrypted result of homomorphic evaluaton
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluaton
"""
self.keygen(force=False)
outputs = ClientSupport.decrypt_result(self.specs.client_parameters, self._keyset, result)
if not isinstance(outputs, tuple):
outputs = (outputs,)
sanitized_outputs: List[Union[int, np.ndarray]] = []
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
assert_that("outputs" in client_parameters_json)
output_specs = client_parameters_json["outputs"]
for index, output in enumerate(outputs):
is_signed = self.specs.output_signs[index]
crt_decomposition = (
output_specs[index].get("encryption", {}).get("encoding", {}).get("crt", [])
)
if is_signed:
if crt_decomposition:
if isinstance(output, int):
sanititzed_output = (
output
if output < (int(np.prod(crt_decomposition)) // 2)
else -int(np.prod(crt_decomposition)) + output
)
else:
output = output.astype(np.longlong) # to prevent overflows in numpy
sanititzed_output = np.where(
output < (np.prod(crt_decomposition) // 2),
output,
-np.prod(crt_decomposition) + output,
).astype(
np.int64
) # type: ignore
sanitized_outputs.append(sanititzed_output)
else:
n = output_specs[index]["shape"]["width"]
output %= 2**n
if isinstance(output, int):
sanititzed_output = output if output < (2 ** (n - 1)) else output - (2**n)
sanitized_outputs.append(sanititzed_output)
else:
output = output.astype(np.longlong) # to prevent overflows in numpy
sanititzed_output = np.where(
output < (2 ** (n - 1)), output, output - (2**n)
).astype(
np.int64
) # type: ignore
sanitized_outputs.append(sanititzed_output)
else:
sanitized_outputs.append(
output if isinstance(output, int) else output.astype(np.uint64)
)
return sanitized_outputs[0] if len(sanitized_outputs) == 1 else tuple(sanitized_outputs)
@property
def evaluation_keys(self) -> EvaluationKeys:
"""
Get evaluation keys for encrypted computation.
Returns:
EvaluationKeys
evaluation keys for encrypted computation
"""
self.keygen(force=False)
assert self._keyset is not None
return self._keyset.get_evaluation_keys()

View File

@@ -0,0 +1,550 @@
"""
Declaration of `Compiler` class.
"""
import inspect
import os
import traceback
from copy import deepcopy
from enum import Enum, unique
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ..extensions import AutoRounder
from ..mlir import GraphConverter
from ..representation import Graph
from ..tracing import Tracer
from ..values import Value
from .artifacts import DebugArtifacts
from .circuit import Circuit
from .configuration import Configuration
from .utils import fuse
@unique
class EncryptionStatus(str, Enum):
"""
EncryptionStatus enum, to represent encryption status of parameters.
"""
CLEAR = "clear"
ENCRYPTED = "encrypted"
class Compiler:
"""
Compiler class, to glue the compilation pipeline.
"""
function: Callable
parameter_encryption_statuses: Dict[str, EncryptionStatus]
configuration: Configuration
artifacts: Optional[DebugArtifacts]
inputset: List[Any]
graph: Optional[Graph]
_is_direct: bool
_parameter_values: Dict[str, Value]
@staticmethod
def assemble(
function: Callable,
parameter_values: Dict[str, Value],
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Circuit:
"""
Assemble a circuit from the raw parameter values, used in direct circuit definition.
Args:
function (Callable):
function to convert to a circuit
parameter_values (Dict[str, Value]):
parameter values of the function
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Circuit:
assembled circuit
"""
compiler = Compiler(
function,
{
name: "encrypted" if value.is_encrypted else "clear"
for name, value in parameter_values.items()
},
)
# pylint: disable=protected-access
compiler._is_direct = True
compiler._parameter_values = parameter_values
# pylint: enable=protected-access
return compiler.compile(None, configuration, artifacts, **kwargs)
def __init__(
self,
function: Callable,
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
):
signature = inspect.signature(function)
missing_args = list(signature.parameters)
for arg in parameter_encryption_statuses.keys():
if arg in signature.parameters:
missing_args.remove(arg)
if len(missing_args) != 0:
parameter_str = repr(missing_args[0])
for arg in missing_args[1:-1]:
parameter_str += f", {repr(arg)}"
if len(missing_args) != 1:
parameter_str += f" and {repr(missing_args[-1])}"
message = (
f"Encryption status{'es' if len(missing_args) > 1 else ''} "
f"of parameter{'s' if len(missing_args) > 1 else ''} "
f"{parameter_str} of function '{function.__name__}' "
f"{'are' if len(missing_args) > 1 else 'is'} not provided"
)
raise ValueError(message)
additional_args = list(parameter_encryption_statuses)
for arg in signature.parameters.keys():
if arg in parameter_encryption_statuses:
additional_args.remove(arg)
if len(additional_args) != 0:
parameter_str = repr(additional_args[0])
for arg in additional_args[1:-1]:
parameter_str += f", {repr(arg)}"
if len(additional_args) != 1:
parameter_str += f" and {repr(additional_args[-1])}"
message = (
f"Encryption status{'es' if len(additional_args) > 1 else ''} "
f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but "
f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter "
f"of function '{function.__name__}'"
)
raise ValueError(message)
self.function = function # type: ignore
self.parameter_encryption_statuses = {
param: EncryptionStatus(status.lower())
for param, status in parameter_encryption_statuses.items()
}
self.configuration = Configuration()
self.artifacts = None
self.inputset = []
self.graph = None
self._is_direct = False
self._parameter_values = {}
def __call__(
self,
*args: Any,
**kwargs: Any,
) -> Union[
np.bool_,
np.integer,
np.floating,
np.ndarray,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
]:
if len(kwargs) != 0:
message = f"Calling function '{self.function.__name__}' with kwargs is not supported"
raise RuntimeError(message)
sample = args[0] if len(args) == 1 else args
if self.graph is None:
self._trace(sample)
assert self.graph is not None
self.inputset.append(sample)
return self.graph(*args)
def _trace(self, sample: Union[Any, Tuple[Any, ...]]):
"""
Trace the function and fuse the resulting graph with a sample input.
Args:
sample (Union[Any, Tuple[Any, ...]]):
sample to use for tracing
"""
if self.artifacts is not None:
self.artifacts.add_source_code(self.function)
for param, encryption_status in self.parameter_encryption_statuses.items():
self.artifacts.add_parameter_encryption_status(param, encryption_status)
parameters = {
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
for arg, (param, status) in zip(
sample if len(self.parameter_encryption_statuses) > 1 else (sample,),
self.parameter_encryption_statuses.items(),
)
}
self.graph = Tracer.trace(self.function, parameters)
if self.artifacts is not None:
self.artifacts.add_graph("initial", self.graph)
fuse(self.graph, self.artifacts)
def _evaluate(
self,
action: str,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]],
):
"""
Trace, fuse, measure bounds, and update values in the resulting graph in one go.
Args:
action (str):
action being performed (e.g., "trace", "compile")
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
"""
if self._is_direct:
self.graph = Tracer.trace(self.function, self._parameter_values, is_direct=True)
if self.artifacts is not None:
self.artifacts.add_graph("initial", self.graph) # pragma: no cover
fuse(self.graph, self.artifacts)
if self.artifacts is not None:
self.artifacts.add_graph("final", self.graph) # pragma: no cover
return
if inputset is not None:
previous_inputset_length = len(self.inputset)
for index, sample in enumerate(iter(inputset)):
self.inputset.append(sample)
if not isinstance(sample, tuple):
sample = (sample,)
if len(sample) != len(self.parameter_encryption_statuses):
self.inputset = self.inputset[:previous_inputset_length]
expected = (
"a single value"
if len(self.parameter_encryption_statuses) == 1
else f"a tuple of {len(self.parameter_encryption_statuses)} values"
)
actual = (
"a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values"
)
message = (
f"Input #{index} of your inputset is not well formed "
f"(expected {expected} got {actual})"
)
raise ValueError(message)
if self.configuration.auto_adjust_rounders:
AutoRounder.adjust(self.function, self.inputset)
if self.graph is None:
try:
first_sample = next(iter(self.inputset))
except StopIteration as error:
message = (
f"{action} function '{self.function.__name__}' "
f"without an inputset is not supported"
)
raise RuntimeError(message) from error
self._trace(first_sample)
assert self.graph is not None
bounds = self.graph.measure_bounds(self.inputset)
self.graph.update_with_bounds(bounds)
if self.artifacts is not None:
self.artifacts.add_graph("final", self.graph)
def trace(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Graph:
"""
Trace the function using an inputset.
Args:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Graph:
computation graph representing the function prior to MLIR conversion
"""
old_configuration = deepcopy(self.configuration)
old_artifacts = deepcopy(self.artifacts)
if configuration is not None:
self.configuration = configuration
if len(kwargs) != 0:
self.configuration = self.configuration.fork(**kwargs)
self.artifacts = (
artifacts
if artifacts is not None
else DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
)
try:
self._evaluate("Tracing", inputset)
assert self.graph is not None
if self.configuration.verbose or self.configuration.show_graph:
graph = self.graph.format()
longest_line = max([len(line) for line in graph.split("\n")])
try: # pragma: no cover
# this branch cannot be covered
# because `os.get_terminal_size()`
# raises an exception during tests
columns, _ = os.get_terminal_size()
if columns == 0:
columns = min(longest_line, 80)
else:
columns = min(longest_line, columns)
except OSError: # pragma: no cover
columns = min(longest_line, 80)
print()
print("Computation Graph")
print("-" * columns)
print(graph)
print("-" * columns)
print()
return self.graph
except Exception: # pragma: no cover
# this branch is reserved for unexpected issues and hence it shouldn't be tested
# if it could be tested, we would have fixed the underlying issue
# if the user desires so,
# we need to export all the information we have about the compilation
if self.configuration.dump_artifacts_on_unexpected_failures:
assert self.artifacts is not None
self.artifacts.export()
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
finally:
self.configuration = old_configuration
self.artifacts = old_artifacts
# pylint: disable=too-many-branches,too-many-statements
def compile(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Circuit:
"""
Compile the function using an inputset.
Args:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Circuit:
compiled circuit
"""
old_configuration = deepcopy(self.configuration)
old_artifacts = deepcopy(self.artifacts)
if configuration is not None:
self.configuration = configuration
if len(kwargs) != 0:
self.configuration = self.configuration.fork(**kwargs)
self.artifacts = (
artifacts
if artifacts is not None
else DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
)
try:
self._evaluate("Compiling", inputset)
assert self.graph is not None
mlir = GraphConverter.convert(self.graph)
if self.artifacts is not None:
self.artifacts.add_mlir_to_compile(mlir)
show_graph = (
self.configuration.show_graph
if self.configuration.show_graph is not None
else self.configuration.verbose
)
show_mlir = (
self.configuration.show_mlir
if self.configuration.show_mlir is not None
else self.configuration.verbose
)
show_optimizer = (
self.configuration.show_optimizer
if self.configuration.show_optimizer is not None
else self.configuration.verbose
)
columns = 0
if show_graph or show_mlir or show_optimizer:
graph = (
self.graph.format()
if self.configuration.verbose or self.configuration.show_graph
else ""
)
longest_graph_line = max([len(line) for line in graph.split("\n")])
longest_mlir_line = max([len(line) for line in mlir.split("\n")])
longest_line = max(longest_graph_line, longest_mlir_line)
try: # pragma: no cover
# this branch cannot be covered
# because `os.get_terminal_size()`
# raises an exception during tests
columns, _ = os.get_terminal_size()
if columns == 0:
columns = min(longest_line, 80)
else:
columns = min(longest_line, columns)
except OSError: # pragma: no cover
columns = min(longest_line, 80)
if show_graph:
print()
print("Computation Graph")
print("-" * columns)
print(graph)
print("-" * columns)
print()
if show_mlir:
print("\n" if not show_graph else "", end="")
print("MLIR")
print("-" * columns)
print(mlir)
print("-" * columns)
print()
if show_optimizer:
print("\n" if not (show_graph or show_mlir) else "", end="")
print("Optimizer")
print("-" * columns)
circuit = Circuit(self.graph, mlir, self.configuration)
client_parameters = circuit.client.specs.client_parameters
if self.artifacts is not None:
self.artifacts.add_client_parameters(client_parameters.serialize())
if show_optimizer:
print("-" * columns)
print()
return circuit
except Exception: # pragma: no cover
# this branch is reserved for unexpected issues and hence it shouldn't be tested
# if it could be tested, we would have fixed the underlying issue
# if the user desires so,
# we need to export all the information we have about the compilation
if self.configuration.dump_artifacts_on_unexpected_failures:
assert self.artifacts is not None
self.artifacts.export()
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
finally:
self.configuration = old_configuration
self.artifacts = old_artifacts
# pylint: enable=too-many-branches,too-many-statements

View File

@@ -0,0 +1,162 @@
"""
Declaration of `Configuration` class.
"""
from copy import deepcopy
from pathlib import Path
from typing import Optional, Union, get_type_hints
DEFAULT_P_ERROR = None
DEFAULT_GLOBAL_P_ERROR = 1 / 100_000
class Configuration:
"""
Configuration class, to allow the compilation process to be customized.
"""
# pylint: disable=too-many-instance-attributes
verbose: bool
show_graph: Optional[bool]
show_mlir: Optional[bool]
show_optimizer: Optional[bool]
dump_artifacts_on_unexpected_failures: bool
enable_unsafe_features: bool
use_insecure_key_cache: bool
loop_parallelize: bool
dataflow_parallelize: bool
auto_parallelize: bool
jit: bool
p_error: Optional[float]
global_p_error: Optional[float]
insecure_key_cache_location: Optional[str]
auto_adjust_rounders: bool
# pylint: enable=too-many-instance-attributes
def _validate(self):
"""
Validate configuration.
"""
if not self.enable_unsafe_features:
if self.use_insecure_key_cache:
message = "Insecure key cache cannot be used without enabling unsafe features"
raise RuntimeError(message)
if self.use_insecure_key_cache and self.insecure_key_cache_location is None:
message = "Insecure key cache cannot be enabled without specifying its location"
raise RuntimeError(message)
# pylint: disable=too-many-arguments
def __init__(
self,
verbose: bool = False,
show_graph: Optional[bool] = None,
show_mlir: Optional[bool] = None,
show_optimizer: Optional[bool] = None,
dump_artifacts_on_unexpected_failures: bool = True,
enable_unsafe_features: bool = False,
use_insecure_key_cache: bool = False,
insecure_key_cache_location: Optional[Union[Path, str]] = None,
loop_parallelize: bool = True,
dataflow_parallelize: bool = True,
auto_parallelize: bool = False,
jit: bool = False,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
auto_adjust_rounders: bool = False,
):
self.verbose = verbose
self.show_graph = show_graph
self.show_mlir = show_mlir
self.show_optimizer = show_optimizer
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
self.enable_unsafe_features = enable_unsafe_features
self.use_insecure_key_cache = use_insecure_key_cache
self.insecure_key_cache_location = (
str(insecure_key_cache_location) if insecure_key_cache_location is not None else None
)
self.loop_parallelize = loop_parallelize
self.dataflow_parallelize = dataflow_parallelize
self.auto_parallelize = auto_parallelize
self.jit = jit
self.p_error = p_error
self.global_p_error = global_p_error
self.auto_adjust_rounders = auto_adjust_rounders
self._validate()
# pylint: enable=too-many-arguments
def fork(self, **kwargs) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Args:
**kwargs:
changes to make
Returns:
Configuration:
configuration that is forked from self and updated using kwargs
"""
# pylint: disable=too-many-branches
result = deepcopy(self)
hints = get_type_hints(Configuration)
for name, value in kwargs.items():
if name not in hints:
message = f"Unexpected keyword argument '{name}'"
raise TypeError(message)
hint = hints[name]
expected = None
is_correctly_typed = True
if name == "insecure_key_cache_location":
if not (value is None or isinstance(value, str)):
is_correctly_typed = False
expected = "Optional[str]"
elif name == "p_error":
if not (value is None or isinstance(value, float)):
is_correctly_typed = False
expected = "Optional[float]"
elif name == "global_p_error":
if not (value is None or isinstance(value, float)):
is_correctly_typed = False
expected = "Optional[float]"
elif name in ["show_graph", "show_mlir", "show_optimizer"]:
if not (value is None or isinstance(value, bool)):
is_correctly_typed = False
expected = "Optional[bool]"
elif not isinstance(value, hint): # type: ignore
is_correctly_typed = False
if not is_correctly_typed:
if expected is None:
expected = hint.__name__ if hasattr(hint, "__name__") else str(hint)
message = (
f"Unexpected type for keyword argument '{name}' "
f"(expected '{expected}', got '{type(value).__name__}')"
)
raise TypeError(message)
setattr(result, name, value)
# pylint: disable=protected-access
result._validate()
# pylint: enable=protected-access
return result
# pylint: enable=too-many-branches

View File

@@ -0,0 +1,163 @@
"""
Declaration of `circuit` and `compiler` decorators.
"""
import inspect
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
from ..representation import Graph
from ..tracing.typing import ScalarAnnotation
from ..values import Value
from .artifacts import DebugArtifacts
from .circuit import Circuit
from .compiler import Compiler, EncryptionStatus
from .configuration import Configuration
def circuit(
parameters: Mapping[str, Union[str, EncryptionStatus]],
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
):
"""
Provide a direct interface for compilation.
Args:
parameters (Mapping[str, Union[str, EncryptionStatus]]):
encryption statuses of the parameters of the function to compile
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
"""
def decoration(function: Callable):
signature = inspect.signature(function)
parameter_values: Dict[str, Value] = {}
for name, details in signature.parameters.items():
if name not in parameters:
continue
annotation = details.annotation
is_value = isinstance(annotation, Value)
is_scalar_annotation = isinstance(annotation, type) and issubclass(
annotation, ScalarAnnotation
)
if not (is_value or is_scalar_annotation):
message = (
f"Annotation {annotation} for argument '{name}' is not valid "
f"(please use a cnp type such as "
f"`cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')"
)
raise ValueError(message)
parameter_values[name] = (
annotation if is_value else Value(annotation.dtype, shape=(), is_encrypted=False)
)
status = EncryptionStatus(parameters[name].lower())
parameter_values[name].is_encrypted = status == "encrypted"
return Compiler.assemble(function, parameter_values, configuration, artifacts, **kwargs)
return decoration
def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]):
"""
Provide an easy interface for compilation.
Args:
parameters (Mapping[str, Union[str, EncryptionStatus]]):
encryption statuses of the parameters of the function to compile
"""
def decoration(function: Callable):
class Compilable:
"""
Compilable class, to wrap a function and provide methods to trace and compile it.
"""
function: Callable
compiler: Compiler
def __init__(self, function: Callable):
self.function = function # type: ignore
self.compiler = Compiler(self.function, dict(parameters))
def __call__(self, *args, **kwargs) -> Any:
self.compiler(*args, **kwargs)
return self.function(*args, **kwargs)
def trace(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Graph:
"""
Trace the function into computation graph.
Args:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Graph:
computation graph representing the function prior to MLIR conversion
"""
return self.compiler.trace(inputset, configuration, artifacts, **kwargs)
def compile(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Circuit:
"""
Compile the function into a circuit.
Args:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Circuit:
compiled circuit
"""
return self.compiler.compile(inputset, configuration, artifacts, **kwargs)
return Compilable(function)
return decoration

View File

@@ -0,0 +1,350 @@
"""
Declaration of `Server` class.
"""
import json
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Union
import concrete.compiler
from concrete.compiler import (
CompilationFeedback,
CompilationOptions,
EvaluationKeys,
JITCompilationResult,
JITLambda,
JITSupport,
LibraryCompilationResult,
LibraryLambda,
LibrarySupport,
PublicArguments,
PublicResult,
)
from ..internal.utils import assert_that
from .configuration import DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, Configuration
from .specs import ClientSpecs
class Server:
"""
Server class, which can be used to perform homomorphic computation.
"""
client_specs: ClientSpecs
_output_dir: Optional[tempfile.TemporaryDirectory]
_support: Union[JITSupport, LibrarySupport]
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
_compilation_feedback: CompilationFeedback
_server_lambda: Union[JITLambda, LibraryLambda]
_mlir: Optional[str]
_configuration: Optional[Configuration]
def __init__(
self,
client_specs: ClientSpecs,
output_dir: Optional[tempfile.TemporaryDirectory],
support: Union[JITSupport, LibrarySupport],
compilation_result: Union[JITCompilationResult, LibraryCompilationResult],
server_lambda: Union[JITLambda, LibraryLambda],
):
self.client_specs = client_specs
self._output_dir = output_dir
self._support = support
self._compilation_result = compilation_result
self._compilation_feedback = self._support.load_compilation_feedback(compilation_result)
self._server_lambda = server_lambda
self._mlir = None
assert_that(
support.load_client_parameters(compilation_result).serialize()
== client_specs.client_parameters.serialize()
)
@staticmethod
def create(
mlir: str,
input_signs: List[bool],
output_signs: List[bool],
configuration: Configuration,
) -> "Server":
"""
Create a server using MLIR and output sign information.
Args:
mlir (str):
mlir to compile
input_signs (List[bool]):
sign status of the inputs
output_signs (List[bool]):
sign status of the outputs
configuration (Optional[Configuration], default = None):
configuration to use
"""
options = CompilationOptions.new("main")
options.set_loop_parallelize(configuration.loop_parallelize)
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
options.set_auto_parallelize(configuration.auto_parallelize)
if configuration.auto_parallelize or configuration.dataflow_parallelize:
concrete.compiler.init_dfr()
global_p_error_is_set = configuration.global_p_error is not None
p_error_is_set = configuration.p_error is not None
if global_p_error_is_set and p_error_is_set: # pragma: no cover
options.set_global_p_error(configuration.global_p_error)
options.set_p_error(configuration.p_error)
elif global_p_error_is_set: # pragma: no cover
options.set_global_p_error(configuration.global_p_error)
options.set_p_error(1.0)
elif p_error_is_set: # pragma: no cover
options.set_global_p_error(1.0)
options.set_p_error(configuration.p_error)
else: # pragma: no cover
if DEFAULT_GLOBAL_P_ERROR is not None:
options.set_global_p_error(DEFAULT_GLOBAL_P_ERROR)
else:
options.set_global_p_error(1.0)
if DEFAULT_P_ERROR is not None:
options.set_p_error(DEFAULT_P_ERROR)
else:
options.set_p_error(1.0)
show_optimizer = (
configuration.show_optimizer
if configuration.show_optimizer is not None
else configuration.verbose
)
options.set_display_optimizer_choice(show_optimizer)
if configuration.jit:
output_dir = None
support = JITSupport.new()
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
else:
# pylint: disable=consider-using-with
output_dir = tempfile.TemporaryDirectory()
output_dir_path = Path(output_dir.name)
# pylint: enable=consider-using-with
support = LibrarySupport.new(
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
)
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
client_parameters = support.load_client_parameters(compilation_result)
client_specs = ClientSpecs(input_signs, client_parameters, output_signs)
result = Server(client_specs, output_dir, support, compilation_result, server_lambda)
# pylint: disable=protected-access
result._mlir = mlir
result._configuration = configuration
# pylint: enable=protected-access
return result
def save(self, path: Union[str, Path], via_mlir: bool = False):
"""
Save the server into the given path in zip format.
Args:
path (Union[str, Path]):
path to save the server
via_mlir (bool, default = False)
export using the MLIR code of the program,
this will make the export cross-platform
"""
path = str(path)
if path.endswith(".zip"):
path = path[: len(path) - 4]
if via_mlir:
if self._mlir is None or self._configuration is None:
message = "Loaded server objects cannot be saved again via MLIR"
raise RuntimeError(message)
with tempfile.TemporaryDirectory() as tmp:
with open(Path(tmp) / "circuit.mlir", "w", encoding="utf-8") as f:
f.write(self._mlir)
with open(Path(tmp) / "input_signs.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self.client_specs.input_signs))
with open(Path(tmp) / "output_signs.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self.client_specs.output_signs))
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self._configuration.__dict__))
shutil.make_archive(path, "zip", tmp)
return
if self._output_dir is None:
message = "Just-in-Time compilation cannot be saved"
raise RuntimeError(message)
with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f:
f.write(self.client_specs.serialize())
shutil.make_archive(path, "zip", self._output_dir.name)
@staticmethod
def load(path: Union[str, Path]) -> "Server":
"""
Load the server from the given path in zip format.
Args:
path (Union[str, Path]):
path to load the server from
Returns:
Server:
server loaded from the filesystem
"""
# pylint: disable=consider-using-with
output_dir = tempfile.TemporaryDirectory()
output_dir_path = Path(output_dir.name)
# pylint: enable=consider-using-with
shutil.unpack_archive(path, str(output_dir_path), "zip")
if (output_dir_path / "circuit.mlir").exists():
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
mlir = f.read()
with open(output_dir_path / "input_signs.json", "r", encoding="utf-8") as f:
input_signs = json.load(f)
assert_that(isinstance(input_signs, list))
assert_that(all(isinstance(sign, bool) for sign in input_signs))
with open(output_dir_path / "output_signs.json", "r", encoding="utf-8") as f:
output_signs = json.load(f)
assert_that(isinstance(output_signs, list))
assert_that(all(isinstance(sign, bool) for sign in output_signs))
with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
configuration = Configuration().fork(**json.load(f))
return Server.create(mlir, input_signs, output_signs, configuration)
with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f:
client_specs = ClientSpecs.unserialize(f.read())
support = LibrarySupport.new(
str(output_dir_path),
generateCppHeader=False,
generateStaticLib=False,
)
compilation_result = support.reload("main")
server_lambda = support.load_server_lambda(compilation_result)
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
def run(self, args: PublicArguments, evaluation_keys: EvaluationKeys) -> PublicResult:
"""
Evaluate using encrypted arguments.
Args:
args (PublicArguments):
encrypted arguments of the computation
evaluation_keys (EvaluationKeys):
evaluation keys for encrypted computation
Returns:
PublicResult:
encrypted result of the computation
"""
return self._support.server_call(self._server_lambda, args, evaluation_keys)
def cleanup(self):
"""
Cleanup the temporary library output directory.
"""
if self._output_dir is not None:
self._output_dir.cleanup()
@property
def complexity(self) -> float:
"""
Get complexity of the compiled program.
"""
return self._compilation_feedback.complexity
@property
def size_of_secret_keys(self) -> int:
"""
Get size of the secret keys of the compiled program.
"""
return self._compilation_feedback.total_secret_keys_size
@property
def size_of_bootstrap_keys(self) -> int:
"""
Get size of the bootstrap keys of the compiled program.
"""
return self._compilation_feedback.total_bootstrap_keys_size
@property
def size_of_keyswitch_keys(self) -> int:
"""
Get size of the key switch keys of the compiled program.
"""
return self._compilation_feedback.total_keyswitch_keys_size
@property
def size_of_inputs(self) -> int:
"""
Get size of the inputs of the compiled program.
"""
return self._compilation_feedback.total_inputs_size
@property
def size_of_outputs(self) -> int:
"""
Get size of the outputs of the compiled program.
"""
return self._compilation_feedback.total_output_size
@property
def p_error(self) -> int:
"""
Get the probability of error for each simple TLU (on a scalar).
"""
return self._compilation_feedback.p_error
@property
def global_p_error(self) -> int:
"""
Get the probability of having at least one simple TLU error during the entire execution.
"""
return self._compilation_feedback.global_p_error

View File

@@ -0,0 +1,127 @@
"""
Declaration of `ClientSpecs` class.
"""
import json
from typing import List
from concrete.compiler import ClientParameters, PublicArguments, PublicResult
class ClientSpecs:
"""
ClientSpecs class, to create Client objects.
"""
input_signs: List[bool]
client_parameters: ClientParameters
output_signs: List[bool]
def __init__(
self,
input_signs: List[bool],
client_parameters: ClientParameters,
output_signs: List[bool],
):
self.input_signs = input_signs
self.client_parameters = client_parameters
self.output_signs = output_signs
def serialize(self) -> str:
"""
Serialize client specs into a string representation.
Returns:
str:
string representation of the client specs
"""
client_parameters_json = json.loads(self.client_parameters.serialize())
return json.dumps(
{
"input_signs": self.input_signs,
"client_parameters": client_parameters_json,
"output_signs": self.output_signs,
}
)
@staticmethod
def unserialize(serialized_client_specs: str) -> "ClientSpecs":
"""
Create client specs from its string representation.
Args:
serialized_client_specs (str):
client specs to unserialize
Returns:
ClientSpecs:
unserialized client specs
"""
raw_specs = json.loads(serialized_client_specs)
client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8")
client_parameters = ClientParameters.unserialize(client_parameters_bytes)
return ClientSpecs(raw_specs["input_signs"], client_parameters, raw_specs["output_signs"])
def serialize_public_args(self, args: PublicArguments) -> bytes: # pylint: disable=no-self-use
"""
Serialize public arguments to bytes.
Args:
args (PublicArguments):
public arguments to serialize
Returns:
bytes:
serialized public arguments
"""
return args.serialize()
def unserialize_public_args(self, serialized_args: bytes) -> PublicArguments:
"""
Unserialize public arguments from bytes.
Args:
serialized_args (bytes):
serialized public arguments
Returns:
PublicArguments:
unserialized public arguments
"""
return PublicArguments.unserialize(self.client_parameters, serialized_args)
def serialize_public_result(self, result: PublicResult) -> bytes: # pylint: disable=no-self-use
"""
Serialize public result to bytes.
Args:
result (PublicResult):
public result to serialize
Returns:
bytes:
serialized public result
"""
return result.serialize()
def unserialize_public_result(self, serialized_result: bytes) -> PublicResult:
"""
Unserialize public result from bytes.
Args:
serialized_result (bytes):
serialized public result
Returns:
PublicResult:
unserialized public result
"""
return PublicResult.unserialize(self.client_parameters, serialized_result)

View File

@@ -0,0 +1,685 @@
"""
Declaration of various functions and constants related to compilation.
"""
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Set, Tuple
import networkx as nx
from ..dtypes import Float, Integer
from ..representation import Graph, Node, Operation
from .artifacts import DebugArtifacts
# ruff: noqa: ERA001
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
"""
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
Args:
graph (Graph):
graph to search and update
artifacts (Optional[DebugArtifacts], default = None):
compilation artifacts to store information about the fusing process
Raises:
RuntimeError:
if there is a subgraph which needs to be fused cannot be fused
"""
nx_graph = graph.graph
processed_terminal_nodes: Set[Node] = set()
fusing_floats = True
while True:
subgraph_to_fuse = (
find_float_subgraph_with_unique_terminal_node(
graph,
processed_terminal_nodes,
)
if fusing_floats
else find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
graph,
processed_terminal_nodes,
)
)
if subgraph_to_fuse is None:
if fusing_floats:
fusing_floats = False
processed_terminal_nodes.clear()
continue
break
all_nodes, start_nodes, terminal_node = subgraph_to_fuse
processed_terminal_nodes.add(terminal_node)
fused_node, node_before_subgraph = convert_subgraph_to_subgraph_node(
graph,
all_nodes,
start_nodes,
terminal_node,
)
nx_graph.add_node(fused_node)
if terminal_node in graph.output_nodes.values():
output_node_to_idx: Dict[Node, List[int]] = {
out_node: [] for out_node in graph.output_nodes.values()
}
for output_idx, output_node in graph.output_nodes.items():
output_node_to_idx[output_node].append(output_idx)
for output_idx in output_node_to_idx.get(terminal_node, []):
graph.output_nodes[output_idx] = fused_node
terminal_node_succ = list(nx_graph.successors(terminal_node))
for succ in terminal_node_succ:
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
for edge_key, edge_data in succ_edge_data.items():
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
new_edge_data = deepcopy(edge_data)
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
graph.prune_useless_nodes()
if artifacts is not None:
artifacts.add_graph("after-fusing", graph)
def find_float_subgraph_with_unique_terminal_node(
graph: Graph,
processed_terminal_nodes: Set[Node],
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
"""
Find a subgraph with float computations that end with an integer output.
Args:
graph (Graph):
graph to search
processed_terminal_nodes (Set[Node]):
set of terminal nodes which have already been searched for float subgraphs
Returns:
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
None if there are no such subgraphs,
tuple containing all nodes in the subgraph, start nodes of the subgraph,
and terminal node of the subgraph otherwise
"""
nx_graph = graph.graph
terminal_nodes = (
node
for node in nx_graph.nodes()
if (
node not in processed_terminal_nodes
and any(isinstance(input.dtype, Float) for input in node.inputs)
and isinstance(node.output.dtype, Integer)
)
)
try:
terminal_node = next(terminal_nodes)
except StopIteration:
return None
all_nodes: Dict[Node, None] = {}
start_single_int_output_nodes_search_from = terminal_node
while True:
all_nodes, start_nodes = find_closest_integer_output_nodes(
graph,
[start_single_int_output_nodes_search_from],
all_nodes,
)
variable_start_nodes = [
start_node for start_node in start_nodes if start_node.operation != Operation.Constant
]
if len(variable_start_nodes) == 1:
break
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
lca = find_single_lca(graph, variable_start_nodes)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
break
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
all_nodes = add_nodes_from_to(graph, start_nodes, {lca: None}, all_nodes)
# if `lca` is a valid starting node for fusing break
if isinstance(lca.output.dtype, Integer):
# `lca` is the new start node
start_nodes = {lca: None}
break
# otherwise, push a little further
# (e.g., if there is a node just before, which has an integer output)
start_single_int_output_nodes_search_from = lca
return all_nodes, start_nodes, terminal_node
def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
graph: Graph,
processed_terminal_nodes: Set[Node],
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
"""
Find a subgraph with a tlu computation that has multiple variable inputs \
where all variable inputs share a common ancestor.
Args:
graph (Graph):
graph to search
processed_terminal_nodes (Set[Node]):
set of terminal nodes which have already been searched for tlu subgraphs
Returns:
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
None if there are no such subgraphs,
tuple containing all nodes in the subgraph, start nodes of the subgraph,
and terminal node of the subgraph otherwise
"""
nx_graph = graph.graph
terminal_nodes = (
node
for node in nx_graph.nodes()
if (
node not in processed_terminal_nodes
and node.converted_to_table_lookup
and all(isinstance(input.dtype, Integer) for input in node.inputs)
and isinstance(node.output.dtype, Integer)
and len(
[
pred
for pred in nx_graph.predecessors(node)
if pred.operation != Operation.Constant
]
)
> 1
)
)
try:
terminal_node = next(terminal_nodes)
except StopIteration:
return None
all_nodes: Dict[Node, None] = {}
while True:
variable_start_nodes = list(nx_graph.predecessors(terminal_node))
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
lca = find_single_lca(graph, variable_start_nodes)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
start_nodes = {node: None for node in variable_start_nodes}
all_nodes = {node: None for node in variable_start_nodes + [terminal_node]}
break
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
all_nodes = add_nodes_from_to(
graph,
list(nx_graph.predecessors(terminal_node)),
{lca: None},
all_nodes,
)
all_nodes[terminal_node] = None
# if `lca` is a valid starting node for fusing break
if isinstance(lca.output.dtype, Integer):
# `lca` is the new start node
start_nodes = {lca: None}
break
return all_nodes, start_nodes, terminal_node
def find_single_lca(graph: Graph, nodes: List[Node]) -> Optional[Node]:
"""
Find the single lowest common ancestor of a list of nodes.
Args:
graph (Graph):
graph to search for single lca
nodes (List[Node]):
nodes to find the single lca of
Returns
Optional[Node]:
single lca if it exists, None otherwise
"""
nx_graph = graph.graph
# find all ancestors of `nodes`
# nodes themselves need to be in this set because the single lca can be within `nodes`
all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes]
# find common ancestors among `nodes`
# if the single lca exists, it's in this set
common_ancestors = {
node
for node in nx_graph.nodes()
if node.operation != Operation.Constant
and all(node in ancestors for ancestors in all_ancestors)
}
# iterate over every node in the graph reversed topological order
# this is to ensure result, if found, is the single "lowest" common ancestor
for candidate in reversed(list(nx.topological_sort(nx_graph))):
# check if node is a common ancestor of all `nodes`
if candidate not in common_ancestors:
# if not, it cannot be the single lca
continue
# check if node is a single common ancestor of `nodes`
if is_single_common_ancestor(graph, candidate, nodes):
# if so, it's the single lca of `nodes`
# so return it
return candidate
# if none of the nodes in `common_ancestors` is the single lca
# there is no single lca of this set of nodes, so return None
return None
def is_single_common_ancestor(
graph: Graph,
candidate: Node,
nodes: List[Node],
) -> bool:
"""
Determine if a node is the single common ancestor of a list of nodes.
Note that this function doesn't care about `lowest` property of `lca`.
Args:
graph (Graph):
graph to perform the check
candidate (Node):
node to determine single common ancestor status
nodes (List[Node]):
nodes to determine single common ancestor status against
Returns
bool:
True if `candidate` is a single common ancestor of `nodes`, False otherwise
"""
nx_graph = graph.graph
# create a subgraph with `candidate` node
subgraph = nx.DiGraph()
subgraph.add_node(candidate)
# iterate over `nodes` to add them to the subgraph
# along with every path from `candidate` to them
for node in nodes:
subgraph.add_node(node)
for path in nx.all_simple_paths(nx_graph, source=candidate, target=node):
nx.add_path(subgraph, path)
# iterate over the nodes of the subgraph
for node in subgraph.nodes():
# the condition below doesn't apply to `candidate`
# as its predecessors are not in the subgraph
if node == candidate:
continue
# find number of predecessors in the subgraph and in the original graph
# except constant nodes in the original graph as
# - they are not in the subgraph
# - they don't affect fusability status
predecessor_count_in_subgraph = len(list(subgraph.predecessors(node)))
predecessor_count_in_nx_graph = len(
[pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant]
)
# see if number of predecessors are different
if predecessor_count_in_subgraph != predecessor_count_in_nx_graph:
# if so, `candidate` cannot be a single common ancestor
# reasoning for is explained below
return False
# if every node in the subgraph has the same number of predecessors
# as in the original graph `candidate` is in fact a single common ancestor
return True
# Here is why this function works.
#
# Legend:
# - /|\- = Edge
# - (...) = Intermediate Node
# - {...} = Candidate Node
# - [...] = Node of which single common ancestor is searched
# - {[...]} = Both Candidate Node and Node of which single common ancestor is searched
#
# Consider the folowing graph:
#
# (3) (x) (2)
# \ / \ /
# [{*}] (/)
# \ /
# [+]
#
# - Operation: (x * 3) + (x / 2)
# - Candidate: {*}
# - Nodes: [*] and [+]
#
# So we want to know if multiplication node is a single common ancestor of
# multiplication and addition nodes. The result is no in this case for our purposes.
#
# Once you apply the subgraph creation above, you'll get the following graph:
#
# (*)
# |
# (+)
#
# In this subgraph, addition node only have a single predecessor,
# which means there is path leading to the addition node and that path doesn't include
# the multiplication node, so we conclude multiplication node is not a single common ancestor
#
# Now, consider the folowing graph:
#
# (3) {x} (2)
# \ / \ /
# [*] (/)
# \ /
# [+]
#
# - Operation: (x * 3) + (x / 2)
# - Candidate: {x}
# - Nodes: [*] and [+]
#
# So we want to know if the input node 'x' is the single common ancestor of
# multiplication and addition nodes. The result is yes in this case.
#
# Once you apply the subgraph creation above, you'll get the following graph:
#
# {x}
# / \
# [*] (/)
# \ /
# [+]
#
# In this subgraph, every node except the candidate node
# will keep all of their non-constant predecessors,
# which means all of their non-constant predecessors originated
# from the `candidate`, so it's a single common anscestor.
#
# When you think about it, this implementation makes a lot of sense for our purposes
# It basically determines if `nodes` "solely" depend on the `candidate`,
# which is the condition for fusing.
def find_closest_integer_output_nodes(
graph: Graph,
start_nodes: List[Node],
all_nodes: Dict[Node, None],
) -> Tuple[Dict[Node, None], Dict[Node, None]]:
"""
Find the closest upstream integer output nodes to a set of start nodes in a graph.
Args:
graph (Graph):
graph to search
start_nodes (List[Node]):
nodes from which to start the search
all_nodes (Dict[Node, None]):
set of nodes to be extended with visited nodes during the search
Returns:
Tuple[Dict[Node, None], Dict[Node, None]]:
tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes`
"""
nx_graph = graph.graph
closest_integer_output_nodes: Dict[Node, None] = {}
visited_nodes: Set[Node] = set()
current_nodes = {start_node: None for start_node in start_nodes}
while current_nodes:
next_nodes: Dict[Node, None] = {}
for node in current_nodes:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.update({node: None})
for pred in nx_graph.predecessors(node):
if isinstance(pred.output.dtype, Integer):
closest_integer_output_nodes.update({pred: None})
all_nodes.update({pred: None})
else:
next_nodes.update({pred: None})
current_nodes = next_nodes
return all_nodes, closest_integer_output_nodes
def add_nodes_from_to(
graph: Graph,
from_nodes: Iterable[Node],
to_nodes: Dict[Node, None],
all_nodes: Dict[Node, None],
) -> Dict[Node, None]:
"""
Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`.
Args:
graph (Graph):
graph to traverse
from_nodes (Iterable[Node]):
nodes from which extending `all_nodes` start
to_nodes (Dict[Node, None]):
nodes to which extending `all_nodes` stop
all_nodes (Dict[Node, None]):
nodes to be extended
Returns:
Dict[Node, None]:
extended `all_nodes`
"""
nx_graph = graph.graph
all_nodes.update(to_nodes)
visited_nodes: Set[Node] = set()
current_nodes = {from_node: None for from_node in from_nodes}
while current_nodes:
next_nodes: Dict[Node, None] = {}
for node in current_nodes:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.update({node: None})
if node not in to_nodes:
predecessors = nx_graph.predecessors(node)
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
current_nodes = next_nodes
return all_nodes
def convert_subgraph_to_subgraph_node(
graph: Graph,
all_nodes: Dict[Node, None],
start_nodes: Dict[Node, None],
terminal_node: Node,
) -> Tuple[Node, Node]:
"""
Convert a subgraph to Operation.Generic node.
Args:
graph (Graph):
orginal graph
all_nodes (Dict[Node, None]):
all nodes in the subgraph
start_nodes (Dict[Node, None]):
start nodes of the subgraph
terminal_node (Node):
terminal node of the subgraph
Raises:
RuntimeError:
if subgraph is not fusable
Returns:
Tuple[Node, Node]:
None if the subgraph cannot be fused,
subgraph node and its predecessor otherwise
"""
nx_graph = graph.graph
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
if len(variable_input_nodes) != 1:
base_highlighted_nodes = {
node: ["within this subgraph", node.location] for node in all_nodes
}
for variable_input_node in variable_input_nodes:
base_highlighted_nodes[variable_input_node] = [
"this is one of the input nodes",
variable_input_node.location,
]
raise RuntimeError(
"A subgraph within the function you are trying to compile cannot be fused "
"because it has multiple input nodes\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
)
variable_input_node = variable_input_nodes[0]
check_subgraph_fusability(graph, all_nodes, variable_input_node)
nx_subgraph = nx.MultiDiGraph(nx_graph)
nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes]
nx_subgraph.remove_nodes_from(nodes_to_remove)
subgraph_variable_input_node = Node.input("input", deepcopy(variable_input_node.output))
nx_subgraph.add_node(subgraph_variable_input_node)
subgraph_variable_input_node.location = variable_input_node.location
subgraph_variable_input_node.tag = variable_input_node.tag
subgraph_variable_input_node.created_at = variable_input_node.created_at
variable_input_node_successors = {
node: None for node in all_nodes if node in nx_graph.succ[variable_input_node]
}
for successor in variable_input_node_successors:
edges = deepcopy(nx_subgraph.get_edge_data(variable_input_node, successor))
for edge_key, edge_data in edges.items():
nx_subgraph.remove_edge(variable_input_node, successor, key=edge_key)
new_edge_data = deepcopy(edge_data)
nx_subgraph.add_edge(
subgraph_variable_input_node,
successor,
key=edge_key,
**new_edge_data,
)
original_location = terminal_node.location
original_tag = terminal_node.tag
original_created_at = terminal_node.created_at
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
subgraph_node = Node.generic(
"subgraph",
subgraph_variable_input_node.inputs,
terminal_node.output,
lambda x, subgraph, terminal_node: subgraph.evaluate(x)[terminal_node],
kwargs={
"subgraph": subgraph,
"terminal_node": terminal_node,
},
)
subgraph_node.location = original_location
subgraph_node.tag = original_tag
subgraph_node.created_at = original_created_at
return subgraph_node, variable_input_node
def check_subgraph_fusability(
graph: Graph,
all_nodes: Dict[Node, None],
variable_input_node: Node,
):
"""
Determine if a subgraph can be fused.
e.g.,
shuffling or reshaping a tensor make fusing impossible as there should be a one-to-one mapping
between each cell of the input and each cell of the output for table lookups
Args:
graph (Graph):
original graph
all_nodes (Dict[Node, None]):
all nodes in the subgraph
variable_input_node (Node):
variable input node to the subgraph
Raises:
RuntimeError:
if subgraph is not fusable
"""
base_highlighted_nodes = {node: ["within this subgraph", node.location] for node in all_nodes}
base_highlighted_nodes[variable_input_node] = [
"with this input node",
variable_input_node.location,
]
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
for node in non_constant_nodes:
if node == variable_input_node:
continue
if not node.is_fusable:
base_highlighted_nodes[node] = ["this node is not fusable", node.location]
raise RuntimeError(
"A subgraph within the function you are trying to compile cannot be fused "
"because of a node, which is marked explicitly as non-fusable\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
)
if node.output.shape != variable_input_node.output.shape:
base_highlighted_nodes[node] = [
"this node has a different shape than the input node",
node.location,
]
raise RuntimeError(
"A subgraph within the function you are trying to compile cannot be fused "
"because of a node, which is has a different shape than the input node\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
)
return True

View File

@@ -0,0 +1,7 @@
"""
Define available data types and their semantics.
"""
from .base import BaseDataType
from .float import Float
from .integer import Integer, SignedInteger, UnsignedInteger

View File

@@ -0,0 +1,17 @@
"""
Declaration of `BaseDataType` abstract class.
"""
from abc import ABC, abstractmethod
class BaseDataType(ABC):
"""BaseDataType abstract class, to form a basis for data types."""
@abstractmethod
def __eq__(self, other: object) -> bool:
pass # pragma: no cover
@abstractmethod
def __str__(self) -> str:
pass # pragma: no cover

View File

@@ -0,0 +1,31 @@
"""
Declaration of `Float` class.
"""
from .base import BaseDataType
class Float(BaseDataType):
"""
Float class, to represent floating point numbers.
"""
bit_width: int
def __init__(self, bit_width: int):
super().__init__()
if bit_width not in [16, 32, 64]:
message = (
f"Float({repr(bit_width)}) is not supported "
f"(bit width must be one of 16, 32 or 64)"
)
raise ValueError(message)
self.bit_width = bit_width
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.bit_width == other.bit_width
def __str__(self) -> str:
return f"float{self.bit_width}"

View File

@@ -0,0 +1,155 @@
"""
Declaration of `Integer` class.
"""
import math
from functools import partial
from typing import Any
import numpy as np
from .base import BaseDataType
class Integer(BaseDataType):
"""
Integer class, to represent integers.
"""
is_signed: bool
bit_width: int
@staticmethod
def that_can_represent(value: Any, force_signed: bool = False) -> "Integer":
"""
Get the minimal `Integer` that can represent `value`.
Args:
value (Any):
value that needs to be represented
force_signed (bool, default = False):
whether to force signed integers or not
Returns:
Integer:
minimal `Integer` that can represent `value`
Raises:
ValueError:
if `value` cannot be represented by `Integer`
"""
lower_bound: int
upper_bound: int
if isinstance(value, list):
try:
value = np.array(value)
except Exception: # pylint: disable=broad-except
# here we try our best to convert the list to np.ndarray
# if it fails we raise the exception at the else branch below
pass
if isinstance(value, (int, np.integer)):
lower_bound = int(value)
upper_bound = int(value)
elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.integer):
lower_bound = int(value.min())
upper_bound = int(value.max())
else:
message = f"Integer cannot represent {repr(value)}"
raise ValueError(message)
def bits_to_represent_int(value: int, force_signed: bool) -> int:
bits: int
if value == 0:
return 1
if value < 0:
bits = int(math.ceil(math.log2(abs(value)))) + 1
else:
bits = int(math.ceil(math.log2(value + 1)))
if force_signed:
bits += 1
return bits
is_signed = force_signed or lower_bound < 0
bit_width = (
bits_to_represent_int(lower_bound, is_signed)
if lower_bound == upper_bound
else max(
bits_to_represent_int(lower_bound, is_signed),
bits_to_represent_int(upper_bound, is_signed),
)
)
return Integer(is_signed, bit_width)
def __init__(self, is_signed: bool, bit_width: int):
super().__init__()
if not isinstance(bit_width, int) or bit_width <= 0:
integer_str = "SignedInteger" if is_signed else "UnsignedInteger"
message = (
f"{integer_str}({repr(bit_width)}) is not supported "
f"(bit width must be a positive integer)"
)
raise ValueError(message)
self.is_signed = is_signed
self.bit_width = bit_width
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__)
and self.is_signed == other.is_signed
and self.bit_width == other.bit_width
)
def __str__(self) -> str:
return f"{('int' if self.is_signed else 'uint')}{self.bit_width}"
def min(self) -> int:
"""
Get the minumum value that can be represented by the `Integer`.
Returns:
int:
minumum value that can be represented by the `Integer`
"""
return 0 if not self.is_signed else -(2 ** (self.bit_width - 1))
def max(self) -> int:
"""
Get the maximum value that can be represented by the `Integer`.
Returns:
int:
maximum value that can be represented by the `Integer`
"""
return (2**self.bit_width) - 1 if not self.is_signed else (2 ** (self.bit_width - 1)) - 1
def can_represent(self, value: int) -> bool:
"""
Get whether `value` can be represented by the `Integer` or not.
Args:
value (int):
value to check representability
Returns:
bool:
True if `value` is representable by the `integer`, False otherwise
"""
return self.min() <= value <= self.max()
SignedInteger = partial(Integer, True)
UnsignedInteger = partial(Integer, False)

View File

@@ -0,0 +1,74 @@
"""
Declaration of various functions and constants related to data types.
"""
from typing import List
from ..internal.utils import assert_that
from .base import BaseDataType
from .float import Float
from .integer import Integer, SignedInteger, UnsignedInteger
def combine_dtypes(dtypes: List[BaseDataType]) -> BaseDataType:
"""
Get the 'BaseDataType' that can represent a set of 'BaseDataType's.
Args:
dtypes (List[BaseDataType]):
dtypes to combine
Returns:
BaseDataType:
dtype that can hold all the given dtypes (potentially lossy)
"""
assert_that(len(dtypes) != 0)
assert_that(all(isinstance(dtype, (Integer, Float)) for dtype in dtypes))
def combine_2_dtypes(dtype1: BaseDataType, dtype2: BaseDataType) -> BaseDataType:
result: BaseDataType = dtype1
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
max_bits = max(dtype1.bit_width, dtype2.bit_width)
if dtype1.is_signed and dtype2.is_signed:
result = SignedInteger(max_bits)
elif not dtype1.is_signed and not dtype2.is_signed:
result = UnsignedInteger(max_bits)
elif dtype1.is_signed and not dtype2.is_signed:
# if dtype2 has the bigger bit_width,
# we need a signed integer that can hold
# it, so add 1 bit of sign to its bit_width
if dtype2.bit_width >= dtype1.bit_width:
new_bit_width = dtype2.bit_width + 1
result = SignedInteger(new_bit_width)
else:
result = SignedInteger(dtype1.bit_width)
elif not dtype1.is_signed and dtype2.is_signed:
# Same as above, with dtype1 and dtype2 switched around
if dtype1.bit_width >= dtype2.bit_width:
new_bit_width = dtype1.bit_width + 1
result = SignedInteger(new_bit_width)
else:
result = SignedInteger(dtype2.bit_width)
elif isinstance(dtype1, Float) and isinstance(dtype2, Float):
max_bits = max(dtype1.bit_width, dtype2.bit_width)
result = Float(max_bits)
elif isinstance(dtype1, Float):
result = dtype1
elif isinstance(dtype2, Float):
result = dtype2
return result
result = dtypes[0]
for other in dtypes[1:]:
result = combine_2_dtypes(result, other)
return result

View File

@@ -0,0 +1,11 @@
"""
Provide additional features that are not present in numpy.
"""
from .array import array
from .ones import one, ones
from .round_bit_pattern import AutoRounder, round_bit_pattern
from .table import LookupTable
from .tag import tag
from .univariate import univariate
from .zeros import zero, zeros

View File

@@ -0,0 +1,59 @@
"""
Declaration of `array` function, to simplify creation of encrypted arrays.
"""
from typing import Any, Union
import numpy as np
from ..dtypes.utils import combine_dtypes
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
def array(values: Any) -> Union[np.ndarray, Tracer]:
"""
Create an encrypted array from either encrypted or clear values.
Args:
values (Any):
array like object compatible with numpy to construct the resulting encrypted array
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray with values otherwise
"""
# pylint: disable=protected-access
is_tracing = Tracer._is_tracing
# pylint: enable=protected-access
if not isinstance(values, np.ndarray):
values = np.array(values)
if not is_tracing:
return values
shape = values.shape
values = values.flatten()
for i, value in enumerate(values):
if not isinstance(value, Tracer):
values[i] = Tracer.sanitize(value)
if not values[i].output.is_scalar:
message = "Encrypted arrays can only be created from scalars"
raise ValueError(message)
dtype = combine_dtypes([value.output.dtype for value in values])
is_encrypted = True
computation = Node.generic(
"array",
[value.output for value in values],
Value(dtype, shape, is_encrypted),
lambda *args: np.array(args).reshape(shape),
)
return Tracer(computation, values)

View File

@@ -0,0 +1,56 @@
"""
Declaration of `ones` and `one` functions, to simplify creation of encrypted ones.
"""
from typing import Tuple, Union
import numpy as np
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
def ones(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
"""
Create an encrypted array of ones.
Args:
shape (Tuple[int, ...]):
shape of the array
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray filled with ones otherwise
"""
# pylint: disable=protected-access
is_tracing = Tracer._is_tracing
# pylint: enable=protected-access
numpy_ones = np.ones(shape, dtype=np.int64)
if is_tracing:
computation = Node.generic(
"ones",
[],
Value.of(numpy_ones, is_encrypted=True),
lambda: np.ones(shape, dtype=np.int64),
)
return Tracer(computation, [])
return numpy_ones
def one() -> Union[np.ndarray, Tracer]:
"""
Create an encrypted scalar with the value of one.
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray with one otherwise
"""
return ones(())

View File

@@ -0,0 +1,246 @@
"""
Declaration of `round_bit_pattern` function, to provide an interface for rounded table lookups.
"""
import threading
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Tuple, Union
import numpy as np
from ..dtypes import Integer
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
local = threading.local()
# pylint: disable=protected-access
local._is_adjusting = False
# pylint: enable=protected-access
class Adjusting(BaseException):
"""
Adjusting class, to be used as early stop signal during adjustment.
"""
rounder: "AutoRounder"
input_min: int
input_max: int
def __init__(self, rounder: "AutoRounder", input_min: int, input_max: int):
super().__init__()
self.rounder = rounder
self.input_min = input_min
self.input_max = input_max
class AutoRounder:
"""
AutoRounder class, to optimize for number of msbs to keep druing round bit pattern operation.
"""
target_msbs: int
is_adjusted: bool
input_min: int
input_max: int
input_bit_width: int
lsbs_to_remove: int
def __init__(self, target_msbs: int = MAXIMUM_TLU_BIT_WIDTH):
# pylint: disable=protected-access
if local._is_adjusting:
message = (
"AutoRounders cannot be constructed during adjustment, "
"please construct AutoRounders outside the function and reference it"
)
raise RuntimeError(message)
# pylint: enable=protected-access
self.target_msbs = target_msbs
self.is_adjusted = False
self.input_min = 0
self.input_max = 0
self.input_bit_width = 0
self.lsbs_to_remove = 0
@staticmethod
def adjust(function: Callable, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
"""
Adjust AutoRounders in a function using an inputset.
"""
# pylint: disable=protected-access,too-many-branches
try: # extract underlying function for decorators
function = function.function # type: ignore
assert callable(function)
except AttributeError:
pass
if local._is_adjusting:
message = "AutoRounders cannot be adjusted recursively"
raise RuntimeError(message)
try:
local._is_adjusting = True
while True:
rounder = None
for sample in inputset:
if not isinstance(sample, tuple):
sample = (sample,)
try:
function(*sample)
except Adjusting as adjuster:
rounder = adjuster.rounder
rounder.input_min = min(rounder.input_min, adjuster.input_min)
rounder.input_max = max(rounder.input_max, adjuster.input_max)
input_value = Value.of([rounder.input_min, rounder.input_max])
assert isinstance(input_value.dtype, Integer)
rounder.input_bit_width = input_value.dtype.bit_width
if rounder.input_bit_width - rounder.lsbs_to_remove > rounder.target_msbs:
rounder.lsbs_to_remove = rounder.input_bit_width - rounder.target_msbs
else:
return
if rounder is None:
message = "AutoRounders cannot be adjusted with an empty inputset"
raise ValueError(message)
rounder.is_adjusted = True
finally:
local._is_adjusting = False
# pylint: enable=protected-access,too-many-branches
def round_bit_pattern(
x: Union[int, np.integer, List, np.ndarray, Tracer],
lsbs_to_remove: Union[int, AutoRounder],
) -> Union[int, np.integer, List, np.ndarray, Tracer]:
"""
Round the bit pattern of an integer.
If `lsbs_to_remove` is an `AutoRounder`:
corresponding integer value will be determined by adjustment process.
x = 0b_0000_0000 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0001 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0010 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0011 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0100 , lsbs_to_remove = 3 => 0b_0000_1000
x = 0b_0000_0101 , lsbs_to_remove = 3 => 0b_0000_1000
x = 0b_0000_0110 , lsbs_to_remove = 3 => 0b_0000_1000
x = 0b_0000_0111 , lsbs_to_remove = 3 => 0b_0000_1000
x = 0b_1010_0000 , lsbs_to_remove = 3 => 0b_1010_0000
x = 0b_1010_0001 , lsbs_to_remove = 3 => 0b_1010_0000
x = 0b_1010_0010 , lsbs_to_remove = 3 => 0b_1010_0000
x = 0b_1010_0011 , lsbs_to_remove = 3 => 0b_1010_0000
x = 0b_1010_0100 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_0101 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_0110 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_0111 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1000 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1001 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1010 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1011 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1100 , lsbs_to_remove = 3 => 0b_1011_0000
x = 0b_1010_1101 , lsbs_to_remove = 3 => 0b_1011_0000
x = 0b_1010_1110 , lsbs_to_remove = 3 => 0b_1011_0000
x = 0b_1010_1111 , lsbs_to_remove = 3 => 0b_1011_0000
x = 0b_1011_1000 , lsbs_to_remove = 3 => 0b_1011_1000
x = 0b_1011_1001 , lsbs_to_remove = 3 => 0b_1011_1000
x = 0b_1011_1010 , lsbs_to_remove = 3 => 0b_1011_1000
x = 0b_1011_1011 , lsbs_to_remove = 3 => 0b_1011_1000
x = 0b_1011_1100 , lsbs_to_remove = 3 => 0b_1100_0000
x = 0b_1011_1101 , lsbs_to_remove = 3 => 0b_1100_0000
x = 0b_1011_1110 , lsbs_to_remove = 3 => 0b_1100_0000
x = 0b_1011_1111 , lsbs_to_remove = 3 => 0b_1100_0000
Args:
x (Union[int, np.integer, np.ndarray, Tracer]):
input to round
lsbs_to_remove (Union[int, AutoRounder]):
number of the least significant bits to remove
or an auto rounder object which will be used to determine the integer value
Returns:
Union[int, np.integer, np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
rounded value(s) otherwise
"""
# pylint: disable=protected-access,too-many-branches
if isinstance(lsbs_to_remove, AutoRounder):
if local._is_adjusting:
if not lsbs_to_remove.is_adjusted:
raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x))) # type: ignore
elif not lsbs_to_remove.is_adjusted:
message = (
"AutoRounders cannot be used before adjustment, "
"please call AutoRounder.adjust with the function that will be compiled "
"and provide the exact inputset that will be used for compilation"
)
raise RuntimeError(message)
lsbs_to_remove = lsbs_to_remove.lsbs_to_remove
assert isinstance(lsbs_to_remove, int)
def evaluator(
x: Union[int, np.integer, np.ndarray],
lsbs_to_remove: int,
) -> Union[int, np.integer, np.ndarray]:
if lsbs_to_remove == 0:
return x
unit = 1 << lsbs_to_remove
half = 1 << lsbs_to_remove - 1
rounded = (x + half) // unit
return rounded * unit
if isinstance(x, Tracer):
computation = Node.generic(
"round_bit_pattern",
[x.output],
deepcopy(x.output),
evaluator,
kwargs={"lsbs_to_remove": lsbs_to_remove},
)
return Tracer(computation, [x])
if isinstance(x, list): # pragma: no cover
try:
x = np.array(x)
except Exception: # pylint: disable=broad-except
pass
if isinstance(x, np.ndarray):
if not np.issubdtype(x.dtype, np.integer):
message = (
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
)
raise TypeError(message)
elif not isinstance(x, (int, np.integer)):
message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}"
raise TypeError(message)
return evaluator(x, lsbs_to_remove)
# pylint: enable=protected-access,too-many-branches

View File

@@ -0,0 +1,134 @@
"""
Declaration of `LookupTable` class.
"""
from copy import deepcopy
from typing import Any, Union
import numpy as np
from ..dtypes import BaseDataType, Integer
from ..representation import Node
from ..tracing import Tracer
class LookupTable:
"""
LookupTable class, to provide a way to do direct table lookups.
"""
table: np.ndarray
output_dtype: BaseDataType
def __init__(self, table: Any):
is_valid = True
try:
self.table = table if isinstance(table, np.ndarray) else np.array(table)
except Exception: # pragma: no cover # pylint: disable=broad-except
# here we try our best to convert the table to np.ndarray
# if it fails we raise the exception at the end of the function
is_valid = False
if is_valid:
is_valid = self.table.size > 0
if is_valid:
minimum: int = 0
maximum: int = 0
if np.issubdtype(self.table.dtype, np.integer):
minimum = int(self.table.min())
maximum = int(self.table.max())
if self.table.ndim != 1:
is_valid = False
else:
is_valid = all(isinstance(item, LookupTable) for item in self.table.flat)
if is_valid:
minimum = int(self.table.flat[0].table.min())
maximum = int(self.table.flat[0].table.max())
for item in self.table.flat:
minimum = min(minimum, item.table.min())
maximum = max(maximum, item.table.max())
self.output_dtype = Integer.that_can_represent([minimum, maximum])
if not is_valid:
message = f"LookupTable cannot be constructed with {repr(table)}"
raise ValueError(message)
def __repr__(self):
return str(list(self.table))
def __getitem__(self, key: Union[int, np.integer, np.ndarray, Tracer]):
if not isinstance(key, Tracer):
return LookupTable.apply(key, self.table)
if not isinstance(key.output.dtype, Integer):
message = f"LookupTable cannot be looked up with {key.output}"
raise ValueError(message)
table = self.table
if not np.issubdtype(self.table.dtype, np.integer):
try:
table = np.broadcast_to(table, key.output.shape)
except Exception as error:
message = (
f"LookupTable of shape {self.table.shape} "
f"cannot be looked up with {key.output}"
)
raise ValueError(message) from error
output = deepcopy(key.output)
output.dtype = self.output_dtype
computation = Node.generic(
"tlu",
[key.output],
output,
LookupTable.apply,
kwargs={"table": table},
)
return Tracer(computation, [key])
@staticmethod
def apply(
key: Union[int, np.integer, np.ndarray],
table: np.ndarray,
) -> Union[int, np.integer, np.ndarray]:
"""
Apply lookup table.
Args:
key (Union[int, np.integer, np.ndarray]):
lookup key
table (np.ndarray):
lookup table
Returns:
Union[int, np.integer, np.ndarray]:
lookup result
Raises:
ValueError:
if `table` cannot be looked up with `key`
"""
if not isinstance(key, (int, np.integer, np.ndarray)) or (
isinstance(key, np.ndarray) and not np.issubdtype(key.dtype, np.integer)
):
message = f"LookupTable cannot be looked up with {key}"
raise ValueError(message)
if np.issubdtype(table.dtype, np.integer):
return table[key]
if not isinstance(key, np.ndarray) or key.shape != table.shape:
message = f"LookupTable of shape {table.shape} cannot be looked up with {key}"
raise ValueError(message)
flat_result = np.fromiter(
(lt.table[k] for lt, k in zip(table.flat, key.flat)),
dtype=np.longlong,
)
return flat_result.reshape(table.shape)

View File

@@ -0,0 +1,24 @@
"""
Declaration of `tag` context manager, to allow tagging certain nodes.
"""
import threading
from contextlib import contextmanager
tag_context = threading.local()
tag_context.stack = []
@contextmanager
def tag(name: str):
"""
Introduce a new tag to the tag stack.
Can be nested, and the resulting tag will be `tag1.tag2`.
"""
tag_context.stack.append(name)
try:
yield
finally:
tag_context.stack.pop()

View File

@@ -0,0 +1,89 @@
"""
Declaration of `univariate` function.
"""
from typing import Any, Callable, Optional, Type, Union
import numpy as np
from ..dtypes import BaseDataType, Float
from ..representation import Node
from ..tracing import ScalarAnnotation, Tracer
from ..values import Value
def univariate(
function: Callable[[Any], Any],
outputs: Optional[Union[BaseDataType, Type[ScalarAnnotation]]] = None,
) -> Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
"""
Wrap a univariate function so that it is traced into a single generic node.
Args:
function (Callable[[Any], Any]):
univariate function to wrap
outputs (Optional[Union[BaseDataType, Type[ScalarAnnotation]]], default = None):
data type of the result, unused during compilation, required for direct definition
Returns:
Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
another univariate function that can be called with a Tracer as well
"""
def wrapper(x: Union[Tracer, Any]) -> Union[Tracer, Any]:
"""
Evaluate or trace wrapped univariate function.
Args:
x (Union[Tracer, Any]):
input of the function
Returns:
Union[Tracer, Any]:
result of tracing or evaluation
"""
if isinstance(x, Tracer):
dtype = (
{64: np.float64, 32: np.float32, 16: np.float16}[x.output.dtype.bit_width]
if isinstance(x.output.dtype, Float)
else np.int64
)
if x.output.shape == ():
sample = dtype(1) # type: ignore
else:
sample = np.ones(x.output.shape, dtype=dtype)
evaluation = function(sample)
output_value = Value.of(evaluation, is_encrypted=x.output.is_encrypted)
if output_value.shape != x.output.shape:
message = f"Function {function.__name__} cannot be used with cnp.univariate"
raise ValueError(message)
# pylint: disable=protected-access
is_direct = Tracer._is_direct
# pylint: enable=protected-access
if is_direct:
if outputs is None:
message = (
"Univariate extension requires "
"`outputs` argument for direct circuit definition "
"(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))"
)
raise ValueError(message)
output_value.dtype = outputs if isinstance(outputs, BaseDataType) else outputs.dtype
computation = Node.generic(
function.__name__,
[x.output],
output_value,
lambda x: function(x), # pylint: disable=unnecessary-lambda
)
return Tracer(computation, [x])
return function(x)
return wrapper

View File

@@ -0,0 +1,56 @@
"""
Declaration of `zeros` and `zero` functions, to simplify creation of encrypted zeros.
"""
from typing import Tuple, Union
import numpy as np
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
def zeros(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
"""
Create an encrypted array of zeros.
Args:
shape (Tuple[int, ...]):
shape of the array
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray filled with zeros otherwise
"""
# pylint: disable=protected-access
is_tracing = Tracer._is_tracing
# pylint: enable=protected-access
numpy_zeros = np.zeros(shape, dtype=np.int64)
if is_tracing:
computation = Node.generic(
"zeros",
[],
Value.of(numpy_zeros, is_encrypted=True),
lambda: np.zeros(shape, dtype=np.int64),
)
return Tracer(computation, [])
return numpy_zeros
def zero() -> Union[np.ndarray, Tracer]:
"""
Create an encrypted scalar with the value of zero.
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray with zero otherwise
"""
return zeros(())

View File

@@ -0,0 +1,3 @@
"""
Export functions that are used internally by other modules for common things (e.g., assertions).
"""

View File

@@ -0,0 +1,32 @@
"""
Declaration of various functions and constants related to the entire project.
"""
def assert_that(condition: bool, message: str = ""):
"""
Assert a condition.
Args:
condition (bool):
condition to assert
message (str):
message to give to `AssertionError` if the condition does not hold
Raises:
AssertionError:
if the condition does not hold
"""
if not condition:
raise AssertionError(message)
def unreachable():
"""
Raise a RuntimeError to indicate unreachable code is entered.
"""
message = "Entered unreachable code"
raise RuntimeError(message)

View File

@@ -0,0 +1,6 @@
"""
Provide `computation graph` to `mlir` functionality.
"""
from .graph_converter import GraphConverter
from .node_converter import NodeConverter

View File

@@ -0,0 +1,739 @@
"""
Declaration of `GraphConverter` class.
"""
# pylint: disable=no-member,no-name-in-module
from copy import deepcopy
from typing import Any, Dict, List, Optional, cast
import concrete.lang as concretelang
import networkx as nx
import numpy as np
from concrete.lang.dialects import fhe, fhelinalg
from mlir.dialects import arith, func
from mlir.ir import (
Attribute,
Context,
InsertionPoint,
IntegerAttr,
IntegerType,
Location,
Module,
OpResult,
RankedTensorType,
)
from ..dtypes import Integer, SignedInteger
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import ClearScalar, EncryptedScalar
from .node_converter import NodeConverter
from .utils import MAXIMUM_TLU_BIT_WIDTH
# pylint: enable=no-member,no-name-in-module
class GraphConverter:
"""
GraphConverter class, to convert computation graphs to their MLIR equivalent.
"""
@staticmethod
def _check_node_convertibility(graph: Graph, node: Node) -> Optional[str]:
"""
Check node convertibility to MLIR.
Args:
graph (Graph):
computation graph of the node
node (Node):
node to be checked
Returns:
Optional[str]:
None if node is convertible to MLIR, the reason for inconvertibility otherwise
"""
# pylint: disable=too-many-branches,too-many-return-statements,too-many-statements
inputs = node.inputs
output = node.output
if node.operation == Operation.Constant:
assert_that(len(inputs) == 0)
if not isinstance(output.dtype, Integer):
return "only integer constants are supported"
elif node.operation == Operation.Input:
assert_that(len(inputs) == 1)
assert_that(inputs[0] == output)
if not isinstance(output.dtype, Integer):
return "only integer inputs are supported"
if output.dtype.is_signed and output.is_clear:
return "only encrypted signed integer inputs are supported"
else:
assert_that(node.operation == Operation.Generic)
if not isinstance(output.dtype, Integer):
return "only integer operations are supported"
name = node.properties["name"]
if name == "add":
assert_that(len(inputs) == 2)
elif name == "array":
assert_that(len(inputs) > 0)
assert_that(all(input.is_scalar for input in inputs))
elif name == "assign.static":
if not inputs[0].is_encrypted:
return "only assignment to encrypted tensors are supported"
elif name in ["bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift"]:
assert_that(len(inputs) == 2)
if all(value.is_encrypted for value in node.inputs):
pred_nodes = graph.ordered_preds_of(node)
if (
name in ["left_shift", "right_shift"]
and cast(Integer, pred_nodes[1].output.dtype).bit_width > 4
):
return "only up to 4-bit shifts are supported"
for pred_node in pred_nodes:
assert isinstance(pred_node.output.dtype, Integer)
if pred_node.output.dtype.is_signed:
return "only unsigned bitwise operations are supported"
elif name == "broadcast_to":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted broadcasting is supported"
elif name == "concatenate":
if not all(input.is_encrypted for input in inputs):
return "only all encrypted concatenate is supported"
elif name in ["conv1d", "conv2d", "conv3d"]:
assert_that(len(inputs) == 2 or len(inputs) == 3)
if not (inputs[0].is_encrypted and inputs[1].is_clear):
return f"only {name} with encrypted input and clear weight is supported"
elif name == "dot":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only dot product between encrypted and clear is supported"
elif name in ["equal", "greater", "greater_equal", "less", "less_equal", "not_equal"]:
assert_that(len(inputs) == 2)
elif name == "expand_dims":
assert_that(len(inputs) == 1)
elif name == "index.static":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted indexing supported"
elif name == "matmul":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only matrix multiplication between encrypted and clear is supported"
elif name == "maxpool":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted maxpool is supported"
elif name == "multiply":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only multiplication between encrypted and clear is supported"
elif name == "negative":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted negation is supported"
elif name == "ones":
assert_that(len(inputs) == 0)
elif name == "reshape":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted reshape is supported"
elif name == "squeeze":
assert_that(len(inputs) == 1)
elif name == "subtract":
assert_that(len(inputs) == 2)
elif name == "sum":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted sum is supported"
elif name == "transpose":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted transpose is supported"
elif name == "zeros":
assert_that(len(inputs) == 0)
else:
assert_that(node.converted_to_table_lookup)
variable_input_indices = [
idx
for idx, pred in enumerate(graph.ordered_preds_of(node))
if not pred.operation == Operation.Constant
]
assert_that(len(variable_input_indices) == 1)
if len(inputs) > 0 and all(input.is_clear for input in inputs):
return "one of the operands must be encrypted"
return None
# pylint: enable=too-many-branches,too-many-return-statements,too-many-statements
@staticmethod
def _check_graph_convertibility(graph: Graph):
"""
Check graph convertibility to MLIR.
Args:
graph (Graph):
computation graph to be checked
Raises:
RuntimeError:
if `graph` is not convertible to MLIR
"""
offending_nodes = {}
if len(graph.output_nodes) > 1:
offending_nodes.update(
{
node: ["only a single output is supported", node.location]
for node in graph.output_nodes.values()
}
)
if len(offending_nodes) == 0:
for node in graph.graph.nodes:
reason = GraphConverter._check_node_convertibility(graph, node)
if reason is not None:
offending_nodes[node] = [reason, node.location]
if len(offending_nodes) != 0:
message = (
"Function you are trying to compile cannot be converted to MLIR\n\n"
+ graph.format(highlighted_nodes=offending_nodes)
)
raise RuntimeError(message)
@staticmethod
def _update_bit_widths(graph: Graph):
"""
Update bit-widths in a computation graph to be convertible to MLIR.
Args:
graph (Graph):
computation graph to be updated
"""
offending_nodes: Dict[Node, List[str]] = {}
max_bit_width = 0
max_bit_width_node = None
first_tlu_node = None
first_signed_node = None
for node in nx.lexicographical_topological_sort(graph.graph):
dtype = node.output.dtype
assert_that(isinstance(dtype, Integer))
current_node_bit_width = (
dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width
)
if (
all(value.is_encrypted for value in node.inputs)
and node.operation == Operation.Generic
and node.properties["name"]
in [
"greater",
"greater_equal",
"less",
"less_equal",
]
):
# implementation of these operators require at least 4 bits
current_node_bit_width = max(current_node_bit_width, 4)
if max_bit_width < current_node_bit_width:
max_bit_width = current_node_bit_width
max_bit_width_node = node
if node.converted_to_table_lookup and first_tlu_node is None:
first_tlu_node = node
if dtype.is_signed and first_signed_node is None:
first_signed_node = node
if first_tlu_node is not None:
if max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
assert max_bit_width_node is not None
offending_nodes[max_bit_width_node] = [
(
{
Operation.Input: f"this input is {max_bit_width}-bits",
Operation.Constant: f"this constant is {max_bit_width}-bits",
Operation.Generic: f"this operation results in {max_bit_width}-bits",
}[max_bit_width_node.operation]
),
max_bit_width_node.location,
]
offending_nodes[first_tlu_node] = [
f"table lookups are only supported on circuits with "
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bits",
first_tlu_node.location,
]
if len(offending_nodes) != 0:
raise RuntimeError(
"Function you are trying to compile cannot be converted to MLIR:\n\n"
+ graph.format(highlighted_nodes=offending_nodes)
)
for node in nx.topological_sort(graph.graph):
assert isinstance(node.output.dtype, Integer)
node.properties["original_bit_width"] = node.output.dtype.bit_width
for value in node.inputs + [node.output]:
dtype = value.dtype
assert_that(isinstance(dtype, Integer))
dtype.bit_width = max_bit_width + 1 if value.is_clear else max_bit_width
@staticmethod
def _offset_negative_lookup_table_inputs(graph: Graph):
"""
Offset negative table lookup inputs to be convertible to MLIR.
Args:
graph (Graph):
computation graph to apply offset
"""
# ugly hack to add an offset before entering a TLU
# if its variable input node has a signed output.
# this makes hardcoded assumptions about the way bit widths are handled in MLIR.
# this does not update the TLU input values to allow for proper table generation.
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic:
if not node.converted_to_table_lookup:
continue
variable_input_index = -1
preds = graph.ordered_preds_of(node)
for index, pred in enumerate(preds):
if pred.operation != Operation.Constant:
variable_input_index = index
break
variable_input_node = preds[variable_input_index]
variable_input_value = variable_input_node.output
variable_input_dtype = variable_input_value.dtype
assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = cast(Integer, variable_input_dtype)
if not variable_input_dtype.is_signed:
continue
variable_input_bit_width = variable_input_dtype.bit_width
offset_constant_dtype = SignedInteger(variable_input_bit_width + 1)
offset_constant_value = abs(variable_input_dtype.min())
offset_constant = Node.constant(offset_constant_value)
offset_constant.output.dtype = offset_constant_dtype
original_bit_width = Integer.that_can_represent(offset_constant_value).bit_width
offset_constant.properties["original_bit_width"] = original_bit_width
add_offset = Node.generic(
"add",
[variable_input_value, ClearScalar(offset_constant_dtype)],
variable_input_value,
np.add,
)
original_bit_width = variable_input_node.properties["original_bit_width"]
add_offset.properties["original_bit_width"] = original_bit_width
nx_graph.remove_edge(variable_input_node, node)
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0)
nx_graph.add_edge(offset_constant, add_offset, input_idx=1)
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
@staticmethod
def _broadcast_assignments(graph: Graph):
"""
Broadcast assignments.
Args:
graph (Graph):
computation graph to transform
"""
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
shape = node.inputs[0].shape
index = node.properties["kwargs"]["index"]
assert_that(isinstance(index, tuple))
while len(index) < len(shape):
index = (*index, slice(None, None, None))
required_value_shape_list = []
for i, indexing_element in enumerate(index):
if isinstance(indexing_element, slice):
n = len(np.zeros(shape[i])[indexing_element])
required_value_shape_list.append(n)
else:
required_value_shape_list.append(1)
required_value_shape = tuple(required_value_shape_list)
actual_value_shape = node.inputs[1].shape
if required_value_shape != actual_value_shape:
preds = graph.ordered_preds_of(node)
pred_to_modify = preds[1]
modified_value = deepcopy(pred_to_modify.output)
modified_value.shape = required_value_shape
try:
np.broadcast_to(np.zeros(actual_value_shape), required_value_shape)
modified_value.is_encrypted = True
modified_value.dtype = node.output.dtype
modified_pred = Node.generic(
"broadcast_to",
[pred_to_modify.output],
modified_value,
np.broadcast_to,
kwargs={"shape": required_value_shape},
)
except Exception: # pylint: disable=broad-except
np.reshape(np.zeros(actual_value_shape), required_value_shape)
modified_pred = Node.generic(
"reshape",
[pred_to_modify.output],
modified_value,
np.reshape,
kwargs={"newshape": required_value_shape},
)
modified_pred.properties["original_bit_width"] = pred_to_modify.properties[
"original_bit_width"
]
nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0)
nx_graph.remove_edge(pred_to_modify, node)
nx_graph.add_edge(modified_pred, node, input_idx=1)
node.inputs[1] = modified_value
@staticmethod
def _encrypt_clear_assignments(graph: Graph):
"""
Encrypt clear assignments.
Args:
graph (Graph):
computation graph to transform
"""
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
assigned_value = node.inputs[1]
if assigned_value.is_clear:
preds = graph.ordered_preds_of(node)
assigned_pred = preds[1]
new_assigned_pred_value = deepcopy(assigned_value)
new_assigned_pred_value.is_encrypted = True
new_assigned_pred_value.dtype = preds[0].output.dtype
zero = Node.generic(
"zeros",
[],
EncryptedScalar(new_assigned_pred_value.dtype),
lambda: np.zeros((), dtype=np.int64),
)
original_bit_width = 1
zero.properties["original_bit_width"] = original_bit_width
new_assigned_pred = Node.generic(
"add",
[assigned_pred.output, zero.output],
new_assigned_pred_value,
np.add,
)
original_bit_width = assigned_pred.properties["original_bit_width"]
new_assigned_pred.properties["original_bit_width"] = original_bit_width
nx_graph.remove_edge(preds[1], node)
nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0)
nx_graph.add_edge(zero, new_assigned_pred, input_idx=1)
nx_graph.add_edge(new_assigned_pred, node, input_idx=1)
@staticmethod
def _tensorize_scalars_for_fhelinalg(graph: Graph):
"""
Tensorize scalars if they are used within fhelinalg operations.
Args:
graph (Graph):
computation graph to update
"""
# pylint: disable=invalid-name
OPS_TO_TENSORIZE = [
"add",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"broadcast_to",
"dot",
"equal",
"greater",
"greater_equal",
"left_shift",
"less",
"less_equal",
"multiply",
"not_equal",
"right_shift",
"subtract",
]
# pylint: enable=invalid-name
tensorized_scalars: Dict[Node, Node] = {}
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE:
assert len(node.inputs) in {1, 2}
if len(node.inputs) == 2:
if {inp.is_scalar for inp in node.inputs} != {True, False}:
continue
else:
if not node.inputs[0].is_scalar:
continue
# for bitwise and comparison operators that can have constants
# we don't need broadcasting here
if node.converted_to_table_lookup:
continue
pred_to_tensorize: Optional[Node] = None
pred_to_tensorize_index = 0
preds = graph.ordered_preds_of(node)
for index, pred in enumerate(preds):
if pred.output.is_scalar:
pred_to_tensorize = pred
pred_to_tensorize_index = index
break
assert pred_to_tensorize is not None
tensorized_pred = tensorized_scalars.get(pred_to_tensorize)
if tensorized_pred is None:
tensorized_value = deepcopy(pred_to_tensorize.output)
tensorized_value.shape = (1,)
tensorized_pred = Node.generic(
"array",
[pred_to_tensorize.output],
tensorized_value,
lambda *args: np.array(args),
)
original_bit_width = pred_to_tensorize.properties["original_bit_width"]
tensorized_pred.properties["original_bit_width"] = original_bit_width
original_shape = ()
tensorized_pred.properties["original_shape"] = original_shape
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
tensorized_scalars[pred_to_tensorize] = tensorized_pred
assert tensorized_pred is not None
nx_graph.remove_edge(pred_to_tensorize, node)
nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index)
new_input_value = deepcopy(node.inputs[pred_to_tensorize_index])
new_input_value.shape = (1,)
node.inputs[pred_to_tensorize_index] = new_input_value
@staticmethod
def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]:
"""
Use subtraction to sanitize signed inputs.
Args:
graph (Graph):
computation graph being converted
args (List[Any]):
list of arguments from mlir main
ctx (Context):
mlir context where the conversion is being performed
Returns:
Tuple[List[str], List[Any]]:
sanitized args and name of the sanitized variables in MLIR
"""
sanitized_args = []
for i, arg in enumerate(args):
input_node = graph.input_nodes[i]
input_value = input_node.output
assert_that(isinstance(input_value.dtype, Integer))
input_dtype = cast(Integer, input_value.dtype)
if input_dtype.is_signed:
assert_that(input_value.is_encrypted)
n = input_dtype.bit_width
sanitizer_type = IntegerType.get_signless(n + 1)
sanitizer = 2 ** (n - 1)
if input_value.is_scalar:
sanitizer_attr = IntegerAttr.get(sanitizer_type, sanitizer)
else:
sanitizer_type = RankedTensorType.get((1,), sanitizer_type)
sanitizer_attr = Attribute.parse(f"dense<[{sanitizer}]> : {sanitizer_type}")
# pylint: disable=too-many-function-args
sanitizer_cst = arith.ConstantOp(sanitizer_type, sanitizer_attr)
# pylint: enable=too-many-function-args
resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value)
if input_value.is_scalar:
sanitized = fhe.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
else:
sanitized = fhelinalg.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
sanitized_args.append(sanitized)
else:
sanitized_args.append(arg)
return sanitized_args
@staticmethod
def convert(graph: Graph) -> str:
"""
Convert a computation graph to its corresponding MLIR representation.
Args:
graph (Graph):
computation graph to be converted
Returns:
str:
textual MLIR representation corresponding to `graph`
"""
graph = deepcopy(graph)
GraphConverter._check_graph_convertibility(graph)
GraphConverter._update_bit_widths(graph)
GraphConverter._offset_negative_lookup_table_inputs(graph)
GraphConverter._broadcast_assignments(graph)
GraphConverter._encrypt_clear_assignments(graph)
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
from_elements_operations: Dict[OpResult, List[OpResult]] = {}
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
module = Module.create()
with InsertionPoint(module.body):
parameters = [
NodeConverter.value_to_mlir_type(ctx, input_node.output)
for input_node in graph.ordered_inputs()
]
@func.FuncOp.from_py_func(*parameters)
def main(*args):
sanitized_args = GraphConverter._sanitize_signed_inputs(graph, args, ctx)
ir_to_mlir = {}
for arg_num, node in graph.input_nodes.items():
ir_to_mlir[node] = sanitized_args[arg_num]
constant_cache = {}
for node in nx.topological_sort(graph.graph):
if node.operation == Operation.Input:
continue
preds = [ir_to_mlir[pred] for pred in graph.ordered_preds_of(node)]
node_converter = NodeConverter(
ctx,
graph,
node,
preds,
constant_cache,
from_elements_operations,
)
ir_to_mlir[node] = node_converter.convert()
results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs())
return results
direct_replacements = {}
for placeholder, elements in from_elements_operations.items():
element_names = [NodeConverter.mlir_name(element) for element in elements]
actual_value = f"tensor.from_elements {', '.join(element_names)} : {placeholder.type}"
direct_replacements[NodeConverter.mlir_name(placeholder)] = actual_value
module_lines_after_hacks_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()
if mlir_name not in direct_replacements:
module_lines_after_hacks_are_applied.append(line)
continue
new_value = direct_replacements[mlir_name]
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
return "\n".join(module_lines_after_hacks_are_applied).strip()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,171 @@
"""
Declaration of various functions and constants related to MLIR conversion.
"""
from collections import defaultdict, deque
from copy import deepcopy
from itertools import product
from typing import Any, DefaultDict, List, Optional, Tuple, Union, cast
import numpy as np
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Node, Operation
MAXIMUM_TLU_BIT_WIDTH = 16
class HashableNdarray:
"""
HashableNdarray class, to use numpy arrays in dictionaries.
"""
array: np.ndarray
def __init__(self, array: np.ndarray):
self.array = array
def __eq__(self, other: object) -> bool:
return isinstance(other, HashableNdarray) and np.array_equal(self.array, other.array)
def __hash__(self) -> int:
return hash(self.array.tobytes())
def flood_replace_none_values(table: list):
"""
Use flooding algorithm to replace `None` values.
Args:
table (list):
the list in which there are `None` values that need to be replaced
with copies of the closest non `None` data from the list
"""
assert_that(any(value is not None for value in table))
not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None)
while not_none_values_idx:
current_idx = not_none_values_idx.popleft()
current_value = table[current_idx]
previous_idx = current_idx - 1
next_idx = current_idx + 1
if previous_idx >= 0 and table[previous_idx] is None:
table[previous_idx] = deepcopy(current_value)
not_none_values_idx.append(previous_idx)
if next_idx < len(table) and table[next_idx] is None:
table[next_idx] = deepcopy(current_value)
not_none_values_idx.append(next_idx)
assert_that(all(value is not None for value in table))
def construct_table(node: Node, preds: List[Node]) -> List[Any]:
"""
Construct the lookup table for an Operation.Generic node.
Args:
node (Node):
Operation.Generic to construct the table
preds (List[Node]):
ordered predecessors to `node`
Returns:
List[Any]:
lookup table corresponding to `node` and its input value
"""
variable_input_index = -1
for index, pred in enumerate(preds):
if pred.operation != Operation.Constant:
variable_input_index = index
break
assert_that(variable_input_index != -1)
variable_input_dtype = node.inputs[variable_input_index].dtype
variable_input_shape = node.inputs[variable_input_index].shape
assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = cast(Integer, variable_input_dtype)
inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
table: List[Optional[Union[np.bool_, np.integer, np.floating, np.ndarray]]] = []
for value in range(variable_input_dtype.min(), variable_input_dtype.max() + 1):
try:
inputs[variable_input_index] = np.ones(variable_input_shape, dtype=np.int64) * value
table.append(node(*inputs))
except Exception: # pylint: disable=broad-except
# here we try our best to fill the table
# if it fails, we append None and let flooding algoritm replace None values below
table.append(None)
flood_replace_none_values(table)
return table
def construct_deduplicated_tables(
node: Node,
preds: List[Node],
) -> Tuple[Tuple[np.ndarray, List[Tuple[int, ...]]], ...]:
"""
Construct lookup tables for each cell of the input for an Operation.Generic node.
Args:
node (Node):
Operation.Generic to construct the table
preds (List[Node]):
ordered predecessors to `node`
Returns:
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
tuple containing tuples of 2 for
- constructed table
- list of indices of the input that use the constructed table
e.g.,
.. code-block:: python
(
(np.array([3, 1, 2, 4]), [(1, 0), (2, 1)]),
(np.array([5, 8, 6, 7]), [(0, 0), (0, 1), (1, 1), (2, 0)]),
)
means the lookup on 3x2 input will result in
.. code-block:: python
[ [5, 8, 6, 7][input[0, 0]] , [5, 8, 6, 7][input[0, 1]] ]
[ [3, 1, 2, 4][input[1, 0]] , [5, 8, 6, 7][input[1, 1]] ]
[ [5, 8, 6, 7][input[2, 0]] , [3, 1, 2, 4][input[2, 1]] ]
"""
node_complete_table = np.concatenate(
tuple(np.expand_dims(array, -1) for array in construct_table(node, preds)),
axis=-1,
)
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
tables_to_cell_idx: DefaultDict[HashableNdarray, List[Tuple[int, ...]]] = defaultdict(list)
idx: Tuple[int, ...]
all_idx_set = set()
for idx in all_cells_idx:
hashable_array = HashableNdarray(node_complete_table[idx])
tables_to_cell_idx[hashable_array].append(idx)
all_idx_set.add(idx)
assert_that(len(all_idx_set) == np.prod(node_complete_table.shape[:-1]))
return tuple(
(hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items()
)

View File

@@ -0,0 +1,7 @@
"""
Define structures used to represent computation.
"""
from .graph import Graph
from .node import Node
from .operation import Operation

View File

@@ -0,0 +1,52 @@
"""
Declaration of various `Evaluator` classes, to make graphs picklable.
"""
# ruff: noqa: ARG002
class ConstantEvaluator:
"""
ConstantEvaluator class, to evaluate Operation.Constant nodes.
"""
def __init__(self, properties):
self.properties = properties
def __call__(self, *args, **kwargs):
return self.properties["constant"]
class InputEvaluator:
"""
InputEvaluator class, to evaluate Operation.Input nodes.
"""
def __call__(self, *args, **kwargs):
return args[0]
class GenericEvaluator:
"""
GenericEvaluator class, to evaluate Operation.Generic nodes.
"""
def __init__(self, operation, properties):
self.operation = operation
self.properties = properties
def __call__(self, *args, **kwargs):
return self.operation(*args, *self.properties["args"], **self.properties["kwargs"])
class GenericTupleEvaluator:
"""
GenericEvaluator class, to evaluate Operation.Generic nodes where args are packed in a tuple.
"""
def __init__(self, operation, properties):
self.operation = operation
self.properties = properties
def __call__(self, *args, **kwargs):
return self.operation(tuple(args), *self.properties["args"], **self.properties["kwargs"])

View File

@@ -0,0 +1,692 @@
"""
Declaration of `Graph` class.
"""
import math
import re
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import networkx as nx
import numpy as np
import scipy.special
from ..dtypes import Float, Integer, UnsignedInteger
from .node import Node
from .operation import Operation
P_ERROR_PER_ERROR_SIZE_CACHE: Dict[float, Dict[int, float]] = {}
class Graph:
"""
Graph class, to represent computation graphs.
"""
graph: nx.MultiDiGraph
input_nodes: Dict[int, Node]
output_nodes: Dict[int, Node]
input_indices: Dict[Node, int]
is_direct: bool
def __init__(
self,
graph: nx.MultiDiGraph,
input_nodes: Dict[int, Node],
output_nodes: Dict[int, Node],
is_direct: bool = False,
):
self.graph = graph
self.input_nodes = input_nodes
self.output_nodes = output_nodes
self.input_indices = {node: index for index, node in input_nodes.items()}
self.is_direct = is_direct
self.prune_useless_nodes()
def __call__(
self,
*args: Any,
p_error: Optional[float] = None,
) -> Union[
np.bool_,
np.integer,
np.floating,
np.ndarray,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
]:
evaluation = self.evaluate(*args, p_error=p_error)
result = tuple(evaluation[node] for node in self.ordered_outputs())
return result if len(result) > 1 else result[0]
def evaluate(
self,
*args: Any,
p_error: Optional[float] = None,
) -> Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]]:
r"""
Perform the computation `Graph` represents and get resulting values for all nodes.
Args:
*args (List[Any]):
inputs to the computation
p_error (Optional[float]):
probability of error for table lookups
Returns:
Dict[Node, Union[np.bool\_, np.integer, np.floating, np.ndarray]]:
nodes and their values during computation
"""
# pylint: disable=no-member,too-many-nested-blocks,too-many-branches,too-many-statements
if p_error is None:
p_error = 0.0
assert isinstance(p_error, float)
node_results: Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]] = {}
for node in nx.topological_sort(self.graph):
if node.operation == Operation.Input:
node_results[node] = node(args[self.input_indices[node]])
continue
pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)]
if p_error > 0.0 and node.converted_to_table_lookup:
variable_input_indices = [
idx
for idx, pred in enumerate(self.ordered_preds_of(node))
if not pred.operation == Operation.Constant
]
for index in variable_input_indices:
pred_node = self.ordered_preds_of(node)[index]
if pred_node.operation != Operation.Input:
dtype = node.inputs[index].dtype
if isinstance(dtype, Integer):
# see https://github.com/zama-ai/concrete-numpy/blob/main/docs/_static/p_error_simulation.pdf # noqa: E501 # pylint: disable=line-too-long
# to learn more about the distribution of error
if p_error not in P_ERROR_PER_ERROR_SIZE_CACHE:
std_score = math.sqrt(2) * scipy.special.erfcinv(p_error)
p_error_per_error_size = {}
error_size = 1
last_p = 1 - p_error
while last_p != 1.0 or error_size == 1:
new_std_score = (2 * error_size + 1) * std_score
new_p = scipy.special.erf(new_std_score / math.sqrt(2))
p_error_per_error_size[error_size] = new_p - last_p
last_p = new_p
error_size += 1
# ordering of `p_error_per_error_size` is relied on
# during the introduction of the error below
# thus we explicitly sort it to make sure it's ordered
p_error_per_error_size = dict(
sorted(p_error_per_error_size.items())
)
P_ERROR_PER_ERROR_SIZE_CACHE[p_error] = p_error_per_error_size
else: # pragma: no cover
p_error_per_error_size = P_ERROR_PER_ERROR_SIZE_CACHE[p_error]
error = np.random.rand(*pred_results[index].shape)
accumulated_p_error = 0.0
for error_size, p_error_for_size in p_error_per_error_size.items():
accumulated_p_error += p_error_for_size
error = np.where(error < accumulated_p_error, error_size, error)
error = np.where(error < 1, 0, error).astype(np.int64)
error_sign = np.random.rand(*pred_results[index].shape)
error_sign = np.where(error_sign < 0.5, 1, -1).astype(np.int64)
new_result = pred_results[index] + (error * error_sign)
if new_result.shape == (): # pragma: no cover
if new_result < dtype.min():
new_result = dtype.max() - (dtype.min() - new_result) + 1
elif new_result > dtype.max():
new_result = dtype.min() - (new_result - dtype.max()) - 1
else:
underflow_indices = np.where(new_result < dtype.min())
new_result[underflow_indices] = (
dtype.max() - (dtype.min() - new_result[underflow_indices]) + 1
)
overflow_indices = np.where(new_result > dtype.max())
new_result[overflow_indices] = (
dtype.min() + (new_result[overflow_indices] - dtype.max()) - 1
)
pred_results[index] = new_result
try:
node_results[node] = node(*pred_results)
except Exception as error:
raise RuntimeError(
"Evaluation of the graph failed\n\n"
+ self.format(
highlighted_nodes={node: ["evaluation of this node failed"]},
show_bounds=False,
)
) from error
return node_results
def format(
self,
maximum_constant_length: int = 25,
highlighted_nodes: Optional[Dict[Node, List[str]]] = None,
show_types: bool = True,
show_bounds: bool = True,
show_tags: bool = True,
show_locations: bool = False,
) -> str:
"""
Get the textual representation of the `Graph`.
Args:
maximum_constant_length (int, default = 25):
maximum length of formatted constants
highlighted_nodes (Optional[Dict[Node, List[str]]], default = None):
nodes to be highlighted and their corresponding messages
show_types (bool, default = True):
whether to show types of nodes
show_bounds (bool, default = True):
whether to show bounds of nodes
show_tags (bool, default = True):
whether to show tags of nodes
show_locations (bool, default = False):
whether to show line information of nodes
Returns:
str:
textual representation of the `Graph`
"""
# pylint: disable=too-many-branches,too-many-locals,too-many-statements
# ruff: noqa: ERA001
if self.is_direct:
show_bounds = False
# node -> identifier
# e.g., id_map[node1] = 2
# means line for node1 is in this form %2 = node1.format(...)
id_map: Dict[Node, int] = {}
# lines that will be merged at the end
lines: List[str] = []
# metadata to add to each line
# (for alignment, this is done after lines are determined)
line_metadata: List[Dict[str, str]] = []
# default highlighted nodes is empty
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
# highlight information for lines, this is required because highlights are added to lines
# after their type information is added, and we only have line numbers, not nodes
highlighted_lines: Dict[int, List[str]] = {}
# subgraphs to format after the main graph is formatted
subgraphs: Dict[str, Graph] = {}
# format nodes
for node in nx.lexicographical_topological_sort(self.graph):
# assign a unique id to outputs of node
id_map[node] = len(id_map)
# remember highlights of the node
if node in highlighted_nodes:
highlighted_lines[len(lines)] = highlighted_nodes[node]
# extract predecessors and their ids
predecessors = []
for predecessor in self.ordered_preds_of(node):
predecessors.append(f"%{id_map[predecessor]}")
# start the build the line for the node
line = ""
# add output information to the line
line += f"%{id_map[node]}"
# add node information to the line
line += " = "
line += node.format(predecessors, maximum_constant_length)
# append line to list of lines
lines.append(line)
# if exists, save the subgraph
if node.operation == Operation.Generic and "subgraph" in node.properties["kwargs"]:
subgraphs[line] = node.properties["kwargs"]["subgraph"]
# get formatted bounds
bounds = ""
if node.bounds is not None:
bounds += "∈ ["
lower, upper = node.bounds
assert type(lower) == type(upper) # pylint: disable=unidiomatic-typecheck
if isinstance(lower, (float, np.float32, np.float64)):
bounds += f"{round(lower, 6)}, {round(upper, 6)}"
else:
bounds += f"{int(lower)}, {int(upper)}"
bounds += "]"
# remember metadata of the node
line_metadata.append(
{
"type": f"# {node.output}",
"bounds": bounds,
"tag": (f"@ {node.tag}" if node.tag != "" else ""),
"location": node.location,
},
)
# align = signs
#
# e.g.,
#
# %1 = ...
# %2 = ...
# ...
# %8 = ...
# %9 = ...
# %10 = ...
# %11 = ...
# ...
longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines)
for i, line in enumerate(lines):
length_before_equals_sign = len(line.split("=")[0])
lines[i] = (
" " * (longest_length_before_equals_sign - length_before_equals_sign)
) + line
# determine which metadata to show
shown_metadata_keys = []
if show_types:
shown_metadata_keys.append("type")
if show_bounds:
shown_metadata_keys.append("bounds")
if show_tags:
shown_metadata_keys.append("tag")
if show_locations:
shown_metadata_keys.append("location")
# show requested metadata
indent = 8
for metadata_key in shown_metadata_keys:
longest_line_length = max(len(line) for line in lines)
lines = [
line + (" " * ((longest_line_length - len(line)) + indent)) + metadata[metadata_key]
for line, metadata in zip(lines, line_metadata)
]
# strip whitespaces
lines = [line.rstrip() for line in lines]
# add highlights (this is done in reverse to keep indices consistent)
for i in reversed(range(len(lines))):
if i in highlighted_lines:
for j, message in enumerate(highlighted_lines[i]):
highlight = "^" if j == 0 else " "
lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}")
# add return information
# (if there is a single return, it's in the form `return %id`
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
returns: List[str] = []
for node in self.output_nodes.values():
returns.append(f"%{id_map[node]}")
lines.append("return " + (returns[0] if len(returns) == 1 else f"({', '.join(returns)})"))
# format subgraphs after the actual graph
result = "\n".join(lines)
if len(subgraphs) > 0:
result += "\n\n"
result += "Subgraphs:"
for line, subgraph in subgraphs.items():
subgraph_lines = subgraph.format(
maximum_constant_length=maximum_constant_length,
highlighted_nodes={},
show_types=show_types,
show_bounds=False, # doesn't make sense as we don't measure bounds in subgraphs
show_tags=show_tags,
show_locations=show_locations,
).split("\n")
result += "\n\n"
result += f" {line}:\n\n"
result += "\n".join(f" {line}" for line in subgraph_lines)
return result
# pylint: enable=too-many-branches,too-many-locals,too-many-statements
def measure_bounds(
self,
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
) -> Dict[Node, Dict[str, Union[np.integer, np.floating]]]:
"""
Evaluate the `Graph` using an inputset and measure bounds.
inputset is either an iterable of anything
for a single parameter
or
an iterable of tuples of anything (of rank number of parameters)
for multiple parameters
e.g.,
.. code-block:: python
inputset = [1, 3, 5, 2, 4]
def f(x):
...
inputset = [(1, 2), (2, 4), (3, 1), (2, 2)]
def g(x, y):
...
Args:
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
inputset to use
Returns:
Dict[Node, Dict[str, Union[np.integer, np.floating]]]:
bounds of each node in the `Graph`
"""
bounds = {}
inputset_iterator = iter(inputset)
sample = next(inputset_iterator)
if not isinstance(sample, tuple):
sample = (sample,)
index = 0
try:
evaluation = self.evaluate(*sample)
for node, value in evaluation.items():
bounds[node] = {
"min": value.min(),
"max": value.max(),
}
for sample in inputset_iterator:
index += 1
if not isinstance(sample, tuple):
sample = (sample,)
evaluation = self.evaluate(*sample)
for node, value in evaluation.items():
bounds[node] = {
"min": np.minimum(bounds[node]["min"], value.min()),
"max": np.maximum(bounds[node]["max"], value.max()),
}
except Exception as error:
message = f"Bound measurement using inputset[{index}] failed"
raise RuntimeError(message) from error
return bounds
def update_with_bounds(self, bounds: Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
"""
Update `Value`s within the `Graph` according to measured bounds.
Args:
bounds (Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
bounds of each node in the `Graph`
"""
for node in self.graph.nodes():
if node in bounds:
min_bound = bounds[node]["min"]
max_bound = bounds[node]["max"]
node.bounds = (min_bound, max_bound)
new_value = deepcopy(node.output)
if isinstance(min_bound, np.integer):
new_value.dtype = Integer.that_can_represent(np.array([min_bound, max_bound]))
else:
new_value.dtype = {
np.bool_: UnsignedInteger(1),
np.float64: Float(64),
np.float32: Float(32),
np.float16: Float(16),
}[type(min_bound)]
node.output = new_value
if node.operation == Operation.Input:
node.inputs[0] = new_value
for successor in self.graph.successors(node):
edge_data = self.graph.get_edge_data(node, successor)
for edge in edge_data.values():
input_idx = edge["input_idx"]
successor.inputs[input_idx] = node.output
def ordered_inputs(self) -> List[Node]:
"""
Get the input nodes of the `Graph`, ordered by their indices.
Returns:
List[Node]:
ordered input nodes
"""
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
def ordered_outputs(self) -> List[Node]:
"""
Get the output nodes of the `Graph`, ordered by their indices.
Returns:
List[Node]:
ordered output nodes
"""
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
def ordered_preds_of(self, node: Node) -> List[Node]:
"""
Get predecessors of `node`, ordered by their indices.
Args:
node (Node):
node whose predecessors are requested
Returns:
List[Node]:
ordered predecessors of `node`.
"""
idx_to_pred: Dict[int, Node] = {}
for pred in self.graph.predecessors(node):
edge_data = self.graph.get_edge_data(pred, node)
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
def prune_useless_nodes(self):
"""
Remove unreachable nodes from the graph.
"""
useful_nodes: Dict[Node, None] = {}
current_nodes = {node: None for node in self.ordered_outputs()}
while current_nodes:
useful_nodes.update(current_nodes)
next_nodes: Dict[Node, None] = {}
for node in current_nodes:
next_nodes.update({node: None for node in self.graph.predecessors(node)})
current_nodes = next_nodes
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
self.graph.remove_nodes_from(useless_nodes)
def query_nodes(
self,
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
) -> List[Node]:
"""
Query nodes within the graph.
Filters work like so:
str -> nodes without exact match is skipped
List[str] -> nodes without exact match with one of the strings in the list is skipped
re.Pattern -> nodes without pattern match is skipped
Args:
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for tags
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for operations
Returns:
List[Node]:
filtered nodes
"""
def match_text_filter(text_filter, text):
if text_filter is None:
return True
if isinstance(text_filter, str):
return text == text_filter
if isinstance(text_filter, re.Pattern):
return text_filter.match(text)
return any(text == alternative for alternative in text_filter)
def get_operation_name(node):
result: str
if node.operation == Operation.Input:
result = "input"
elif node.operation == Operation.Constant:
result = "constant"
else:
result = node.properties["name"]
return result
return [
node
for node in self.graph.nodes()
if (
match_text_filter(tag_filter, node.tag)
and match_text_filter(operation_filter, get_operation_name(node))
)
]
def maximum_integer_bit_width(
self,
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
) -> int:
"""
Get maximum integer bit-width within the graph.
Only nodes after filtering will be used to calculate the result.
Args:
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for tags
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for operations
Returns:
int:
maximum integer bit-width within the graph
if there are no integer nodes matching the query, result is -1
"""
filtered_bit_widths = (
node.output.dtype.bit_width
for node in self.query_nodes(tag_filter, operation_filter)
if isinstance(node.output.dtype, Integer)
)
return max(filtered_bit_widths, default=-1)
def integer_range(
self,
tag_filter: Optional[Union[str, List[str], re.Pattern]] = None,
operation_filter: Optional[Union[str, List[str], re.Pattern]] = None,
) -> Optional[Tuple[int, int]]:
"""
Get integer range of the graph.
Only nodes after filtering will be used to calculate the result.
Args:
tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for tags
operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None):
filter for operations
Returns:
Optional[Tuple[int, int]]:
minimum and maximum integer value observed during inputset evaluation
if there are no integer nodes matching the query, result is None
"""
result: Optional[Tuple[int, int]] = None
if not self.is_direct:
filtered_bounds = (
node.bounds
for node in self.query_nodes(tag_filter, operation_filter)
if isinstance(node.output.dtype, Integer) and node.bounds is not None
)
for min_bound, max_bound in filtered_bounds:
assert isinstance(min_bound, np.integer) and isinstance(max_bound, np.integer)
if result is None:
result = (int(min_bound), int(max_bound))
else:
old_min_bound, old_max_bound = result # pylint: disable=unpacking-non-sequence
result = (
min(old_min_bound, int(min_bound)),
max(old_max_bound, int(max_bound)),
)
return result

View File

@@ -0,0 +1,434 @@
"""
Declaration of `Node` class.
"""
import os
import time
import traceback
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from ..internal.utils import assert_that
from ..values import Value
from .evaluator import ConstantEvaluator, GenericEvaluator, GenericTupleEvaluator, InputEvaluator
from .operation import Operation
from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element
class Node:
"""
Node class, to represent computation in a computation graph.
"""
inputs: List[Value]
output: Value
operation: Operation
evaluator: Callable
bounds: Optional[Tuple[Union[int, float], Union[int, float]]]
properties: Dict[str, Any]
location: str
tag: str
created_at: float
@staticmethod
def constant(constant: Any) -> "Node":
"""
Create an Operation.Constant node.
Args:
constant (Any):
constant to represent
Returns:
Node:
node representing constant
Raises:
ValueError:
if the constant is not representable
"""
try:
value = Value.of(constant)
except Exception as error:
message = f"Constant {repr(constant)} is not supported"
raise ValueError(message) from error
properties = {"constant": np.array(constant)}
return Node([], value, Operation.Constant, ConstantEvaluator(properties), properties)
@staticmethod
def generic(
name: str,
inputs: List[Value],
output: Value,
operation: Callable,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
attributes: Optional[Dict[str, Any]] = None,
):
"""
Create an Operation.Generic node.
Args:
name (str):
name of the operation
inputs (List[Value]):
inputs to the operation
output (Value):
output of the operation
operation (Callable):
operation itself
args (Optional[Tuple[Any, ...]]):
args to pass to operation during evaluation
kwargs (Optional[Dict[str, Any]]):
kwargs to pass to operation during evaluation
attributes (Optional[Dict[str, Any]]):
attributes of the operation
Returns:
Node:
node representing operation
"""
properties = {
"name": name,
"args": args if args is not None else (),
"kwargs": kwargs if kwargs is not None else {},
"attributes": attributes if attributes is not None else {},
}
return Node(
inputs,
output,
Operation.Generic,
(
GenericTupleEvaluator(operation, properties) # type: ignore
if name in ["concatenate"]
else GenericEvaluator(operation, properties) # type: ignore
),
properties,
)
@staticmethod
def input(name: str, value: Value) -> "Node":
"""
Create an Operation.Input node.
Args:
name (Any):
name of the input
value (Any):
value of the input
Returns:
Node:
node representing input
"""
return Node([value], value, Operation.Input, InputEvaluator(), {"name": name})
def __init__(
self,
inputs: List[Value],
output: Value,
operation: Operation,
evaluator: Callable,
properties: Optional[Dict[str, Any]] = None,
):
self.inputs = inputs
self.output = output
self.operation = operation
self.evaluator = evaluator # type: ignore
self.bounds = None
self.properties = properties if properties is not None else {}
# pylint: disable=cyclic-import,import-outside-toplevel
import concrete.numpy as cnp
cnp_directory = os.path.dirname(cnp.__file__)
import concrete.onnx as coonx
coonx_directory = os.path.dirname(coonx.__file__)
# pylint: enable=cyclic-import,import-outside-toplevel
for frame in reversed(traceback.extract_stack()):
if frame.filename == "<__array_function__ internals>":
continue
if frame.filename.startswith(cnp_directory):
continue
if frame.filename.startswith(coonx_directory):
continue
self.location = f"{frame.filename}:{frame.lineno}"
break
# pylint: disable=cyclic-import,import-outside-toplevel
from ..extensions.tag import tag_context
self.tag = ".".join(tag_context.stack)
# pylint: enable=cyclic-import,import-outside-toplevel
self.created_at = time.time()
def __call__(self, *args: List[Any]) -> Union[np.bool_, np.integer, np.floating, np.ndarray]:
def generic_error_message() -> str:
result = f"Evaluation of {self.operation.value} '{self.label()}' node"
if len(args) != 0:
result += f" using {', '.join(repr(arg) for arg in args)}"
return result
if len(args) != len(self.inputs):
message = f"{generic_error_message()} failed because of invalid number of arguments"
raise ValueError(message)
for arg, input_ in zip(args, self.inputs):
try:
arg_value = Value.of(arg)
except Exception as error:
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
message = f"{generic_error_message()} failed because {arg_str} is not valid"
raise ValueError(message) from error
if input_.shape != arg_value.shape:
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
message = (
f"{generic_error_message()} failed because "
f"{arg_str} does not have the expected "
f"shape of {input_.shape}"
)
raise ValueError(message)
result = self.evaluator(*args)
if isinstance(result, int) and -(2**63) < result < (2**63) - 1:
result = np.int64(result)
if isinstance(result, float):
result = np.float64(result)
if isinstance(result, list):
try:
np_result = np.array(result)
result = np_result
except Exception: # pylint: disable=broad-except
# here we try our best to convert the list to np.ndarray
# if it fails we raise the exception below
pass
if not isinstance(result, (np.bool_, np.integer, np.floating, np.ndarray)):
message = (
f"{generic_error_message()} resulted in {repr(result)} "
f"of type {result.__class__.__name__} "
f"which is not acceptable either because of the type or because of overflow"
)
raise ValueError(message)
if isinstance(result, np.ndarray):
dtype = result.dtype
if (
not np.issubdtype(dtype, np.integer)
and not np.issubdtype(dtype, np.floating)
and not np.issubdtype(dtype, np.bool_)
):
message = (
f"{generic_error_message()} resulted in {repr(result)} "
f"of type np.ndarray and of underlying type '{type(dtype).__name__}' "
f"which is not acceptable because of the underlying type"
)
raise ValueError(message)
if result.shape != self.output.shape:
message = (
f"{generic_error_message()} resulted in {repr(result)} "
f"which does not have the expected "
f"shape of {self.output.shape}"
)
raise ValueError(message)
return result
def format(self, predecessors: List[str], maximum_constant_length: int = 45) -> str:
"""
Get the textual representation of the `Node` (dependent to preds).
Args:
predecessors (List[str]):
predecessor names to this node
maximum_constant_length (int, default = 45):
maximum length of formatted constants
Returns:
str:
textual representation of the `Node` (dependent to preds)
"""
if self.operation == Operation.Constant:
return format_constant(self(), maximum_constant_length)
if self.operation == Operation.Input:
return self.properties["name"]
assert_that(self.operation == Operation.Generic)
name = self.properties["name"]
if name == "index.static":
index = self.properties["kwargs"]["index"]
elements = [format_indexing_element(element) for element in index]
return f"{predecessors[0]}[{', '.join(elements)}]"
if name == "assign.static":
index = self.properties["kwargs"]["index"]
elements = [format_indexing_element(element) for element in index]
return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[1]})"
if name == "concatenate":
args = [f"({', '.join(predecessors)})"]
else:
args = deepcopy(predecessors)
if name == "array":
values = str(np.array(predecessors).reshape(self.output.shape).tolist()).replace(
"'", ""
)
return f"array({format_constant(values, maximum_constant_length)})"
args.extend(
format_constant(value, maximum_constant_length) for value in self.properties["args"]
)
args.extend(
f"{name}={format_constant(value, maximum_constant_length)}"
for name, value in self.properties["kwargs"].items()
if name not in KWARGS_IGNORED_IN_FORMATTING
)
return f"{name}({', '.join(args)})"
def label(self) -> str:
"""
Get the textual representation of the `Node` (independent of preds).
Returns:
str:
textual representation of the `Node` (independent of preds).
"""
if self.operation == Operation.Constant:
return format_constant(self(), maximum_length=45, keep_newlines=True)
if self.operation == Operation.Input:
return self.properties["name"]
assert_that(self.operation == Operation.Generic)
name = self.properties["name"]
if name == "index.static":
name = self.format([""])
if name == "assign.static":
name = self.format(["", ""])[1:-1]
return name
@property
def converted_to_table_lookup(self) -> bool:
"""
Get whether the node is converted to a table lookup during MLIR conversion.
Returns:
bool:
True if the node is converted to a table lookup, False otherwise
"""
if (
all(value.is_encrypted for value in self.inputs)
and self.operation == Operation.Generic
and self.properties["name"]
in [
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"equal",
"greater",
"greater_equal",
"left_shift",
"less",
"less_equal",
"not_equal",
"right_shift",
]
):
return False
return self.operation == Operation.Generic and self.properties["name"] not in [
"add",
"array",
"assign.static",
"broadcast_to",
"concatenate",
"conv1d",
"conv2d",
"conv3d",
"dot",
"expand_dims",
"index.static",
"matmul",
"maxpool",
"multiply",
"negative",
"ones",
"reshape",
"squeeze",
"subtract",
"sum",
"transpose",
"zeros",
]
@property
def is_fusable(self) -> bool:
"""
Get whether the node is can be fused into a table lookup.
Returns:
bool:
True if the node can be fused into a table lookup, False otherwise
"""
if self.converted_to_table_lookup:
return True
return self.operation != Operation.Generic or self.properties["name"] in [
"add",
"multiply",
"negative",
"ones",
"subtract",
"zeros",
]
def __lt__(self, other) -> bool:
return self.created_at < other.created_at

View File

@@ -0,0 +1,29 @@
"""
Declaration of `Operation` enum.
"""
from enum import Enum
class Operation(Enum):
"""
Operation enum, to distinguish nodes within a computation graph.
"""
# pylint: disable=invalid-name
Constant = "constant"
Generic = "generic"
Input = "input"
# pylint: enable=invalid-name
# https://graphviz.org/doc/info/colors.html#svg
OPERATION_COLOR_MAPPING = {
Operation.Constant: "grey",
Operation.Generic: "black",
Operation.Input: "crimson",
"output": "gold",
}

View File

@@ -0,0 +1,114 @@
"""
Declaration of various functions and constants related to representation of computation.
"""
from typing import Any, Dict, Hashable, Set, Union
import numpy as np
from ..internal.utils import assert_that
KWARGS_IGNORED_IN_FORMATTING: Set[str] = {
"subgraph",
"terminal_node",
}
SPECIAL_OBJECT_MAPPING: Dict[Any, str] = {
np.float16: "float16",
np.float32: "float32",
np.float64: "float64",
np.int8: "int8",
np.int16: "int16",
np.int32: "int32",
np.int64: "int64",
np.uint8: "uint8",
np.uint16: "uint16",
np.uint32: "uint32",
np.uint64: "uint64",
np.byte: "byte",
np.short: "short",
np.intc: "intc",
np.int_: "int_",
np.longlong: "longlong",
np.ubyte: "ubyte",
np.ushort: "ushort",
np.uintc: "uintc",
np.uint: "uint",
np.ulonglong: "ulonglong",
}
def format_constant(constant: Any, maximum_length: int = 45, keep_newlines: bool = False) -> str:
"""
Get the textual representation of a constant.
Args:
constant (Any):
constant to format
maximum_length (int, default = 45):
maximum length of the resulting string
keep_newlines (bool, default = False):
whether to keep newlines or not
Returns:
str:
textual representation of `constant`
"""
if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING:
return SPECIAL_OBJECT_MAPPING[constant]
# maximum_length should not be smaller than 7 characters because
# the constant will be formatted to `x ... y`
# where x and y are part of the constant, and they are at least 1 character
assert_that(maximum_length >= 7)
result = str(constant)
if not keep_newlines:
result = result.replace("\n", "")
if len(result) > maximum_length:
from_start = (maximum_length - 5) // 2
from_end = (maximum_length - 5) - from_start
if keep_newlines and "\n" in result:
result = f"{result[:from_start]}\n...\n{result[-from_end:]}"
else:
result = f"{result[:from_start]} ... {result[-from_end:]}"
return result
def format_indexing_element(indexing_element: Union[int, np.integer, slice]):
"""
Format an indexing element.
This is required mainly for slices. The reason is that string representation of slices
are very long and verbose. To give an example, `x[:, 2:]` will have the following index
`[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper,
it will be formatted as `[:, 2:]`.
Args:
indexing_element (Union[int, np.integer, slice]):
indexing element to format
Returns:
str:
textual representation of `indexing_element`
"""
result = ""
if isinstance(indexing_element, slice):
if indexing_element.start is not None:
result += str(indexing_element.start)
result += ":"
if indexing_element.stop is not None:
result += str(indexing_element.stop)
if indexing_element.step is not None:
result += ":"
result += str(indexing_element.step)
else:
result += str(indexing_element)
return result.replace("\n", " ")

View File

@@ -0,0 +1,5 @@
"""
Provide `function` to `computation graph` functionality.
"""
from .tracer import ScalarAnnotation, TensorAnnotation, Tracer

View File

@@ -0,0 +1,867 @@
"""
Declaration of `Tracer` class.
"""
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, cast
import networkx as nx
import numpy as np
from numpy.typing import DTypeLike
from ..dtypes import BaseDataType, Float, Integer
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..representation.utils import format_indexing_element
from ..values import Value
class Tracer:
"""
Tracer class, to create computation graphs from python functions.
"""
computation: Node
input_tracers: List["Tracer"]
output: Value
# property to keep track of assignments
last_version: Optional["Tracer"] = None
# variables to control the behavior of certain functions
_is_tracing: bool = False
_is_direct: bool = False
@staticmethod
def trace(function: Callable, parameters: Dict[str, Value], is_direct: bool = False) -> Graph:
"""
Trace `function` and create the `Graph` that represents it.
Args:
function (Callable):
function to trace
parameters (Dict[str, Value]):
parameters of function to trace
e.g. parameter x is an EncryptedScalar holding a 7-bit UnsignedInteger
is_direct (bool, default = False):
whether the tracing is done on actual parameters or placeholders
Returns:
Graph:
computation graph corresponding to `function`
"""
# pylint: disable=too-many-statements
signature = inspect.signature(function)
missing_args = list(signature.parameters)
for arg in parameters.keys():
missing_args.remove(arg)
assert_that(len(missing_args) == 0)
arguments = {}
input_indices = {}
for index, param in enumerate(signature.parameters.keys()):
node = Node.input(param, parameters[param])
arguments[param] = Tracer(node, [])
input_indices[node] = index
Tracer._is_direct = is_direct
Tracer._is_tracing = True
output_tracers: Any = function(**arguments)
Tracer._is_tracing = False
if not isinstance(output_tracers, tuple):
output_tracers = (output_tracers,)
output_tracer_list = list(output_tracers)
for i, output_tracer in enumerate(output_tracer_list):
if isinstance(output_tracer, Tracer) and output_tracer.last_version is not None:
output_tracer_list[i] = output_tracer.last_version
output_tracers = tuple(output_tracer_list)
sanitized_tracers = []
for tracer in output_tracers:
if isinstance(tracer, Tracer):
sanitized_tracers.append(tracer)
continue
try:
sanitized_tracers.append(Tracer.sanitize(tracer))
except Exception as error:
message = (
f"Function '{function.__name__}' "
f"returned '{tracer}', "
f"which is not supported"
)
raise ValueError(message) from error
output_tracers = tuple(sanitized_tracers)
def create_graph_from_output_tracers(
arguments: Dict[str, Tracer],
output_tracers: Tuple[Tracer, ...],
) -> nx.MultiDiGraph:
graph = nx.MultiDiGraph()
visited_tracers: Set[Tracer] = set()
current_tracers = {tracer: None for tracer in output_tracers}
while current_tracers:
next_tracers: Dict[Tracer, None] = {}
for tracer in current_tracers:
if tracer not in visited_tracers:
current_node = tracer.computation
graph.add_node(current_node)
for input_idx, input_tracer in enumerate(tracer.input_tracers):
pred_node = input_tracer.computation
graph.add_node(pred_node)
graph.add_edge(
pred_node,
current_node,
input_idx=input_idx,
)
if input_tracer not in visited_tracers:
next_tracers.update({input_tracer: None})
visited_tracers.add(tracer)
current_tracers = next_tracers
assert_that(nx.algorithms.dag.is_directed_acyclic_graph(graph))
unique_edges = {
(pred, succ, tuple((k, v) for k, v in edge_data.items()))
for pred, succ, edge_data in graph.edges(data=True)
}
assert_that(len(unique_edges) == len(graph.edges))
for tracer in arguments.values():
graph.add_node(tracer.computation)
return graph
graph = create_graph_from_output_tracers(arguments, output_tracers)
input_nodes = {
input_indices[node]: node
for node in graph.nodes()
if len(graph.pred[node]) == 0 and node.operation == Operation.Input
}
output_nodes = {
output_idx: tracer.computation for output_idx, tracer in enumerate(output_tracers)
}
return Graph(graph, input_nodes, output_nodes, is_direct)
# pylint: enable=too-many-statements
def __init__(self, computation: Node, input_tracers: List["Tracer"]):
self.computation = computation
self.input_tracers = input_tracers
self.output = computation.output
for i, tracer in enumerate(self.input_tracers):
self.input_tracers[i] = tracer if tracer.last_version is None else tracer.last_version
def __hash__(self) -> int:
return id(self)
def __bool__(self) -> bool:
# pylint: disable=invalid-bool-returned
message = "Branching within circuits is not possible"
raise RuntimeError(message)
@staticmethod
def sanitize(value: Any) -> Any:
"""
Try to create a tracer from a value.
Args:
value (Any):
value to use
Returns:
Any:
resulting tracer
"""
if isinstance(value, tuple):
return tuple(Tracer.sanitize(item) for item in value)
if isinstance(value, Tracer):
return value
computation = Node.constant(value)
return Tracer(computation, [])
SUPPORTED_NUMPY_OPERATORS: Set[Any] = {
np.abs,
np.absolute,
np.add,
np.arccos,
np.arccosh,
np.arcsin,
np.arcsinh,
np.arctan,
np.arctan2,
np.arctanh,
np.around,
np.bitwise_and,
np.bitwise_or,
np.bitwise_xor,
np.broadcast_to,
np.cbrt,
np.ceil,
np.clip,
np.concatenate,
np.copysign,
np.cos,
np.cosh,
np.deg2rad,
np.degrees,
np.divide,
np.dot,
np.equal,
np.exp,
np.exp2,
np.expand_dims,
np.expm1,
np.fabs,
np.float_power,
np.floor,
np.floor_divide,
np.fmax,
np.fmin,
np.fmod,
np.gcd,
np.greater,
np.greater_equal,
np.heaviside,
np.hypot,
np.invert,
np.isfinite,
np.isinf,
np.isnan,
np.lcm,
np.ldexp,
np.left_shift,
np.less,
np.less_equal,
np.log,
np.log10,
np.log1p,
np.log2,
np.logaddexp,
np.logaddexp2,
np.logical_and,
np.logical_not,
np.logical_or,
np.logical_xor,
np.matmul,
np.maximum,
np.minimum,
np.mod,
np.multiply,
np.negative,
np.nextafter,
np.not_equal,
np.ones_like,
np.positive,
np.power,
np.rad2deg,
np.radians,
np.reciprocal,
np.remainder,
np.reshape,
np.right_shift,
np.rint,
np.round_,
np.sign,
np.signbit,
np.sin,
np.sinh,
np.spacing,
np.sqrt,
np.square,
np.squeeze,
np.subtract,
np.sum,
np.tan,
np.tanh,
np.transpose,
np.true_divide,
np.trunc,
np.where,
np.zeros_like,
}
SUPPORTED_KWARGS: Dict[Any, Set[str]] = {
np.around: {
"decimals",
},
np.broadcast_to: {
"shape",
},
np.concatenate: {
"axis",
},
np.expand_dims: {
"axis",
},
np.ones_like: {
"dtype",
},
np.reshape: {
"newshape",
},
np.round_: {
"decimals",
},
np.squeeze: {
"axis",
},
np.sum: {
"axis",
"keepdims",
},
np.transpose: {
"axes",
},
np.zeros_like: {
"dtype",
},
}
@staticmethod
def _trace_numpy_operation(operation: Callable, *args, **kwargs) -> "Tracer":
"""
Trace an arbitrary numpy operation into an Operation.Generic node.
Args:
operation (Callable):
operation to trace
args (List[Any]):
args of the arbitrary computation
kwargs (Dict[str, Any]):
kwargs of the arbitrary computation
Returns:
Tracer:
tracer representing the arbitrary computation
"""
if operation not in Tracer.SUPPORTED_NUMPY_OPERATORS:
message = f"Function 'np.{operation.__name__}' is not supported"
raise RuntimeError(message)
supported_kwargs = Tracer.SUPPORTED_KWARGS.get(operation, set())
for kwarg in kwargs:
if kwarg not in supported_kwargs:
message = (
f"Function 'np.{operation.__name__}' is not supported with kwarg '{kwarg}'"
)
raise RuntimeError(message)
if operation == np.ones_like: # pylint: disable=comparison-with-callable
dtype = kwargs.get("dtype", np.int64)
return Tracer(Node.constant(np.ones(args[0].shape, dtype=dtype)), [])
if operation == np.zeros_like: # pylint: disable=comparison-with-callable
dtype = kwargs.get("dtype", np.int64)
return Tracer(Node.constant(np.zeros(args[0].shape, dtype=dtype)), [])
def sampler(arg: Any) -> Any:
if isinstance(arg, tuple):
return tuple(sampler(item) for item in arg)
output = arg.output
assert_that(isinstance(output.dtype, (Float, Integer)))
dtype: Any = np.int64
if isinstance(output.dtype, Float):
assert_that(output.dtype.bit_width in [16, 32, 64])
dtype = {64: np.float64, 32: np.float32, 16: np.float16}[output.dtype.bit_width]
if output.shape == ():
return dtype(1)
return np.ones(output.shape, dtype=dtype)
sample = [sampler(arg) for arg in args]
evaluation = operation(*sample, **kwargs)
def extract_tracers(arg: Any, tracers: List[Tracer]):
if isinstance(arg, tuple):
for item in arg:
extract_tracers(item, tracers)
if isinstance(arg, Tracer):
tracers.append(arg)
tracers: List[Tracer] = []
for arg in args:
extract_tracers(arg, tracers)
output_value = Value.of(evaluation)
output_value.is_encrypted = any(tracer.output.is_encrypted for tracer in tracers)
if Tracer._is_direct and isinstance(output_value.dtype, Integer):
assert all(isinstance(tracer.output.dtype, Integer) for tracer in tracers)
dtypes = cast(List[Integer], [tracer.output.dtype for tracer in tracers])
output_value.dtype.bit_width = max(dtype.bit_width for dtype in dtypes)
output_value.dtype.is_signed = any(dtype.is_signed for dtype in dtypes)
computation = Node.generic(
operation.__name__,
[tracer.output for tracer in tracers],
output_value,
operation,
kwargs=kwargs,
)
return Tracer(computation, tracers)
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
"""
Numpy ufunc hook.
(https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch)
"""
if method == "__call__":
sanitized_args = [self.sanitize(arg) for arg in args]
return Tracer._trace_numpy_operation(ufunc, *sanitized_args, **kwargs)
message = "Only __call__ hook is supported for numpy ufuncs"
raise RuntimeError(message)
def __array_function__(self, func, _types, args, kwargs):
"""
Numpy function hook.
(https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch)
"""
if func is np.broadcast_to:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["shape"] = args[1]
elif func is np.reshape:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["newshape"] = args[1]
elif func is np.transpose:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["axes"] = args[1]
else:
sanitized_args = [self.sanitize(arg) for arg in args]
return Tracer._trace_numpy_operation(func, *sanitized_args, **kwargs)
def __add__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.add, self, self.sanitize(other))
def __radd__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.add, self.sanitize(other), self)
def __sub__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.subtract, self, self.sanitize(other))
def __rsub__(self, other) -> "Tracer":
return Tracer._trace_numpy_operation(np.subtract, self.sanitize(other), self)
def __mul__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.multiply, self, self.sanitize(other))
def __rmul__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.multiply, self.sanitize(other), self)
def __truediv__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.true_divide, self, self.sanitize(other))
def __rtruediv__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.true_divide, self.sanitize(other), self)
def __floordiv__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.floor_divide, self, self.sanitize(other))
def __rfloordiv__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.floor_divide, self.sanitize(other), self)
def __pow__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.power, self, self.sanitize(other))
def __rpow__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.power, self.sanitize(other), self)
def __mod__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.mod, self, self.sanitize(other))
def __rmod__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.mod, self.sanitize(other), self)
def __matmul__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.matmul, self, self.sanitize(other))
def __rmatmul__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.matmul, self.sanitize(other), self)
def __neg__(self) -> "Tracer":
return Tracer._trace_numpy_operation(np.negative, self)
def __pos__(self) -> "Tracer":
return Tracer._trace_numpy_operation(np.positive, self)
def __abs__(self):
return Tracer._trace_numpy_operation(np.absolute, self)
def __round__(self, ndigits=None):
if ndigits is None:
result = Tracer._trace_numpy_operation(np.around, self)
if self._is_direct:
message = (
"'round(x)' cannot be used in direct definition (you may use np.around instead)"
)
raise RuntimeError(message)
return result.astype(np.int64)
return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits)
def __invert__(self):
return Tracer._trace_numpy_operation(np.invert, self)
def __and__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_and, self, self.sanitize(other))
def __rand__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_and, self.sanitize(other), self)
def __or__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_or, self, self.sanitize(other))
def __ror__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_or, self.sanitize(other), self)
def __xor__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_xor, self, self.sanitize(other))
def __rxor__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_xor, self.sanitize(other), self)
def __lshift__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.left_shift, self, self.sanitize(other))
def __rlshift__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.left_shift, self.sanitize(other), self)
def __rshift__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.right_shift, self, self.sanitize(other))
def __rrshift__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.right_shift, self.sanitize(other), self)
def __gt__(self, other: Any) -> "Tracer": # type: ignore
return Tracer._trace_numpy_operation(np.greater, self, self.sanitize(other))
def __ge__(self, other: Any) -> "Tracer": # type: ignore
return Tracer._trace_numpy_operation(np.greater_equal, self, self.sanitize(other))
def __lt__(self, other: Any) -> "Tracer": # type: ignore
return Tracer._trace_numpy_operation(np.less, self, self.sanitize(other))
def __le__(self, other: Any) -> "Tracer": # type: ignore
return Tracer._trace_numpy_operation(np.less_equal, self, self.sanitize(other))
def __eq__(self, other: Any) -> Union[bool, "Tracer"]: # type: ignore
return (
self is other
if not self._is_tracing
else Tracer._trace_numpy_operation(np.equal, self, self.sanitize(other))
)
def __ne__(self, other: Any) -> Union[bool, "Tracer"]: # type: ignore
return (
self is not other
if not self._is_tracing
else Tracer._trace_numpy_operation(np.not_equal, self, self.sanitize(other))
)
def astype(self, dtype: Union[DTypeLike, Type["ScalarAnnotation"]]) -> "Tracer":
"""
Trace numpy.ndarray.astype(dtype).
"""
if Tracer._is_direct:
output_value = deepcopy(self.output)
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
output_value.dtype = dtype.dtype
else:
message = (
"`astype` method must be called with a concrete.numpy type "
"for direct circuit definition (e.g., value.astype(cnp.uint4))"
)
raise ValueError(message)
computation = Node.generic(
"astype",
[self.output],
output_value,
lambda x: x, # unused for direct definition
)
return Tracer(computation, [self])
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
message = (
"`astype` method must be called with a "
"numpy type for compilation (e.g., value.astype(np.int64))"
)
raise ValueError(message)
dtype = np.dtype(dtype).type
if np.issubdtype(dtype, np.integer) and dtype != np.int64:
print(
"Warning: When using `value.astype(newtype)` "
"with an integer newtype, "
"only use `np.int64` as the newtype "
"to avoid unexpected overflows "
"during inputset evaluation"
)
output_value = deepcopy(self.output)
output_value.dtype = Value.of(dtype(0)).dtype # type: ignore
if np.issubdtype(dtype, np.integer):
def evaluator(x, dtype):
if np.any(np.isnan(x)):
message = "A `NaN` value is tried to be converted to integer"
raise ValueError(message)
if np.any(np.isinf(x)):
message = "An `Inf` value is tried to be converted to integer"
raise ValueError(message)
return x.astype(dtype)
else:
def evaluator(x, dtype):
return x.astype(dtype)
computation = Node.generic(
"astype",
[self.output],
output_value,
evaluator,
kwargs={"dtype": dtype},
)
return Tracer(computation, [self])
def clip(self, minimum: Any, maximum: Any) -> "Tracer":
"""
Trace numpy.ndarray.clip().
"""
return Tracer._trace_numpy_operation(
np.clip, self, self.sanitize(minimum), self.sanitize(maximum)
)
def dot(self, other: Any) -> "Tracer":
"""
Trace numpy.ndarray.dot().
"""
return Tracer._trace_numpy_operation(np.dot, self, self.sanitize(other))
def flatten(self) -> "Tracer":
"""
Trace numpy.ndarray.flatten().
"""
return Tracer._trace_numpy_operation(np.reshape, self, newshape=(self.output.size,))
def reshape(self, newshape: Tuple[Any, ...]) -> "Tracer":
"""
Trace numpy.ndarray.reshape(newshape).
"""
return Tracer._trace_numpy_operation(np.reshape, self, newshape=newshape)
def round(self, decimals: int = 0) -> "Tracer":
"""
Trace numpy.ndarray.round().
"""
return Tracer._trace_numpy_operation(np.around, self, decimals=decimals)
def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer":
"""
Trace numpy.ndarray.transpose().
"""
if axes is None:
return Tracer._trace_numpy_operation(np.transpose, self)
return Tracer._trace_numpy_operation(np.transpose, self, axes=axes)
def __getitem__(
self,
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
) -> "Tracer":
if not isinstance(index, tuple):
index = (index,)
for indexing_element in index:
valid = isinstance(indexing_element, (int, np.integer, slice))
if isinstance(indexing_element, slice):
if (
not (
indexing_element.start is None
or isinstance(indexing_element.start, (int, np.integer))
)
or not (
indexing_element.stop is None
or isinstance(indexing_element.stop, (int, np.integer))
)
or not (
indexing_element.step is None
or isinstance(indexing_element.step, (int, np.integer))
)
):
valid = False
if not valid:
message = (
f"Indexing with '{format_indexing_element(indexing_element)}' is not supported"
)
raise ValueError(message)
output_value = deepcopy(self.output)
output_value.shape = np.zeros(output_value.shape)[index].shape
computation = Node.generic(
"index.static",
[self.output],
output_value,
lambda x, index: x[index],
kwargs={"index": index},
)
return Tracer(computation, [self])
def __setitem__(
self,
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
value: Any,
):
if not isinstance(index, tuple):
index = (index,)
for indexing_element in index:
valid = isinstance(indexing_element, (int, np.integer, slice))
if isinstance(indexing_element, slice):
if (
not (
indexing_element.start is None
or isinstance(indexing_element.start, (int, np.integer))
)
or not (
indexing_element.stop is None
or isinstance(indexing_element.stop, (int, np.integer))
)
or not (
indexing_element.step is None
or isinstance(indexing_element.step, (int, np.integer))
)
):
valid = False
if not valid:
message = (
f"Assigning to '{format_indexing_element(indexing_element)}' is not supported"
)
raise ValueError(message)
np.zeros(self.output.shape)[index] = 1
def assign(x, value, index):
x[index] = value
return x
sanitized_value = self.sanitize(value)
computation = Node.generic(
"assign.static",
[self.output, sanitized_value.output],
self.output,
assign,
kwargs={"index": index},
)
new_version = Tracer(computation, [self, sanitized_value])
self.last_version = new_version
@property
def shape(self) -> Tuple[int, ...]:
"""
Trace numpy.ndarray.shape.
"""
return self.output.shape
@property
def ndim(self) -> int:
"""
Trace numpy.ndarray.ndim.
"""
return self.output.ndim
@property
def size(self) -> int:
"""
Trace numpy.ndarray.size.
"""
return self.output.size
@property
def T(self) -> "Tracer": # pylint: disable=invalid-name # noqa: N802
"""
Trace numpy.ndarray.T.
"""
return Tracer._trace_numpy_operation(np.transpose, self)
class Annotation(Tracer):
"""
Base annotation for direct definition.
"""
class ScalarAnnotation(Annotation):
"""
Base scalar annotation for direct definition.
"""
dtype: BaseDataType
class TensorAnnotation(Annotation):
"""
Base tensor annotation for direct definition.
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
"""
Define the available values and their semantics.
"""
from .scalar import ClearScalar, EncryptedScalar
from .tensor import ClearTensor, EncryptedTensor
from .value import Value

View File

@@ -0,0 +1,44 @@
"""
Declaration of `ClearScalar` and `EncryptedScalar` wrappers.
"""
from ..dtypes import BaseDataType
from .value import Value
def clear_scalar_builder(dtype: BaseDataType) -> Value:
"""
Build a clear scalar value.
Args:
dtype (BaseDataType):
dtype of the value
Returns:
Value:
clear scalar value with given dtype
"""
return Value(dtype=dtype, shape=(), is_encrypted=False)
ClearScalar = clear_scalar_builder
def encrypted_scalar_builder(dtype: BaseDataType) -> Value:
"""
Build an encrypted scalar value.
Args:
dtype (BaseDataType):
dtype of the value
Returns:
Value:
encrypted scalar value with given dtype
"""
return Value(dtype=dtype, shape=(), is_encrypted=True)
EncryptedScalar = encrypted_scalar_builder

View File

@@ -0,0 +1,52 @@
"""
Declaration of `ClearTensor` and `EncryptedTensor` wrappers.
"""
from typing import Tuple
from ..dtypes import BaseDataType
from .value import Value
def clear_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
"""
Build a clear tensor value.
Args:
dtype (BaseDataType):
dtype of the value
shape (Tuple[int, ...]):
shape of the value
Returns:
Value:
clear tensor value with given dtype and shape
"""
return Value(dtype=dtype, shape=shape, is_encrypted=False)
ClearTensor = clear_tensor_builder
def encrypted_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
"""
Build an encrypted tensor value.
Args:
dtype (BaseDataType):
dtype of the value
shape (Tuple[int, ...]):
shape of the value
Returns:
Value:
encrypted tensor value with given dtype and shape
"""
return Value(dtype=dtype, shape=shape, is_encrypted=True)
EncryptedTensor = encrypted_tensor_builder

View File

@@ -0,0 +1,162 @@
"""
Declaration of `Value` class.
"""
from typing import Any, Tuple
import numpy as np
from ..dtypes import BaseDataType, Float, Integer, UnsignedInteger
class Value:
"""
Value class, to combine data type, shape, and encryption status into a single object.
"""
dtype: BaseDataType
shape: Tuple[int, ...]
is_encrypted: bool
@staticmethod
def of(value: Any, is_encrypted: bool = False) -> "Value": # pylint: disable=invalid-name
"""
Get the `Value` that can represent `value`.
Args:
value (Any):
value that needs to be represented
is_encrypted (bool, default = False):
whether the resulting `Value` is encrypted or not
Returns:
Value:
`Value` that can represent `value`
Raises:
ValueError:
if `value` cannot be represented by `Value`
"""
# pylint: disable=too-many-branches,too-many-return-statements
if isinstance(value, (bool, np.bool_)):
return Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=is_encrypted)
if isinstance(value, (int, np.integer)):
return Value(
dtype=Integer.that_can_represent(value),
shape=(),
is_encrypted=is_encrypted,
)
if isinstance(value, (float, np.float64)):
return Value(dtype=Float(64), shape=(), is_encrypted=is_encrypted)
if isinstance(value, np.float32):
return Value(dtype=Float(32), shape=(), is_encrypted=is_encrypted)
if isinstance(value, np.float16):
return Value(dtype=Float(16), shape=(), is_encrypted=is_encrypted)
if isinstance(value, list):
try:
value = np.array(value)
except Exception: # pylint: disable=broad-except
# here we try our best to convert the list to np.ndarray
# if it fails we raise the exception at the end of the function
pass
if isinstance(value, np.ndarray):
if np.issubdtype(value.dtype, np.bool_):
return Value(dtype=UnsignedInteger(1), shape=value.shape, is_encrypted=is_encrypted)
if np.issubdtype(value.dtype, np.integer):
return Value(
dtype=Integer.that_can_represent(value),
shape=value.shape,
is_encrypted=is_encrypted,
)
if np.issubdtype(value.dtype, np.float64):
return Value(dtype=Float(64), shape=value.shape, is_encrypted=is_encrypted)
if np.issubdtype(value.dtype, np.float32):
return Value(dtype=Float(32), shape=value.shape, is_encrypted=is_encrypted)
if np.issubdtype(value.dtype, np.float16):
return Value(dtype=Float(16), shape=value.shape, is_encrypted=is_encrypted)
message = f"Value cannot represent {repr(value)}"
raise ValueError(message)
# pylint: enable=too-many-branches,too-many-return-statements
def __init__(self, dtype: BaseDataType, shape: Tuple[int, ...], is_encrypted: bool):
self.dtype = dtype
self.shape = shape
self.is_encrypted = is_encrypted
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Value)
and self.dtype == other.dtype
and self.shape == other.shape
and self.is_encrypted == other.is_encrypted
)
def __str__(self) -> str:
encrypted_or_clear_str = "Encrypted" if self.is_encrypted else "Clear"
scalar_or_tensor_str = "Scalar" if self.is_scalar else "Tensor"
shape_str = f", shape={self.shape}" if not self.is_scalar else ""
return f"{encrypted_or_clear_str}{scalar_or_tensor_str}<{str(self.dtype)}{shape_str}>"
@property
def is_clear(self) -> bool:
"""
Get whether the value is clear or not.
Returns:
bool:
True if value is not encrypted, False otherwise
"""
return not self.is_encrypted
@property
def is_scalar(self) -> bool:
"""
Get whether the value is scalar or not.
Returns:
bool:
True if shape of the value is (), False otherwise
"""
return self.shape == ()
@property
def ndim(self) -> int:
"""
Get number of dimensions of the value.
Returns:
int:
number of dimensions of the value
"""
return len(self.shape)
@property
def size(self) -> int:
"""
Get number of elements in the value.
Returns:
int:
number of elements in the value
"""
return int(np.prod(self.shape))

View File

@@ -0,0 +1,6 @@
"""
Implement machine learning operations as specified by ONNX.
"""
from .convolution import conv
from .maxpool import maxpool

View File

@@ -0,0 +1,683 @@
"""
Convolution operations' tracing and evaluation.
"""
import math
from typing import Callable, List, Optional, Tuple, Union, cast
import numpy as np
import torch
from ..numpy.internal.utils import assert_that
from ..numpy.representation import Node
from ..numpy.tracing import Tracer
from ..numpy.values import EncryptedTensor
SUPPORTED_AUTO_PAD = {
"NOTSET",
}
# pylint: disable=too-many-branches,too-many-statements
def conv(
x: Union[np.ndarray, Tracer],
weight: Union[np.ndarray, Tracer],
bias: Optional[Union[np.ndarray, Tracer]] = None,
pads: Optional[Union[Tuple[int, ...], List[int]]] = None,
strides: Optional[Union[Tuple[int, ...], List[int]]] = None,
dilations: Optional[Union[Tuple[int, ...], List[int]]] = None,
kernel_shape: Optional[Union[Tuple[int, ...], List[int]]] = None,
group: int = 1,
auto_pad: str = "NOTSET",
) -> Union[np.ndarray, Tracer]:
"""
Trace and evaluate convolution operations.
Refer to https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv for more info.
Args:
x (Union[np.ndarray, Tracer]): input of shape (N, C, D1, ..., DN)
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C / group, K1, ..., KN)
bias (Optional[Union[np.ndarray, Tracer]], optional): bias of shape (F,). Defaults to None.
pads (Optional[Union[Tuple[int, ...], List[int]]], optional):
padding for the beginning and ending along each spatial axis
(D1_begin, D2_begin, ..., D1_end, D2_end, ...).
Will be set to 0 along each spatial axis if not set.
strides (Optional[Union[Tuple[int, ...], List[int]]], optional):
stride along each spatial axis. Will be set to 1 along each spatial axis if not set.
dilations (Optional[Union[Tuple[int, ...], List[int]]], optional):
dilation along each spatial axis. Will be set to 1 along each spatial axis if not set.
kernel_shape (Optional[Union[Tuple[int, ...], List[int]]], optional):
shape of the convolution kernel. Inferred from input weight if not present
group (int, optional):
number of groups input channels and output channels are divided into. Defaults to 1.
auto_pad (str, optional): padding strategy. Defaults to "NOTSET".
Raises:
ValueError: if arguments are not appropriate
TypeError: unexpected types
NotImplementedError: a convolution that we don't support
Returns:
Union[np.ndarray, Tracer]: evaluation result or traced computation
"""
if kernel_shape is not None and (
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
):
message = f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}"
raise ValueError(message)
if isinstance(x, np.ndarray):
if not isinstance(weight, np.ndarray):
message = "expected weight to be of same type as x"
raise TypeError(message)
if bias is not None and not isinstance(bias, np.ndarray):
message = "expected bias to be of same type as x"
raise TypeError(message)
elif isinstance(x, Tracer):
if not isinstance(weight, (Tracer, np.ndarray)):
message = "expected weight to be of type Tracer or ndarray"
raise TypeError(message)
if bias is not None and not isinstance(bias, (Tracer, np.ndarray)):
message = "expected bias to be of type Tracer or ndarray"
raise TypeError(message)
if x.ndim <= 2:
message = (
f"expected input x to have at least 3 dimensions (N, C, D1, ...), but got {x.ndim}"
)
raise ValueError(message)
if weight.ndim <= 2:
message = (
f"expected weight to have at least 3 dimensions (F, C / group, K1, ...), but got "
f"{weight.ndim}"
)
raise ValueError(message)
if bias is not None and bias.ndim != 1:
message = f"expected bias to have a single dimension (F,), but got {bias.ndim}"
raise ValueError(message)
if not isinstance(group, int) or group <= 0:
message = f"expected group to be an integer > 0, but got {group}"
raise ValueError(message)
if auto_pad not in SUPPORTED_AUTO_PAD:
message = f"auto_pad should be in {SUPPORTED_AUTO_PAD}, but got {repr(auto_pad)}"
raise ValueError(message)
n_channels = x.shape[1]
if weight.shape[1] != n_channels / group:
message = (
f"expected number of channel in weight to be {n_channels / group} (C / group), but got "
f"{weight.shape[1]}"
)
raise ValueError(message)
if weight.shape[0] % group != 0:
message = (
f"expected number of feature maps ({weight.shape[0]}) to be a multiple of group "
f"({group})"
)
raise ValueError(message)
dims = x.ndim - 2
if dims == 1:
pads = (0, 0) if pads is None else pads
strides = (1,) if strides is None else strides
dilations = (1,) if dilations is None else dilations
return _conv1d(
x,
weight,
bias=bias,
pads=pads,
strides=strides,
dilations=dilations,
group=group,
auto_pad=auto_pad,
)
if dims == 2:
pads = (0, 0, 0, 0) if pads is None else pads
strides = (1, 1) if strides is None else strides
dilations = (1, 1) if dilations is None else dilations
return _conv2d(
x,
weight,
bias=bias,
pads=pads,
strides=strides,
dilations=dilations,
group=group,
auto_pad=auto_pad,
)
if dims == 3:
pads = (0, 0, 0, 0, 0, 0) if pads is None else pads
strides = (1, 1, 1) if strides is None else strides
dilations = (1, 1, 1) if dilations is None else dilations
return _conv3d(
x,
weight,
bias=bias,
pads=pads,
strides=strides,
dilations=dilations,
group=group,
auto_pad=auto_pad,
)
message = "only 1D, 2D, and 3D convolutions are supported"
raise NotImplementedError(message)
# pylint: enable=too-many-branches
def _conv1d(
x: Union[np.ndarray, Tracer],
weight: Union[np.ndarray, Tracer],
bias: Optional[Union[np.ndarray, Tracer]],
pads: Union[Tuple[int, ...], List[int]],
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
auto_pad: str, # pylint: disable=unused-argument
) -> Union[np.ndarray, Tracer]:
"""
Trace or evaluate 1D convolution.
Args:
x (Union[np.ndarray, Tracer]): input of shape (N, C, D)
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C, D)
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
pads (Union[Tuple[int, ...], List[int]]):
padding over dimension D (D_beg, D_end)
strides (Union[Tuple[int, ...], List[int]]): stride over dimension D
dilations (Union[Tuple[int, ...], List[int]]): dilation over dimension D
group (int, optional):
number of groups input channels and output channels are divided into.
auto_pad (str, optional): padding strategy.
Raises:
ValueError: if arguments are not appropriate
Returns:
Union[np.ndarray, Tracer]: evaluation result or traced computation
"""
assert_that(
x.ndim == 3,
f"expected input x to be of shape (N, C, D) when performing 1D convolution, but "
f"got {x.shape}",
)
assert_that(
weight.ndim == 3,
f"expected weight to be of shape (F, C, D) when performing 1D convolution, but "
f"got {weight.shape}",
)
if len(pads) != 2:
message = (
f"pads should be of form "
f"(D_begin_pad, D_end_pad) when performing "
f"1D convolution, but it's {pads}"
)
raise ValueError(message)
if len(strides) != 1:
message = (
f"strides should be of form (D_stride,) when performing 1D "
f"convolution, but it's {strides}"
)
raise ValueError(message)
if len(dilations) != 1:
message = (
f"dilations should be of form (D_dilation,) when performing 1D "
f"convolution, but it's {dilations}"
)
raise ValueError(message)
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
def _conv2d(
x: Union[np.ndarray, Tracer],
weight: Union[np.ndarray, Tracer],
bias: Optional[Union[np.ndarray, Tracer]],
pads: Union[Tuple[int, ...], List[int]],
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
auto_pad: str, # pylint: disable=unused-argument
) -> Union[np.ndarray, Tracer]:
"""
Trace or evaluate 2D convolution.
Args:
x (Union[np.ndarray, Tracer]): input of shape (N, C, H, W)
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C, H, W)
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
pads (Union[Tuple[int, ...], List[int]]):
padding over each height and width (H_beg, W_beg, H_end, W_end)
strides (Union[Tuple[int, ...], List[int]]): stride over height and width
dilations (Union[Tuple[int, ...], List[int]]): dilation over height and width
group (int, optional):
number of groups input channels and output channels are divided into.
auto_pad (str, optional): padding strategy.
Raises:
ValueError: if arguments are not appropriate
Returns:
Union[np.ndarray, Tracer]: evaluation result or traced computation
"""
assert_that(
x.ndim == 4,
f"expected input x to be of shape (N, C, H, W) when performing 2D convolution, but "
f"got {x.shape}",
)
assert_that(
weight.ndim == 4,
f"expected weight to be of shape (F, C, H, W) when performing 2D convolution, but "
f"got {weight.shape}",
)
if len(pads) != 4:
message = (
f"pads should be of form "
f"(height_begin_pad, width_begin_pad, height_end_pad, width_end_pad) when performing "
f"2D convolution, but it's {pads}"
)
raise ValueError(message)
if len(strides) != 2:
message = (
f"strides should be of form (height_stride, width_stride) when performing 2D "
f"convolution, but it's {strides}"
)
raise ValueError(message)
if len(dilations) != 2:
message = (
f"dilations should be of form (height_dilation, width_dilation) when performing 2D "
f"convolution, but it's {dilations}"
)
raise ValueError(message)
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
def _conv3d(
x: Union[np.ndarray, Tracer],
weight: Union[np.ndarray, Tracer],
bias: Optional[Union[np.ndarray, Tracer]],
pads: Union[Tuple[int, ...], List[int]],
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
auto_pad: str, # pylint: disable=unused-argument
) -> Union[np.ndarray, Tracer]:
"""
Trace or evaluate 3D convolution.
Args:
x (Union[np.ndarray, Tracer]): input of shape (N, C, D, H, W)
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C, D, H, W)
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
pads (Union[Tuple[int, ...], List[int]]):
padding over each spatial axis (D_beg, H_beg, W_beg, D_end, H_end, W_end)
strides (Union[Tuple[int, ...], List[int]]): stride over each spatial axis
dilations (Union[Tuple[int, ...], List[int]]): dilation over each spatial axis
group (int, optional):
number of groups input channels and output channels are divided into.
auto_pad (str, optional): padding strategy.
Raises:
ValueError: if arguments are not appropriate
Returns:
Union[np.ndarray, Tracer]: evaluation result or traced computation
"""
assert_that(
x.ndim == 5,
f"expected input x to be of shape (N, C, D, H, W) when performing 3D convolution, but "
f"got {x.shape}",
)
assert_that(
weight.ndim == 5,
f"expected weight to be of shape (F, C, D, H, W) when performing 3D convolution, but "
f"got {weight.shape}",
)
if len(pads) != 6:
message = (
f"pads should be of form "
f"(D_begin_pad, height_begin_pad, width_begin_pad, "
f"D_end_pad, height_end_pad, width_end_pad) when performing "
f"3D convolution, but it's {pads}"
)
raise ValueError(message)
if len(strides) != 3:
message = (
f"strides should be of form (D_stride, height_stride, width_stride) when performing "
f"3D convolution, but it's {strides}"
)
raise ValueError(message)
if len(dilations) != 3:
message = (
f"dilations should be of form (D_dilation, height_dilation, width_dilation) when "
f"performing 3D convolution, but it's {dilations}"
)
raise ValueError(message)
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
def _trace_or_eval(
x: Union[np.ndarray, Tracer],
weight: Union[np.ndarray, Tracer],
bias: Optional[Union[np.ndarray, Tracer]],
pads: Union[Tuple[int, ...], List[int]],
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
) -> Union[np.ndarray, Tracer]:
"""
Trace or evaluate convolution.
Args:
x (Union[np.ndarray, Tracer]): input of shape (N, C, D1, ..., DN)
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C / group, K1, ..., KN)
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
pads (Union[Tuple[int, ...], List[int]]):
padding for the beginning and ending along each spatial axis
(D1_begin, D2_begin, ..., D1_end, D2_end, ...).
strides (Union[Tuple[int, ...], List[int]]): stride along each spatial axis.
dilations (Union[Tuple[int, ...], List[int]]): dilation along each spatial axis.
group (int, optional):
number of groups input channels and output channels are divided into.
Returns:
Union[np.ndarray, Tracer]: evaluation result or traced computation
"""
assert_that(x.ndim in [3, 4, 5], "only support 1D, 2D, and 3D conv")
if x.ndim == 3:
conv_func = "conv1d"
elif x.ndim == 4:
conv_func = "conv2d"
else: # x.ndim == 5
conv_func = "conv3d"
if isinstance(x, Tracer):
return _trace_conv(x, weight, bias, pads, strides, dilations, group, conv_func)
assert isinstance(x, np.ndarray)
assert isinstance(weight, np.ndarray)
dtype = (
np.float64
if np.issubdtype(x.dtype, np.floating) or np.issubdtype(weight.dtype, np.floating)
else np.int64
)
bias = np.zeros(weight.shape[0], dtype=dtype) if bias is None else bias
assert isinstance(bias, np.ndarray)
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, conv_func)
def _trace_conv(
x: Tracer,
weight: Union[np.ndarray, Tracer],
bias: Optional[Union[np.ndarray, Tracer]],
pads: Union[Tuple[int, ...], List[int]],
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
conv_func: str,
) -> Tracer:
"""
Trace convolution.
Args:
x (Tracer): input of shape (N, C, D1, ..., DN)
weight (Union[np.ndarray, Tracer]): kernel of shape (F, C / group, K1, ..., KN)
bias (Optional[Union[np.ndarray, Tracer]]): bias of shape (F,)
pads (Union[Tuple[int, int, int, int], List[int]]):
padding for the beginning and ending along each spatial axis
(D1_begin, D2_begin, ..., D1_end, D2_end, ...).
strides (Union[Tuple[int, int], List[int]]): stride along each spatial axis.
dilations (Union[Tuple[int, int], List[int]]): dilation along each spatial axis.
group (int, optional):
number of groups input channels and output channels are divided into.
conv_func (str): convolution to apply, should be one of {conv1d,conv2d,conv3d}
Returns:
Tracer:
traced computation
"""
conv_eval_funcs = {
"conv1d": _evaluate_conv1d,
"conv2d": _evaluate_conv2d,
"conv3d": _evaluate_conv3d,
}
eval_func = conv_eval_funcs.get(conv_func, None)
assert_that(
eval_func is not None,
f"expected conv_func to be one of {list(conv_eval_funcs.keys())}, but got {conv_func}",
)
eval_func = cast(Callable, eval_func)
weight = weight if isinstance(weight, Tracer) else Tracer(Node.constant(weight), [])
input_values = [x.output, weight.output]
inputs = [x, weight]
if bias is not None:
bias = bias if isinstance(bias, Tracer) else Tracer(Node.constant(bias), [])
input_values.append(bias.output)
inputs.append(bias)
batch_size = x.output.shape[0]
n_filters = weight.output.shape[0]
n_dim = x.ndim - 2 # remove batch_size and channel dims
total_pads_per_dim = []
for dim in range(n_dim):
total_pads_per_dim.append(pads[dim] + pads[n_dim + dim])
output_shape = [batch_size, n_filters]
for dim in range(n_dim):
input_dim_at_dim = x.output.shape[dim + 2]
weight_dim_at_dim = weight.output.shape[dim + 2]
output_shape.append(
math.floor(
(
input_dim_at_dim
+ total_pads_per_dim[dim]
- dilations[dim] * (weight_dim_at_dim - 1)
- 1
)
/ strides[dim]
)
+ 1
)
output_value = EncryptedTensor(dtype=x.output.dtype, shape=tuple(output_shape))
computation = Node.generic(
conv_func, # "conv1d" or "conv2d" or "conv3d"
input_values,
output_value,
eval_func,
args=() if bias is not None else (np.zeros(n_filters, dtype=np.int64),),
kwargs={"pads": pads, "strides": strides, "dilations": dilations, "group": group},
)
return Tracer(computation, inputs)
def _evaluate_conv1d(
x: np.ndarray,
weight: np.ndarray,
bias: np.ndarray,
pads: Union[Tuple[int, ...], List[int]],
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
) -> np.ndarray:
"""
Evaluate 1D convolution.
Args:
x (np.ndarray): input of shape (N, C, D)
weight (np.ndarray): kernel of shape (F, C / group, D)
bias (np.ndarray): bias of shape (F,)
pads (Union[Tuple[int, ...], List[int]]):
padding over each axis (D_beg, D_end)
strides (Union[Tuple[int, ...], List[int]]): stride over dimension D
dilations (Union[Tuple[int, ...], List[int]]): dilation over dimension D
group (int, optional):
number of groups input channels and output channels are divided into.
Returns:
np.ndarray: result of the convolution
"""
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, "conv1d")
def _evaluate_conv2d(
x: np.ndarray,
weight: np.ndarray,
bias: np.ndarray,
pads: Union[Tuple[int, ...], List[int]],
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
) -> np.ndarray:
"""
Evaluate 2D convolution.
Args:
x (np.ndarray): input of shape (N, C, H, W)
weight (np.ndarray): kernel of shape (F, C / group, H, W)
bias (np.ndarray): bias of shape (F,)
pads (Union[Tuple[int, ...], List[int]]):
padding over each axis (H_beg, W_beg, H_end, W_end)
strides (Union[Tuple[int, ...], List[int]]): stride over height and width
dilations (Union[Tuple[int, ...], List[int]]): dilation over height and width
group (int, optional):
number of groups input channels and output channels are divided into.
Returns:
np.ndarray: result of the convolution
"""
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, "conv2d")
def _evaluate_conv3d(
x: np.ndarray,
weight: np.ndarray,
bias: np.ndarray,
pads: Union[Tuple[int, ...], List[int]],
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
) -> np.ndarray:
"""
Evaluate 3D convolution.
Args:
x (np.ndarray): input of shape (N, C, D, H, W)
weight (np.ndarray): kernel of shape (F, C / group, D, H, W)
bias (np.ndarray): bias of shape (F,)
pads (Union[Tuple[int, ...], List[int]]):
padding over each axis (D_beg, H_beg, W_beg, D_end, H_end, W_end)
strides (Union[Tuple[int, ...], List[int]]): stride over D, height, and width
dilations (Union[Tuple[int, ...], List[int]]): dilation over D, height, and width
group (int, optional):
number of groups input channels and output channels are divided into.
Returns:
np.ndarray: result of the convolution
"""
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, "conv3d")
def _evaluate_conv(
x: np.ndarray,
weight: np.ndarray,
bias: np.ndarray,
pads: Union[Tuple[int, ...], List[int]], # pylint: disable=unused-argument
strides: Union[Tuple[int, ...], List[int]],
dilations: Union[Tuple[int, ...], List[int]],
group: int,
conv_func: str,
) -> np.ndarray:
"""
Evaluate 2D convolution.
Args:
x (np.ndarray): input of shape (N, C, D1, ..., DN)
weight (np.ndarray): kernel of shape (F, C / group, K1, ..., KN)
bias (np.ndarray): bias of shape (F,)
pads (Union[Tuple[int, ...], List[int]]):
padding for the beginning and ending along each spatial axis
(D1_begin, D2_begin, ..., D1_end, D2_end, ...).
strides (Union[Tuple[int, ...], List[int]]): stride along each spatial axis.
dilations (Union[Tuple[int, ...], List[int]]): dilation along each spatial axis.
group (int, optional):
number of groups input channels and output channels are divided into.
conv_func (str): convolution to apply, should be one of {conv1d,conv2d,conv3d}
Returns:
np.ndarray: result of the convolution
"""
# pylint: disable=no-member
conv_funcs = {
"conv1d": torch.conv1d,
"conv2d": torch.conv2d,
"conv3d": torch.conv3d,
}
torch_conv_func = conv_funcs.get(conv_func, None)
assert_that(
torch_conv_func is not None,
f"expected conv_func to be one of {list(conv_funcs.keys())}, but got {conv_func}",
)
torch_conv_func = cast(Callable, torch_conv_func)
n_dim = x.ndim - 2 # remove batch_size and channel dims
torch_padding = []
for dim in range(n_dim):
if pads[dim] != pads[n_dim + dim]:
message = (
f"padding should be the same for the beginning of the dimension and its end, but "
f"got {pads[dim]} in the beginning, and {pads[n_dim + dim]} at the end for "
f"dimension {dim}"
)
raise ValueError(message)
torch_padding.append(pads[dim])
dtype = (
torch.float64
if np.issubdtype(x.dtype, np.floating)
or np.issubdtype(weight.dtype, np.floating)
or np.issubdtype(bias.dtype, np.floating)
else torch.long
)
return torch_conv_func(
torch.tensor(x, dtype=dtype),
torch.tensor(weight, dtype=dtype),
torch.tensor(bias, dtype=dtype),
stride=strides,
padding=torch_padding,
dilation=dilations,
groups=group,
).numpy()
# pylint: enable=no-member

View File

@@ -0,0 +1,336 @@
"""
Tracing and evaluation of maxpool function.
"""
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..numpy.internal.utils import assert_that
from ..numpy.representation import Node
from ..numpy.tracing import Tracer
from ..numpy.values import Value
# pylint: disable=too-many-branches,too-many-statements
AVAILABLE_AUTO_PAD = {
"NOTSET",
"SAME_UPPER",
"SAME_LOWER",
"VALID",
}
AVAILABLE_CEIL_MODE = {
0,
1,
}
AVAILABLE_STORAGE_ORDER = {
0,
1,
}
SUPPORTED_AUTO_PAD = {
"NOTSET",
}
SUPPORTED_CEIL_MODE = {
0,
}
SUPPORTED_STORAGE_ORDER = {
0,
}
# pylint: disable=no-member
_EVALUATORS = {
1: torch.max_pool1d,
2: torch.max_pool2d,
3: torch.max_pool3d,
}
# pylint: enable=no-member
def maxpool(
x: Union[np.ndarray, Tracer],
kernel_shape: Union[Tuple[int, ...], List[int]],
strides: Optional[Union[Tuple[int, ...], List[int]]] = None,
auto_pad: str = "NOTSET",
pads: Optional[Union[Tuple[int, ...], List[int]]] = None,
dilations: Optional[Union[Tuple[int, ...], List[int]]] = None,
ceil_mode: int = 0,
storage_order: int = 0,
) -> Union[np.ndarray, Tracer]:
"""
Evaluate or trace MaxPool operation.
Refer to https://github.com/onnx/onnx/blob/main/docs/Operators.md#maxpool for more info.
Args:
x (Union[np.ndarray, Tracer]):
input of shape (N, C, D1, ..., DN)
kernel_shape (Union[Tuple[int, ...], List[int]]):
shape of the kernel
strides (Optional[Union[Tuple[int, ...], List[int]]]):
stride along each spatial axis
set to 1 along each spatial axis if not set
auto_pad (str, default = "NOTSET"):
padding strategy
pads (Optional[Union[Tuple[int, ...], List[int]]]):
padding for the beginning and ending along each spatial axis
(D1_begin, D2_begin, ..., D1_end, D2_end, ...)
set to 0 along each spatial axis if not set
dilations (Optional[Union[Tuple[int, ...], List[int]]]):
dilation along each spatial axis
set to 1 along each spatial axis if not set
ceil_mode (int, default = 1):
ceiling mode
storage_order (int, default = 0):
storage order, 0 for row major, 1 for column major
Raises:
TypeError:
if arguments are inappropriately typed
ValueError:
if arguments are inappropriate
NotImplementedError:
if desired operation is not supported yet
Returns:
Union[np.ndarray, Tracer]:
maxpool over the input or traced computation
"""
def check_value_is_a_tuple_or_list_of_ints_of_size(value_name, value, size) -> Tuple[int, ...]:
if isinstance(value, list):
value = tuple(value)
if not isinstance(value, tuple):
message = (
f"Expected {value_name} to be a tuple or a list but it's {type(value).__name__}"
)
raise TypeError(message)
for element in value:
if not isinstance(element, int):
message = (
f"Expected {value_name} to consist of integers "
f"but it has an element of type {type(element).__name__}"
)
raise TypeError(message)
if len(value) != size:
message = f"Expected {value_name} to have {size} elements but it has {len(value)}"
raise ValueError(message)
return value
# check x
if isinstance(x, list): # pragma: no cover
try:
x = np.array(x)
except Exception: # pylint: disable=broad-except
pass
if isinstance(x, np.ndarray):
if not (
np.issubdtype(x.dtype, np.integer)
or np.issubdtype(x.dtype, np.floating)
or np.issubdtype(x.dtype, np.bool_)
):
message = (
f"Expected input elements to be of type np.integer, np.floating, or np.bool_ "
f"but it's {type(x.dtype).__name__}"
)
raise TypeError(message)
elif not isinstance(x, Tracer):
message = (
f"Expected input to be of type np.ndarray or Tracer "
f"but it's {type(auto_pad).__name__}"
)
raise TypeError(message)
if x.ndim < 3:
message = (
f"Expected input to have at least 3 dimensions (N, C, D1, ...) "
f"but it only has {x.ndim}"
)
raise ValueError(message)
if x.ndim > 5:
message = f"{x.ndim - 2}D maximum pooling is not supported yet"
raise NotImplementedError(message)
# check kernel_shape
kernel_shape = check_value_is_a_tuple_or_list_of_ints_of_size(
"kernel_shape", kernel_shape, x.ndim - 2
)
# check strides
if strides is None:
strides = (1,) * (x.ndim - 2)
strides = check_value_is_a_tuple_or_list_of_ints_of_size("strides", strides, x.ndim - 2)
# check auto_pad
if not isinstance(auto_pad, str):
message = f"Expected auto_pad to be of type str but it's {type(auto_pad).__name__}"
raise TypeError(message)
if auto_pad not in AVAILABLE_AUTO_PAD:
message = (
f"Expected auto_pad to be one of "
f"{', '.join(sorted(AVAILABLE_AUTO_PAD))} "
f"but it's {auto_pad}"
)
raise ValueError(message)
if auto_pad not in SUPPORTED_AUTO_PAD:
message = f"Desired auto_pad of {auto_pad} is not supported yet"
raise NotImplementedError(message)
# check pads
if pads is None:
pads = (0,) * (2 * (x.ndim - 2))
pads = check_value_is_a_tuple_or_list_of_ints_of_size("pads", pads, 2 * (x.ndim - 2))
for i in range(len(pads) // 2):
pad_begin = pads[i]
pad_end = pads[i + len(pads) // 2]
if pad_begin != pad_end:
message = f"Desired pads of {pads} is not supported yet because of uneven padding"
raise NotImplementedError(message)
# check dilations
if dilations is None:
dilations = (1,) * (x.ndim - 2)
dilations = check_value_is_a_tuple_or_list_of_ints_of_size("dilations", dilations, x.ndim - 2)
# check ceil_mode
if not isinstance(ceil_mode, int):
message = f"Expected ceil_mode to be of type int but it's {type(ceil_mode).__name__}"
raise TypeError(message)
if ceil_mode not in AVAILABLE_CEIL_MODE:
message = (
f"Expected ceil_mode to be one of "
f"{', '.join(sorted(str(x) for x in AVAILABLE_CEIL_MODE))} "
f"but it's {ceil_mode}"
)
raise ValueError(message)
if ceil_mode not in SUPPORTED_CEIL_MODE:
message = f"Desired ceil_mode of {ceil_mode} is not supported yet"
raise NotImplementedError(message)
# check storage_order
if not isinstance(storage_order, int):
message = (
f"Expected storage_order to be of type int but it's {type(storage_order).__name__}"
)
raise TypeError(message)
if storage_order not in AVAILABLE_STORAGE_ORDER:
message = (
f"Expected storage_order to be one of "
f"{', '.join(sorted(str(x) for x in AVAILABLE_STORAGE_ORDER))} "
f"but it's {storage_order}"
)
raise ValueError(message)
if storage_order not in SUPPORTED_STORAGE_ORDER:
message = f"Desired storage_order of {storage_order} is not supported yet"
raise NotImplementedError(message)
# trace or evaluate
return _trace_or_evaluate(x, kernel_shape, strides, pads, dilations, ceil_mode == 1)
def _trace_or_evaluate(
x: Union[np.ndarray, Tracer],
kernel_shape: Tuple[int, ...],
strides: Tuple[int, ...],
pads: Tuple[int, ...],
dilations: Tuple[int, ...],
ceil_mode: bool,
):
if not isinstance(x, Tracer):
return _evaluate(x, kernel_shape, strides, pads, dilations, ceil_mode == 1)
result = _evaluate(np.zeros(x.shape), kernel_shape, strides, pads, dilations, ceil_mode == 1)
resulting_value = Value.of(result)
resulting_value.is_encrypted = x.output.is_encrypted
resulting_value.dtype = x.output.dtype
computation = Node.generic(
"maxpool",
[x.output],
resulting_value,
_evaluate,
kwargs={
"kernel_shape": kernel_shape,
"strides": strides,
"pads": pads,
"dilations": dilations,
"ceil_mode": ceil_mode,
},
)
return Tracer(computation, [x])
def _evaluate(
x: np.ndarray,
kernel_shape: Tuple[int, ...],
strides: Tuple[int, ...],
pads: Tuple[int, ...],
dilations: Tuple[int, ...],
ceil_mode: bool,
) -> np.ndarray:
# pylint: disable=no-member
dims = x.ndim - 2
assert_that(dims in {1, 2, 3})
evaluator = _EVALUATORS[dims]
result = (
evaluator(
torch.from_numpy(x.astype(np.float64)), # torch only supports float maxpools
kernel_shape,
strides,
pads[: len(pads) // 2],
dilations,
ceil_mode,
)
.numpy()
.astype(x.dtype)
)
# pylint: enable=no-member
return result

View File

@@ -0,0 +1,53 @@
FROM ubuntu:22.04
ENV TZ=Europe/Paris
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
# Replace default archive.ubuntu.com with fr mirror
# original archive showed performance issues and is farther away
RUN sed -i 's|^deb http://archive|deb http://fr.archive|g' /etc/apt/sources.list
COPY ./script/make_utils/setup_os_deps.sh ./setup_os_deps.sh
RUN ./setup_os_deps.sh --linux-install-python && rm ./setup_os_deps.sh
ENV SRC_DIR=/src
# Default to Ubuntu default uid for first user
ARG BUILD_GID=1000
ARG BUILD_UID=1000
# Get sudo for our future user
RUN apt-get update && \
apt-get install --no-install-recommends -y sudo && \
rm -rf /var/lib/apt/lists/*
# From https://dev.to/emmanuelnk/using-sudo-without-password-prompt-as-non-root-docker-user-52bg
# Create dev_user and add it to relevant groups
# Create /src and make the dev user own it
# Ensure sudo group users are not asked for a password when using
# sudo command by ammending sudoers file
RUN groupadd -g "${BUILD_GID}" dev_user && \
adduser --disabled-password \
--uid "${BUILD_UID}" --gid "${BUILD_GID}" --shell /bin/bash --gecos "" dev_user && \
usermod -aG sudo dev_user && \
mkdir -p "${SRC_DIR}" && \
chown dev_user "${SRC_DIR}" && \
echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers
# Now switch to the newly created user
USER dev_user
RUN echo "source ~/dev_venv/bin/activate" >> ~/.bashrc && \
echo "if [[ \"\$?\" != \"0\" ]]; then" >> ~/.bashrc && \
echo " python3 -m venv ~/dev_venv" >> ~/.bashrc && \
echo " source ~/dev_venv/bin/activate" >> ~/.bashrc && \
echo " cd ${SRC_DIR}/ && make setup_env" >> ~/.bashrc && \
echo "fi" >> ~/.bashrc && \
echo "export MPLBACKEND=TkAgg" >> ~/.bashrc && \
touch ~/.sudo_as_admin_successful && \
mkdir -p ~/dev_venv && \
mkdir -p ~/.cache
WORKDIR ${SRC_DIR}
CMD ["/bin/bash"]

View File

@@ -0,0 +1 @@
!script/make_utils/setup_os_deps.sh

View File

@@ -0,0 +1,36 @@
FROM ubuntu:22.04
ENV TZ=Europe/Paris
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
# Replace default archive.ubuntu.com with fr mirror
# original archive showed performance issues and is farther away
RUN sed -i 's|^deb http://archive|deb http://fr.archive|g' /etc/apt/sources.list
RUN mkdir /pkg && mkdir /app
WORKDIR /pkg
COPY docker/release_resources/release_requirements.txt .
COPY ./pkg/*.whl .
RUN apt-get update && apt-get upgrade --no-install-recommends -y && \
apt-get install --no-install-recommends -y \
build-essential \
python3-pip \
python3 \
python3-dev \
python3-tk \
python-is-python3 && \
rm -rf /var/lib/apt/lists/* && \
python3 -m pip install --no-cache-dir --upgrade pip wheel setuptools && \
echo "export MPLBACKEND=TkAgg" >> /root/.bashrc && \
python3 -m pip install --no-cache-dir "$(ls ./*.whl)" && \
python3 -m pip install --no-cache-dir -r release_requirements.txt
WORKDIR /app
COPY docker/release_resources/entry_point.sh ./entry_point.sh
RUN mkdir /data
WORKDIR /data
VOLUME [ "/data" ]
CMD ["/bin/bash", "-i", "/app/entry_point.sh"]

View File

@@ -0,0 +1,6 @@
# Not our sources
!docker/release_resources/entry_point.sh
!docker/release_resources/release_requirements.txt
!pkg/
!pkg/**

View File

@@ -0,0 +1,5 @@
#!/usr/bin/env bash
CURR_DIR=$(dirname "$0")
DOCKER_BUILDKIT=1 docker build --pull --no-cache -f "$CURR_DIR/Dockerfile.release" \
-t concrete-numpy-release "$CURR_DIR/.."

View File

@@ -0,0 +1,3 @@
#!/bin/bash
python3 -m jupyter notebook --ip=0.0.0.0 --allow-root --no-browser

View File

@@ -0,0 +1 @@
jupyter~=1.0.0

View File

@@ -0,0 +1,41 @@
import random
import concrete.numpy as cnp
def main():
def function_to_compile(x):
return x + 42
n_bits = 3
compiler = cnp.Compiler(
function_to_compile,
{"x": "encrypted"},
)
print("Compiling...")
engine = compiler.compile(range(2 ** n_bits))
inputs = []
labels = []
for _ in range(4):
sample_x = random.randint(0, 2 ** n_bits - 1)
inputs.append([sample_x])
labels.append(function_to_compile(*inputs[-1]))
correct = 0
for idx, (input_i, label_i) in enumerate(zip(inputs, labels), 1):
print(f"Inference #{idx}")
result_i = engine.encrypt_run_decrypt(*input_i)
if result_i == label_i:
correct += 1
print(f"{correct}/{len(inputs)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,24 @@
# What is Concrete Numpy?
[<mark style="background-color:yellow;">⭐️ Star the repo on Github</mark>](https://github.com/zama-ai/concrete-numpy) <mark style="background-color:yellow;">| 🗣</mark> [<mark style="background-color:yellow;">Community support forum</mark>](https://community.zama.ai/c/concrete-numpy) <mark style="background-color:yellow;">| 📁</mark> [<mark style="background-color:yellow;">Contribute to the project</mark>](dev/contributing.md)
<figure><img src="_static/zama_home_docs.png" alt=""><figcaption></figcaption></figure>
**Concrete-Numpy** is an open-source library which simplifies the use of fully homomorphic encryption (FHE).
FHE is a powerful cryptographic tool, which allows computation to be performed directly on encrypted data without needing to decrypt it first. With FHE, you can build services that preserve privacy for all users. FHE is also great against data breaches as everything is done on encrypted data. Even if the server is compromised, in the end no sensitive data is leaked.
## Organization of this documentation
This documentation is split into several sections:
* **Getting Started** gives you the basics,
* **Tutorials** gives you some essential examples on various features of the library,
* **How to** helps you perform specific tasks,
* and **Developer** explains the inner workings of the library and everything related to contributing to the project.
## Looking for support? Ask our team!
* Support forum: [https://community.zama.ai](https://community.zama.ai) (we answer in less than 24 hours).
* Live discussion on the FHE.org discord server: [https://discord.fhe.org](https://discord.fhe.org) (inside the #**concrete** channel).
* Do you have a question about Zama? You can write us on [Twitter](https://twitter.com/zama\_fhe) or send us an email at: **hello@zama.ai**

View File

@@ -0,0 +1,40 @@
# Table of contents
* [What is Concrete Numpy?](README.md)
## Getting Started
* [Installation](getting-started/installing.md)
* [Quick Start](getting-started/quick\_start.md)
* [Compatibility](getting-started/compatibility.md)
* [Exactness](getting-started/exactness.md)
* [Performance](getting-started/performance.md)
## Tutorials
* [Decorator](tutorial/decorator.md)
* [Formatting](tutorial/formatting.md)
* [Tagging](tutorial/tagging.md)
* [Extensions](tutorial/extensions.md)
* [Table Lookups](tutorial/table\_lookups.md)
* [Rounded Table Lookups](tutorial/rounded\_table\_lookups.md)
* [Floating Points](tutorial/floating\_points.md)
* [Simulation](tutorial/simulation.md)
* [Direct Circuits](tutorial/direct\_circuits.md)
* [Key Value Database](tutorial/key\_value\_database.md)
## How To
* [Configure](howto/configure.md)
* [Debug](howto/debug.md)
* [Deploy](howto/deploy.md)
## Developer
* [Project Setup](dev/project\_setup.md)
* [Docker Setup](dev/docker.md)
* [Contribute](dev/contributing.md)
* [Terminology and Structure](dev/terminology\_and\_structure.md)
* [Compilation](dev/compilation.md)
* [Fusing](dev/fusing.md)
* [MLIR](dev/mlir.md)

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 377 KiB

View File

@@ -0,0 +1 @@
../tests/conftest.py

View File

@@ -0,0 +1,148 @@
# Compilation
The compilation journey begins with tracing to get an easy-to-manipulate representation of the function. We call this representation a `Computation Graph`, which is basically a Directed Acyclic Graph (DAG) containing nodes representing the computations done in the function. Working with graphs is good because they have been studied extensively over the years and there are a lot of algorithms to manipulate them. Internally, we use [networkx](https://networkx.org), which is an excellent graph library for Python.
The next step in the compilation is transforming the computation graph. There are many transformations we perform, and they will be discussed in their own sections. In any case, the result of transformations is just another computation graph.
After transformations are applied, we need to determine the bounds (i.e., the minimum and the maximum values) of each intermediate node. This is required because FHE currently allows a limited precision for computations. Bound measurement is our way to know what is the required precision for the function.
The final step is to transform the computation graph to equivalent `MLIR` code. How this is done will be explained in detail in its own chapter.
Once the MLIR is generated, we send it to the **Concrete-Compiler**, and it completes the compilation process.
## Tracing
Given a Python function `f` such as this one:
```
def f(x):
return (2 * x) + 3
```
...the goal of tracing is to create the following computation graph without needing any change from the user.
![](../\_static/compilation-pipeline/two\_x\_plus\_three.png)
(Note that the edge labels are for non-commutative operations. To give an example, a subtraction node represents `(predecessor with edge label 0) - (predecessor with edge label 1)`)
To do this, we make use of `Tracer`s, which are objects that record the operation performed during their creation. We create a `Tracer` for each argument of the function and call the function with those tracers. `Tracer`s make use of the operator overloading feature of Python to achieve their goal:
```
def f(x, y):
return x + 2 * y
x = Tracer(computation=Input("x"))
y = Tracer(computation=Input("y"))
resulting_tracer = f(x, y)
```
`2 * y` will be performed first, and `*` is overloaded for `Tracer` to return another tracer: `Tracer(computation=Multiply(Constant(2), self.computation))`, which is equal to `Tracer(computation=Multiply(Constant(2), Input("y")))`
`x + (2 * y)` will be performed next, and `+` is overloaded for `Tracer` to return another tracer: `Tracer(computation=Add(self.computation, (2 * y).computation))`, which is equal to `Tracer(computation=Add(Input("x"), Multiply(Constant(2), Input("y")))`
In the end, we will have output tracers that can be used to create the computation graph. The implementation is a bit more complex than this, but the idea is the same.
Tracing is also responsible for indicating whether the values in the node would be encrypted or not, and the rule for that is if a node has an encrypted predecessor, it is encrypted as well.
## Topological transforms
The goal of topological transforms is to make more functions compilable.
With the current version of **Concrete-Numpy**, floating-point inputs and floating-point outputs are not supported. However, if the floating-point operations are intermediate operations, they can sometimes be fused into a single table lookup from integer to integer, thanks to some specific transforms.
Let's take a closer look at the transforms we can currently perform.
### Fusing.
We have allocated a whole new chapter to explaining fusing. You can find it after this chapter.
## Bounds measurement
Given a computation graph, the goal of the bound measurement step is to assign the minimal data type to each node in the graph.
Let's say we have an encrypted input that is always between `0` and `10`. We should assign the type `Encrypted<uint4>` to the node of this input as `Encrypted<uint4>` is the minimal encrypted integer that supports all values between `0` and `10`.
If there were negative values in the range, we could have used `intX` instead of `uintX`.
Bounds measurement is necessary because FHE supports limited precision, and we don't want unexpected behaviour while evaluating the compiled functions.
Let's take a closer look at how we perform bounds measurement.
### Inputset evaluation.
This is a simple approach that requires an inputset to be provided by the user.
The inputset is not to be confused with the dataset, which is classical in ML, as it doesn't require labels. Rather, it is a set of values which are typical inputs of the function.
The idea is to evaluate each input in the inputset and record the result of each operation in the computation graph. Then we compare the evaluation results with the current minimum/maximum values of each node and update the minimum/maximum accordingly. After the entire inputset is evaluated, we assign a data type to each node using the minimum and the maximum values it contains.
Here is an example, given this computation graph where `x` is encrypted:
![](../\_static/compilation-pipeline/two\_x\_plus\_three.png)
and this inputset:
```
[2, 3, 1]
```
Evaluation Result of `2`:
* `x`: 2
* `2`: 2
* `*`: 4
* `3`: 3
* `+`: 7
New Bounds:
* `x`: \[**2**, **2**]
* `2`: \[**2**, **2**]
* `*`: \[**4**, **4**]
* `3`: \[**3**, **3**]
* `+`: \[**7**, **7**]
Evaluation Result of `3`:
* `x`: 3
* `2`: 2
* `*`: 6
* `3`: 3
* `+`: 9
New Bounds:
* `x`: \[2, **3**]
* `2`: \[2, 2]
* `*`: \[4, **6**]
* `3`: \[3, 3]
* `+`: \[7, **9**]
Evaluation Result of `1`:
* `x`: 1
* `2`: 2
* `*`: 2
* `3`: 3
* `+`: 5
New Bounds:
* `x`: \[**1**, 3]
* `2`: \[2, 2]
* `*`: \[**2**, 6]
* `3`: \[3, 3]
* `+`: \[**5**, 9]
Assigned Data Types:
* `x`: Encrypted<**uint2**>
* `2`: Clear<**uint2**>
* `*`: Encrypted<**uint3**>
* `3`: Clear<**uint2**>
* `+`: Encrypted<**uint4**>
## MLIR conversion
The actual compilation will be done by the **Concrete-Compiler**, which is expecting an MLIR input. The MLIR conversion goes from a computation graph to its MLIR equivalent. You can read more about it [here](mlir.md).

View File

@@ -0,0 +1,99 @@
# Contribute
{% hint style="info" %}
There are two ways to contribute to **Concrete-Numpy** or to **Concrete** tools in general:
* You can open issues to report bugs and typos and to suggest ideas.
* You can ask to become an official contributor by emailing hello@zama.ai. Only approved contributors can send pull requests (PRs), so please make sure to get in touch before you do!
{% endhint %}
Now, let's go over some other important items that you need to know.
## Creating a new branch
We are using a consistent branch naming scheme, and you are expected to follow it as well. Here is the format:
```shell
git checkout -b {feat|fix|refactor|test|benchmark|doc|style|chore}/short-description
```
...and here are some examples:
```shell
git checkout -b feat/direct-tlu
git checkout -b fix/tracing-indexing
```
## Before committing
### Conformance.
Each commit to **Concrete-Numpy** should conform to the standards decided by the team. Conformance can be checked using the following command:
```shell
make pcc
```
### Testing.
On top of conformance, all tests must pass with 100% code coverage across the codebase:
```shell
make pytest
```
{% hint style="info" %}
There may be cases where covering 100% of the code is not possible (e.g., exceptions that cannot be triggered in normal execution circumstances). In those cases, you may be allowed to disable coverage for some specific lines. This should be the exception rather than the rule. Reviewers may ask why some lines are not covered and, if it appears they can be covered, then the PR won't be accepted in that state.
{% endhint %}
## Committing
We are using a consistent commit naming scheme, and you are expected to follow it as well. Again, here is the accepted format:
```shell
make show_scope
```
...and some examples:
```shell
git commit -m "feat: implement bounds checking"
git commit -m "feat(debugging): add an helper function to print intermediate representation"
git commit -m "fix(tracing): fix a bug that crashed pytorch tracer"
```
To learn more about conventional commits, check [this](https://www.conventionalcommits.org/en/v1.0.0/) page.
## Before creating a pull request
{% hint style="info" %}
We remind you that only official contributors can send pull requests. To become an official contributor, please email hello@zama.ai.
{% endhint %}
You should rebase on top of the `main` branch before you create your pull request. We don't allow merge commits, so rebasing on `main` before pushing gives you the best chance of avoiding rewriting parts of your PR later if conflicts arise with other PRs being merged. After you commit your changes to your new branch, you can use the following commands to rebase:
```shell
# fetch the list of active remote branches
git fetch --all --prune
# checkout to main
git checkout main
# pull the latest changes to main (--ff-only is there to prevent accidental commits to main)
git pull --ff-only
# checkout back to your branch
git checkout $YOUR_BRANCH
# rebase on top of main branch
git rebase main
# If there are conflicts during the rebase, resolve them
# and continue the rebase with the following command
git rebase --continue
# push the latest version of the local branch to remote
git push --force
```
You can learn more about rebasing [here](https://git-scm.com/docs/git-rebase).

View File

@@ -0,0 +1,45 @@
# Docker Setup
## Installation
Before you start this section, go ahead and install Docker. You can follow [this](https://docs.docker.com/engine/install/) official guide if you need help.
## X forwarding
### Linux.
You can use this xhost command:
```shell
xhost +localhost
```
### macOS.
To use X forwarding on macOS:
* Install XQuartz
* Open XQuartz.app application, make sure in the application parameters that `authorize network connections` are set (currently in the Security settings)
* Open a new terminal within XQuartz.app and type:
```shell
xhost +127.0.0.1
```
X server should be all set for Docker in the regular terminal.
## Building
You can use the dedicated target in the makefile to build the docker image:
```shell
make docker_build
```
## Starting
You can use the dedicated target in the makefile to start the docker session:
```shell
make docker_start
```

View File

@@ -0,0 +1,37 @@
# Fusing
Fusing is the act of combining multiple nodes into a single node, which is converted to a table lookup.
## How is it done?
Code related to fusing is in the `concrete/numpy/compilation/utils.py` file. Fusing can be performed using the `fuse` function.
Within `fuse`:
1. We loop until there are no more subgraphs to fuse.
2. <mark style="background-color:yellow;">Within each iteration:</mark>
2.1. We find a subgraph to fuse.
2.2. We search for a terminal node that is appropriate for fusing.
2.3. We crawl backwards to find the closest integer nodes to this node.
2.4. If there is a single node as such, we return the subgraph from this node to the terminal node.
2.5. Otherwise, we try to find the lowest common ancestor (lca) of this list of nodes.
2.6. If an lca doesn't exist, we say this particular terminal node is not fusable, and we go back to search for another subgraph.
2.7. Otherwise, we use this lca as the input of the subgraph and continue with `subgraph` node creation below.
2.8. We convert the subgraph into a `subgraph` node by checking fusability status of the nodes of the subgraph in this step.
2.10. We substitute the `subgraph` node to the original graph.
## Limitations
With the current implementation, we cannot fuse subgraphs that depend on multiple encrypted values where those values doesn't have a common lca (e.g., `np.round(np.sin(x) + np.cos(y))`).
{% hint style="info" %}
[KolmogorovArnold representation theorem](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Arnold\_representation\_theorem) states that every multivariate continuous function can be represented as a superposition of continuous functions of one variable. Therefore, the case above could be handled in future versions of **Concrete-Numpy**.
{% endhint %}

View File

@@ -0,0 +1,17 @@
# MLIR
The MLIR project is a sub-project of the LLVM project. It's designed to simplify building domain-specific compilers such as our **Concrete-Compiler**.
**Concrete-Compiler** accepts MLIR as an input and emits compiled assembly code for a target architecture.
**Concrete-Numpy** performs the MLIR generation from the computation graph. Code related to this conversion is in the `concrete/numpy/mlir` folder.
The conversion can be performed using the `convert` method of the `GraphConverter` class.
Within the `convert` method of `GraphConverter`:
* MLIR compatibility of the graph is checked;
* bit width constraints are checked;
* negative lookup tables are offset;
* the computation graph is traversed and each node is converted to their corresponding MLIR representation using the `NodeConverter` class;
* and string representation of the resulting MLIR is returned.

View File

@@ -0,0 +1,91 @@
# Project Setup
{% hint style="info" %}
It is **strongly** recommended to use the development tool Docker. However, you are able to set the project up on a bare Linux or macOS as long as you have the required dependencies. You can see the required dependencies in `Dockerfile.dev` under `docker` directory.
{% endhint %}
## Installing `Python`
**Concrete-Numpy** is a `Python` library, so `Python` should be installed to develop it. `v3.8` and `v3.9` are, currently, the only supported versions.
You probably have Python already, but in case you don't, or in case you have an unsupported version, you can google `how to install python 3.8` and follow one of the results.
## Installing `Poetry`
`Poetry` is our package manager. It drastically simplifies dependency and environment management.
You can follow [this](https://python-poetry.org/docs/#installation) official guide to install it.
## Installing `make`
`make` is used to launch various commands such as formatting and testing.
On Linux, you can install `make` using the package manager of your distribution.
On macOS, you can install `gmake` via brew:
```shell
brew install make
```
{% hint style="info" %}
In the following sections, be sure to use the proper `make` tool for your system (i.e., `make`, `gmake`, etc).
{% endhint %}
## Cloning the repository
Now, it's time to get the source code of **Concrete-Numpy**.
Clone the git repository from GitHub using the protocol of your choice (ssh or https).
## Setting up the environment
Virtual environments are utilized to keep the project isolated from other `Python` projects in the system.
To create a new virtual environment and install dependencies, use the command:
```shell
make setup_env
```
## Activating the environment
To activate the newly created environment, use:
```shell
source .venv/bin/activate
```
## Syncing the environment
From time to time, new dependencies will be added to the project and old ones will be removed.mThe command below will make sure the project has the proper environment, so run it regularly.
```shell
make sync_env
```
## Troubleshooting
### In native setups.
If you are having issues in a native setup, you can try to re-create your environment like this:
```shell
deactivate
rm -rf .venv
make setup_env
source .venv/bin/activate
```
If the problem persists, you should consider using Docker. If you are working on a platform specific feature and Docker is not an option, you should create an issue so that we can take a look at your problem.
### In docker setups.
If you are having issues in a docker setup, you can try to re-build the docker image:
```shell
make docker_rebuild
make docker_start
```
If the problem persists, you should contact us for help.

View File

@@ -0,0 +1,17 @@
# Release process
## Release candidate cycle
Throughout the quarter, many release candidatess are relesed. Those candidates are released in a private package repository. At the end of the quarter, we take the latest release candidate, and release it in PyPI without `rcX` tag.
## Release flow
* Checkout to the commit that you want to include in the release (everything before this commit and this commit will be in the release)
* Run `make release`
* Wait for CI to complete
* Checkout to `chore/version` branch
* Run `VERSION=a.b.c-rcX make set_version` with appropriate version
* Push the branch to origin
* Create a PR to merge it to main
* Wait for CI to finish and get approval in the meantime
* Merge the version update to main

View File

@@ -0,0 +1,26 @@
# Terminology and Structure
## Terminology
Some terms used throughout the project include:
* computation graph - a data structure to represent a computation. This is basically a directed acyclic graph in which nodes are either inputs, constants or operations on other nodes.
* tracing - the technique that takes a Python function from the user and generates the corresponding computation graph in an easy-to-read format.
* bounds - before a computation graph is converted to MLIR, we need to know which node will output which type (e.g., uint3 vs euint5). Computation graphs with different inputs must remember the minimum and maximum values for each node, which is what we call bounds, and use bounds to determine the appropriate type for each node.
* circuit - the result of compilation. A circuit is made of the client and server components and has methods, everything from printing to evaluation.
## Module structure
In this section, we will briefly discuss the module structure of **Concrete-Numpy**. You are encouraged to check individual `.py` files to learn more.
* Concrete
* Numpy
* dtypes - data type specifications
* values - value specifications (i.e., data type + shape + encryption status)
* representation - representation of computation
* tracing - tracing of Python functions
* extensions - custom functionality which is not available in NumPy (e.g., direct table lookups)
* MLIR - MLIR conversion
* compilation - compilation from a Python function to a circuit, client/server architecture
* ONNX
* convolution - custom convolution operations that follow the behavior of ONNX

View File

@@ -0,0 +1,177 @@
# Compatibility
## Supported operations
Here are the operations you can use inside the function you are compiling.
{% hint style="info" %}
Some of these operations are not supported between two encrypted values. A detailed error will be raised if you try to do something that is not supported.
{% endhint %}
### Supported Python operators.
* [\_\_abs\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_abs\_\_)
* [\_\_add\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_add\_\_)
* [\_\_and\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_and\_\_)
* [\_\_eq\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_eq\_\_)
* [\_\_floordiv\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_floordiv\_\_)
* [\_\_ge\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_ge\_\_)
* [\_\_getitem\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_getitem\_\_)
* [\_\_gt\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_gt\_\_)
* [\_\_invert\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_invert\_\_)
* [\_\_le\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_le\_\_)
* [\_\_lshift\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_lshift\_\_)
* [\_\_lt\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_lt\_\_)
* [\_\_matmul\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_matmul\_\_)
* [\_\_mod\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_mod\_\_)
* [\_\_mul\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_mul\_\_)
* [\_\_ne\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_ne\_\_)
* [\_\_neg\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_neg\_\_)
* [\_\_or\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_or\_\_)
* [\_\_pos\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_pos\_\_)
* [\_\_pow\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_pow\_\_)
* [\_\_radd\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_radd\_\_)
* [\_\_rand\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rand\_\_)
* [\_\_rfloordiv\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rfloordiv\_\_)
* [\_\_rlshift\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rlshift\_\_)
* [\_\_rmatmul\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rmatmul\_\_)
* [\_\_rmod\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rmod\_\_)
* [\_\_rmul\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rmul\_\_)
* [\_\_ror\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_ror\_\_)
* [\_\_round\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_round\_\_)
* [\_\_rpow\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rpow\_\_)
* [\_\_rrshift\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rrshift\_\_)
* [\_\_rshift\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rshift\_\_)
* [\_\_rsub\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rsub\_\_)
* [\_\_rtruediv\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rtruediv\_\_)
* [\_\_rxor\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_rxor\_\_)
* [\_\_sub\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_sub\_\_)
* [\_\_truediv\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_truediv\_\_)
* [\_\_xor\_\_](https://docs.python.org/3/reference/datamodel.html#object.\_\_xor\_\_)
### Supported NumPy functions.
* [np.absolute](https://numpy.org/doc/stable/reference/generated/numpy.absolute.html)
* [np.add](https://numpy.org/doc/stable/reference/generated/numpy.add.html)
* [np.arccos](https://numpy.org/doc/stable/reference/generated/numpy.arccos.html)
* [np.arccosh](https://numpy.org/doc/stable/reference/generated/numpy.arccosh.html)
* [np.arcsin](https://numpy.org/doc/stable/reference/generated/numpy.arcsin.html)
* [np.arcsinh](https://numpy.org/doc/stable/reference/generated/numpy.arcsinh.html)
* [np.arctan](https://numpy.org/doc/stable/reference/generated/numpy.arctan.html)
* [np.arctan2](https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html)
* [np.arctanh](https://numpy.org/doc/stable/reference/generated/numpy.arctanh.html)
* [np.around](https://numpy.org/doc/stable/reference/generated/numpy.around.html)
* [np.bitwise\_and](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_and.html)
* [np.bitwise\_or](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_or.html)
* [np.bitwise\_xor](https://numpy.org/doc/stable/reference/generated/numpy.bitwise\_xor.html)
* [np.broadcast\_to](https://numpy.org/doc/stable/reference/generated/numpy.broadcast\_to.html)
* [np.cbrt](https://numpy.org/doc/stable/reference/generated/numpy.cbrt.html)
* [np.ceil](https://numpy.org/doc/stable/reference/generated/numpy.ceil.html)
* [np.clip](https://numpy.org/doc/stable/reference/generated/numpy.clip.html)
* [np.concatenate](https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html)
* [np.copysign](https://numpy.org/doc/stable/reference/generated/numpy.copysign.html)
* [np.cos](https://numpy.org/doc/stable/reference/generated/numpy.cos.html)
* [np.cosh](https://numpy.org/doc/stable/reference/generated/numpy.cosh.html)
* [np.deg2rad](https://numpy.org/doc/stable/reference/generated/numpy.deg2rad.html)
* [np.degrees](https://numpy.org/doc/stable/reference/generated/numpy.degrees.html)
* [np.dot](https://numpy.org/doc/stable/reference/generated/numpy.dot.html)
* [np.equal](https://numpy.org/doc/stable/reference/generated/numpy.equal.html)
* [np.exp](https://numpy.org/doc/stable/reference/generated/numpy.exp.html)
* [np.exp2](https://numpy.org/doc/stable/reference/generated/numpy.exp2.html)
* [np.expand\_dims](https://numpy.org/doc/stable/reference/generated/numpy.expand\_dims.html)
* [np.expm1](https://numpy.org/doc/stable/reference/generated/numpy.expm1.html)
* [np.fabs](https://numpy.org/doc/stable/reference/generated/numpy.fabs.html)
* [np.float\_power](https://numpy.org/doc/stable/reference/generated/numpy.float\_power.html)
* [np.floor](https://numpy.org/doc/stable/reference/generated/numpy.floor.html)
* [np.floor\_divide](https://numpy.org/doc/stable/reference/generated/numpy.floor\_divide.html)
* [np.fmax](https://numpy.org/doc/stable/reference/generated/numpy.fmax.html)
* [np.fmin](https://numpy.org/doc/stable/reference/generated/numpy.fmin.html)
* [np.fmod](https://numpy.org/doc/stable/reference/generated/numpy.fmod.html)
* [np.gcd](https://numpy.org/doc/stable/reference/generated/numpy.gcd.html)
* [np.greater](https://numpy.org/doc/stable/reference/generated/numpy.greater.html)
* [np.greater\_equal](https://numpy.org/doc/stable/reference/generated/numpy.greater\_equal.html)
* [np.heaviside](https://numpy.org/doc/stable/reference/generated/numpy.heaviside.html)
* [np.hypot](https://numpy.org/doc/stable/reference/generated/numpy.hypot.html)
* [np.invert](https://numpy.org/doc/stable/reference/generated/numpy.invert.html)
* [np.isfinite](https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html)
* [np.isinf](https://numpy.org/doc/stable/reference/generated/numpy.isinf.html)
* [np.isnan](https://numpy.org/doc/stable/reference/generated/numpy.isnan.html)
* [np.lcm](https://numpy.org/doc/stable/reference/generated/numpy.lcm.html)
* [np.ldexp](https://numpy.org/doc/stable/reference/generated/numpy.ldexp.html)
* [np.left\_shift](https://numpy.org/doc/stable/reference/generated/numpy.left\_shift.html)
* [np.less](https://numpy.org/doc/stable/reference/generated/numpy.less.html)
* [np.less\_equal](https://numpy.org/doc/stable/reference/generated/numpy.less\_equal.html)
* [np.log](https://numpy.org/doc/stable/reference/generated/numpy.log.html)
* [np.log10](https://numpy.org/doc/stable/reference/generated/numpy.log10.html)
* [np.log1p](https://numpy.org/doc/stable/reference/generated/numpy.log1p.html)
* [np.log2](https://numpy.org/doc/stable/reference/generated/numpy.log2.html)
* [np.logaddexp](https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html)
* [np.logaddexp2](https://numpy.org/doc/stable/reference/generated/numpy.logaddexp2.html)
* [np.logical\_and](https://numpy.org/doc/stable/reference/generated/numpy.logical\_and.html)
* [np.logical\_not](https://numpy.org/doc/stable/reference/generated/numpy.logical\_not.html)
* [np.logical\_or](https://numpy.org/doc/stable/reference/generated/numpy.logical\_or.html)
* [np.logical\_xor](https://numpy.org/doc/stable/reference/generated/numpy.logical\_xor.html)
* [np.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html)
* [np.maximum](https://numpy.org/doc/stable/reference/generated/numpy.maximum.html)
* [np.minimum](https://numpy.org/doc/stable/reference/generated/numpy.minimum.html)
* [np.multiply](https://numpy.org/doc/stable/reference/generated/numpy.multiply.html)
* [np.negative](https://numpy.org/doc/stable/reference/generated/numpy.negative.html)
* [np.nextafter](https://numpy.org/doc/stable/reference/generated/numpy.nextafter.html)
* [np.not\_equal](https://numpy.org/doc/stable/reference/generated/numpy.not\_equal.html)
* [np.ones\_like](https://numpy.org/doc/stable/reference/generated/numpy.ones\_like.html)
* [np.positive](https://numpy.org/doc/stable/reference/generated/numpy.positive.html)
* [np.power](https://numpy.org/doc/stable/reference/generated/numpy.power.html)
* [np.rad2deg](https://numpy.org/doc/stable/reference/generated/numpy.rad2deg.html)
* [np.radians](https://numpy.org/doc/stable/reference/generated/numpy.radians.html)
* [np.reciprocal](https://numpy.org/doc/stable/reference/generated/numpy.reciprocal.html)
* [np.remainder](https://numpy.org/doc/stable/reference/generated/numpy.remainder.html)
* [np.reshape](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html)
* [np.right\_shift](https://numpy.org/doc/stable/reference/generated/numpy.right\_shift.html)
* [np.rint](https://numpy.org/doc/stable/reference/generated/numpy.rint.html)
* [np.round\_](https://numpy.org/doc/stable/reference/generated/numpy.round\_.html)
* [np.sign](https://numpy.org/doc/stable/reference/generated/numpy.sign.html)
* [np.signbit](https://numpy.org/doc/stable/reference/generated/numpy.signbit.html)
* [np.sin](https://numpy.org/doc/stable/reference/generated/numpy.sin.html)
* [np.sinh](https://numpy.org/doc/stable/reference/generated/numpy.sinh.html)
* [np.spacing](https://numpy.org/doc/stable/reference/generated/numpy.spacing.html)
* [np.sqrt](https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html)
* [np.square](https://numpy.org/doc/stable/reference/generated/numpy.square.html)
* [np.subtract](https://numpy.org/doc/stable/reference/generated/numpy.subtract.html)
* [np.sum](https://numpy.org/doc/stable/reference/generated/numpy.sum.html)
* [np.tan](https://numpy.org/doc/stable/reference/generated/numpy.tan.html)
* [np.tanh](https://numpy.org/doc/stable/reference/generated/numpy.tanh.html)
* [np.transpose](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html)
* [np.true\_divide](https://numpy.org/doc/stable/reference/generated/numpy.true\_divide.html)
* [np.trunc](https://numpy.org/doc/stable/reference/generated/numpy.trunc.html)
* [np.where](https://numpy.org/doc/stable/reference/generated/numpy.where.html)
* [np.zeros\_like](https://numpy.org/doc/stable/reference/generated/numpy.zeros\_like.html)
### Supported `ndarray` methods.
* [np.ndarray.astype](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.astype.html)
* [np.ndarray.clip](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.clip.html)
* [np.ndarray.dot](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dot.html)
* [np.ndarray.flatten](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flatten.html)
* [np.ndarray.reshape](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.reshape.html)
* [np.ndarray.transpose](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.transpose.html)
### Supported `ndarray` properties.
* [np.ndarray.shape](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html)
* [np.ndarray.ndim](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html)
* [np.ndarray.size](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.size.html)
* [np.ndarray.T](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.T.html)
## Limitations
### Control flow constraints.
Some Python control flow statements are not supported. For example, you cannot have an `if` statement or a `while` statement for which the condition depends on an encrypted value. However, such statements are supported with constant values (e.g., `for i in range(SOME_CONSTANT)`, `if os.environ.get("SOME_FEATURE") == "ON":`).
### Type constraints.
Another constraint is that you cannot have floating-point inputs or floating-point outputs. You can have floating-point intermediate values as long as they can be converted to an integer Table Lookup (e.g., `(60 * np.sin(x)).astype(np.int64)`).
### Bit width constraints.
There is a limit on the bit width of encrypted values. We are constantly working on increasing this bit width. If you go above the limit, you will get an error.

View File

@@ -0,0 +1,27 @@
# Exactness
One of the most common operations in **Concrete-Numpy** is `Table Lookups` (TLUs). TLUs are performed with an FHE operation called `Programmable Bootstrapping` (PBS). PBSes have a certain probability of error, which, when triggered, result in inaccurate results.
Let's say you have the table:
```python
[0, 1, 4, 9, 16, 25, 36, 49, 64]
```
And you performed a table lookup using `4`. The result you should get is `16`, but because of the possibility of error, you can sometimes get `9` or `25`. Sometimes even `4` or `36` if you have a high probability of error.
The probability of this error can be configured through the `p_error` and `global_p_error` configuration options. The difference between these two options is that, `p_error` is for individual TLUs but `global_p_error` is for the whole circuit.
Here is an example, if you set `p_error` to `0.01`, it means every TLU in the circuit will have a 1% chance of not being exact and 99% chance of being exact. If you have a single TLU in the circuit, `global_p_error` would be 1% as well. But if you have 2 TLUs for example, `global_p_error` would be almost 2% (`1 - (0.99 * 0.99)`).
However, if you set `global_p_error` to `0.01`, the whole circuit will have 1% probability of being not exact, no matter how many table lookups are there.
If you set both of them, both will be satisfied. Essentially, the stricter one will be used.
By default, both `p_error` and `global_p_error` is set to `None`, which results in `global_p_error` of `1 / 100_000` being used. Feel free to play with these configuration options to pick the one best suited for your needs! For example, in some machine learning use cases, off-by-one or off-by-two errors doesn't affect the result much, in such cases `p_error` could be set to increase performance without losing accuracy.
See [How to Configure](../howto/configure.md) to learn how you can set a custom `p_error` and/or `global_p_error`.
{% hint style="info" %}
Configuring either of those variables would affect computation time (compilation, keys generation, circuit execution) and space requirements (size of the keys on disk and in memory). Lower error probability would result in longer computation time and larger space requirements.
{% endhint %}

View File

@@ -0,0 +1,46 @@
# Installation
**Concrete-Numpy** is natively supported on Linux and macOS for Python 3.8 and 3.9, but if you have Docker support in your platform, you can use the docker image to use **Concrete-Numpy**.
## Using PyPI
You can install **Concrete-Numpy** from PyPI:
```shell
pip install -U pip wheel setuptools
pip install concrete-numpy
```
{% hint style="warning" %}
Apple silicon users must use docker installation (explained below) as there is no ARM version of some of our dependencies for the time being.
{% endhint %}
## Using Docker
You can also get the **Concrete-Numpy** docker image:
```shell
docker pull zamafhe/concrete-numpy:v1.0.0
```
### Starting a Jupyter server.
By default, the entry point of the **Concrete-Numpy** docker image is a jupyter server that you can access from your browser:
```shell
docker run --rm -it -p 8888:8888 zamafhe/concrete-numpy:v1.0.0
```
To save notebooks on host, you can use a local volume:
```shell
docker run --rm -it -p 8888:8888 -v /path/to/notebooks:/data zamafhe/concrete-numpy:v1.0.0
```
### Starting a Bash session.
Alternatively, you can launch a Bash session:
```shell
docker run --rm -it zamafhe/concrete-numpy:latest /bin/bash
```

View File

@@ -0,0 +1,104 @@
# Performance
The most important operation in Concrete-Numpy is the table lookup operation. All operations except addition, subtraction, multiplication with non-encrypted values, and a few operations built with those primitive operations (e.g. matmul, conv) are converted to table lookups under the hood:
```python
import concrete.numpy as cnp
@cnp.compiler({"x": "encrypted"})
def f(x):
return x ** 2
inputset = range(2 ** 4)
circuit = f.compile(inputset)
```
is exactly the same as
```python
import concrete.numpy as cnp
table = cnp.LookupTable([x ** 2 for x in range(2 ** 4)])
@cnp.compiler({"x": "encrypted"})
def f(x):
return table[x]
inputset = range(2 ** 4)
circuit = f.compile(inputset)
```
Table lookups are very flexible, and they allow Concrete Numpy to support many operations, but they are expensive! Therefore, you should try to avoid them as much as possible. In most cases, it's not possible to avoid them completely, but you might remove the number of TLUs or replace some of them with other primitive operations.
The exact cost depend on many variables (machine configuration, error probability, etc.), but you can develop some intuition for single threaded CPU execution performance using:
```python
import time
import concrete.numpy as cnp
import numpy as np
WARMUP = 3
SAMPLES = 8
BITWIDTHS = range(1, 15)
CONFIGURATION = cnp.Configuration(
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location=".keys",
)
timings = {}
for n in BITWIDTHS:
@cnp.compiler({"x": "encrypted"})
def base(x):
return x
table = cnp.LookupTable([np.sqrt(x).round().astype(np.int64) for x in range(2 ** n)])
@cnp.compiler({"x": "encrypted"})
def tlu(x):
return table[x]
inputset = [0, 2**n - 1]
base_circuit = base.compile(inputset, CONFIGURATION)
tlu_circuit = tlu.compile(inputset, CONFIGURATION)
print()
print(f"Generating keys for n={n}...")
base_circuit.keygen()
tlu_circuit.keygen()
timings[n] = []
for i in range(SAMPLES + WARMUP):
sample = np.random.randint(0, 2 ** n)
encrypted_sample = base_circuit.encrypt(sample)
start = time.time()
encrypted_result = base_circuit.run(encrypted_sample)
end = time.time()
assert base_circuit.decrypt(encrypted_result) == sample
base_time = end - start
encrypted_sample = tlu_circuit.encrypt(sample)
start = time.time()
encrypted_result = tlu_circuit.run(encrypted_sample)
end = time.time()
assert tlu_circuit.decrypt(encrypted_result) == np.sqrt(sample).round().astype(np.int64)
tlu_time = end - start
if i >= WARMUP:
timings[n].append(tlu_time - base_time)
print(f"Sample #{i - WARMUP + 1} took {timings[n][-1] * 1000:.3f}ms")
print()
for n, times in timings.items():
print(f"{n}-bits -> {np.mean(times) * 1000:.3f}ms")
```
{% hint style="info" %}
Concrete Numpy automatically parallelize execution if TLUs are applied to tensors.
{% endhint %}

View File

@@ -0,0 +1,90 @@
# Quick Start
To compute on encrypted data, you first need to define the function that you want to compute, then compile it into a Concrete-Numpy `Circuit`, which you can use to perform homomorphic evaluation.
Here is the full example that we will walk through:
```python
import concrete.numpy as cnp
def add(x, y):
return x + y
compiler = cnp.Compiler(add, {"x": "encrypted", "y": "clear"})
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1)]
circuit = compiler.compile(inputset)
x = 4
y = 4
clear_evaluation = add(x, y)
homomorphic_evaluation = circuit.encrypt_run_decrypt(x, y)
print(x, "+", y, "=", clear_evaluation, "=", homomorphic_evaluation)
```
## Importing the library
Everything you need to perform homomorphic evaluation is included in a single module:
<!--pytest-codeblocks:skip-->
```python
import concrete.numpy as cnp
```
## Defining the function to compile
In this example, we will compile a simple addition function:
<!--pytest-codeblocks:skip-->
```python
def add(x, y):
return x + y
```
## Creating a compiler
To compile the function, you need to create a `Compiler` by specifying the function to compile and encryption status of its inputs:
<!--pytest-codeblocks:skip-->
```python
compiler = cnp.Compiler(add, {"x": "encrypted", "y": "clear"})
```
## Defining an inputset
An inputset is a collection representing the typical inputs to the function. It is used to determine the bit widths and shapes of the variables within the function.
It should be an iterable, yielding tuples of the same length as the number of arguments of the function being compiled:
<!--pytest-codeblocks:skip-->
```python
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1)]
```
{% hint style="warning" %}
All inputs in the inputset will be evaluated in the graph, which takes time. If you're experiencing long compilation times, consider providing a smaller inputset.
{% endhint %}
## Compiling the function
You can use the `compile` method of a `Compiler` class with an inputset to perform the compilation and get the resulting circuit back:
<!--pytest-codeblocks:skip-->
```python
circuit = compiler.compile(inputset)
```
## Performing homomorphic evaluation
You can use the `encrypt_run_decrypt` method of a `Circuit` class to perform homomorphic evaluation:
<!--pytest-codeblocks:skip-->
```python
homomorphic_evaluation = circuit.encrypt_run_decrypt(4, 4)
```
{% hint style="info" %}
`circuit.encrypt_run_decrypt(*args)` is just a convenient way to do everything at once. It is implemented as `circuit.decrypt(circuit.run(circuit.encrypt(*args)))`.
{% endhint %}

Some files were not shown because too many files have changed in this diff Show More