mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
chore: Integrate concrete-compiler to the mono-repo
This commit is contained in:
55
compilers/concrete-compiler/.gitignore
vendored
Normal file
55
compilers/concrete-compiler/.gitignore
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
## C++
|
||||
|
||||
# Prerequisites
|
||||
*.d
|
||||
|
||||
# Compiled Object files
|
||||
*.slo
|
||||
*.lo
|
||||
*.o
|
||||
*.obj
|
||||
|
||||
# Precompiled Headers
|
||||
*.gch
|
||||
*.pch
|
||||
|
||||
# Compiled Dynamic libraries
|
||||
*.so
|
||||
*.dylib
|
||||
*.dll
|
||||
|
||||
# Fortran module files
|
||||
*.mod
|
||||
*.smod
|
||||
|
||||
# Compiled Static libraries
|
||||
*.lai
|
||||
*.la
|
||||
*.a
|
||||
*.lib
|
||||
|
||||
# Executables
|
||||
*.exe
|
||||
*.out
|
||||
*.app
|
||||
|
||||
# VSCODE
|
||||
.vscode/
|
||||
|
||||
# Jetbrains tools
|
||||
.idea/
|
||||
|
||||
# Python cache
|
||||
__pycache__/
|
||||
|
||||
# Sphinx
|
||||
_build/
|
||||
.venv
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
|
||||
compiler/tests/TestLib/out/
|
||||
compiler/lib/Bindings/Rust/target/
|
||||
compiler/lib/Bindings/Rust/Cargo.lock
|
||||
1
compilers/concrete-compiler/README.md
Symbolic link
1
compilers/concrete-compiler/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
./compiler/README.md
|
||||
@@ -0,0 +1,49 @@
|
||||
FROM quay.io/pypa/manylinux_2_28_x86_64:2022-11-19-1b19e81
|
||||
|
||||
# epel-release is for install ccache
|
||||
# clang is needed for rust bindings
|
||||
RUN dnf install -y epel-release
|
||||
RUN dnf update -y
|
||||
RUN dnf install -y ninja-build hwloc-devel ccache clang ncurses-devel
|
||||
RUN dnf install -y openssh-clients
|
||||
RUN dnf clean all
|
||||
RUN mkdir -p ~/.ssh/ && ssh-keyscan -t ecdsa github.com >> ~/.ssh/known_hosts
|
||||
# setup ccache with an unlimited amount of files and storage
|
||||
RUN ccache -M 0
|
||||
RUN ccache -F 0
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
ENV PATH=/root/.cargo/bin:$PATH
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
# Install boost
|
||||
ADD https://boostorg.jfrog.io/artifactory/main/release/1.71.0/source/boost_1_71_0.tar.gz /boost_1_71_0.tar.gz
|
||||
RUN tar -xzvf /boost_1_71_0.tar.gz
|
||||
WORKDIR /boost_1_71_0
|
||||
RUN ./bootstrap.sh && ./b2 --with-filesystem install
|
||||
# Setup HPX
|
||||
COPY --from=ghcr.io/zama-ai/hpx:latest /hpx /hpx
|
||||
ENV HPX_INSTALL_DIR=/hpx/build
|
||||
# Setup CUDA
|
||||
COPY --from=ghcr.io/zama-ai/cuda:11-7 /usr/local/cuda-11.7/ /usr/local/cuda-11.7/
|
||||
COPY --from=ghcr.io/zama-ai/cuda:11-7 /usr/lib64/libcuda.so* /usr/lib64/
|
||||
ENV PATH "$PATH:/usr/local/cuda-11.7/bin"
|
||||
# Set the python path. Options: [cp37-cp37m, cp38-cp38, cp39-cp39, cp310-cp310]
|
||||
# Links and env would be available to use the appropriate python version
|
||||
ARG python_tag=cp38-cp38
|
||||
RUN ln -s /opt/python/${python_tag}/bin/pip /bin/pip
|
||||
RUN ln -s /opt/python/${python_tag}/bin/python /bin/python
|
||||
ENV PYTHON_EXEC=/opt/python/${python_tag}/bin/python
|
||||
# Install python deps
|
||||
RUN pip install numpy pybind11==2.8 PyYAML pytest wheel auditwheel
|
||||
# Setup LLVM
|
||||
COPY /llvm-project /llvm-project
|
||||
# Setup and build compiler
|
||||
COPY /compiler /compiler
|
||||
WORKDIR /compiler
|
||||
RUN mkdir -p /build
|
||||
RUN --mount=type=ssh make DATAFLOW_EXECUTION_ENABLED=ON BUILD_DIR=/build CCACHE=ON \
|
||||
Python3_EXECUTABLE=${PYTHON_EXEC} \
|
||||
concretecompiler python-bindings rust-bindings
|
||||
ENV PYTHONPATH "$PYTHONPATH:/build/tools/concretelang/python_packages/concretelang_core"
|
||||
ENV PATH "$PATH:/build/bin"
|
||||
RUN ccache -z
|
||||
14
compilers/concrete-compiler/builders/Dockerfile.cuda-env
Normal file
14
compilers/concrete-compiler/builders/Dockerfile.cuda-env
Normal file
@@ -0,0 +1,14 @@
|
||||
FROM quay.io/pypa/manylinux_2_28_x86_64:2022-11-19-1b19e81
|
||||
|
||||
RUN dnf install -y kernel-devel kernel-headers
|
||||
RUN curl https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda-repo-rhel8-11-7-local-11.7.1_515.65.01-1.x86_64.rpm -o cuda-repo-rhel8-11-7-local-11.7.1_515.65.01-1.x86_64.rpm
|
||||
RUN rpm -i cuda-repo-rhel8-11-7-local-11.7.1_515.65.01-1.x86_64.rpm
|
||||
RUN dnf clean all
|
||||
RUN dnf install -y epel-release
|
||||
RUN dnf update -y
|
||||
RUN dnf -y module install nvidia-driver:latest-dkms
|
||||
RUN dnf -y install cuda
|
||||
|
||||
FROM scratch
|
||||
COPY --from=0 /usr/local/cuda-11.7/ /usr/local/cuda-11.7/
|
||||
COPY --from=0 /usr/lib64/libcuda.so* /usr/lib64/
|
||||
25
compilers/concrete-compiler/builders/Dockerfile.hpx-env
Normal file
25
compilers/concrete-compiler/builders/Dockerfile.hpx-env
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM quay.io/pypa/manylinux_2_28_x86_64:2022-11-19-1b19e81
|
||||
|
||||
RUN dnf update -y
|
||||
RUN dnf install -y ninja-build hwloc-devel
|
||||
# Install boost
|
||||
ADD https://boostorg.jfrog.io/artifactory/main/release/1.71.0/source/boost_1_71_0.tar.gz /boost_1_71_0.tar.gz
|
||||
RUN tar -xzvf /boost_1_71_0.tar.gz
|
||||
WORKDIR /boost_1_71_0
|
||||
RUN ./bootstrap.sh && ./b2 --with-filesystem install
|
||||
# Build HPX
|
||||
RUN git clone https://github.com/STEllAR-GROUP/hpx.git /hpx
|
||||
WORKDIR /hpx
|
||||
RUN git checkout 1.7.1
|
||||
RUN mkdir build
|
||||
# empty HPX_WITH_MAX_CPU_COUNT = dynamic
|
||||
# ref https://github.com/STEllAR-GROUP/hpx/blob/1.7.1/CMakeLists.txt#L759
|
||||
RUN cd build && cmake \
|
||||
-DHPX_WITH_MAX_CPU_COUNT="" \
|
||||
-DHPX_WITH_FETCH_ASIO=on \
|
||||
-DHPX_FILESYSTEM_WITH_BOOST_FILESYSTEM_COMPATIBILITY=ON \
|
||||
-DHPX_WITH_MALLOC=system ..
|
||||
RUN cd build && make -j2
|
||||
|
||||
FROM scratch
|
||||
COPY --from=0 /hpx/ /hpx/
|
||||
@@ -0,0 +1,2 @@
|
||||
FROM alpine:latest
|
||||
COPY KeySetCache /KeySetCache
|
||||
103
compilers/concrete-compiler/ci/benchmark_parser.py
Normal file
103
compilers/concrete-compiler/ci/benchmark_parser.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
benchmark_parser
|
||||
----------------
|
||||
|
||||
Parse benchmark raw results.
|
||||
"""
|
||||
import argparse
|
||||
import pathlib
|
||||
import json
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('results_path',
|
||||
help=('Location of raw benchmark results,'
|
||||
' could be either a file or a directory.'
|
||||
'In a case of a directory, this script will attempt to parse all the'
|
||||
'files containing a .json extension'))
|
||||
parser.add_argument('output_file', help='File storing parsed results')
|
||||
parser.add_argument('-d', '--database', dest='database', required=True,
|
||||
help='Name of the database used to store results')
|
||||
parser.add_argument('-w', '--hardware', dest='hardware', required=True,
|
||||
help='Hardware reference used to perform benchmark')
|
||||
parser.add_argument('-V', '--project-version', dest='project_version', required=True,
|
||||
help='Commit hash reference')
|
||||
parser.add_argument('-b', '--branch', dest='branch', required=True,
|
||||
help='Git branch name on which benchmark was performed')
|
||||
parser.add_argument('--commit-date', dest='commit_date', required=True,
|
||||
help='Timestamp of commit hash used in project_version')
|
||||
parser.add_argument('--bench-date', dest='bench_date', required=True,
|
||||
help='Timestamp when benchmark was run')
|
||||
|
||||
|
||||
def parse_results(raw_results):
|
||||
"""
|
||||
Parse raw benchmark results.
|
||||
|
||||
:param raw_results: path to file that contains raw results as :class:`pathlib.Path`
|
||||
|
||||
:return: :class:`list` of data points
|
||||
"""
|
||||
raw_results = json.loads(raw_results.read_text())
|
||||
return [
|
||||
{"value": res["cpu_time"], "test": res["name"]}
|
||||
for res in raw_results["benchmarks"]
|
||||
]
|
||||
|
||||
|
||||
def recursive_parse(directory):
|
||||
"""
|
||||
Parse all the benchmark results in a directory. It will attempt to parse all the files having a
|
||||
.json extension at the top-level of this directory.
|
||||
|
||||
:param directory: path to directory that contains raw results as :class:`pathlib.Path`
|
||||
|
||||
:return: :class:`list` of data points
|
||||
"""
|
||||
result_values = []
|
||||
for file in directory.glob('*.json'):
|
||||
try:
|
||||
result_values.extend(parse_results(file))
|
||||
except KeyError as err:
|
||||
print(f"Failed to parse '{file.resolve()}': {repr(err)}")
|
||||
|
||||
return result_values
|
||||
|
||||
|
||||
def dump_results(parsed_results, filename, input_args):
|
||||
"""
|
||||
Dump parsed results formatted as JSON to file.
|
||||
|
||||
:param parsed_results: :class:`list` of data points
|
||||
:param filename: filename for dump file as :class:`pathlib.Path`
|
||||
:param input_args: CLI input arguments
|
||||
"""
|
||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
series = {
|
||||
"database": input_args.database,
|
||||
"hardware": input_args.hardware,
|
||||
"project_version": input_args.project_version,
|
||||
"branch": input_args.branch,
|
||||
"insert_date": input_args.bench_date,
|
||||
"commit_date": input_args.commit_date,
|
||||
"points": parsed_results,
|
||||
}
|
||||
filename.write_text(json.dumps(series))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
results_path = pathlib.Path(args.results_path)
|
||||
print("Parsing benchmark results... ")
|
||||
if results_path.is_dir():
|
||||
results = recursive_parse(results_path)
|
||||
else:
|
||||
results = parse_results(results_path)
|
||||
print("Parsing results done")
|
||||
|
||||
output_file = pathlib.Path(args.output_file)
|
||||
print(f"Dump parsed results into '{output_file.resolve()}' ... ", end="")
|
||||
dump_results(results, output_file, args)
|
||||
|
||||
print("Done")
|
||||
45
compilers/concrete-compiler/ci/slab.toml
Normal file
45
compilers/concrete-compiler/ci/slab.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
[profile.m6i]
|
||||
region = "eu-west-3"
|
||||
image_id = "ami-0a24aaee029d1295c" # Based on Ubuntu 22.4
|
||||
instance_type = "m6i.metal"
|
||||
subnet_id = "subnet-a886b4c1"
|
||||
security_group= ["sg-0bf1c1d79c97bc88f", ]
|
||||
|
||||
[profile.m6i-old]
|
||||
region = "eu-west-3"
|
||||
image_id = "ami-05e4c0e628378ad6d" # Based on Ubuntu 20.4
|
||||
instance_type = "m6i.metal"
|
||||
subnet_id = "subnet-a886b4c1"
|
||||
security_group= ["sg-0bf1c1d79c97bc88f", ]
|
||||
|
||||
[profile.gpu]
|
||||
region = "us-east-1"
|
||||
image_id = "ami-0c4773f5626d919b6"
|
||||
instance_type = "p3.2xlarge"
|
||||
subnet_id = "subnet-8123c9e7"
|
||||
security_group= ["sg-0f8b52622a2669491", ]
|
||||
|
||||
# Trigger CPU build
|
||||
[command.cpu-build]
|
||||
workflow = "aws_build_cpu.yml"
|
||||
profile = "m6i-old"
|
||||
check_run_name = "AWS CPU build (Slab)"
|
||||
|
||||
# Trigger GPU build
|
||||
[command.gpu-build]
|
||||
workflow = "aws_build_gpu.yml"
|
||||
profile = "gpu"
|
||||
check_run_name = "AWS GPU build (Slab)"
|
||||
|
||||
# Trigger Docker images build
|
||||
[command.docker-images-build]
|
||||
workflow = "publish_docker_images.yml"
|
||||
profile = "m6i-old"
|
||||
check_run_name = "AWS Docker images build & publish (Slab)"
|
||||
|
||||
# Trigger ML benchmarks by running each use cases subset in parallel.
|
||||
[command.ml-bench]
|
||||
workflow = "ml_benchmark_subset.yml"
|
||||
profile = "m6i"
|
||||
matrix = [0,1,2,3,4,5,6,7,8,9,10]
|
||||
max_parallel_jobs = 2
|
||||
1
compilers/concrete-compiler/compiler/.clang-format
Normal file
1
compilers/concrete-compiler/compiler/.clang-format
Normal file
@@ -0,0 +1 @@
|
||||
BasedOnStyle: LLVM
|
||||
18
compilers/concrete-compiler/compiler/.clang-tidy
Normal file
18
compilers/concrete-compiler/compiler/.clang-tidy
Normal file
@@ -0,0 +1,18 @@
|
||||
Checks: '-*,clang-diagnostic-*,llvm-*,misc-*,-misc-unused-parameters,-misc-non-private-member-variables-in-classes,readability-identifier-naming'
|
||||
CheckOptions:
|
||||
- key: readability-identifier-naming.ClassCase
|
||||
value: CamelCase
|
||||
- key: readability-identifier-naming.EnumCase
|
||||
value: CamelCase
|
||||
- key: readability-identifier-naming.FunctionCase
|
||||
value: camelBack
|
||||
- key: readability-identifier-naming.MemberCase
|
||||
value: camelBack
|
||||
- key: readability-identifier-naming.ParameterCase
|
||||
value: camelBack
|
||||
- key: readability-identifier-naming.UnionCase
|
||||
value: CamelCase
|
||||
- key: readability-identifier-naming.VariableCase
|
||||
value: camelBack
|
||||
- key: readability-identifier-naming.IgnoreMainLikeFunctions
|
||||
value: 1
|
||||
11
compilers/concrete-compiler/compiler/.cmake-format-config.py
Normal file
11
compilers/concrete-compiler/compiler/.cmake-format-config.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# -----------------------------
|
||||
# Options effecting formatting.
|
||||
# -----------------------------
|
||||
with section("format"):
|
||||
|
||||
# How wide to allow formatted cmake files
|
||||
line_width = 120
|
||||
|
||||
# How many spaces to tab for indent
|
||||
tab_size = 2
|
||||
|
||||
13
compilers/concrete-compiler/compiler/.gitignore
vendored
Normal file
13
compilers/concrete-compiler/compiler/.gitignore
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
# Build dirs
|
||||
build*/
|
||||
|
||||
*.mlir.script
|
||||
*.lit_test_times.txt
|
||||
|
||||
# Test-generated artifacts
|
||||
concrete-compiler_compilation_artifacts/
|
||||
py_test_lib_compile_and_run_custom_perror/
|
||||
tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table.yaml
|
||||
tests/end_to_end_fixture/end_to_end_linalg_leveled.yaml
|
||||
tests/end_to_end_fixture/end_to_end_linalg_2_apply_lookup_table.yaml
|
||||
tests/end_to_end_fixture/bug_report.yaml
|
||||
162
compilers/concrete-compiler/compiler/CMakeLists.txt
Normal file
162
compilers/concrete-compiler/compiler/CMakeLists.txt
Normal file
@@ -0,0 +1,162 @@
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
project(concretecompiler LANGUAGES C CXX)
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# Needed on linux with clang 15 and on MacOS because cxx emits dollars in the optimizer C++ API
|
||||
add_definitions("-Wno-dollar-in-identifier-extension")
|
||||
|
||||
add_definitions("-Wall ")
|
||||
add_definitions("-Werror ")
|
||||
add_definitions("-Wfatal-errors")
|
||||
|
||||
# If we are trying to build the compiler with LLVM/MLIR as libraries
|
||||
if(NOT DEFINED LLVM_EXTERNAL_CONCRETELANG_SOURCE_DIR)
|
||||
message(FATAL_ERROR "Concrete compiler requires a unified build with LLVM/MLIR")
|
||||
endif()
|
||||
|
||||
# CMake library generation settings.
|
||||
set(BUILD_SHARED_LIBS
|
||||
OFF
|
||||
CACHE BOOL "Default to building a static mondo-lib")
|
||||
set(CMAKE_PLATFORM_NO_VERSIONED_SONAME
|
||||
ON
|
||||
CACHE BOOL "Python soname linked libraries are bad")
|
||||
set(CMAKE_VISIBILITY_INLINES_HIDDEN
|
||||
ON
|
||||
CACHE BOOL "Hide inlines")
|
||||
|
||||
# The -fvisibility=hidden option only works for static builds.
|
||||
if(BUILD_SHARED_LIBS AND (CMAKE_CXX_VISIBILITY_PRESET STREQUAL "hidden"))
|
||||
message(FATAL_ERROR "CMAKE_CXX_VISIBILITY_PRESET=hidden is incompatible \
|
||||
with BUILD_SHARED_LIBS.")
|
||||
endif()
|
||||
|
||||
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) # --src-root
|
||||
set(MLIR_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include) # --includedir
|
||||
set(MLIR_TABLEGEN_OUTPUT_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
|
||||
set(MLIR_TABLEGEN_EXE $<TARGET_FILE:mlir-tblgen>)
|
||||
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
|
||||
include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MLIR_MAIN_SRC_DIR}/cmake/modules")
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_BINARY_DIR}/include)
|
||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
|
||||
# Custom doc generation function
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
|
||||
include(AddConcretelangDoc)
|
||||
set(CONCRETELANG_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Concrete Security curves Configuration
|
||||
# -------------------------------------------------------------------------------
|
||||
include_directories(${PROJECT_SOURCE_DIR}/parameter-curves/concrete-security-curves-cpp/include)
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Concrete CPU Configuration
|
||||
# -------------------------------------------------------------------------------
|
||||
set(CONCRETE_CPU_STATIC_LIB "${PROJECT_SOURCE_DIR}/concrete-cpu/target/release/libconcrete_cpu.a")
|
||||
ExternalProject_Add(
|
||||
concrete_cpu_rust
|
||||
DOWNLOAD_COMMAND ""
|
||||
CONFIGURE_COMMAND "" OUTPUT "${CONCRETE_CPU_STATIC_LIB}"
|
||||
BUILD_COMMAND cargo build
|
||||
COMMAND cargo build --release
|
||||
BINARY_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu"
|
||||
INSTALL_COMMAND ""
|
||||
LOG_BUILD ON)
|
||||
add_library(concrete_cpu STATIC IMPORTED)
|
||||
# TODO - Change that to a location in the release dir
|
||||
set(CONCRETE_CPU_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu/concrete-cpu")
|
||||
set_target_properties(concrete_cpu PROPERTIES IMPORTED_LOCATION "${CONCRETE_CPU_STATIC_LIB}")
|
||||
add_dependencies(concrete_cpu concrete_cpu_rust)
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Concrete Cuda Configuration
|
||||
# --------------------------------------------------------------------------------
|
||||
option(CONCRETELANG_CUDA_SUPPORT "Support Concrete CUDA Execution." OFF)
|
||||
|
||||
if(CONCRETELANG_CUDA_SUPPORT)
|
||||
if(NOT DEFINED CONCRETE_CORE_PATH)
|
||||
message(FATAL_ERROR "Compiling with CUDA support requires setting CONCRETE_CORE_PATH")
|
||||
endif()
|
||||
remove_definitions("-Werror ")
|
||||
message(STATUS "Building with Concrete CUDA execution support")
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
message(STATUS "Found CUDA version: ${CUDAToolkit_VERSION}")
|
||||
message(STATUS "Found CUDA library dir: ${CUDAToolkit_LIBRARY_DIR}")
|
||||
link_directories(${CUDAToolkit_LIBRARY_DIR})
|
||||
add_subdirectory(${CONCRETE_CORE_PATH}/concrete-cuda/cuda)
|
||||
include_directories(${CONCRETE_CORE_PATH}/concrete-cuda/cuda/include)
|
||||
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
||||
add_compile_options(-DCONCRETELANG_CUDA_SUPPORT)
|
||||
endif()
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Python Configuration
|
||||
# -------------------------------------------------------------------------------
|
||||
option(CONCRETELANG_BINDINGS_PYTHON_ENABLED "Enables ConcreteLang Python bindings." ON)
|
||||
|
||||
if(CONCRETELANG_BINDINGS_PYTHON_ENABLED)
|
||||
message(STATUS "ConcreteLang Python bindings are enabled.")
|
||||
|
||||
include(MLIRDetectPythonEnv)
|
||||
mlir_configure_python_dev_packages()
|
||||
set(CONCRETELANG_PYTHON_PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/python_packages)
|
||||
else()
|
||||
message(STATUS "ConcreteLang Python bindings are disabled.")
|
||||
endif()
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# DFR - parallel execution configuration
|
||||
# -------------------------------------------------------------------------------
|
||||
option(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED "Enables dataflow execution for ConcreteLang." ON)
|
||||
option(CONCRETELANG_TIMING_ENABLED "Enables execution timing." ON)
|
||||
|
||||
if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED)
|
||||
message(STATUS "ConcreteLang dataflow execution enabled.")
|
||||
|
||||
find_package(HPX REQUIRED CONFIG)
|
||||
list(APPEND CMAKE_MODULE_PATH "${HPX_CMAKE_DIR}")
|
||||
add_compile_options(-DCONCRETELANG_DATAFLOW_EXECUTION_ENABLED
|
||||
-DHPX_DEFAULT_CONFIG_FILE="${PROJECT_SOURCE_DIR}/hpx.ini")
|
||||
|
||||
else()
|
||||
message(STATUS "ConcreteLang dataflow execution disabled.")
|
||||
endif()
|
||||
|
||||
if(CONCRETELANG_TIMING_ENABLED)
|
||||
add_compile_options(-DCONCRETELANG_TIMING_ENABLED)
|
||||
else()
|
||||
message(STATUS "ConcreteLang execution timing disabled.")
|
||||
endif()
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Unit tests
|
||||
# -------------------------------------------------------------------------------
|
||||
option(CONCRETELANG_UNIT_TESTS "Enables the build of unittests" ON)
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Benchmarks
|
||||
# -------------------------------------------------------------------------------
|
||||
option(CONCRETELANG_BENCHMARK "Enables the build of benchmarks" ON)
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Handling sub dirs
|
||||
# -------------------------------------------------------------------------------
|
||||
include_directories(${CONCRETE_OPTIMIZER_DIR}/concrete-optimizer-cpp/src/cpp)
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(src)
|
||||
add_subdirectory(tests)
|
||||
|
||||
add_subdirectory(${CONCRETE_OPTIMIZER_DIR}/concrete-optimizer-cpp/cmake-utils)
|
||||
523
compilers/concrete-compiler/compiler/Makefile
Normal file
523
compilers/concrete-compiler/compiler/Makefile
Normal file
@@ -0,0 +1,523 @@
|
||||
BUILD_TYPE?=Release
|
||||
BUILD_DIR?=./build
|
||||
Python3_EXECUTABLE?=$(shell which python3)
|
||||
BINDINGS_PYTHON_ENABLED=ON
|
||||
DATAFLOW_EXECUTION_ENABLED=OFF
|
||||
TIMING_ENABLED=OFF
|
||||
CC_COMPILER=
|
||||
CXX_COMPILER=
|
||||
CUDA_SUPPORT?=OFF
|
||||
CONCRETE_CORE_PATH?= $(shell pwd)/concrete-core
|
||||
INSTALL_PREFIX?=$(abspath $(BUILD_DIR))/install
|
||||
INSTALL_PATH=$(abspath $(INSTALL_PREFIX))/concretecompiler/
|
||||
MAKEFILE_ROOT_DIR=$(shell pwd)
|
||||
|
||||
CONCRETE_OPTIMIZER_DIR ?= $(shell pwd)/concrete-optimizer
|
||||
|
||||
KEYSETCACHEDEV=/tmp/KeySetCache
|
||||
KEYSETCACHECI ?= ../KeySetCache
|
||||
KEYSETCACHENAME ?= KeySetCacheV4
|
||||
|
||||
HPX_VERSION?=1.7.1
|
||||
HPX_URL=https://github.com/STEllAR-GROUP/hpx/archive/refs/tags/$(HPX_VERSION).tar.gz
|
||||
HPX_TARBALL=$(shell pwd)/hpx-$(HPX_VERSION).tar.gz
|
||||
HPX_LOCAL_DIR=$(shell pwd)/hpx-$(HPX_VERSION)
|
||||
HPX_INSTALL_DIR?=$(HPX_LOCAL_DIR)/build
|
||||
|
||||
ML_BENCH_SUBSET_ID=
|
||||
|
||||
# Find OS
|
||||
OS=undefined
|
||||
ifeq ($(shell uname), Linux)
|
||||
OS=linux
|
||||
else ifeq ($(shell uname), Darwin)
|
||||
OS=darwin
|
||||
endif
|
||||
|
||||
# Setup find arguments for MacOS
|
||||
ifeq ($(OS), darwin)
|
||||
FIND_EXECUTABLE_ARG=-perm +111
|
||||
else
|
||||
FIND_EXECUTABLE_ARG=-executable
|
||||
endif
|
||||
|
||||
ARCHITECTURE=undefined
|
||||
ifeq ($(shell uname -m), arm64)
|
||||
ARCHITECTURE=aarch64
|
||||
else
|
||||
ARCHITECTURE=amd64
|
||||
endif
|
||||
|
||||
export PATH := $(abspath $(BUILD_DIR))/bin:$(PATH)
|
||||
|
||||
ifeq ($(shell which ccache),)
|
||||
CCACHE=OFF
|
||||
else
|
||||
CCACHE=ON
|
||||
endif
|
||||
|
||||
ifeq ($(CCACHE),ON)
|
||||
CMAKE_CCACHE_OPTIONS=-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
else
|
||||
CMAKE_CCACHE_OPTIONS=
|
||||
endif
|
||||
|
||||
ifneq ($(CC_COMPILER),)
|
||||
CC_COMPILER_OPTION=-DCMAKE_C_COMPILER=$(CC_COMPILER)
|
||||
else
|
||||
CC_COMPILER_OPTION=
|
||||
endif
|
||||
|
||||
ifneq ($(CXX_COMPILER),)
|
||||
CXX_COMPILER_OPTION=-DCMAKE_CXX_COMPILER=$(CXX_COMPILER)
|
||||
else
|
||||
CXX_COMPILER_OPTION=
|
||||
endif
|
||||
|
||||
# don't run parallel python tests if compiler doesn't support it
|
||||
ifeq ($(DATAFLOW_EXECUTION_ENABLED),ON)
|
||||
PYTHON_TESTS_MARKER=""
|
||||
else
|
||||
PYTHON_TESTS_MARKER="not parallel"
|
||||
endif
|
||||
|
||||
all: concretecompiler python-bindings build-tests build-benchmarks build-mlbench doc rust-bindings
|
||||
|
||||
# concrete-optimizer ######################################
|
||||
|
||||
LIB_CONCRETE_OPTIMIZER_CPP = $(CONCRETE_OPTIMIZER_DIR)/target/libconcrete_optimizer_cpp.a
|
||||
|
||||
concrete-optimizer-lib:
|
||||
make -C $(CONCRETE_OPTIMIZER_DIR)/concrete-optimizer-cpp $(LIB_CONCRETE_OPTIMIZER_CPP)
|
||||
|
||||
# HPX #####################################################
|
||||
|
||||
install-hpx-from-source: $(HPX_LOCAL_DIR)
|
||||
mkdir -p $(HPX_LOCAL_DIR)/build
|
||||
cd $(HPX_LOCAL_DIR)/build && cmake \
|
||||
-DHPX_WITH_MAX_CPU_COUNT="" \
|
||||
-DHPX_WITH_FETCH_ASIO=on \
|
||||
-DHPX_FILESYSTEM_WITH_BOOST_FILESYSTEM_COMPATIBILITY=ON \
|
||||
-DHPX_WITH_MALLOC=system ..
|
||||
cd $(HPX_LOCAL_DIR)/build && make -j2
|
||||
|
||||
$(HPX_TARBALL):
|
||||
curl -L $(HPX_URL) -o $(HPX_TARBALL)
|
||||
|
||||
$(HPX_LOCAL_DIR): $(HPX_TARBALL)
|
||||
tar xzvf $(HPX_TARBALL)
|
||||
|
||||
$(BUILD_DIR)/configured.stamp:
|
||||
mkdir -p $(BUILD_DIR)
|
||||
cmake -B $(BUILD_DIR) -GNinja ../llvm-project/llvm/ \
|
||||
$(CMAKE_CCACHE_OPTIONS) \
|
||||
$(CC_COMPILER_OPTION) \
|
||||
$(CXX_COMPILER_OPTION) \
|
||||
-DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \
|
||||
-DLLVM_BUILD_EXAMPLES=OFF \
|
||||
-DLLVM_TARGETS_TO_BUILD="host" \
|
||||
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=$(BINDINGS_PYTHON_ENABLED) \
|
||||
-DCONCRETELANG_BINDINGS_PYTHON_ENABLED=$(BINDINGS_PYTHON_ENABLED) \
|
||||
-DCONCRETELANG_DATAFLOW_EXECUTION_ENABLED=$(DATAFLOW_EXECUTION_ENABLED) \
|
||||
-DCONCRETELANG_TIMING_ENABLED=$(TIMING_ENABLED) \
|
||||
-DHPX_DIR=${HPX_INSTALL_DIR}/lib/cmake/HPX \
|
||||
-DLLVM_EXTERNAL_PROJECTS=concretelang \
|
||||
-DLLVM_EXTERNAL_CONCRETELANG_SOURCE_DIR=. \
|
||||
-DPython3_EXECUTABLE=${Python3_EXECUTABLE} \
|
||||
-DCONCRETE_OPTIMIZER_DIR=${CONCRETE_OPTIMIZER_DIR} \
|
||||
-DCONCRETE_CORE_PATH=$(CONCRETE_CORE_PATH) \
|
||||
-DCONCRETELANG_CUDA_SUPPORT=${CUDA_SUPPORT} \
|
||||
-DCUDAToolkit_ROOT=$(CUDA_PATH)
|
||||
touch $@
|
||||
|
||||
build-initialized: concrete-optimizer-lib $(BUILD_DIR)/configured.stamp
|
||||
|
||||
doc: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target mlir-doc
|
||||
|
||||
concretecompiler: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target concretecompiler
|
||||
|
||||
python-bindings: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangPythonModules
|
||||
|
||||
rust-bindings: install
|
||||
cd lib/Bindings/Rust && \
|
||||
CONCRETE_COMPILER_INSTALL_DIR=$(INSTALL_PATH) \
|
||||
cargo build --release
|
||||
|
||||
CAPI:
|
||||
cmake --build $(BUILD_DIR) --target CONCRETELANGCAPIFHE CONCRETELANGCAPIFHELINALG CONCRETELANGCAPISupport
|
||||
|
||||
clientlib: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangClientLib
|
||||
|
||||
serverlib: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangServerLib
|
||||
|
||||
|
||||
|
||||
GITHUB_URL=https://api.github.com/repos/zama-ai/concrete-compiler-internal
|
||||
GITHUB_URL_LIST_ARTIFACTS="${GITHUB_URL}/actions/artifacts?name=${KEYSETCACHENAME}&per_page=1"
|
||||
CURL=curl -H"Accept: application/vnd.github.v3+json" -H"authorization: Bearer ${GITHUB_TOKEN}"
|
||||
keysetcache.zip: REDIRECT_URL = $(shell ${CURL} -s ${GITHUB_URL_LIST_ARTIFACTS} | grep archive_download_url | grep -o 'http[^"]\+')
|
||||
keysetcache.zip:
|
||||
${CURL} --location -o keysetcache.zip ${REDIRECT_URL}
|
||||
du -h keysetcache.zip
|
||||
|
||||
keysetcache_ci_populated: keysetcache.zip
|
||||
unzip keysetcache.zip -d ${KEYSETCACHECI}
|
||||
du -sh ${KEYSETCACHECI}
|
||||
rm keysetcache.zip
|
||||
|
||||
keysetcache_populated: keysetcache.zip
|
||||
unzip keysetcache.zip -d ${KEYSETCACHEDEV}
|
||||
du -sh ${KEYSETCACHEDEV}
|
||||
rm keysetcache.zip
|
||||
|
||||
|
||||
# test
|
||||
|
||||
build-tests: build-unit-tests build-end-to-end-tests
|
||||
|
||||
run-tests: run-check-tests run-unit-tests run-end-to-end-tests run-python-tests
|
||||
|
||||
## check-tests
|
||||
|
||||
run-check-tests: concretecompiler file-check not
|
||||
$(BUILD_DIR)/bin/llvm-lit -v tests/check_tests
|
||||
|
||||
## unit-tests
|
||||
|
||||
build-unit-tests: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangUnitTests
|
||||
|
||||
run-unit-tests: build-unit-tests
|
||||
find $(BUILD_DIR)/tools/concretelang/tests/unit_tests -name unit_tests_concretelang* $(FIND_EXECUTABLE_ARG) -type f | xargs -n1 ./run_test_bin.sh
|
||||
|
||||
## python-tests
|
||||
|
||||
run-python-tests: python-bindings concretecompiler
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs -m $(PYTHON_TESTS_MARKER) tests/python
|
||||
|
||||
test-compiler-file-output: concretecompiler
|
||||
pytest -vs tests/test_compiler_file_output
|
||||
|
||||
|
||||
## rust-tests
|
||||
run-rust-tests: rust-bindings
|
||||
cd lib/Bindings/Rust && \
|
||||
CONCRETE_COMPILER_INSTALL_DIR=$(INSTALL_PATH) \
|
||||
LD_LIBRARY_PATH=$(INSTALL_PATH)/lib \
|
||||
cargo test --release
|
||||
|
||||
## end-to-end-tests
|
||||
|
||||
build-end-to-end-jit-chunked-int: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_chunked_int
|
||||
|
||||
build-end-to-end-jit-test: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_test
|
||||
|
||||
build-end-to-end-test: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_test
|
||||
|
||||
build-end-to-end-jit-encrypted-tensor: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_encrypted_tensor
|
||||
|
||||
build-end-to-end-jit-fhelinalg: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_fhelinalg
|
||||
|
||||
build-end-to-end-jit-lambda: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda
|
||||
|
||||
build-end-to-end-tests: build-end-to-end-jit-chunked-int build-end-to-end-jit-test build-end-to-end-test build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-fhelinalg build-end-to-end-jit-lambda
|
||||
|
||||
### end-to-end-tests CPU
|
||||
|
||||
FIXTURE_CPU_DIR=tests/end_to_end_fixture/tests_cpu
|
||||
|
||||
$(FIXTURE_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py
|
||||
mkdir -p $(FIXTURE_CPU_DIR)
|
||||
$(Python3_EXECUTABLE) $< > $@
|
||||
|
||||
$(FIXTURE_CPU_DIR)/bug_report.yaml:
|
||||
unzip -o $(FIXTURE_CPU_DIR)/bug_report.zip -d $(FIXTURE_CPU_DIR)
|
||||
|
||||
generate-cpu-tests: $(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml $(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/bug_report.yaml $(FIXTURE_CPU_DIR)/end_to_end_round.yaml
|
||||
|
||||
SECURITY_TO_TEST=80 128
|
||||
run-end-to-end-tests: build-end-to-end-tests generate-cpu-tests
|
||||
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_jit_test
|
||||
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_jit_encrypted_tensor
|
||||
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_jit_fhelinalg
|
||||
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_jit_lambda
|
||||
$(foreach security,$(SECURITY_TO_TEST),$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_test \
|
||||
--backend=cpu --security-level=$(security) --jit $(FIXTURE_CPU_DIR)/*.yaml;)
|
||||
|
||||
### end-to-end-tests GPU
|
||||
|
||||
FIXTURE_GPU_DIR=tests/end_to_end_fixture/tests_gpu
|
||||
|
||||
$(FIXTURE_GPU_DIR):
|
||||
mkdir -p $(FIXTURE_GPU_DIR)
|
||||
|
||||
$(FIXTURE_GPU_DIR)/end_to_end_apply_lookup_table.yaml: tests/end_to_end_fixture/end_to_end_apply_lookup_table_gen.py
|
||||
$(Python3_EXECUTABLE) $< --bitwidth 1 2 3 4 5 6 7 > $@
|
||||
|
||||
$(FIXTURE_GPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml: tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py
|
||||
$(Python3_EXECUTABLE) $< --bitwidth 1 2 3 4 5 6 7 > $@
|
||||
|
||||
|
||||
generate-gpu-tests: $(FIXTURE_GPU_DIR) $(FIXTURE_GPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_GPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml
|
||||
|
||||
run-end-to-end-tests-gpu: build-end-to-end-test generate-gpu-tests
|
||||
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_test \
|
||||
--backend=gpu --library /tmp/concrete_compiler/gpu_tests/ \
|
||||
$(FIXTURE_GPU_DIR)/*.yaml
|
||||
|
||||
## end-to-end-dataflow-tests
|
||||
|
||||
build-end-to-end-dataflow-tests: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_auto_parallelization
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_distributed
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_aes_short
|
||||
|
||||
run-end-to-end-dataflow-tests: build-end-to-end-dataflow-tests
|
||||
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_jit_auto_parallelization
|
||||
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_jit_distributed
|
||||
|
||||
# benchmark
|
||||
|
||||
build-benchmarks: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_benchmark
|
||||
|
||||
## benchmark CPU
|
||||
|
||||
BENCHMARK_CPU_DIR=tests/end_to_end_fixture/benchmarks_cpu
|
||||
|
||||
$(BENCHMARK_CPU_DIR):
|
||||
mkdir -p $@
|
||||
|
||||
$(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml: tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py
|
||||
$(Python3_EXECUTABLE) $< --n-ct 64 128 1024 > $@
|
||||
|
||||
$(BENCHMARK_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py
|
||||
$(Python3_EXECUTABLE) $< > $@
|
||||
|
||||
generate-cpu-benchmarks: $(BENCHMARK_CPU_DIR) $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_apply_lookup_table.yaml
|
||||
|
||||
SECURITY_TO_BENCH=128
|
||||
run-cpu-benchmarks: build-benchmarks generate-cpu-benchmarks
|
||||
$(foreach security,$(SECURITY_TO_BENCH),$(BUILD_DIR)/bin/end_to_end_benchmark \
|
||||
--backend=cpu --security-level=$(security)\
|
||||
--benchmark_out=benchmarks_results.json --benchmark_out_format=json \
|
||||
$(BENCHMARK_CPU_DIR)/*.yaml;)
|
||||
|
||||
FIXTURE_APPLICATION_DIR=tests/end_to_end_fixture/application/
|
||||
|
||||
run-cpu-benchmarks-application:
|
||||
unzip $(FIXTURE_APPLICATION_DIR)/*.zip -d $(FIXTURE_APPLICATION_DIR)
|
||||
$(BUILD_DIR)/bin/end_to_end_benchmark \
|
||||
--backend=cpu --benchmark_out=benchmarks_results.json --benchmark_out_format=json \
|
||||
$(FIXTURE_APPLICATION_DIR)*.yaml
|
||||
|
||||
## benchmark GPU
|
||||
|
||||
BENCHMARK_GPU_DIR=tests/end_to_end_fixture/benchmarks_gpu
|
||||
|
||||
$(BENCHMARK_GPU_DIR):
|
||||
mkdir -p $@
|
||||
|
||||
$(BENCHMARK_GPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml: tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py
|
||||
$(Python3_EXECUTABLE) $< \
|
||||
--bitwidth 1 2 3 4 5 6 7 --n-ct 1 128 1024 2048 8192
|
||||
|
||||
|
||||
generate-gpu-benchmarks: $(BENCHMARK_GPU_DIR) $(BENCHMARK_GPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml
|
||||
|
||||
run-gpu-benchmarks: build-benchmarks generate-cpu-benchmarks
|
||||
$(BUILD_DIR)/bin/end_to_end_benchmark \
|
||||
--backend=gpu \
|
||||
--benchmark_out=benchmarks_results.json --benchmark_out_format=json \
|
||||
$(BENCHMARK_CPU_DIR)/*.yaml
|
||||
|
||||
|
||||
|
||||
build-mlbench: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_mlbench
|
||||
|
||||
generate-mlbench:
|
||||
mkdir -p tests/end_to_end_benchmarks/mlbench
|
||||
rm -rf tests/end_to_end_benchmarks/mlbench/*
|
||||
unzip tests/end_to_end_benchmarks/mlbench.zip -d tests/end_to_end_benchmarks/mlbench
|
||||
rm -f tests/end_to_end_benchmarks/mlbench/**/*\=*
|
||||
find tests/end_to_end_benchmarks/mlbench -name "*.mlir" -exec sed -e '1d' -e 's/ func / func.func /g' -e 's/ linalg.tensor_/ tensor./g' -e '$$d' -i {} \;
|
||||
$(Python3_EXECUTABLE) tests/end_to_end_benchmarks/generate_bench_yaml.py tests/end_to_end_benchmarks/mlbench tests/end_to_end_benchmarks/mlbench/end_to_end_mlbench
|
||||
|
||||
run-mlbench: build-mlbench generate-mlbench
|
||||
tests/end_to_end_benchmarks/end_to_end_mlbench.sh tests/end_to_end_benchmarks/mlbench/ $(BUILD_DIR)/bin/end_to_end_mlbench
|
||||
|
||||
run-mlbench-subset: build-mlbench generate-mlbench
|
||||
@[ "${ML_BENCH_SUBSET_ID}" ] || ( echo "ML_BENCH_SUBSET_ID is not set"; exit 1 )
|
||||
tests/end_to_end_benchmarks/end_to_end_mlbench.sh tests/end_to_end_benchmarks/mlbench/end_to_end_mlbench_$(ML_BENCH_SUBSET_ID).yaml $(BUILD_DIR)/bin/end_to_end_mlbench
|
||||
|
||||
show-stress-tests-summary:
|
||||
@echo '------ Stress tests summary ------'
|
||||
@echo
|
||||
@echo 'Rates:'
|
||||
@cd tests/stress_tests/trace && grep success_rate -R
|
||||
@echo
|
||||
@echo 'Parameters issues:'
|
||||
@cd tests/stress_tests/trace && grep BAD -R || echo 'No issues'
|
||||
|
||||
stress-tests: concretecompiler
|
||||
pytest -vs tests/stress_tests
|
||||
|
||||
# useful for faster cache generation, need pytest-parallel
|
||||
stress-tests-fast-cache: concretecompiler
|
||||
pytest --workers auto -vs tests/stress_tests
|
||||
|
||||
# LLVM/MLIR dependencies
|
||||
|
||||
all-deps: file-check not
|
||||
|
||||
file-check: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target FileCheck
|
||||
|
||||
not: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target not
|
||||
|
||||
mlir-cpu-runner: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target mlir-cpu-runner
|
||||
|
||||
opt: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target opt
|
||||
|
||||
mlir-opt: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target mlir-opt
|
||||
|
||||
mlir-translate: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target mlir-translate
|
||||
|
||||
update-python-version:
|
||||
echo "__version__ = \"`git describe --tags --abbrev=0 | grep -e '[0-9].*' -o`\"" > lib/Bindings/Python/version.txt
|
||||
|
||||
check-python-format:
|
||||
black --check tests/python/ lib/Bindings/Python/concrete/
|
||||
|
||||
python-format:
|
||||
black tests/python/ lib/Bindings/Python/concrete/
|
||||
|
||||
python-lint:
|
||||
pylint --rcfile=../pylintrc lib/Bindings/Python/concrete/compiler
|
||||
|
||||
check-rust-format:
|
||||
cd lib/Bindings/Rust && cargo fmt --check
|
||||
|
||||
rust-format:
|
||||
cd lib/Bindings/Rust && cargo fmt
|
||||
|
||||
# libraries we want to have in the installation that aren't already a deps of other targets
|
||||
install-deps:
|
||||
cmake --build $(BUILD_DIR) --target MLIRCAPIRegistration
|
||||
|
||||
ifeq ($(OS), darwin)
|
||||
# rsync should normally come pre-installed on macOS
|
||||
# and the --parents only exists for GNU's cp not BSD's cp
|
||||
HIERARCHY_PRESERVING_COPY=rsync -R
|
||||
else
|
||||
HIERARCHY_PRESERVING_COPY=cp --parents
|
||||
endif
|
||||
|
||||
ifeq ($(OS),Windows_NT)
|
||||
detected_OS := Windows
|
||||
else
|
||||
detected_OS := $(shell sh -c 'uname 2>/dev/null || echo Unknown')
|
||||
endif
|
||||
|
||||
PIP=$(Python3_EXECUTABLE) -m pip
|
||||
PIP_WHEEL=$(PIP) wheel --no-deps -w $(BUILD_DIR)/wheels .
|
||||
AUDIT_WHEEL_REPAIR=$(Python3_EXECUTABLE) -m auditwheel repair -w $(BUILD_DIR)/wheels
|
||||
|
||||
linux-python-package:
|
||||
$(PIP) install wheel auditwheel
|
||||
# We need to run it twice: the first will generate the directories, so that
|
||||
# the second run can find the packages via find_namespace_packages
|
||||
$(PIP_WHEEL)
|
||||
$(PIP_WHEEL)
|
||||
GLIBC_VER=$(shell ldd --version | head -n 1 | grep -o '[^ ]*$$'|head|tr '.' '_'); \
|
||||
for PLATFORM in manylinux_$${GLIBC_VER}_x86_64 linux_x86_64; do \
|
||||
if $(AUDIT_WHEEL_REPAIR) $(BUILD_DIR)/wheels/*.whl --plat $$PLATFORM; then \
|
||||
echo Success for $$PLATFORM; \
|
||||
break; \
|
||||
else \
|
||||
echo No repair with $$PLATFORM; \
|
||||
fi \
|
||||
done
|
||||
|
||||
darwin-python-package:
|
||||
$(PIP) install wheel delocate
|
||||
$(PIP_WHEEL)
|
||||
delocate-wheel -v $(BUILD_DIR)/wheels/*macosx*.whl
|
||||
|
||||
python-package: python-bindings $(OS)-python-package
|
||||
@echo The python package is: $(BUILD_DIR)/wheels/*.whl
|
||||
|
||||
install: concretecompiler concrete-optimizer-lib CAPI install-deps
|
||||
$(info Install prefix set to $(INSTALL_PREFIX))
|
||||
$(info Installing under $(INSTALL_PATH))
|
||||
mkdir -p $(INSTALL_PATH)/include
|
||||
cp -R $(abspath $(BUILD_DIR))/bin $(INSTALL_PATH)
|
||||
cp -R $(abspath $(BUILD_DIR))/lib $(INSTALL_PATH)
|
||||
cp $(LIB_CONCRETE_OPTIMIZER_CPP) $(INSTALL_PATH)/lib/
|
||||
cp $(CONCRETE_OPTIMIZER_DIR)/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp $(INSTALL_PATH)/include
|
||||
|
||||
# Doing find + grep + while loop is a way to have portable behaviour between macOS and GNU/Linux
|
||||
# as with `find . -regex "regex"`, the regex language is not the same / to have the same language, the
|
||||
# command changes (macOs: `find -E . -regex`, GNU: `find . -regextype posix-extended "regex")
|
||||
cd $(MAKEFILE_ROOT_DIR)/include && \
|
||||
find . | \
|
||||
grep "^.*\.\(h\|hpp\|td\)$$" | \
|
||||
while read filepath; do $(HIERARCHY_PRESERVING_COPY) $$filepath $(INSTALL_PATH)/include; done
|
||||
cd $(MAKEFILE_ROOT_DIR)/../llvm-project/llvm/include && \
|
||||
find . | \
|
||||
grep "^.*\.\(h\|hpp\|td\)$$" | \
|
||||
while read filepath; do $(HIERARCHY_PRESERVING_COPY) $$filepath $(INSTALL_PATH)/include; done
|
||||
cd $(MAKEFILE_ROOT_DIR)/../llvm-project/mlir/include && \
|
||||
find . | \
|
||||
grep "^.*\.\(h\|hpp\|td\)$$" | \
|
||||
while read filepath; do $(HIERARCHY_PRESERVING_COPY) $$filepath $(INSTALL_PATH)/include; done
|
||||
|
||||
cd $(abspath $(BUILD_DIR))/include && find . -iname '*.inc' -exec $(HIERARCHY_PRESERVING_COPY) {} $(INSTALL_PATH)/include \;
|
||||
cd $(abspath $(BUILD_DIR))/tools/concretelang/include && find . -iname '*.inc' -exec $(HIERARCHY_PRESERVING_COPY) {} $(INSTALL_PATH)/include \;
|
||||
cd $(abspath $(BUILD_DIR))/tools/mlir/include && find . -iname '*.inc' -exec $(HIERARCHY_PRESERVING_COPY) {} $(INSTALL_PATH)/include \;
|
||||
|
||||
.PHONY: build-initialized \
|
||||
build-end-to-end-jit \
|
||||
concretecompiler \
|
||||
python-bindings \
|
||||
add-deps \
|
||||
file-check \
|
||||
not \
|
||||
update-python-version \
|
||||
python-lint \
|
||||
python-format \
|
||||
check-python-format \
|
||||
concrete-optimizer-lib \
|
||||
build-tests \
|
||||
run-tests \
|
||||
run-check-tests \
|
||||
build-unit-tests \
|
||||
run-unit-tests \
|
||||
run-python-tests \
|
||||
build-end-to-end-tests \
|
||||
build-end-to-end-dataflow-tests \
|
||||
run-end-to-end-dataflow-tests \
|
||||
opt \
|
||||
mlir-opt \
|
||||
mlir-cpu-runner \
|
||||
mlir-translate
|
||||
142
compilers/concrete-compiler/compiler/README.md
Normal file
142
compilers/concrete-compiler/compiler/README.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# Concrete Compiler
|
||||
|
||||
The Concrete Compiler is a set of tools that allows the compilation and from an high-level and crypto free representation of an arithmetic circuit of operations on encrypted integers.
|
||||
This compiler is based on the [MLIR project](https://mlir.llvm.org/) it use the framework, the standard dialects exposed by MLIR and define new fhe specific dialects and passes to lower the high-level fhe dialects to standard MLIR dialects.
|
||||
|
||||
## Getting started
|
||||
|
||||
The source of the project is located in the `compiler` directory.
|
||||
|
||||
```sh
|
||||
cd compiler
|
||||
```
|
||||
|
||||
### Prerequisite: Building HPX and enable dataflow parallelism (optional)
|
||||
|
||||
In order to implement the dataflow parallelism and the distribution of the computation we use the [HPX Standard Library](https://hpx-docs.stellar-group.org/). You can else use your own HPX installation by set the `HPX_INSTALL_DIR` environment variable or you can install HPX on the default path of our build system thanks the following command:
|
||||
|
||||
```sh
|
||||
make install-hpx-from-source
|
||||
```
|
||||
|
||||
This may fail on some systems when dependencies are missing. Some recent packages required are Cmake, HWLOC and BOOST. For full details see [HPX Quickstart guide](https://hpx-docs.stellar-group.org/tags/1.7.1/html/quickstart.html).
|
||||
Once you have a proper installation of HPX to enable the dataflow parallelism set the `DATAFLOW_EXECUTION_ENABLED=ON`.
|
||||
|
||||
### Prerequisite: Fetch git submodules
|
||||
|
||||
This project rely on `llvm-project` and `concrete-optimizer` as git submodules so you need to initialize and update the git submodules.
|
||||
|
||||
```sh
|
||||
git submodule init
|
||||
git submodule update
|
||||
```
|
||||
|
||||
### Prerequisite: python packages
|
||||
|
||||
Install MLIR python requirements in your dev python environment:
|
||||
|
||||
```bash
|
||||
# From repo root
|
||||
pip install -r ./llvm-project/mlir/python/requirements.txt
|
||||
# From compiler dir
|
||||
pip install -r ../llvm-project/mlir/python/requirements.txt
|
||||
```
|
||||
|
||||
### Build from source
|
||||
|
||||
We use cmake as the main build system but in order to initialize the build system and define straightforward target for the main artifacts of the project. You can initialize and build all the main artifacts thanks the following command:
|
||||
|
||||
```sh
|
||||
make all
|
||||
```
|
||||
|
||||
or in several steps:
|
||||
|
||||
Generate the compiler build system, in a `build-*` directory
|
||||
|
||||
```sh
|
||||
make build-initialized
|
||||
```
|
||||
|
||||
Build the compiler
|
||||
|
||||
```sh
|
||||
make concretecompiler
|
||||
```
|
||||
|
||||
Run the compiler
|
||||
|
||||
```sh
|
||||
./build-Release/bin/concretecompiler
|
||||
```
|
||||
|
||||
### Installation from source
|
||||
|
||||
You can install libs, bins, and include files into a specific directory by running:
|
||||
|
||||
```sh
|
||||
make INSTALL_PREFIX=/your/directory install
|
||||
```
|
||||
|
||||
You will then find `lib`, `bin`, and `include` under `/your/directory/concretecompiler`.
|
||||
|
||||
### Tests
|
||||
|
||||
You can build all the tests with the following command:
|
||||
|
||||
```sh
|
||||
make build-tests
|
||||
```
|
||||
|
||||
and run them with:
|
||||
|
||||
```sh
|
||||
make run-tests
|
||||
```
|
||||
|
||||
### Benchmarks
|
||||
|
||||
You can build all the benchmarks with the following command:
|
||||
|
||||
```sh
|
||||
make build-benchmarks
|
||||
```
|
||||
|
||||
and run them with:
|
||||
|
||||
```sh
|
||||
make run-benchmarks
|
||||
```
|
||||
|
||||
## Build releases
|
||||
|
||||
### Build tarball
|
||||
|
||||
You can create a tarball containing libs, bins, and include files for the tools of the compiler, by following previous steps of [installation from source](#installation-from-source), then creating a tar archive from the installation directory.
|
||||
|
||||
### Build the Python Package
|
||||
|
||||
Currently supported platforms:
|
||||
- Linux x86_64 for python 3.7, 3.8, 3.9, and 3.10
|
||||
|
||||
pybind11 is required to build the python package, you can install it in your current environment with:
|
||||
|
||||
```bash
|
||||
$ pip install pybind11
|
||||
```
|
||||
|
||||
To specify which python executable to target you can specify the `Python3_EXECUTABLE` environment variable.
|
||||
|
||||
#### Build wheels in your environment
|
||||
|
||||
Building the wheels is actually simple.
|
||||
|
||||
```bash
|
||||
$ pip wheel --no-deps -w ../wheels .
|
||||
```
|
||||
|
||||
Depending on the platform you are using (specially Linux), you might need to use `auditwheel` to specify the platform this wheel is targeting. For example, in our build of the package for Linux x86_64 and GLIBC 2.24, we also run:
|
||||
|
||||
```bash
|
||||
$ auditwheel repair ../wheels/*.whl --plat manylinux_2_24_x86_64 -w ../wheels
|
||||
```
|
||||
3
compilers/concrete-compiler/compiler/RELEASE_README.md
Normal file
3
compilers/concrete-compiler/compiler/RELEASE_README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Concrete Compiler
|
||||
|
||||
The Concrete Compiler takes a high level computation model and produces a programs that evaluate the model in an homomorphic way.
|
||||
@@ -0,0 +1,9 @@
|
||||
include(AddMLIR)
|
||||
|
||||
function(add_concretelang_doc doc_filename output_file output_directory command)
|
||||
set(SAVED_MLIR_BINARY_DIR ${MLIR_BINARY_DIR})
|
||||
set(MLIR_BINARY_DIR ${CONCRETELANG_BINARY_DIR})
|
||||
add_mlir_doc(${doc_filename} ${output_file} ${output_directory} ${command} ${ARGN})
|
||||
set(MLIR_BINARY_DIR ${SAVED_MLIR_BINARY_DIR})
|
||||
unset(SAVED_MLIR_BINARY_DIR)
|
||||
endfunction()
|
||||
Submodule compilers/concrete-compiler/compiler/concrete-core added at bf79f5db63
Submodule compilers/concrete-compiler/compiler/concrete-cpu added at db262714cd
Submodule compilers/concrete-compiler/compiler/concrete-optimizer added at 85abbeadae
27
compilers/concrete-compiler/compiler/hpx.ini
Normal file
27
compilers/concrete-compiler/compiler/hpx.ini
Normal file
@@ -0,0 +1,27 @@
|
||||
[hpx]
|
||||
location = ${HPX_LOCATION:$[system.prefix]}
|
||||
component_path = $[hpx.location]/lib/hpx:$[system.executable_prefix]/lib/hpx:$[system.executable_prefix]/../lib/hpx
|
||||
master_ini_path = $[hpx.location]/share/hpx-<version>:$[system.executable_prefix]/share/hpx-<version>:$[system.executable_prefix]/../share/hpx-<version>
|
||||
ini_path = $[hpx.master_ini_path]/ini
|
||||
os_threads = 2
|
||||
localities = 1
|
||||
program_name =
|
||||
cmd_line =
|
||||
lock_detection = ${HPX_LOCK_DETECTION:0}
|
||||
throw_on_held_lock = ${HPX_THROW_ON_HELD_LOCK:1}
|
||||
minimal_deadlock_detection = <debug>
|
||||
spinlock_deadlock_detection = <debug>
|
||||
spinlock_deadlock_detection_limit = ${HPX_SPINLOCK_DEADLOCK_DETECTION_LIMIT:1000000}
|
||||
max_background_threads = ${HPX_MAX_BACKGROUND_THREADS:$[hpx.os_threads]}
|
||||
max_idle_loop_count = ${HPX_MAX_IDLE_LOOP_COUNT:<hpx_idle_loop_count_max>}
|
||||
max_busy_loop_count = ${HPX_MAX_BUSY_LOOP_COUNT:<hpx_busy_loop_count_max>}
|
||||
max_idle_backoff_time = ${HPX_MAX_IDLE_BACKOFF_TIME:<hpx_idle_backoff_time_max>}
|
||||
exception_verbosity = ${HPX_EXCEPTION_VERBOSITY:1}
|
||||
default_stack_size = 0x20000000
|
||||
|
||||
[hpx.stacks]
|
||||
small_size = 0x8000000
|
||||
medium_size = 0x10000000
|
||||
large_size = 0x20000000
|
||||
huge_size = 0x40000000
|
||||
use_guard_pages = ${HPX_THREAD_GUARD_PAGE:3}
|
||||
@@ -0,0 +1 @@
|
||||
add_subdirectory(concretelang)
|
||||
File diff suppressed because it is too large
Load Diff
14
compilers/concrete-compiler/compiler/include/boost/outcome.h
Normal file
14
compilers/concrete-compiler/compiler/include/boost/outcome.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_BOOST_OUTCOME_H
|
||||
#define CONCRETELANG_BOOST_OUTCOME_H
|
||||
|
||||
// https://github.com/ned14/outcome/raw/master/single-header/outcome.hpp
|
||||
#include "boost-single-header/outcome.hpp"
|
||||
|
||||
namespace outcome = outcome_v2_e261cebd;
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,49 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_DIALECT_FHE_H
|
||||
#define CONCRETELANG_C_DIALECT_FHE_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// \brief structure to return an MlirType or report that there was an error
|
||||
/// during type creation.
|
||||
typedef struct {
|
||||
MlirType type;
|
||||
bool isError;
|
||||
} MlirTypeOrError;
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHE, fhe);
|
||||
|
||||
/// Creates an encrypted integer type of `width` bits
|
||||
MLIR_CAPI_EXPORTED MlirTypeOrError
|
||||
fheEncryptedIntegerTypeGetChecked(MlirContext context, unsigned width);
|
||||
|
||||
/// If the type is an EncryptedInteger
|
||||
MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedIntegerType(MlirType);
|
||||
|
||||
/// Creates an encrypted signed integer type of `width` bits
|
||||
MLIR_CAPI_EXPORTED MlirTypeOrError
|
||||
fheEncryptedSignedIntegerTypeGetChecked(MlirContext context, unsigned width);
|
||||
|
||||
/// If the type is an EncryptedSignedInteger
|
||||
MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedSignedIntegerType(MlirType);
|
||||
|
||||
/// \brief Get bitwidth of the encrypted integer type.
|
||||
///
|
||||
/// \return bitwidth of the encrypted integer or 0 if it's not an encrypted
|
||||
/// integer
|
||||
MLIR_CAPI_EXPORTED unsigned fheTypeIntegerWidthGet(MlirType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_DIALECT_FHE_H
|
||||
@@ -0,0 +1,22 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
#define CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
@@ -0,0 +1,399 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
#define CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// The CAPI should be really careful about memory allocation. Every pointer
|
||||
/// returned should points to a new buffer allocated for the purpose of the
|
||||
/// CAPI, and should have a respective destructor function.
|
||||
|
||||
/// Opaque type declarations. Inspired from
|
||||
/// llvm-project/mlir/include/mlir-c/IR.h
|
||||
///
|
||||
/// Adds an error pointer to an allocated buffer holding the error message if
|
||||
/// any.
|
||||
#define DEFINE_C_API_STRUCT(name, storage) \
|
||||
struct name { \
|
||||
storage *ptr; \
|
||||
const char *error; \
|
||||
}; \
|
||||
typedef struct name name
|
||||
|
||||
DEFINE_C_API_STRUCT(CompilerEngine, void);
|
||||
DEFINE_C_API_STRUCT(CompilationContext, void);
|
||||
DEFINE_C_API_STRUCT(CompilationResult, void);
|
||||
DEFINE_C_API_STRUCT(Library, void);
|
||||
DEFINE_C_API_STRUCT(LibraryCompilationResult, void);
|
||||
DEFINE_C_API_STRUCT(LibrarySupport, void);
|
||||
DEFINE_C_API_STRUCT(CompilationOptions, void);
|
||||
DEFINE_C_API_STRUCT(OptimizerConfig, void);
|
||||
DEFINE_C_API_STRUCT(ServerLambda, void);
|
||||
DEFINE_C_API_STRUCT(Encoding, void);
|
||||
DEFINE_C_API_STRUCT(EncryptionGate, void);
|
||||
DEFINE_C_API_STRUCT(CircuitGate, void);
|
||||
DEFINE_C_API_STRUCT(ClientParameters, void);
|
||||
DEFINE_C_API_STRUCT(KeySet, void);
|
||||
DEFINE_C_API_STRUCT(KeySetCache, void);
|
||||
DEFINE_C_API_STRUCT(EvaluationKeys, void);
|
||||
DEFINE_C_API_STRUCT(LambdaArgument, void);
|
||||
DEFINE_C_API_STRUCT(PublicArguments, void);
|
||||
DEFINE_C_API_STRUCT(PublicResult, void);
|
||||
DEFINE_C_API_STRUCT(CompilationFeedback, void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
|
||||
/// NULL Pointer checkers. Generate functions to check if the struct contains a
|
||||
/// null pointer.
|
||||
#define DEFINE_NULL_PTR_CHECKER(funcname, storage) \
|
||||
bool funcname(storage s) { return s.ptr == NULL; }
|
||||
|
||||
DEFINE_NULL_PTR_CHECKER(compilerEngineIsNull, CompilerEngine)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationContextIsNull, CompilationContext)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationResultIsNull, CompilationResult)
|
||||
DEFINE_NULL_PTR_CHECKER(libraryIsNull, Library)
|
||||
DEFINE_NULL_PTR_CHECKER(libraryCompilationResultIsNull,
|
||||
LibraryCompilationResult)
|
||||
DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions)
|
||||
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig)
|
||||
DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda)
|
||||
DEFINE_NULL_PTR_CHECKER(circuitGateIsNull, CircuitGate)
|
||||
DEFINE_NULL_PTR_CHECKER(encodingIsNull, Encoding)
|
||||
DEFINE_NULL_PTR_CHECKER(encryptionGateIsNull, EncryptionGate)
|
||||
DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters)
|
||||
DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet)
|
||||
DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache)
|
||||
DEFINE_NULL_PTR_CHECKER(evaluationKeysIsNull, EvaluationKeys)
|
||||
DEFINE_NULL_PTR_CHECKER(lambdaArgumentIsNull, LambdaArgument)
|
||||
DEFINE_NULL_PTR_CHECKER(publicArgumentsIsNull, PublicArguments)
|
||||
DEFINE_NULL_PTR_CHECKER(publicResultIsNull, PublicResult)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationFeedbackIsNull, CompilationFeedback)
|
||||
|
||||
#undef DEFINE_NULL_PTR_CHECKER
|
||||
|
||||
/// Each struct has a creator function that allocates memory for the underlying
|
||||
/// Cpp object referenced, and a destroy function that does free this allocated
|
||||
/// memory.
|
||||
|
||||
/// ********** Utilities *******************************************************
|
||||
|
||||
/// Destroy string references created by the compiler.
|
||||
///
|
||||
/// This is not supposed to destroy any string ref, but only the ones we have
|
||||
/// allocated memory for and know how to free.
|
||||
MLIR_CAPI_EXPORTED void mlirStringRefDestroy(MlirStringRef str);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirStringRefIsNull(MlirStringRef str) {
|
||||
return str.data == NULL;
|
||||
}
|
||||
|
||||
/// ********** BufferRef CAPI **************************************************
|
||||
|
||||
/// A struct for binary buffers.
|
||||
///
|
||||
/// Contraty to MlirStringRef, it doesn't assume the pointer point to a null
|
||||
/// terminated string and the data should be considered as is in binary form.
|
||||
/// Useful for serialized objects.
|
||||
typedef struct BufferRef {
|
||||
const char *data;
|
||||
size_t length;
|
||||
const char *error;
|
||||
} BufferRef;
|
||||
|
||||
MLIR_CAPI_EXPORTED void bufferRefDestroy(BufferRef buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool bufferRefIsNull(BufferRef buffer) {
|
||||
return buffer.data == NULL;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef bufferRefCreate(const char *buffer, size_t length);
|
||||
|
||||
/// ********** CompilationTarget CAPI ******************************************
|
||||
|
||||
enum CompilationTarget {
|
||||
ROUND_TRIP,
|
||||
FHE,
|
||||
TFHE,
|
||||
CONCRETE,
|
||||
STD,
|
||||
LLVM,
|
||||
LLVM_IR,
|
||||
OPTIMIZED_LLVM_IR,
|
||||
LIBRARY
|
||||
};
|
||||
typedef enum CompilationTarget CompilationTarget;
|
||||
|
||||
/// ********** CompilationOptions CAPI *****************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate(
|
||||
MlirStringRef funcName, bool autoParallelize, bool batchConcreteOps,
|
||||
bool dataflowParallelize, bool emitGPUOps, bool loopParallelize,
|
||||
bool optimizeTFHE, OptimizerConfig optimizerConfig, bool verifyDiagnostics);
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreateDefault();
|
||||
|
||||
MLIR_CAPI_EXPORTED void compilationOptionsDestroy(CompilationOptions options);
|
||||
|
||||
/// ********** OptimizerConfig CAPI ********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED OptimizerConfig
|
||||
optimizerConfigCreate(bool display, double fallback_log_norm_woppbs,
|
||||
double global_p_error, double p_error, uint64_t security,
|
||||
bool strategy_v0, bool use_gpu_constraints);
|
||||
|
||||
MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreateDefault();
|
||||
|
||||
MLIR_CAPI_EXPORTED void optimizerConfigDestroy(OptimizerConfig config);
|
||||
|
||||
/// ********** CompilerEngine CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilerEngine compilerEngineCreate();
|
||||
|
||||
MLIR_CAPI_EXPORTED void compilerEngineDestroy(CompilerEngine engine);
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilationResult compilerEngineCompile(
|
||||
CompilerEngine engine, MlirStringRef module, CompilationTarget target);
|
||||
|
||||
MLIR_CAPI_EXPORTED void
|
||||
compilerEngineCompileSetOptions(CompilerEngine engine,
|
||||
CompilationOptions options);
|
||||
|
||||
/// ********** CompilationResult CAPI ******************************************
|
||||
|
||||
/// Get a string reference holding the textual representation of the compiled
|
||||
/// module. The returned `MlirStringRef` should be destroyed using
|
||||
/// `mlirStringRefDestroy` to free memory.
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
compilationResultGetModuleString(CompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED void compilationResultDestroy(CompilationResult result);
|
||||
|
||||
/// ********** Library CAPI ****************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED Library libraryCreate(MlirStringRef outputDirPath,
|
||||
MlirStringRef runtimeLibraryPath,
|
||||
bool cleanUp);
|
||||
|
||||
MLIR_CAPI_EXPORTED void libraryDestroy(Library lib);
|
||||
|
||||
/// ********** LibraryCompilationResult CAPI ***********************************
|
||||
|
||||
MLIR_CAPI_EXPORTED void
|
||||
libraryCompilationResultDestroy(LibraryCompilationResult result);
|
||||
|
||||
/// ********** LibrarySupport CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED LibrarySupport
|
||||
librarySupportCreate(MlirStringRef outputDirPath,
|
||||
MlirStringRef runtimeLibraryPath, bool generateSharedLib,
|
||||
bool generateStaticLib, bool generateClientParameters,
|
||||
bool generateCompilationFeedback, bool generateCppHeader);
|
||||
|
||||
MLIR_CAPI_EXPORTED LibrarySupport librarySupportCreateDefault(
|
||||
MlirStringRef outputDirPath, MlirStringRef runtimeLibraryPath) {
|
||||
return librarySupportCreate(outputDirPath, runtimeLibraryPath, true, true,
|
||||
true, true, true);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile(
|
||||
LibrarySupport support, MlirStringRef module, CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED ServerLambda librarySupportLoadServerLambda(
|
||||
LibrarySupport support, LibraryCompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED ClientParameters librarySupportLoadClientParameters(
|
||||
LibrarySupport support, LibraryCompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilationFeedback librarySupportLoadCompilationFeedback(
|
||||
LibrarySupport support, LibraryCompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicResult
|
||||
librarySupportServerCall(LibrarySupport support, ServerLambda server,
|
||||
PublicArguments args, EvaluationKeys evalKeys);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
librarySupportGetSharedLibPath(LibrarySupport support);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
librarySupportGetClientParametersPath(LibrarySupport support);
|
||||
|
||||
MLIR_CAPI_EXPORTED void librarySupportDestroy(LibrarySupport support);
|
||||
|
||||
/// ********** ServerLamda CAPI ************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server);
|
||||
|
||||
/// ********** ClientParameters CAPI *******************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef clientParametersSerialize(ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED ClientParameters
|
||||
clientParametersUnserialize(BufferRef buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED ClientParameters
|
||||
clientParametersCopy(ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params);
|
||||
|
||||
/// Returns the number of output circuit gates
|
||||
MLIR_CAPI_EXPORTED size_t clientParametersOutputsSize(ClientParameters params);
|
||||
|
||||
/// Returns the number of input circuit gates
|
||||
MLIR_CAPI_EXPORTED size_t clientParametersInputsSize(ClientParameters params);
|
||||
|
||||
/// Returns the output circuit gate corresponding to the index
|
||||
///
|
||||
/// - `index` must be valid.
|
||||
MLIR_CAPI_EXPORTED CircuitGate
|
||||
clientParametersOutputCircuitGate(ClientParameters params, size_t index);
|
||||
|
||||
/// Returns the input circuit gate corresponding to the index
|
||||
///
|
||||
/// - `index` must be valid.
|
||||
MLIR_CAPI_EXPORTED CircuitGate
|
||||
clientParametersInputCircuitGate(ClientParameters params, size_t index);
|
||||
|
||||
/// Returns the EncryptionGate of the circuit gate.
|
||||
///
|
||||
/// - The returned gate will be null if the gate does not represent encrypted
|
||||
/// data
|
||||
MLIR_CAPI_EXPORTED EncryptionGate
|
||||
circuitGateEncryptionGate(CircuitGate circuit_gate);
|
||||
|
||||
/// Returns the variance of the encryption gate
|
||||
MLIR_CAPI_EXPORTED double
|
||||
encryptionGateVariance(EncryptionGate encryption_gate);
|
||||
|
||||
/// Returns the Encoding of the encryption gate.
|
||||
MLIR_CAPI_EXPORTED Encoding
|
||||
encryptionGateEncoding(EncryptionGate encryption_gate);
|
||||
|
||||
/// Returns the precision (bit width) of the encoding
|
||||
MLIR_CAPI_EXPORTED uint64_t encodingPrecision(Encoding encoding);
|
||||
|
||||
MLIR_CAPI_EXPORTED void circuitGateDestroy(CircuitGate gate);
|
||||
MLIR_CAPI_EXPORTED void encryptionGateDestroy(EncryptionGate gate);
|
||||
MLIR_CAPI_EXPORTED void encodingDestroy(Encoding encoding);
|
||||
|
||||
/// ********** KeySet CAPI *****************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
MLIR_CAPI_EXPORTED EvaluationKeys keySetGetEvaluationKeys(KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED void keySetDestroy(KeySet keySet);
|
||||
|
||||
/// ********** KeySetCache CAPI ************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySetCache keySetCacheCreate(MlirStringRef cachePath);
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySet
|
||||
keySetCacheLoadOrGenerateKeySet(KeySetCache cache, ClientParameters params,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
MLIR_CAPI_EXPORTED void keySetCacheDestroy(KeySetCache keySetCache);
|
||||
|
||||
/// ********** EvaluationKeys CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef evaluationKeysSerialize(EvaluationKeys keys);
|
||||
|
||||
MLIR_CAPI_EXPORTED EvaluationKeys evaluationKeysUnserialize(BufferRef buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
|
||||
|
||||
/// ********** LambdaArgument CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(
|
||||
const uint8_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(
|
||||
const uint16_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(
|
||||
const uint32_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(
|
||||
const uint64_t *data, const int64_t *dims, size_t rank);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg,
|
||||
uint64_t *buffer);
|
||||
MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED int64_t
|
||||
lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg,
|
||||
int64_t *buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicArguments
|
||||
lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber,
|
||||
ClientParameters params, KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED void lambdaArgumentDestroy(LambdaArgument lambdaArg);
|
||||
|
||||
/// ********** PublicArguments CAPI ********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef publicArgumentsSerialize(PublicArguments args);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicArguments
|
||||
publicArgumentsUnserialize(BufferRef buffer, ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs);
|
||||
|
||||
/// ********** PublicResult CAPI ***********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument publicResultDecrypt(PublicResult publicResult,
|
||||
KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef publicResultSerialize(PublicResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicResult
|
||||
publicResultUnserialize(BufferRef buffer, ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED void publicResultDestroy(PublicResult publicResult);
|
||||
|
||||
/// ********** CompilationFeedback CAPI ****************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED double
|
||||
compilationFeedbackGetComplexity(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED double
|
||||
compilationFeedbackGetPError(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED double
|
||||
compilationFeedbackGetGlobalPError(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalSecretKeysSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalBootstrapKeysSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalKeyswitchKeysSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalInputsSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalOutputsSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED void
|
||||
compilationFeedbackDestroy(CompilationFeedback feedback);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
@@ -0,0 +1,21 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_BINDINGS_PYTHON_COMPILER_API_MODULE_H
|
||||
#define CONCRETELANG_BINDINGS_PYTHON_COMPILER_API_MODULE_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace python {
|
||||
|
||||
void populateCompilerAPISubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // CONCRETELANG_BINDINGS_PYTHON_COMPILER_API_MODULE_H
|
||||
@@ -0,0 +1,182 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
|
||||
#define CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
|
||||
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/LibrarySupport.h"
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
/// MLIR_CAPI_EXPORTED is used here throughout the API, because of the way the
|
||||
/// python extension is built using MLIR cmake functions, which will cause
|
||||
/// undefined symbols during runtime if those aren't present.
|
||||
|
||||
/// Wrapper of the mlir::concretelang::LambdaArgument
|
||||
struct lambdaArgument {
|
||||
std::shared_ptr<mlir::concretelang::LambdaArgument> ptr;
|
||||
};
|
||||
typedef struct lambdaArgument lambdaArgument;
|
||||
|
||||
/// Hold a list of lambdaArgument to represent execution arguments
|
||||
struct executionArguments {
|
||||
lambdaArgument *data;
|
||||
size_t size;
|
||||
};
|
||||
typedef struct executionArguments executionArguments;
|
||||
|
||||
// JIT Support bindings ///////////////////////////////////////////////////////
|
||||
|
||||
struct JITSupport_Py {
|
||||
mlir::concretelang::JITSupport support;
|
||||
};
|
||||
typedef struct JITSupport_Py JITSupport_Py;
|
||||
|
||||
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITSupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
jit_load_compilation_feedback(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
struct LibrarySupport_Py {
|
||||
mlir::concretelang::LibrarySupport support;
|
||||
};
|
||||
typedef struct LibrarySupport_Py LibrarySupport_Py;
|
||||
|
||||
MLIR_CAPI_EXPORTED LibrarySupport_Py
|
||||
library_support(const char *outputPath, const char *runtimeLibraryPath,
|
||||
bool generateSharedLib, bool generateStaticLib,
|
||||
bool generateClientParameters, bool generateCompilationFeedback,
|
||||
bool generateCppHeader);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibrarySupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
library_load_client_parameters(LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
library_load_compilation_feedback(
|
||||
LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
library_server_call(LibrarySupport_Py support,
|
||||
concretelang::serverlib::ServerLambda lambda,
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_shared_lib_path(LibrarySupport_Py support);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_client_parameters_path(LibrarySupport_Py support);
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<concretelang::clientlib::KeySetCache> cache);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args);
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
// Serialization ////////////////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
clientParametersUnserialize(const std::string &json);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
publicArgumentsUnserialize(
|
||||
mlir::concretelang::ClientParameters &clientParameters,
|
||||
const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize(
|
||||
concretelang::clientlib::PublicArguments &publicArguments);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
|
||||
const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
|
||||
evaluationKeysUnserialize(const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
/// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
/// Terminate/Init dataflow parallelization
|
||||
MLIR_CAPI_EXPORTED void terminateDataflowParallelization();
|
||||
MLIR_CAPI_EXPORTED void initDataflowParallelization();
|
||||
|
||||
/// Create a lambdaArgument from a tensor of different data types
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
|
||||
std::vector<uint8_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16(
|
||||
std::vector<uint16_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32(
|
||||
std::vector<uint32_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64(
|
||||
std::vector<uint64_t> data, std::vector<int64_t> dimensions);
|
||||
/// Create a lambdaArgument from a scalar
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar);
|
||||
/// Check if a lambdaArgument holds a tensor
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg);
|
||||
/// Get tensor data from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED std::vector<uint64_t>
|
||||
lambdaArgumentGetTensorData(lambdaArgument &lambda_arg);
|
||||
/// Get tensor dimensions from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg);
|
||||
/// Check if a lambdaArgument holds a scalar
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg);
|
||||
/// Get scalar value from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg);
|
||||
|
||||
/// Compile the textual representation of MLIR modules to a library.
|
||||
MLIR_CAPI_EXPORTED std::string library(std::string libraryPath,
|
||||
std::vector<std::string> modules);
|
||||
|
||||
#endif // CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
|
||||
@@ -0,0 +1,21 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_BINDINGS_PYTHON_DIALECTMODULES_H
|
||||
#define CONCRETELANG_BINDINGS_PYTHON_DIALECTMODULES_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace python {
|
||||
|
||||
void populateDialectFHESubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // CONCRETELANG_BINDINGS_PYTHON_DIALECTMODULES_H
|
||||
@@ -0,0 +1,70 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CAPI_WRAPPERS_H
|
||||
#define CONCRETELANG_CAPI_WRAPPERS_H
|
||||
|
||||
#include "concretelang-c/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/LibrarySupport.h"
|
||||
|
||||
/// Add a mechanism to go from Cpp objects to C-struct, with the ability to
|
||||
/// represent errors. Also the other way arround.
|
||||
#define DEFINE_C_API_PTR_METHODS_WITH_ERROR(name, cpptype) \
|
||||
static inline name wrap(cpptype *cpp) { return name{cpp, (char *)NULL}; } \
|
||||
static inline name wrap(cpptype *cpp, std::string errorStr) { \
|
||||
char *error = new char[errorStr.size()]; \
|
||||
strcpy(error, errorStr.c_str()); \
|
||||
return name{(cpptype *)NULL, error}; \
|
||||
} \
|
||||
static inline cpptype *unwrap(name c) { \
|
||||
return static_cast<cpptype *>(c.ptr); \
|
||||
} \
|
||||
static inline const char *getErrorPtr(name c) { return c.error; }
|
||||
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilerEngine,
|
||||
mlir::concretelang::CompilerEngine)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationContext,
|
||||
mlir::concretelang::CompilationContext)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
|
||||
CompilationResult, mlir::concretelang::CompilerEngine::CompilationResult)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(Library,
|
||||
mlir::concretelang::CompilerEngine::Library)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
|
||||
LibraryCompilationResult, mlir::concretelang::LibraryCompilationResult)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(LibrarySupport,
|
||||
mlir::concretelang::LibrarySupport)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationOptions,
|
||||
mlir::concretelang::CompilationOptions)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(OptimizerConfig,
|
||||
mlir::concretelang::optimizer::Config)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(ServerLambda,
|
||||
mlir::concretelang::serverlib::ServerLambda)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
|
||||
ClientParameters, mlir::concretelang::clientlib::ClientParameters)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(KeySet,
|
||||
mlir::concretelang::clientlib::KeySet)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(KeySetCache,
|
||||
mlir::concretelang::clientlib::KeySetCache)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
|
||||
EvaluationKeys, mlir::concretelang::clientlib::EvaluationKeys)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(LambdaArgument,
|
||||
mlir::concretelang::LambdaArgument)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
|
||||
PublicArguments, mlir::concretelang::clientlib::PublicArguments)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(PublicResult,
|
||||
mlir::concretelang::clientlib::PublicResult)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationFeedback,
|
||||
mlir::concretelang::CompilationFeedback)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(Encoding,
|
||||
mlir::concretelang::clientlib::Encoding)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
|
||||
EncryptionGate, mlir::concretelang::clientlib::EncryptionGate)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CircuitGate,
|
||||
mlir::concretelang::clientlib::CircuitGate)
|
||||
|
||||
#undef DEFINE_C_API_PTR_METHODS_WITH_ERROR
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,4 @@
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Interfaces)
|
||||
@@ -0,0 +1,46 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CRT_H_
|
||||
#define CONCRETELANG_CLIENTLIB_CRT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
namespace crt {
|
||||
|
||||
/// Compute the product of the moduli of the crt decomposition.
|
||||
///
|
||||
/// \param moduli The moduli of the crt decomposition
|
||||
/// \returns The product of moduli
|
||||
uint64_t productOfModuli(std::vector<int64_t> moduli);
|
||||
|
||||
/// Compute the crt decomposition of a `val` according the given `moduli`.
|
||||
///
|
||||
/// \param moduli The moduli to compute the decomposition.
|
||||
/// \param val The value to decompose.
|
||||
/// \returns The remainders.
|
||||
std::vector<int64_t> crt(std::vector<int64_t> moduli, uint64_t val);
|
||||
|
||||
/// Compute the inverse of the crt decomposition.
|
||||
///
|
||||
/// \param moduli The moduli used to compute the inverse decomposition.
|
||||
/// \param remainders The remainders of the decomposition.
|
||||
uint64_t iCrt(std::vector<int64_t> moduli, std::vector<int64_t> remainders);
|
||||
|
||||
/// Encode the plaintext with the given modulus and the product of moduli of the
|
||||
/// crt decomposition
|
||||
uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product);
|
||||
|
||||
/// Decode follow the crt encoding
|
||||
uint64_t decode(uint64_t val, uint64_t modulus);
|
||||
|
||||
} // namespace crt
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,136 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
|
||||
#define CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
using scalar_in = uint8_t;
|
||||
using scalar_out = uint64_t;
|
||||
using tensor1_in = std::vector<scalar_in>;
|
||||
using tensor2_in = std::vector<std::vector<scalar_in>>;
|
||||
using tensor3_in = std::vector<std::vector<std::vector<scalar_in>>>;
|
||||
using tensor1_out = std::vector<scalar_out>;
|
||||
using tensor2_out = std::vector<std::vector<scalar_out>>;
|
||||
using tensor3_out = std::vector<std::vector<std::vector<scalar_out>>>;
|
||||
|
||||
/// Low-level class to create the client side view of a FHE function.
|
||||
class ClientLambda {
|
||||
public:
|
||||
virtual ~ClientLambda() = default;
|
||||
|
||||
/// Construct a ClientLambda from a ClientParameter file.
|
||||
static outcome::checked<ClientLambda, StringError> load(std::string funcName,
|
||||
std::string jsonPath);
|
||||
|
||||
/// Generate or get from cache a KeySet suitable for this ClientLambda
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
keySet(std::shared_ptr<KeySetCache> optionalCache, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
decryptReturnedValues(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
decryptReturnedScalar(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
decryptReturnedTensor1(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
decryptReturnedTensor2(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
decryptReturnedTensor3(KeySet &keySet, PublicResult &result);
|
||||
|
||||
public:
|
||||
ClientParameters clientParameters;
|
||||
};
|
||||
|
||||
template <typename Result>
|
||||
outcome::checked<Result, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
PublicResult &result);
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
class TypedClientLambda : public ClientLambda {
|
||||
|
||||
public:
|
||||
static outcome::checked<TypedClientLambda<Result, Args...>, StringError>
|
||||
load(std::string funcName, std::string jsonPath) {
|
||||
OUTCOME_TRY(auto lambda, ClientLambda::load(funcName, jsonPath));
|
||||
return TypedClientLambda(lambda);
|
||||
}
|
||||
|
||||
/// Emit a call on this lambda to a binary ostream.
|
||||
/// The ostream is responsible for transporting the call to a
|
||||
/// ServerLambda::real_call_write function. ostream must be in binary mode
|
||||
/// std::ios_base::openmode::binary
|
||||
outcome::checked<void, StringError>
|
||||
serializeCall(Args... args, KeySet &keySet, std::ostream &ostream) {
|
||||
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
|
||||
return publicArguments->serialize(ostream);
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
publicArguments(Args... args, KeySet &keySet) {
|
||||
OUTCOME_TRY(auto clientArguments,
|
||||
EncryptedArguments::create(keySet, args...));
|
||||
|
||||
return clientArguments->exportPublicArguments(clientParameters);
|
||||
}
|
||||
|
||||
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,
|
||||
PublicResult &result) {
|
||||
return topLevelDecryptResult<Result>((*this), keySet, result);
|
||||
}
|
||||
|
||||
TypedClientLambda(ClientLambda &lambda) : ClientLambda(lambda) {
|
||||
// TODO: check parameter types
|
||||
// TODO: add static check on types vs lambda inputs/outpus
|
||||
}
|
||||
|
||||
protected:
|
||||
// Workaround, gcc 6 does not support partial template specialisation in class
|
||||
template <typename Result_>
|
||||
friend outcome::checked<Result_, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
PublicResult &result);
|
||||
};
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
|
||||
PublicResult &result);
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
PublicResult &result);
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
PublicResult &result);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,347 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
#include <llvm/Support/JSON.h>
|
||||
|
||||
namespace concretelang {
|
||||
|
||||
inline size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
if (exactBitWidth <= 8)
|
||||
return 8;
|
||||
if (exactBitWidth <= 16)
|
||||
return 16;
|
||||
if (exactBitWidth <= 32)
|
||||
return 32;
|
||||
if (exactBitWidth <= 64)
|
||||
return 64;
|
||||
assert(false && "Bit witdh > 64 not supported");
|
||||
}
|
||||
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
const uint64_t SMALL_KEY = 1;
|
||||
const uint64_t BIG_KEY = 0;
|
||||
|
||||
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";
|
||||
|
||||
typedef uint64_t DecompositionLevelCount;
|
||||
typedef uint64_t DecompositionBaseLog;
|
||||
typedef uint64_t PolynomialSize;
|
||||
typedef uint64_t Precision;
|
||||
typedef double Variance;
|
||||
typedef std::vector<int64_t> CRTDecomposition;
|
||||
|
||||
typedef uint64_t LweDimension;
|
||||
typedef uint64_t GlweDimension;
|
||||
|
||||
typedef uint64_t LweSecretKeyID;
|
||||
struct LweSecretKeyParam {
|
||||
LweDimension dimension;
|
||||
|
||||
void hash(size_t &seed);
|
||||
inline uint64_t lweDimension() { return dimension; }
|
||||
inline uint64_t lweSize() { return dimension + 1; }
|
||||
inline uint64_t byteSize() { return lweSize() * 8; }
|
||||
};
|
||||
static bool operator==(const LweSecretKeyParam &lhs,
|
||||
const LweSecretKeyParam &rhs) {
|
||||
return lhs.dimension == rhs.dimension;
|
||||
}
|
||||
|
||||
typedef uint64_t BootstrapKeyID;
|
||||
struct BootstrapKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
Variance variance;
|
||||
PolynomialSize polynomialSize;
|
||||
LweDimension inputLweDimension;
|
||||
|
||||
void hash(size_t &seed);
|
||||
|
||||
uint64_t byteSize(uint64_t inputLweSize, uint64_t outputLweSize) {
|
||||
return inputLweSize * level * (glweDimension + 1) * (glweDimension + 1) *
|
||||
outputLweSize * 8;
|
||||
}
|
||||
};
|
||||
static inline bool operator==(const BootstrapKeyParam &lhs,
|
||||
const BootstrapKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance;
|
||||
}
|
||||
|
||||
typedef uint64_t KeyswitchKeyID;
|
||||
struct KeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
|
||||
size_t byteSize(size_t inputLweSize, size_t outputLweSize) {
|
||||
return level * inputLweSize * outputLweSize * 8;
|
||||
}
|
||||
};
|
||||
static inline bool operator==(const KeyswitchKeyParam &lhs,
|
||||
const KeyswitchKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.variance == rhs.variance;
|
||||
}
|
||||
|
||||
typedef uint64_t PackingKeyswitchKeyID;
|
||||
struct PackingKeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
PolynomialSize polynomialSize;
|
||||
LweDimension inputLweDimension;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
static inline bool operator==(const PackingKeyswitchKeyParam &lhs,
|
||||
const PackingKeyswitchKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.glweDimension == rhs.glweDimension &&
|
||||
lhs.polynomialSize == rhs.polynomialSize &&
|
||||
lhs.variance == lhs.variance &&
|
||||
lhs.inputLweDimension == rhs.inputLweDimension;
|
||||
}
|
||||
|
||||
struct Encoding {
|
||||
Precision precision;
|
||||
CRTDecomposition crt;
|
||||
bool isSigned;
|
||||
};
|
||||
static inline bool operator==(const Encoding &lhs, const Encoding &rhs) {
|
||||
return lhs.precision == rhs.precision && lhs.isSigned == rhs.isSigned;
|
||||
}
|
||||
|
||||
struct EncryptionGate {
|
||||
LweSecretKeyID secretKeyID;
|
||||
Variance variance;
|
||||
Encoding encoding;
|
||||
};
|
||||
static inline bool operator==(const EncryptionGate &lhs,
|
||||
const EncryptionGate &rhs) {
|
||||
return lhs.secretKeyID == rhs.secretKeyID && lhs.variance == rhs.variance &&
|
||||
lhs.encoding == rhs.encoding;
|
||||
}
|
||||
|
||||
struct CircuitGateShape {
|
||||
/// Width of the scalar value
|
||||
uint64_t width;
|
||||
/// Dimensions of the tensor, empty if scalar
|
||||
std::vector<int64_t> dimensions;
|
||||
/// Size of the buffer containing the tensor
|
||||
uint64_t size;
|
||||
// Indicated whether elements are signed
|
||||
bool sign;
|
||||
};
|
||||
static inline bool operator==(const CircuitGateShape &lhs,
|
||||
const CircuitGateShape &rhs) {
|
||||
return lhs.width == rhs.width && lhs.dimensions == rhs.dimensions &&
|
||||
lhs.size == rhs.size;
|
||||
}
|
||||
|
||||
struct ChunkInfo {
|
||||
/// total number of bits used for the chunk including the carry.
|
||||
/// size should be at least width + 1
|
||||
unsigned int size;
|
||||
/// number of bits used for the chunk excluding the carry
|
||||
unsigned int width;
|
||||
};
|
||||
static inline bool operator==(const ChunkInfo &lhs, const ChunkInfo &rhs) {
|
||||
return lhs.width == rhs.width && lhs.size == rhs.size;
|
||||
}
|
||||
|
||||
struct CircuitGate {
|
||||
llvm::Optional<EncryptionGate> encryption;
|
||||
CircuitGateShape shape;
|
||||
llvm::Optional<ChunkInfo> chunkInfo;
|
||||
|
||||
bool isEncrypted() { return encryption.hasValue(); }
|
||||
|
||||
/// byteSize returns the size in bytes for this gate.
|
||||
size_t byteSize(std::vector<LweSecretKeyParam> secretKeys) {
|
||||
auto width = shape.width;
|
||||
auto numElts = shape.size == 0 ? 1 : shape.size;
|
||||
if (isEncrypted()) {
|
||||
assert(encryption->secretKeyID < secretKeys.size());
|
||||
auto skParam = secretKeys[encryption->secretKeyID];
|
||||
return 8 * skParam.lweSize() * numElts;
|
||||
}
|
||||
width = bitWidthAsWord(width) / 8;
|
||||
return width * numElts;
|
||||
}
|
||||
};
|
||||
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
|
||||
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape &&
|
||||
lhs.chunkInfo == rhs.chunkInfo;
|
||||
}
|
||||
|
||||
struct ClientParameters {
|
||||
std::vector<LweSecretKeyParam> secretKeys;
|
||||
std::vector<BootstrapKeyParam> bootstrapKeys;
|
||||
std::vector<KeyswitchKeyParam> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKeyParam> packingKeyswitchKeys;
|
||||
std::vector<CircuitGate> inputs;
|
||||
std::vector<CircuitGate> outputs;
|
||||
std::string functionName;
|
||||
|
||||
size_t hash();
|
||||
|
||||
static outcome::checked<std::vector<ClientParameters>, StringError>
|
||||
load(std::string path);
|
||||
|
||||
static std::string getClientParametersPath(std::string path);
|
||||
|
||||
outcome::checked<CircuitGate, StringError> input(size_t pos) {
|
||||
if (pos >= inputs.size()) {
|
||||
return StringError("input gate ") << pos << " didn't exists";
|
||||
}
|
||||
return inputs[pos];
|
||||
}
|
||||
|
||||
outcome::checked<CircuitGate, StringError> ouput(size_t pos) {
|
||||
if (pos >= outputs.size()) {
|
||||
return StringError("output gate ") << pos << " didn't exists";
|
||||
}
|
||||
return outputs[pos];
|
||||
}
|
||||
|
||||
outcome::checked<LweSecretKeyParam, StringError>
|
||||
lweSecretKeyParam(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
return StringError("gate is not encrypted");
|
||||
}
|
||||
assert(gate.encryption->secretKeyID < secretKeys.size());
|
||||
auto secretKey = secretKeys[gate.encryption->secretKeyID];
|
||||
return secretKey;
|
||||
}
|
||||
|
||||
/// bufferSize returns the size of the whole buffer of a gate.
|
||||
int64_t bufferSize(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
// Value is not encrypted just returns the tensor size
|
||||
return gate.shape.size;
|
||||
}
|
||||
auto shapeSize = gate.shape.size == 0 ? 1 : gate.shape.size;
|
||||
// Size of the ciphertext
|
||||
return shapeSize * lweBufferSize(gate);
|
||||
}
|
||||
|
||||
/// lweBufferSize returns the size of one ciphertext of a gate.
|
||||
int64_t lweBufferSize(CircuitGate gate) {
|
||||
assert(gate.encryption.hasValue());
|
||||
auto nbBlocks = gate.encryption->encoding.crt.size();
|
||||
nbBlocks = nbBlocks == 0 ? 1 : nbBlocks;
|
||||
|
||||
auto param = lweSecretKeyParam(gate);
|
||||
assert(param.has_value());
|
||||
return param.value().lweSize() * nbBlocks;
|
||||
}
|
||||
|
||||
/// bufferShape returns the shape of the tensor for the given gate. It returns
|
||||
/// the shape used at low-level, i.e. contains the dimensions for ciphertexts.
|
||||
std::vector<int64_t> bufferShape(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
// Value is not encrypted just returns the tensor shape
|
||||
return gate.shape.dimensions;
|
||||
}
|
||||
auto lweSecreteKeyParam = lweSecretKeyParam(gate);
|
||||
assert(lweSecreteKeyParam.has_value());
|
||||
|
||||
// Copy the shape
|
||||
std::vector<int64_t> shape(gate.shape.dimensions);
|
||||
|
||||
auto crt = gate.encryption->encoding.crt;
|
||||
|
||||
// CRT case: Add one dimension equals to the number of blocks
|
||||
if (!crt.empty()) {
|
||||
shape.push_back(crt.size());
|
||||
}
|
||||
// Add one dimension for the size of ciphertext(s)
|
||||
shape.push_back(lweSecreteKeyParam.value().lweSize());
|
||||
return shape;
|
||||
}
|
||||
};
|
||||
|
||||
static inline bool operator==(const ClientParameters &lhs,
|
||||
const ClientParameters &rhs) {
|
||||
return lhs.secretKeys == rhs.secretKeys &&
|
||||
lhs.bootstrapKeys == rhs.bootstrapKeys &&
|
||||
lhs.keyswitchKeys == rhs.keyswitchKeys && lhs.inputs == lhs.inputs &&
|
||||
lhs.outputs == lhs.outputs;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const LweSecretKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, LweSecretKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const BootstrapKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const KeyswitchKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, PackingKeyswitchKeyParam &,
|
||||
llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const Encoding &);
|
||||
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const EncryptionGate &);
|
||||
bool fromJSON(const llvm::json::Value, EncryptionGate &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGateShape &);
|
||||
bool fromJSON(const llvm::json::Value, CircuitGateShape &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGate &);
|
||||
bool fromJSON(const llvm::json::Value, CircuitGate &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const ClientParameters &);
|
||||
bool fromJSON(const llvm::json::Value, ClientParameters &, llvm::json::Path);
|
||||
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
ClientParameters cp) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(cp));
|
||||
}
|
||||
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_string_ostream &OS,
|
||||
ClientParameters cp) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(cp));
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,244 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
|
||||
#define CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "../Common/Error.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class PublicArguments;
|
||||
|
||||
/// Temporary object used to hold and encrypt parameters before calling a
|
||||
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
|
||||
/// Otherwise convert it to a PublicArguments and use
|
||||
/// serializeCall(PublicArguments, KeySet).
|
||||
class EncryptedArguments {
|
||||
|
||||
public:
|
||||
EncryptedArguments() : currentPos(0) {}
|
||||
|
||||
/// Encrypts args thanks the given KeySet and pack the encrypted arguments to
|
||||
/// an EncryptedArguments
|
||||
template <typename... Args>
|
||||
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
|
||||
create(KeySet &keySet, Args... args) {
|
||||
auto encryptedArgs = std::make_unique<EncryptedArguments>();
|
||||
OUTCOME_TRYV(encryptedArgs->pushArgs(keySet, args...));
|
||||
return std::move(encryptedArgs);
|
||||
}
|
||||
|
||||
template <typename ArgT>
|
||||
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
|
||||
create(KeySet &keySet, const llvm::ArrayRef<ArgT> args) {
|
||||
auto encryptedArgs = EncryptedArguments::empty();
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
OUTCOME_TRYV(encryptedArgs->pushArg(args[i], keySet));
|
||||
}
|
||||
OUTCOME_TRYV(encryptedArgs->checkAllArgs(keySet));
|
||||
return std::move(encryptedArgs);
|
||||
}
|
||||
|
||||
static std::unique_ptr<EncryptedArguments> empty() {
|
||||
return std::make_unique<EncryptedArguments>();
|
||||
}
|
||||
|
||||
/// Export encrypted arguments as public arguments, reset the encrypted
|
||||
/// arguments, i.e. move all buffers to the PublicArguments and reset the
|
||||
/// positional counter.
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
exportPublicArguments(ClientParameters clientParameters);
|
||||
|
||||
/// Check that all arguments as been pushed.
|
||||
// TODO: Remove public method here
|
||||
outcome::checked<void, StringError> checkAllArgs(KeySet &keySet);
|
||||
|
||||
public:
|
||||
/// Add a uint64_t scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint64_t arg, KeySet &keySet);
|
||||
|
||||
/// Add a vector-tensor argument.
|
||||
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
|
||||
KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(),
|
||||
llvm::ArrayRef<int64_t>{(int64_t)arg.size()}, keySet);
|
||||
}
|
||||
|
||||
/// Add a 1D tensor argument with data and size of the dimension.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(const T *data, int64_t dim1,
|
||||
KeySet &keySet) {
|
||||
return pushArg(std::vector<uint8_t>(data, data + dim1), keySet);
|
||||
}
|
||||
|
||||
/// Add a 1D tensor argument.
|
||||
template <size_t size>
|
||||
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
|
||||
KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size},
|
||||
keySet);
|
||||
}
|
||||
|
||||
/// Add a 2D tensor argument.
|
||||
template <size_t size0, size_t size1>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<uint8_t, size1>, size0> arg, KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size0, size1},
|
||||
keySet);
|
||||
}
|
||||
|
||||
/// Add a 3D tensor argument.
|
||||
template <size_t size0, size_t size1, size_t size2>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
|
||||
KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(),
|
||||
llvm::ArrayRef<int64_t>{size0, size1, size2}, keySet);
|
||||
}
|
||||
|
||||
// Generalize by computing shape by template recursion
|
||||
|
||||
/// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(T *data, int64_t dim1,
|
||||
KeySet &keySet) {
|
||||
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1), keySet);
|
||||
}
|
||||
|
||||
/// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
return pushArg(static_cast<const T *>(data), shape, keySet);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos;
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
// Check the width of data
|
||||
if (input.shape.width > 64) {
|
||||
return StringError("argument #")
|
||||
<< pos << " width > 64 bits is not supported";
|
||||
}
|
||||
// Check the shape of tensor
|
||||
if (input.shape.dimensions.empty()) {
|
||||
return StringError("argument #") << pos << "is not a tensor";
|
||||
}
|
||||
if (shape.size() != input.shape.dimensions.size()) {
|
||||
return StringError("argument #")
|
||||
<< pos << "has not the expected number of dimension, got "
|
||||
<< shape.size() << " expected " << input.shape.dimensions.size();
|
||||
}
|
||||
|
||||
// Check shape
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i] != input.shape.dimensions[i]) {
|
||||
return StringError("argument #")
|
||||
<< pos << " has not the expected dimension #" << i << " , got "
|
||||
<< shape[i] << " expected " << input.shape.dimensions[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Set sizes
|
||||
std::vector<int64_t> sizes = keySet.clientParameters().bufferShape(input);
|
||||
|
||||
if (input.encryption.hasValue()) {
|
||||
TensorData td(sizes, EncryptedScalarElementType,
|
||||
EncryptedScalarElementWidth);
|
||||
|
||||
auto lweSize = keySet.clientParameters().lweBufferSize(input);
|
||||
|
||||
for (size_t i = 0, offset = 0; i < input.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(
|
||||
pos, td.getElementPointer<uint64_t>(offset), data[i]));
|
||||
}
|
||||
ciphertextBuffers.push_back(std::move(td));
|
||||
} else {
|
||||
auto bitsPerValue = bitWidthAsWord(input.shape.width);
|
||||
|
||||
TensorData td(sizes, bitsPerValue, input.shape.sign);
|
||||
llvm::ArrayRef<T> values(data, TensorData::getNumElements(sizes));
|
||||
td.bulkAssign(values);
|
||||
ciphertextBuffers.push_back(std::move(td));
|
||||
}
|
||||
TensorData &td = ciphertextBuffers.back().getTensor();
|
||||
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back(td.getValuesAsOpaquePointer());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t size : td.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = td.getNumElements();
|
||||
for (size_t size : td.getDimensions()) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
currentPos++;
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
/// Recursive case for scalars: extract first scalar argument from
|
||||
/// parameter pack and forward rest
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError> pushArgs(KeySet &keySet, Arg0 arg0,
|
||||
OtherArgs... others) {
|
||||
OUTCOME_TRYV(pushArg(arg0, keySet));
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
/// Recursive case for tensors: extract pointer and size from
|
||||
/// parameter pack and forward rest
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError>
|
||||
pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) {
|
||||
OUTCOME_TRYV(pushArg(arg0, size, keySet));
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
/// Terminal case of pushArgs
|
||||
outcome::checked<void, StringError> pushArgs(KeySet &keySet) {
|
||||
return checkAllArgs(keySet);
|
||||
}
|
||||
|
||||
private:
|
||||
outcome::checked<void, StringError> checkPushTooManyArgs(KeySet &keySet);
|
||||
|
||||
private:
|
||||
/// Position of the next pushed argument
|
||||
size_t currentPos;
|
||||
std::vector<void *> preparedArgs;
|
||||
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<ScalarOrTensorData> ciphertextBuffers;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,194 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
|
||||
#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
struct Csprng;
|
||||
struct CsprngVtable;
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
class CSPRNG {
|
||||
public:
|
||||
struct Csprng *ptr;
|
||||
const struct CsprngVtable *vtable;
|
||||
|
||||
CSPRNG() = delete;
|
||||
CSPRNG(CSPRNG &) = delete;
|
||||
|
||||
CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) {
|
||||
assert(ptr != nullptr);
|
||||
other.ptr = nullptr;
|
||||
};
|
||||
|
||||
CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){};
|
||||
};
|
||||
|
||||
class ConcreteCSPRNG : public CSPRNG {
|
||||
public:
|
||||
ConcreteCSPRNG(__uint128_t seed);
|
||||
ConcreteCSPRNG() = delete;
|
||||
ConcreteCSPRNG(ConcreteCSPRNG &) = delete;
|
||||
ConcreteCSPRNG(ConcreteCSPRNG &&other);
|
||||
~ConcreteCSPRNG();
|
||||
};
|
||||
|
||||
/// @brief LweSecretKey implements tools for manipulating lwe secret key on
|
||||
/// client.
|
||||
class LweSecretKey {
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
LweSecretKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweSecretKey() = delete;
|
||||
LweSecretKey(LweSecretKeyParam ¶meters, CSPRNG &csprng);
|
||||
LweSecretKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
LweSecretKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
/// @brief Encrypt the plaintext to the lwe ciphertext buffer.
|
||||
void encrypt(uint64_t *ciphertext, uint64_t plaintext, double variance,
|
||||
CSPRNG &csprng) const;
|
||||
|
||||
/// @brief Decrypt the ciphertext to the plaintext
|
||||
void decrypt(const uint64_t *ciphertext, uint64_t &plaintext) const;
|
||||
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
LweSecretKeyParam parameters() const { return this->_parameters; }
|
||||
|
||||
/// @brief Returns the lwe dimension of the secret key.
|
||||
size_t dimension() const { return parameters().dimension; }
|
||||
};
|
||||
|
||||
/// @brief LweKeyswitchKey implements tools for manipulating keyswitch key on
|
||||
/// client.
|
||||
class LweKeyswitchKey {
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
KeyswitchKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweKeyswitchKey() = delete;
|
||||
LweKeyswitchKey(KeyswitchKeyParam ¶meters, LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng);
|
||||
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
KeyswitchKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
KeyswitchKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
/// @brief LweBootstrapKey implements tools for manipulating bootstrap key on
|
||||
/// client.
|
||||
class LweBootstrapKey {
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
BootstrapKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweBootstrapKey() = delete;
|
||||
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
BootstrapKeyParam ¶meters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
LweBootstrapKey(BootstrapKeyParam ¶meters, LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng);
|
||||
|
||||
///// @brief Returns the buffer that hold the bootstrap key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the bootsrap key.
|
||||
BootstrapKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
/// @brief PackingKeyswitchKey implements tools for manipulating privat packing
|
||||
/// keyswitch key on client.
|
||||
class PackingKeyswitchKey {
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
PackingKeyswitchKeyParam _parameters;
|
||||
|
||||
public:
|
||||
PackingKeyswitchKey() = delete;
|
||||
PackingKeyswitchKey(PackingKeyswitchKeyParam ¶meters,
|
||||
LweSecretKey &inputKey, LweSecretKey &outputKey,
|
||||
CSPRNG &csprng);
|
||||
PackingKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
PackingKeyswitchKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
PackingKeyswitchKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
/// Evalution keys required for execution.
|
||||
class EvaluationKeys {
|
||||
private:
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
public:
|
||||
EvaluationKeys() = delete;
|
||||
|
||||
EvaluationKeys(const std::vector<LweKeyswitchKey> keyswitchKeys,
|
||||
const std::vector<LweBootstrapKey> bootstrapKeys,
|
||||
const std::vector<PackingKeyswitchKey> packingKeyswitchKeys)
|
||||
: keyswitchKeys(keyswitchKeys), bootstrapKeys(bootstrapKeys),
|
||||
packingKeyswitchKeys(packingKeyswitchKeys) {}
|
||||
|
||||
const LweKeyswitchKey &getKeyswitchKey(size_t id) const {
|
||||
return this->keyswitchKeys[id];
|
||||
}
|
||||
const std::vector<LweKeyswitchKey> getKeyswitchKeys() const {
|
||||
return this->keyswitchKeys;
|
||||
}
|
||||
|
||||
const LweBootstrapKey &getBootstrapKey(size_t id) const {
|
||||
return bootstrapKeys[id];
|
||||
}
|
||||
const std::vector<LweBootstrapKey> getBootstrapKeys() const {
|
||||
return this->bootstrapKeys;
|
||||
}
|
||||
|
||||
const PackingKeyswitchKey &getPackingKeyswitchKey(size_t id) const {
|
||||
return this->packingKeyswitchKeys[id];
|
||||
};
|
||||
|
||||
const std::vector<PackingKeyswitchKey> getPackingKeyswitchKeys() const {
|
||||
return this->packingKeyswitchKeys;
|
||||
}
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,128 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_KEYSET_H_
|
||||
#define CONCRETELANG_CLIENTLIB_KEYSET_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class KeySet {
|
||||
public:
|
||||
KeySet(ClientParameters clientParameters, CSPRNG &&csprng)
|
||||
: csprng(std::move(csprng)), _clientParameters(clientParameters){};
|
||||
KeySet(KeySet &other) = delete;
|
||||
|
||||
/// Generate a KeySet from a ClientParameters specification.
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters clientParameters, CSPRNG &&csprng);
|
||||
|
||||
/// Create a KeySet from a set of given keys
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError> fromKeys(
|
||||
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
|
||||
std::vector<LweBootstrapKey> bootstrapKeys,
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys,
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng);
|
||||
|
||||
/// Returns the ClientParameters associated with the KeySet.
|
||||
ClientParameters clientParameters() { return _clientParameters; }
|
||||
|
||||
// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
bool isInputEncrypted(size_t pos);
|
||||
|
||||
/// allocate a lwe ciphertext buffer for the argument at argPos, set the size
|
||||
/// of the allocated buffer.
|
||||
outcome::checked<void, StringError>
|
||||
allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size);
|
||||
|
||||
/// encrypt the input to the ciphertext for the argument at argPos.
|
||||
outcome::checked<void, StringError>
|
||||
encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
|
||||
|
||||
/// isOuputEncrypted return true if the output at the given pos is encrypted.
|
||||
bool isOutputEncrypted(size_t pos);
|
||||
|
||||
/// decrypt the ciphertext to the output for the argument at argPos.
|
||||
outcome::checked<void, StringError>
|
||||
decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output);
|
||||
|
||||
size_t numInputs() { return inputs.size(); }
|
||||
size_t numOutputs() { return outputs.size(); }
|
||||
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
/// @brief evaluationKeys returns the evaluation keys associate to this client
|
||||
/// keyset. Those evaluations keys can be safely shared publicly
|
||||
EvaluationKeys evaluationKeys();
|
||||
|
||||
const std::vector<LweSecretKey> &getSecretKeys();
|
||||
|
||||
const std::vector<LweBootstrapKey> &getBootstrapKeys();
|
||||
|
||||
const std::vector<LweKeyswitchKey> &getKeyswitchKeys();
|
||||
|
||||
const std::vector<PackingKeyswitchKey> &getPackingKeyswitchKeys();
|
||||
|
||||
protected:
|
||||
outcome::checked<void, StringError>
|
||||
generateSecretKey(LweSecretKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateBootstrapKey(BootstrapKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateKeyswitchKey(KeyswitchKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generatePackingKeyswitchKey(PackingKeyswitchKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError> generateKeysFromParams();
|
||||
|
||||
outcome::checked<void, StringError> setupEncryptionMaterial();
|
||||
|
||||
friend class KeySetCache;
|
||||
|
||||
private:
|
||||
CSPRNG csprng;
|
||||
|
||||
///////////////////////////////////////////////
|
||||
// Keys mappings
|
||||
std::vector<LweSecretKey> secretKeys;
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
outcome::checked<LweSecretKey, StringError> findLweSecretKey(LweSecretKeyID);
|
||||
|
||||
///////////////////////////////////////////////
|
||||
// Convenient positional mapping between positional gate en secret key
|
||||
typedef std::vector<std::pair<CircuitGate, llvm::Optional<LweSecretKey>>>
|
||||
SecretKeyGateMapping;
|
||||
outcome::checked<SecretKeyGateMapping, StringError>
|
||||
mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates);
|
||||
|
||||
SecretKeyGateMapping inputs;
|
||||
SecretKeyGateMapping outputs;
|
||||
|
||||
clientlib::ClientParameters _clientParameters;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,43 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
|
||||
#define CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
class KeySet;
|
||||
|
||||
class KeySetCache {
|
||||
std::string backingDirectoryPath;
|
||||
|
||||
public:
|
||||
KeySetCache(std::string backingDirectoryPath)
|
||||
: backingDirectoryPath(backingDirectoryPath) {}
|
||||
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(std::shared_ptr<KeySetCache> optionalCache, ClientParameters ¶ms,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
private:
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
loadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb,
|
||||
std::string folderPath);
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,186 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
|
||||
#define CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
class ServerLambda;
|
||||
}
|
||||
} // namespace concretelang
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
class JITLambda;
|
||||
}
|
||||
} // namespace mlir
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class EncryptedArguments;
|
||||
|
||||
/// PublicArguments will be sended to the server. It includes encrypted
|
||||
/// arguments and public keys.
|
||||
class PublicArguments {
|
||||
public:
|
||||
PublicArguments(const ClientParameters &clientParameters,
|
||||
std::vector<void *> &&preparedArgs,
|
||||
std::vector<ScalarOrTensorData> &&ciphertextBuffers);
|
||||
~PublicArguments();
|
||||
PublicArguments(PublicArguments &other) = delete;
|
||||
PublicArguments(PublicArguments &&other) = delete;
|
||||
|
||||
static outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
unserialize(ClientParameters &expectedParams, std::istream &istream);
|
||||
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
private:
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
friend class ::mlir::concretelang::JITLambda;
|
||||
|
||||
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
|
||||
|
||||
ClientParameters clientParameters;
|
||||
std::vector<void *> preparedArgs;
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<ScalarOrTensorData> ciphertextBuffers;
|
||||
};
|
||||
|
||||
/// PublicResult is a result of a ServerLambda call which contains encrypted
|
||||
/// results.
|
||||
struct PublicResult {
|
||||
|
||||
PublicResult(const ClientParameters &clientParameters,
|
||||
std::vector<ScalarOrTensorData> &&buffers = {})
|
||||
: clientParameters(clientParameters), buffers(std::move(buffers)){};
|
||||
|
||||
PublicResult(PublicResult &) = delete;
|
||||
|
||||
/// Create a public result from buffers.
|
||||
static std::unique_ptr<PublicResult>
|
||||
fromBuffers(const ClientParameters &clientParameters,
|
||||
std::vector<ScalarOrTensorData> &&buffers) {
|
||||
return std::make_unique<PublicResult>(clientParameters, std::move(buffers));
|
||||
}
|
||||
|
||||
/// Unserialize from an input stream inplace.
|
||||
outcome::checked<void, StringError> unserialize(std::istream &istream);
|
||||
/// Unserialize from an input stream returning a new PublicResult.
|
||||
static outcome::checked<std::unique_ptr<PublicResult>, StringError>
|
||||
unserialize(ClientParameters &expectedParams, std::istream &istream) {
|
||||
auto publicResult = std::make_unique<PublicResult>(expectedParams);
|
||||
OUTCOME_TRYV(publicResult->unserialize(istream));
|
||||
return std::move(publicResult);
|
||||
}
|
||||
/// Serialize into an output stream.
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
/// Get the original integer that was decomposed into chunks of `chunkWidth`
|
||||
/// bits each
|
||||
uint64_t fromChunks(std::vector<uint64_t> chunks, unsigned int chunkWidth) {
|
||||
uint64_t value = 0;
|
||||
uint64_t mask = (1 << chunkWidth) - 1;
|
||||
for (size_t i = 0; i < chunks.size(); i++) {
|
||||
auto chunk = chunks[i] & mask;
|
||||
value += chunk << (chunkWidth * i);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/// Get the result at `pos` as a scalar. Decryption happens if the
|
||||
/// result is encrypted.
|
||||
template <typename T>
|
||||
outcome::checked<T, StringError> asClearTextScalar(KeySet &keySet,
|
||||
size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
if (!gate.isEncrypted())
|
||||
return buffers[pos].getScalar().getValue<T>();
|
||||
|
||||
// Chunked integers are represented as tensors at a lower level, so we need
|
||||
// to deal with them as tensors, then build the resulting scalar out of the
|
||||
// tensor values
|
||||
if (gate.chunkInfo.hasValue()) {
|
||||
OUTCOME_TRY(std::vector<uint64_t> decryptedChunks,
|
||||
this->asClearTextVector<uint64_t>(keySet, pos));
|
||||
uint64_t decrypted = fromChunks(decryptedChunks, gate.chunkInfo->width);
|
||||
return (T)decrypted;
|
||||
}
|
||||
|
||||
auto &buffer = buffers[pos].getTensor();
|
||||
|
||||
auto ciphertext = buffer.getOpaqueElementPointer(0);
|
||||
uint64_t decrypted;
|
||||
|
||||
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
|
||||
// FIXME: this may break alignment restrictions on some
|
||||
// architectures
|
||||
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted));
|
||||
|
||||
return (T)decrypted;
|
||||
}
|
||||
|
||||
/// Get the result at `pos` as a vector. Decryption happens if the
|
||||
/// result is encrypted.
|
||||
template <typename T>
|
||||
outcome::checked<std::vector<T>, StringError>
|
||||
asClearTextVector(KeySet &keySet, size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
if (!gate.isEncrypted())
|
||||
return buffers[pos].getTensor().asFlatVector<T>();
|
||||
|
||||
auto &buffer = buffers[pos].getTensor();
|
||||
auto lweSize = clientParameters.lweBufferSize(gate);
|
||||
|
||||
std::vector<T> decryptedValues(buffer.length() / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize);
|
||||
uint64_t decrypted;
|
||||
|
||||
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
|
||||
// FIXME: this may break alignment restrictions on some
|
||||
// architectures
|
||||
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted));
|
||||
decryptedValues[i] = decrypted;
|
||||
}
|
||||
return decryptedValues;
|
||||
}
|
||||
|
||||
/// Return the shape of the clear tensor of a result.
|
||||
outcome::checked<std::vector<int64_t>, StringError>
|
||||
asClearTextShape(size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
return gate.shape.dimensions;
|
||||
}
|
||||
|
||||
// private: TODO tmp
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
ClientParameters clientParameters;
|
||||
std::vector<ScalarOrTensorData> buffers;
|
||||
};
|
||||
|
||||
/// Helper function to convert from MemRefDescriptor to
|
||||
/// TensorData
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
|
||||
bool is_signed, void *allocated, void *aligned,
|
||||
size_t offset, size_t *sizes, size_t *strides);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,123 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
|
||||
#define CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
// integers are not serialized as binary values even on a binary stream
|
||||
// so we cannot rely on << operator directly
|
||||
template <typename Word>
|
||||
std::ostream &writeWord(std::ostream &ostream, Word word) {
|
||||
ostream.write(reinterpret_cast<char *>(&(word)), sizeof(word));
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
template <typename Size>
|
||||
std::ostream &writeSize(std::ostream &ostream, Size size) {
|
||||
return writeWord(ostream, size);
|
||||
}
|
||||
|
||||
// for sake of symetry
|
||||
template <typename Word>
|
||||
std::istream &readWord(std::istream &istream, Word &word) {
|
||||
istream.read(reinterpret_cast<char *>(&(word)), sizeof(word));
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
template <typename Word>
|
||||
std::istream &readWords(std::istream &istream, Word *words, size_t numWords) {
|
||||
assert(std::numeric_limits<size_t>::max() / sizeof(*words) > numWords);
|
||||
istream.read(reinterpret_cast<char *>(words), sizeof(*words) * numWords);
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
template <typename Size>
|
||||
std::istream &readSize(std::istream &istream, Size &size) {
|
||||
return readWord(istream, size);
|
||||
}
|
||||
|
||||
template <typename Stream> bool incorrectMode(Stream &stream) {
|
||||
auto binary = stream.flags() && std::ios::binary;
|
||||
if (!binary) {
|
||||
stream.setstate(std::ios::failbit);
|
||||
}
|
||||
return !binary;
|
||||
}
|
||||
|
||||
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream);
|
||||
|
||||
outcome::checked<ScalarData, StringError>
|
||||
unserializeScalarData(std::istream &istream);
|
||||
|
||||
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
|
||||
std::ostream &ostream);
|
||||
|
||||
template <typename T>
|
||||
std::ostream &serializeTensorDataRaw(const llvm::ArrayRef<size_t> &dimensions,
|
||||
const llvm::ArrayRef<T> &values,
|
||||
std::ostream &ostream) {
|
||||
|
||||
writeWord<uint64_t>(ostream, dimensions.size());
|
||||
|
||||
for (size_t dim : dimensions)
|
||||
writeWord<int64_t>(ostream, dim);
|
||||
|
||||
writeWord<uint64_t>(ostream, sizeof(T) * 8);
|
||||
writeWord<uint8_t>(ostream, std::is_signed<T>());
|
||||
|
||||
for (T val : values)
|
||||
writeWord(ostream, val);
|
||||
|
||||
return ostream;
|
||||
}
|
||||
|
||||
outcome::checked<TensorData, StringError> unserializeTensorData(
|
||||
std::vector<int64_t> &expectedSizes, // includes unsigned to
|
||||
// accomodate non static sizes
|
||||
std::istream &istream);
|
||||
|
||||
std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
|
||||
std::ostream &ostream);
|
||||
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
|
||||
std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
|
||||
LweSecretKey readLweSecretKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey &wrappedKsk);
|
||||
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey &wrappedBsk);
|
||||
LweBootstrapKey readLweBootstrapKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const PackingKeyswitchKey &wrappedKsk);
|
||||
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys);
|
||||
EvaluationKeys readEvaluationKeys(std::istream &istream);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,881 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
#define CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <stddef.h>
|
||||
#include <vector>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
template <size_t N> struct MemRefDescriptor {
|
||||
uint64_t *allocated;
|
||||
uint64_t *aligned;
|
||||
size_t offset;
|
||||
size_t sizes[N];
|
||||
size_t strides[N];
|
||||
};
|
||||
|
||||
using decrypted_scalar_t = std::uint64_t;
|
||||
using decrypted_tensor_1_t = std::vector<decrypted_scalar_t>;
|
||||
using decrypted_tensor_2_t = std::vector<decrypted_tensor_1_t>;
|
||||
using decrypted_tensor_3_t = std::vector<decrypted_tensor_2_t>;
|
||||
|
||||
template <size_t Rank> using encrypted_tensor_t = MemRefDescriptor<Rank>;
|
||||
using encrypted_scalar_t = uint64_t *;
|
||||
using encrypted_scalars_t = uint64_t *;
|
||||
|
||||
// Element types for `TensorData`
|
||||
enum class ElementType { u64, i64, u32, i32, u16, i16, u8, i8 };
|
||||
|
||||
// Returns the width in bits of an integer whose width is a power of
|
||||
// two that can hold values with at most `width` bits
|
||||
static inline constexpr size_t getStorageWidth(size_t width) {
|
||||
if (width > 64)
|
||||
assert(false && "Unsupported scalar width");
|
||||
|
||||
if (width > 32) {
|
||||
return 64;
|
||||
} else if (width > 16) {
|
||||
return 32;
|
||||
} else if (width > 8) {
|
||||
return 16;
|
||||
} else {
|
||||
return 8;
|
||||
}
|
||||
}
|
||||
|
||||
// Translates `sign` and `width` into an `ElementType`.
|
||||
static inline ElementType getElementTypeFromWidthAndSign(size_t width,
|
||||
bool sign) {
|
||||
switch (getStorageWidth(width)) {
|
||||
case 64:
|
||||
return (sign) ? ElementType::i64 : ElementType::u64;
|
||||
case 32:
|
||||
return (sign) ? ElementType::i32 : ElementType::u32;
|
||||
case 16:
|
||||
return (sign) ? ElementType::i16 : ElementType::u16;
|
||||
case 8:
|
||||
return (sign) ? ElementType::i8 : ElementType::u8;
|
||||
default:
|
||||
assert(false && "Unsupported scalar width");
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Returns the number of bits for an element type
|
||||
static constexpr size_t getElementTypeWidth(ElementType t) {
|
||||
switch (t) {
|
||||
case ElementType::u64:
|
||||
case ElementType::i64:
|
||||
return 64;
|
||||
case ElementType::u32:
|
||||
case ElementType::i32:
|
||||
return 32;
|
||||
case ElementType::u16:
|
||||
case ElementType::i16:
|
||||
return 16;
|
||||
case ElementType::u8:
|
||||
case ElementType::i8:
|
||||
return 8;
|
||||
}
|
||||
|
||||
// Cannot happen
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Returns `true` if the element type `t` designates a signed type,
|
||||
// otherwise `false`.
|
||||
static constexpr size_t getElementTypeSignedness(ElementType t) {
|
||||
switch (t) {
|
||||
case ElementType::u64:
|
||||
case ElementType::u32:
|
||||
case ElementType::u16:
|
||||
case ElementType::u8:
|
||||
return false;
|
||||
case ElementType::i64:
|
||||
case ElementType::i32:
|
||||
case ElementType::i16:
|
||||
case ElementType::i8:
|
||||
return true;
|
||||
}
|
||||
|
||||
// Cannot happen
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns `true` iff the element type `t` designates the smallest
|
||||
// unsigned / signed (depending on `sign`) integer type that can hold
|
||||
// values of up to `width` bits, otherwise false.
|
||||
static inline bool checkElementTypeForWidthAndSign(ElementType t, size_t width,
|
||||
bool sign) {
|
||||
return getElementTypeFromWidthAndSign(getStorageWidth(width), sign) == t;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Constants for the element types used for tensors representing
|
||||
// encrypted data and data after decryption
|
||||
constexpr ElementType EncryptedScalarElementType = ElementType::u64;
|
||||
constexpr size_t EncryptedScalarElementWidth =
|
||||
getElementTypeWidth(EncryptedScalarElementType);
|
||||
|
||||
using EncryptedScalarElement = uint64_t;
|
||||
|
||||
namespace detail {
|
||||
namespace TensorData {
|
||||
|
||||
// Union used to store the pointer to the actual data of an instance
|
||||
// of `TensorData`. Values are stored contiguously in memory in a
|
||||
// `std::vector` whose element type corresponds to the element type of
|
||||
// the tensor.
|
||||
union value_vector_union {
|
||||
std::vector<uint64_t> *u64;
|
||||
std::vector<int64_t> *i64;
|
||||
std::vector<uint32_t> *u32;
|
||||
std::vector<int32_t> *i32;
|
||||
std::vector<uint16_t> *u16;
|
||||
std::vector<int16_t> *i16;
|
||||
std::vector<uint8_t> *u8;
|
||||
std::vector<int8_t> *i8;
|
||||
};
|
||||
|
||||
// Function templates that would go into the class `TensorData`, but
|
||||
// which need to declared in namespace scope, since specializations of
|
||||
// templates on the return type cannot be done for member functions as
|
||||
// per the C++ standard
|
||||
template <typename T> T begin(union value_vector_union &vec);
|
||||
template <typename T> T end(union value_vector_union &vec);
|
||||
template <typename T> T cbegin(union value_vector_union &vec);
|
||||
template <typename T> T cend(union value_vector_union &vec);
|
||||
template <typename T> T getElements(union value_vector_union &vec);
|
||||
template <typename T> T getConstElements(const union value_vector_union &vec);
|
||||
|
||||
template <typename T>
|
||||
T getElementValue(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
template <typename T>
|
||||
T &getElementReference(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
template <typename T>
|
||||
T *getElementPointer(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
|
||||
// Specializations for the above templates
|
||||
#define TENSORDATA_SPECIALIZE_FOR_ITERATOR(ELTY, SUFFIX) \
|
||||
template <> \
|
||||
inline std::vector<ELTY>::iterator begin(union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->begin(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::iterator end(union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->end(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::const_iterator cbegin( \
|
||||
union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->cbegin(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::const_iterator cend( \
|
||||
union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->cend(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY> &getElements(union value_vector_union &vec) { \
|
||||
return *vec.SUFFIX; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline const std::vector<ELTY> &getConstElements( \
|
||||
const union value_vector_union &vec) { \
|
||||
return *vec.SUFFIX; \
|
||||
}
|
||||
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint64_t, u64)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int64_t, i64)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint32_t, u32)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int32_t, i32)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint16_t, u16)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int16_t, i16)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint8_t, u8)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int8_t, i8)
|
||||
|
||||
#define TENSORDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \
|
||||
template <> \
|
||||
inline ELTY getElementValue(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return (*vec.SUFFIX)[idx]; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline ELTY &getElementReference(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return (*vec.SUFFIX)[idx]; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline ELTY *getElementPointer(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return &(*vec.SUFFIX)[idx]; \
|
||||
}
|
||||
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8)
|
||||
|
||||
} // namespace TensorData
|
||||
} // namespace detail
|
||||
|
||||
// Representation of a tensor with an arbitrary number of dimensions
|
||||
class TensorData {
|
||||
protected:
|
||||
detail::TensorData::value_vector_union values;
|
||||
ElementType elementType;
|
||||
std::vector<size_t> dimensions;
|
||||
size_t elementWidth;
|
||||
|
||||
/* Multi-dimensional, uninitialized, but preallocated tensor */
|
||||
void initPreallocated(llvm::ArrayRef<size_t> dimensions,
|
||||
ElementType elementType, size_t elementWidth,
|
||||
bool sign) {
|
||||
assert(checkElementTypeForWidthAndSign(elementType, elementWidth, sign) &&
|
||||
"Incoherent parameters for element type, width and sign");
|
||||
|
||||
assert(dimensions.size() != 0);
|
||||
|
||||
size_t n = getNumElements(dimensions);
|
||||
|
||||
switch (elementType) {
|
||||
case ElementType::u64:
|
||||
this->values.u64 = new std::vector<uint64_t>(n);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
this->values.i64 = new std::vector<int64_t>(n);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
this->values.u32 = new std::vector<uint32_t>(n);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
this->values.i32 = new std::vector<int32_t>(n);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
this->values.u16 = new std::vector<uint16_t>(n);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
this->values.i16 = new std::vector<int16_t>(n);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
this->values.u8 = new std::vector<uint8_t>(n);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
this->values.i8 = new std::vector<int8_t>(n);
|
||||
break;
|
||||
}
|
||||
|
||||
this->dimensions.resize(dimensions.size());
|
||||
this->elementWidth = elementWidth;
|
||||
this->elementType = elementType;
|
||||
std::copy(dimensions.begin(), dimensions.end(), this->dimensions.begin());
|
||||
}
|
||||
|
||||
// Creates a vector<size_t> from an ArrayRef<T>
|
||||
template <typename T>
|
||||
static std::vector<size_t> toDimSpec(llvm::ArrayRef<T> dims) {
|
||||
return std::vector<size_t>(dims.begin(), dims.end());
|
||||
}
|
||||
|
||||
public:
|
||||
// Returns the total number of elements of a tensor with the
|
||||
// specified dimensions
|
||||
template <typename T> static size_t getNumElements(T dimensions) {
|
||||
size_t n = 1;
|
||||
for (auto dim : dimensions)
|
||||
n *= dim;
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
// Move constructor. Leaves `that` uninitialized.
|
||||
TensorData(TensorData &&that)
|
||||
: elementType(that.elementType), dimensions(std::move(that.dimensions)),
|
||||
elementWidth(that.elementWidth) {
|
||||
switch (that.elementType) {
|
||||
case ElementType::u64:
|
||||
this->values.u64 = that.values.u64;
|
||||
that.values.u64 = nullptr;
|
||||
break;
|
||||
case ElementType::i64:
|
||||
this->values.i64 = that.values.i64;
|
||||
that.values.i64 = nullptr;
|
||||
break;
|
||||
case ElementType::u32:
|
||||
this->values.u32 = that.values.u32;
|
||||
that.values.u32 = nullptr;
|
||||
break;
|
||||
case ElementType::i32:
|
||||
this->values.i32 = that.values.i32;
|
||||
that.values.i32 = nullptr;
|
||||
break;
|
||||
case ElementType::u16:
|
||||
this->values.u16 = that.values.u16;
|
||||
that.values.u16 = nullptr;
|
||||
break;
|
||||
case ElementType::i16:
|
||||
this->values.i16 = that.values.i16;
|
||||
that.values.i16 = nullptr;
|
||||
break;
|
||||
case ElementType::u8:
|
||||
this->values.u8 = that.values.u8;
|
||||
that.values.u8 = nullptr;
|
||||
break;
|
||||
case ElementType::i8:
|
||||
this->values.i8 = that.values.i8;
|
||||
that.values.i8 = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Constructor to build a multi-dimensional tensor with the
|
||||
// corresponding element type. All elements are initialized with the
|
||||
// default value of `0`.
|
||||
TensorData(llvm::ArrayRef<size_t> dimensions, ElementType elementType,
|
||||
size_t elementWidth) {
|
||||
initPreallocated(dimensions, elementType, elementWidth,
|
||||
getElementTypeSignedness(elementType));
|
||||
}
|
||||
|
||||
TensorData(llvm::ArrayRef<int64_t> dimensions, ElementType elementType,
|
||||
size_t elementWidth)
|
||||
: TensorData(toDimSpec(dimensions), elementType, elementWidth) {}
|
||||
|
||||
// Constructor to build a multi-dimensional tensor with the element
|
||||
// type corresponding to `elementWidth` and `sign`. All elements are
|
||||
// initialized with the default value of `0`.
|
||||
TensorData(llvm::ArrayRef<size_t> dimensions, size_t elementWidth, bool sign)
|
||||
: TensorData(dimensions,
|
||||
getElementTypeFromWidthAndSign(elementWidth, sign),
|
||||
elementWidth) {}
|
||||
|
||||
TensorData(llvm::ArrayRef<int64_t> dimensions, size_t elementWidth, bool sign)
|
||||
: TensorData(toDimSpec(dimensions), elementWidth, sign) {}
|
||||
|
||||
#define DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(ELTY, SUFFIX) \
|
||||
/* Multi-dimensional, initialized tensor, values copied from */ \
|
||||
/* `values` */ \
|
||||
TensorData(llvm::ArrayRef<ELTY> values, llvm::ArrayRef<size_t> dimensions, \
|
||||
size_t elementWidth) \
|
||||
: dimensions(dimensions.begin(), dimensions.end()) { \
|
||||
assert(checkElementTypeForWidthAndSign(ElementType::SUFFIX, elementWidth, \
|
||||
std::is_signed<ELTY>()) && \
|
||||
"wrong element type for width"); \
|
||||
assert(dimensions.size() != 0); \
|
||||
size_t n = getNumElements(dimensions); \
|
||||
this->values.SUFFIX = new std::vector<ELTY>(n); \
|
||||
this->elementType = ElementType::SUFFIX; \
|
||||
this->bulkAssign(values); \
|
||||
} \
|
||||
\
|
||||
/* One-dimensional, initialized tensor. Values are copied from */ \
|
||||
/* `values` */ \
|
||||
TensorData(llvm::ArrayRef<ELTY> values, size_t width) \
|
||||
: TensorData(values, llvm::SmallVector<size_t, 1>{values.size()}, \
|
||||
width) {}
|
||||
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint64_t, u64)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int64_t, i64)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint32_t, u32)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int32_t, i32)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint16_t, u16)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int16_t, i16)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint8_t, u8)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int8_t, i8)
|
||||
|
||||
~TensorData() {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
delete values.u64;
|
||||
break;
|
||||
case ElementType::i64:
|
||||
delete values.i64;
|
||||
break;
|
||||
case ElementType::u32:
|
||||
delete values.u32;
|
||||
break;
|
||||
case ElementType::i32:
|
||||
delete values.i32;
|
||||
break;
|
||||
case ElementType::u16:
|
||||
delete values.u16;
|
||||
break;
|
||||
case ElementType::i16:
|
||||
delete values.i16;
|
||||
break;
|
||||
case ElementType::u8:
|
||||
delete values.u8;
|
||||
break;
|
||||
case ElementType::i8:
|
||||
delete values.i8;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the total number of elements of the tensor
|
||||
size_t length() const { return getNumElements(this->dimensions); }
|
||||
|
||||
// Returns a vector with the size for each dimension of the tensor
|
||||
const std::vector<size_t> &getDimensions() const { return this->dimensions; }
|
||||
|
||||
template <typename T> const std::vector<T> getDimensionsAs() const {
|
||||
return std::vector<T>(this->dimensions.begin(), this->dimensions.end());
|
||||
}
|
||||
|
||||
// Returns the number of dimensions
|
||||
size_t getRank() const { return this->dimensions.size(); }
|
||||
|
||||
// Multi-dimensional access to a tensor element
|
||||
template <typename T> T &operator[](llvm::ArrayRef<int64_t> index) {
|
||||
// Number of dimensions must match
|
||||
assert(index.size() == dimensions.size());
|
||||
|
||||
int64_t offset = 0;
|
||||
int64_t multiplier = 1;
|
||||
for (int64_t i = index.size() - 1; i > 0; i--) {
|
||||
offset += index[i] * multiplier;
|
||||
multiplier *= this->dimensions[i];
|
||||
}
|
||||
|
||||
return detail::TensorData::getElementReference<T>(values, offset,
|
||||
elementType);
|
||||
}
|
||||
|
||||
// Iterator pointing to the first element of a flat representation
|
||||
// of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator begin() {
|
||||
return detail::TensorData::begin<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Iterator pointing past the last element of a flat representation
|
||||
// of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator end() {
|
||||
return detail::TensorData::end<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Const iterator pointing to the first element of a flat
|
||||
// representation of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator cbegin() {
|
||||
return detail::TensorData::cbegin<typename std::vector<T>::iterator>(
|
||||
values);
|
||||
}
|
||||
|
||||
// Const iterator pointing past the last element of a flat
|
||||
// representation of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator cend() {
|
||||
return detail::TensorData::cend<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Flat representation of the const tensor
|
||||
template <typename T> const std::vector<T> &getElements() const {
|
||||
return detail::TensorData::getConstElements<const std::vector<T> &>(values);
|
||||
}
|
||||
|
||||
// Flat representation of the tensor
|
||||
template <typename T> const std::vector<T> &getElements() {
|
||||
return detail::TensorData::getElements<std::vector<T> &>(values);
|
||||
}
|
||||
|
||||
// Returns the `index`-th value of a flat representation of the tensor
|
||||
template <typename T> T getElementValue(size_t index) {
|
||||
return detail::TensorData::getElementValue<T>(values, index, elementType);
|
||||
}
|
||||
|
||||
// Returns a reference to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
template <typename T> T &getElementReference(size_t index) {
|
||||
return detail::TensorData::getElementReference<T>(values, index,
|
||||
elementType);
|
||||
}
|
||||
|
||||
// Returns a pointer to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
template <typename T> T *getElementPointer(size_t index) {
|
||||
return detail::TensorData::getElementPointer<T>(values, index, elementType);
|
||||
}
|
||||
|
||||
// Returns a pointer to the `index`-th value of a flat
|
||||
// representation of the tensor (const version)
|
||||
template <typename T> const T *getElementPointer(size_t index) const {
|
||||
return detail::TensorData::getElementPointer<T>(values, index, elementType);
|
||||
}
|
||||
|
||||
// Returns a void pointer to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
void *getOpaqueElementPointer(size_t index) {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint64_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i64:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int64_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u32:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint32_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i32:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int32_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u16:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint16_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i16:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int16_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u8:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint8_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i8:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int8_t>(values, index,
|
||||
elementType));
|
||||
}
|
||||
|
||||
assert(false && "Unknown element type");
|
||||
}
|
||||
|
||||
// Returns the element type of the tensor
|
||||
ElementType getElementType() const { return this->elementType; }
|
||||
|
||||
// Returns the actual width in bits of a data element (i.e., the
|
||||
// width specified upon construction and not the storage width of an
|
||||
// element)
|
||||
size_t getElementWidth() const { return this->elementWidth; }
|
||||
|
||||
// Returns the size of a tensor element in bytes (i.e., the storage width in
|
||||
// bytes)
|
||||
size_t getElementSize() const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
case ElementType::i64:
|
||||
return 8;
|
||||
case ElementType::u32:
|
||||
case ElementType::i32:
|
||||
return 4;
|
||||
case ElementType::u16:
|
||||
case ElementType::i16:
|
||||
return 2;
|
||||
case ElementType::u8:
|
||||
case ElementType::i8:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns `true` if elements are signed, otherwise `false`
|
||||
bool getElementSignedness() const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
case ElementType::u32:
|
||||
case ElementType::u16:
|
||||
case ElementType::u8:
|
||||
return false;
|
||||
case ElementType::i64:
|
||||
case ElementType::i32:
|
||||
case ElementType::i16:
|
||||
case ElementType::i8:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the total number of elements of the tensor
|
||||
size_t getNumElements() const { return getNumElements(this->dimensions); }
|
||||
|
||||
// Copy all elements from `values` to the tensor. Note that this
|
||||
// does not append values to the tensor, but overwrites existing
|
||||
// values.
|
||||
template <typename T> void bulkAssign(llvm::ArrayRef<T> values) {
|
||||
assert(values.size() <= this->getNumElements());
|
||||
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
std::copy(values.begin(), values.end(), this->values.u64->begin());
|
||||
break;
|
||||
case ElementType::i64:
|
||||
std::copy(values.begin(), values.end(), this->values.i64->begin());
|
||||
break;
|
||||
case ElementType::u32:
|
||||
std::copy(values.begin(), values.end(), this->values.u32->begin());
|
||||
break;
|
||||
case ElementType::i32:
|
||||
std::copy(values.begin(), values.end(), this->values.i32->begin());
|
||||
break;
|
||||
case ElementType::u16:
|
||||
std::copy(values.begin(), values.end(), this->values.u16->begin());
|
||||
break;
|
||||
case ElementType::i16:
|
||||
std::copy(values.begin(), values.end(), this->values.i16->begin());
|
||||
break;
|
||||
case ElementType::u8:
|
||||
std::copy(values.begin(), values.end(), this->values.u8->begin());
|
||||
break;
|
||||
case ElementType::i8:
|
||||
std::copy(values.begin(), values.end(), this->values.i8->begin());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Copies all elements of a flat representation of the tensor to the
|
||||
// positions starting with the iterator `start`.
|
||||
template <typename IT> void copy(IT start) const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
std::copy(this->values.u64->cbegin(), this->values.u64->cend(), start);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
std::copy(this->values.i64->cbegin(), this->values.i64->cend(), start);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
std::copy(this->values.u32->cbegin(), this->values.u32->cend(), start);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
std::copy(this->values.i32->cbegin(), this->values.i32->cend(), start);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
std::copy(this->values.u16->cbegin(), this->values.u16->cend(), start);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
std::copy(this->values.i16->cbegin(), this->values.i16->cend(), start);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
std::copy(this->values.u8->cbegin(), this->values.u8->cend(), start);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
std::copy(this->values.i8->cbegin(), this->values.i8->cend(), start);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a flat representation of the tensor with elements
|
||||
// converted to the type `T`
|
||||
template <typename T> std::vector<T> asFlatVector() const {
|
||||
std::vector<T> ret(getNumElements());
|
||||
this->copy(ret.begin());
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Returns a void pointer to the first element of a flat
|
||||
// representation of the tensor
|
||||
void *getValuesAsOpaquePointer() {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
return static_cast<void *>(values.u64->data());
|
||||
case ElementType::i64:
|
||||
return static_cast<void *>(values.i64->data());
|
||||
case ElementType::u32:
|
||||
return static_cast<void *>(values.u32->data());
|
||||
case ElementType::i32:
|
||||
return static_cast<void *>(values.i32->data());
|
||||
case ElementType::u16:
|
||||
return static_cast<void *>(values.u16->data());
|
||||
case ElementType::i16:
|
||||
return static_cast<void *>(values.i16->data());
|
||||
case ElementType::u8:
|
||||
return static_cast<void *>(values.u8->data());
|
||||
case ElementType::i8:
|
||||
return static_cast<void *>(values.i8->data());
|
||||
}
|
||||
|
||||
assert(false && "Unhandled element type");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
namespace ScalarData {
|
||||
// Union representing a single scalar value
|
||||
union scalar_union {
|
||||
uint64_t u64;
|
||||
int64_t i64;
|
||||
uint32_t u32;
|
||||
int32_t i32;
|
||||
uint16_t u16;
|
||||
int16_t i16;
|
||||
uint8_t u8;
|
||||
int8_t i8;
|
||||
};
|
||||
|
||||
// Template + specializations that should be in ScalarData, but which need to be
|
||||
// in namespace scope
|
||||
template <typename T> T getValue(const union scalar_union &u, ElementType type);
|
||||
|
||||
#define SCALARDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \
|
||||
template <> \
|
||||
inline ELTY getValue(const union scalar_union &u, ElementType type) { \
|
||||
assert(type == ElementType::SUFFIX); \
|
||||
return u.SUFFIX; \
|
||||
}
|
||||
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8)
|
||||
|
||||
} // namespace ScalarData
|
||||
} // namespace detail
|
||||
|
||||
// Class representing a single scalar value
|
||||
class ScalarData {
|
||||
public:
|
||||
ScalarData(const ScalarData &s)
|
||||
: type(s.type), value(s.value), width(s.width) {}
|
||||
|
||||
// Construction with a specific type and an actual width, but with a value
|
||||
// provided in a generic `uint64_t`
|
||||
ScalarData(uint64_t value, ElementType type, size_t width)
|
||||
: type(type), width(width) {
|
||||
assert(width <= getElementTypeWidth(type));
|
||||
|
||||
switch (type) {
|
||||
case ElementType::u64:
|
||||
this->value.u64 = value;
|
||||
break;
|
||||
case ElementType::i64:
|
||||
this->value.i64 = value;
|
||||
break;
|
||||
case ElementType::u32:
|
||||
this->value.u32 = value;
|
||||
break;
|
||||
case ElementType::i32:
|
||||
this->value.i32 = value;
|
||||
break;
|
||||
case ElementType::u16:
|
||||
this->value.u16 = value;
|
||||
break;
|
||||
case ElementType::i16:
|
||||
this->value.i16 = value;
|
||||
break;
|
||||
case ElementType::u8:
|
||||
this->value.u8 = value;
|
||||
break;
|
||||
case ElementType::i8:
|
||||
this->value.i8 = value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Construction with a specific type determined by `sign` and
|
||||
// `width`, but value provided in a generic `uint64_t`
|
||||
ScalarData(uint64_t value, bool sign, size_t width)
|
||||
: ScalarData(value, getElementTypeFromWidthAndSign(width, sign), width) {}
|
||||
|
||||
#define DEF_SCALAR_DATA_CONSTRUCTOR(ELTY, SUFFIX) \
|
||||
ScalarData(ELTY value) \
|
||||
: type(ElementType::SUFFIX), \
|
||||
width(getElementTypeWidth(ElementType::SUFFIX)) { \
|
||||
this->value.SUFFIX = value; \
|
||||
}
|
||||
|
||||
// Construction from specific value type
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(uint64_t, u64)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(int64_t, i64)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(uint32_t, u32)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(int32_t, i32)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(uint16_t, u16)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(int16_t, i16)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(uint8_t, u8)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(int8_t, i8)
|
||||
|
||||
template <typename T> T getValue() const {
|
||||
return detail::ScalarData::getValue<T>(value, type);
|
||||
}
|
||||
|
||||
// Retrieves the value as a generic `uint64_t`
|
||||
uint64_t getValueAsU64() const {
|
||||
size_t width = getElementTypeWidth(type);
|
||||
uint64_t mask = ((uint64_t)1 << width) - 1;
|
||||
uint64_t val = value.u64 & mask;
|
||||
return val;
|
||||
}
|
||||
|
||||
ElementType getType() const { return type; }
|
||||
size_t getWidth() const { return width; }
|
||||
|
||||
protected:
|
||||
ElementType type;
|
||||
union detail::ScalarData::scalar_union value;
|
||||
size_t width;
|
||||
};
|
||||
|
||||
// Variant for TensorData and ScalarData
|
||||
class ScalarOrTensorData {
|
||||
protected:
|
||||
std::unique_ptr<ScalarData> scalar;
|
||||
std::unique_ptr<TensorData> tensor;
|
||||
|
||||
public:
|
||||
ScalarOrTensorData(ScalarOrTensorData &&td)
|
||||
: scalar(std::move(td.scalar)), tensor(std::move(td.tensor)) {}
|
||||
|
||||
ScalarOrTensorData(TensorData &&td)
|
||||
: scalar(nullptr), tensor(std::make_unique<TensorData>(std::move(td))) {}
|
||||
|
||||
ScalarOrTensorData(const ScalarData &s)
|
||||
: scalar(std::make_unique<ScalarData>(s)), tensor(nullptr) {}
|
||||
|
||||
bool isTensor() const { return tensor != nullptr; }
|
||||
bool isScalar() const { return scalar != nullptr; }
|
||||
|
||||
ScalarData &getScalar() {
|
||||
assert(scalar != nullptr &&
|
||||
"Attempt to get a scalar value from variant that is a tensor");
|
||||
return *scalar;
|
||||
}
|
||||
|
||||
const ScalarData &getScalar() const {
|
||||
assert(scalar != nullptr &&
|
||||
"Attempt to get a scalar value from variant that is a tensor");
|
||||
return *scalar;
|
||||
}
|
||||
|
||||
TensorData &getTensor() {
|
||||
assert(tensor != nullptr &&
|
||||
"Attempt to get a tensor value from variant that is a scalar");
|
||||
return *tensor;
|
||||
}
|
||||
|
||||
const TensorData &getTensor() const {
|
||||
assert(tensor != nullptr &&
|
||||
"Attempt to get a tensor value from variant that is a scalar");
|
||||
return *tensor;
|
||||
}
|
||||
};
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_BITS_SIZE_H
|
||||
#define CONCRETELANG_COMMON_BITS_SIZE_H
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace concretelang {
|
||||
namespace common {
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth);
|
||||
|
||||
}
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,43 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
#ifndef CONCRETELANG_COMMON_ERROR_H
|
||||
#define CONCRETELANG_COMMON_ERROR_H
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace concretelang {
|
||||
namespace error {
|
||||
|
||||
class StringError {
|
||||
public:
|
||||
StringError(std::string mesg) : mesg(mesg){};
|
||||
|
||||
std::string mesg;
|
||||
|
||||
StringError &operator<<(const std::string &v) {
|
||||
mesg += v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
StringError &operator<<(const char *v) {
|
||||
mesg += std::string(v);
|
||||
return *this;
|
||||
}
|
||||
|
||||
StringError &operator<<(char *v) {
|
||||
mesg += std::string(v);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T> inline StringError &operator<<(const T v) {
|
||||
mesg += std::to_string(v);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace error
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
|
||||
add_public_tablegen_target(ConcretelangConversionPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangConversionPassIncGen)
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_CONCRETETOCAPI_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_CONCRETETOCAPI_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` dialect to CAPI calls.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToCAPIPass(bool gpu);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
|
||||
createExtractSDFGOpsPass(bool unroll);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,21 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETENSOROPSTOLINALG_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETENSOROPSTOLINALG_PASS_H_
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `FHE` tensor operators to linal.generic
|
||||
/// operators.
|
||||
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
|
||||
createConvertFHETensorOpsToLinalg();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,48 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHECRT_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHECRT_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <cstddef>
|
||||
#include <list>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
struct CrtLoweringParameters {
|
||||
mlir::SmallVector<int64_t> mods;
|
||||
mlir::SmallVector<int64_t> bits;
|
||||
size_t nMods;
|
||||
size_t modsProd;
|
||||
size_t bitsTotal;
|
||||
size_t singleLutSize;
|
||||
|
||||
CrtLoweringParameters(mlir::SmallVector<int64_t> mods) : mods(mods) {
|
||||
nMods = mods.size();
|
||||
modsProd = 1;
|
||||
bitsTotal = 0;
|
||||
bits.clear();
|
||||
for (auto &mod : mods) {
|
||||
modsProd *= mod;
|
||||
uint64_t nbits =
|
||||
static_cast<uint64_t>(ceil(log2(static_cast<double>(mod))));
|
||||
bits.push_back(nbits);
|
||||
bitsTotal += nbits;
|
||||
}
|
||||
singleLutSize = size_t(1) << bitsTotal;
|
||||
}
|
||||
};
|
||||
|
||||
/// Create a pass to convert `FHE` dialect to `TFHE` dialect with the crt
|
||||
// strategy.
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>>
|
||||
createConvertFHEToTFHECrtPass(CrtLoweringParameters lowering);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,28 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHESCALAR_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHESCALAR_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <list>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
struct ScalarLoweringParameters {
|
||||
size_t polynomialSize;
|
||||
ScalarLoweringParameters(size_t polySize) : polynomialSize(polySize){};
|
||||
};
|
||||
|
||||
/// Create a pass to convert `FHE` dialect to `TFHE` dialect with the scalar
|
||||
// strategy.
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>>
|
||||
createConvertFHEToTFHEScalarPass(ScalarLoweringParameters loweringParameters);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_LINALGEXTRAS_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_LINALGEXTRAS_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createLinalgGenericOpWithTensorsToLoopsPass(bool parallelizeLoops);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,20 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_MLIRLOWERABLEDIALECTSTOLLVM_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_MLIRLOWERABLEDIALECTSTOLLVM_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
template <typename T> class OperationPass;
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert MLIR lowerable dialects to LLVM.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertMLIRLowerableDialectsToLLVMPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,36 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_TRANSFORMS_PASSES_H
|
||||
#define CONCRETELANG_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "concretelang/Conversion/ConcreteToCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ExtractSDFGOps/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHECrt/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHEScalar/Pass.h"
|
||||
#include "concretelang/Conversion/LinalgExtras/Passes.h"
|
||||
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/TracingToCAPI/Pass.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "concretelang/Conversion/Passes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,86 @@
|
||||
#ifndef CONCRETELANG_CONVERSION_PASSES
|
||||
#define CONCRETELANG_CONVERSION_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FHETensorOpsToLinalg : Pass<"fhe-tensor-ops-to-linalg", "::mlir::func::FuncOp"> {
|
||||
let summary = "Lowers tensor operations of FHE dialect to linalg.generic";
|
||||
let constructor = "mlir::concretelang::createConvertFHETensorOpsToLinalg()";
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def FHEToTFHEScalar : Pass<"fhe-to-tfhe-scalar", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the FHE dialect to TFHE using the scalar strategy.";
|
||||
let description = [{ Lowers operations from the FHE dialect to Std + Math }];
|
||||
let constructor = "mlir::concretelang::createConvertFHEToTFHEScalarPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def FHEToTFHECrt : Pass<"fhe-to-tfhe-crt", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the FHE dialect to TFHE using the crt strategy.";
|
||||
let description = [{ Lowers operations from the FHE dialect to Std + Math }];
|
||||
let constructor = "mlir::concretelang::createConvertFHEToTFHECrtPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def TFHEGlobalParametrization : Pass<"tfhe-global-parametrization", "mlir::ModuleOp"> {
|
||||
let summary = "Inject global fhe parameters to the TFHE dialect";
|
||||
let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::concretelang::TFHE::TFHEDialect"];
|
||||
}
|
||||
|
||||
def TFHEToConcrete : Pass<"tfhe-to-concrete", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the TFHE dialect to Concrete";
|
||||
let description = [{ Lowers operations from the TFHE dialect to Concrete }];
|
||||
let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::TFHE::TFHEDialect"];
|
||||
}
|
||||
|
||||
def LinalgGenericOpWithTensorsToLoops : Pass<"linalg-generic-op-with-tensors-to-loops", "mlir::ModuleOp"> {
|
||||
let summary = "Converts linalg.generic ops with tensor inputs / outputs to a loop nest";
|
||||
let description = [{ Converts linalg.generic ops with tensor inputs / outputs to a loop nest }];
|
||||
let constructor = "mlir::createLinalgGenericOpWithTensorsToLoopsPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::scf::SCFDialect"];
|
||||
}
|
||||
|
||||
def ExtractSDFGOps : Pass<"extract-sdfg-ops", "::mlir::func::FuncOp"> {
|
||||
let summary = "Extracts SDFG ops and creates a static data flow graph";
|
||||
let description = [{ Extracts SDFG ops and creates a static data flow graph }];
|
||||
let constructor = "mlir::concretelang::createExtractSDFGOps()";
|
||||
let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"];
|
||||
}
|
||||
|
||||
def ConcreteToCAPI : Pass<"concrete-to-capi", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the Concrete dialect to CAPI calls";
|
||||
let description = [{ Lowers operations from the Concrete dialect to CAPI calls }];
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect"];
|
||||
}
|
||||
|
||||
def TracingToCAPI : Pass<"tracing-to-capi", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the Tracing dialect to CAPI calls";
|
||||
let description = [{ Lowers operations from the Tracing dialect to CAPI calls }];
|
||||
let constructor = "mlir::concretelang::createConvertTracingToCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Tracing::TracingDialect"];
|
||||
}
|
||||
|
||||
def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the SDFG dialect to Stream Emulator calls";
|
||||
let description = [{ Lowers operations from the SDFG dialect to Stream Emulator calls }];
|
||||
let constructor = "mlir::concretelang::createConvertSDFGToStreamEmulatorPass()";
|
||||
let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
let dependentDialects = ["mlir::func::FuncDialect", "mlir::arith::ArithmeticDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect"];
|
||||
let options = [];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_SDFGTOSTREAMEMULATOR_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_SDFGTOSTREAMEMULATOR_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `SDFG` dialect to Stream Emulator calls.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertSDFGToStreamEmulatorPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,22 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_TFHEGLOBALPARAMETRIZATION_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_TFHEGLOBALPARAMETRIZATION_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to inject fhe parameters to the TFHE types and operators.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTFHEGlobalParametrizationPass(
|
||||
mlir::concretelang::V0FHEContext &fheContext);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `TFHE` dialect to `Concrete` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTFHEToConcretePass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,17 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
mlir::OpBuilder &rewriter,
|
||||
llvm::StringRef funcName,
|
||||
mlir::FunctionType funcType);
|
||||
|
||||
/// \brief Returns the value of the context argument from the enclosing func
|
||||
///
|
||||
/// \param op initial operation to start the search from
|
||||
/// \return mlir::Value the context value
|
||||
mlir::Value getContextArgument(mlir::Operation *op);
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_TRACINGTOCAPI_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_TRACINGTOCAPI_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Tracing` dialect to CAPI calls.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTracingToCAPIPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,29 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_DIALECTS_SCF_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_DIALECTS_SCF_H_
|
||||
|
||||
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
//
|
||||
// Specializations for ForOp
|
||||
//
|
||||
|
||||
// Specialization copying attributes omitted
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
|
||||
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,70 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_
|
||||
|
||||
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
//
|
||||
// Specializations for CollapseShapeOp
|
||||
//
|
||||
|
||||
// Specialization copying attributes not necessary, as the base
|
||||
// template works correctly
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::CollapseShapeOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::CollapseShapeOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::CollapseShapeOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
//
|
||||
// Specializations for FromElementsOp
|
||||
//
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::FromElementsOp oldOp,
|
||||
mlir::OpConversionPattern<mlir::tensor::FromElementsOp>::OpAdaptor
|
||||
adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
//
|
||||
// Specializations for ExpandShapeOp
|
||||
//
|
||||
|
||||
// Specialization copying attributes not necessary, as the base
|
||||
// template works correctly
|
||||
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::ExpandShapeOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::ExpandShapeOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::ExpandShapeOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
//
|
||||
// Specializations for GenerateOp
|
||||
//
|
||||
|
||||
// Specialization NOT copying attributes omitted
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::GenerateOp, true>::matchAndRewrite(
|
||||
tensor::GenerateOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::GenerateOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_
|
||||
@@ -0,0 +1,66 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
template <typename TypeConverterType>
|
||||
struct FunctionConstantOpConversion
|
||||
: public mlir::OpRewritePattern<mlir::func::ConstantOp> {
|
||||
FunctionConstantOpConversion(mlir::MLIRContext *ctx,
|
||||
TypeConverterType &converter,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::func::ConstantOp>(ctx, benefit),
|
||||
converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::func::ConstantOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto symTab = mlir::SymbolTable::getNearestSymbolTable(op);
|
||||
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, op.getValue());
|
||||
assert(funcOp &&
|
||||
"Function symbol missing in symbol table for function constant op.");
|
||||
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
|
||||
.getFunctionType()
|
||||
.cast<mlir::FunctionType>();
|
||||
typename TypeConverterType::SignatureConversion result(
|
||||
funType.getNumInputs());
|
||||
mlir::SmallVector<mlir::Type, 1> newResults;
|
||||
if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) ||
|
||||
failed(converter.convertTypes(funType.getResults(), newResults)))
|
||||
return mlir::failure();
|
||||
auto newType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), result.getConvertedTypes(), newResults);
|
||||
rewriter.updateRootInPlace(op, [&] { op.getResult().setType(newType); });
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static bool isLegal(mlir::func::ConstantOp fun,
|
||||
TypeConverterType &converter) {
|
||||
auto symTab = mlir::SymbolTable::getNearestSymbolTable(fun);
|
||||
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, fun.getValue());
|
||||
assert(funcOp &&
|
||||
"Function symbol missing in symbol table for function constant op.");
|
||||
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
|
||||
.getFunctionType()
|
||||
.cast<mlir::FunctionType>();
|
||||
typename TypeConverterType::SignatureConversion result(
|
||||
funType.getNumInputs());
|
||||
mlir::SmallVector<mlir::Type, 1> newResults;
|
||||
if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) ||
|
||||
failed(converter.convertTypes(funType.getResults(), newResults)))
|
||||
return false;
|
||||
auto newType = mlir::FunctionType::get(
|
||||
fun.getContext(), result.getConvertedTypes(), newResults);
|
||||
return newType == fun.getType();
|
||||
}
|
||||
|
||||
private:
|
||||
TypeConverterType &converter;
|
||||
};
|
||||
@@ -0,0 +1,102 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_GENERICOPTYPECONVERSIONPATTERN_H_
|
||||
#define CONCRETELANG_CONVERSION_GENERICOPTYPECONVERSIONPATTERN_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include <mlir/Support/LLVM.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// Converts the type of all operands and the return type of `op` by
|
||||
// invoking `convertType`
|
||||
static inline void convertOperandAndResultTypes(
|
||||
mlir::PatternRewriter &rewriter, mlir::Operation *op,
|
||||
llvm::function_ref<mlir::Type(mlir::MLIRContext *, mlir::Type)>
|
||||
convertType) {
|
||||
rewriter.startRootUpdate(op);
|
||||
// Rewrite arguments
|
||||
{
|
||||
for (unsigned i = 0; i < op->getNumOperands(); i++) {
|
||||
auto operand = op->getOperand(i);
|
||||
mlir::Type type = convertType(rewriter.getContext(), operand.getType());
|
||||
if (type != mlir::Type()) {
|
||||
operand.setType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Rewrite results
|
||||
{
|
||||
for (unsigned i = 0; i < op->getNumResults(); i++) {
|
||||
auto result = op->getResult(i);
|
||||
mlir::Type type = convertType(rewriter.getContext(), result.getType());
|
||||
if (type != mlir::Type()) {
|
||||
result.setType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
struct GenericTypeConverterPattern : public mlir::OpRewritePattern<Op> {
|
||||
GenericTypeConverterPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpRewritePattern<Op>(context, benefit), converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.clone(*op);
|
||||
convertOperandAndResultTypes(rewriter, newOp,
|
||||
[&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::TypeConverter &converter;
|
||||
};
|
||||
|
||||
template <typename OldOp, typename NewOp>
|
||||
struct GenericTypeAndOpConverterPattern : public mlir::OpRewritePattern<OldOp> {
|
||||
GenericTypeAndOpConverterPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpRewritePattern<OldOp>(context, benefit), converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(OldOp oldOp, mlir::PatternRewriter &rewriter) const override {
|
||||
// Rewrite results
|
||||
mlir::SmallVector<mlir::Type> resultTypes(oldOp->getNumResults());
|
||||
{
|
||||
for (unsigned i = 0; i < oldOp->getNumResults(); i++) {
|
||||
auto result = oldOp->getResult(i);
|
||||
resultTypes[i] = converter.convertType(result.getType());
|
||||
}
|
||||
}
|
||||
auto newOp = rewriter.replaceOpWithNewOp<NewOp>(
|
||||
oldOp, resultTypes, oldOp->getOperands(), oldOp->getAttrs());
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, newOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::TypeConverter &converter;
|
||||
};
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,74 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_GLOBALFHECONTEXT_H_
|
||||
#define CONCRETELANG_CONVERSION_GLOBALFHECONTEXT_H_
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
typedef std::vector<int64_t> CRTDecomposition;
|
||||
|
||||
struct V0FHEConstraint {
|
||||
size_t norm2;
|
||||
size_t p;
|
||||
};
|
||||
|
||||
struct PackingKeySwitchParameter {
|
||||
size_t inputLweDimension;
|
||||
size_t outputPolynomialSize;
|
||||
size_t level;
|
||||
size_t baseLog;
|
||||
};
|
||||
|
||||
struct CitcuitBoostrapParameter {
|
||||
size_t level;
|
||||
size_t baseLog;
|
||||
};
|
||||
|
||||
struct WopPBSParameter {
|
||||
PackingKeySwitchParameter packingKeySwitch;
|
||||
CitcuitBoostrapParameter circuitBootstrap;
|
||||
};
|
||||
|
||||
struct LargeIntegerParameter {
|
||||
CRTDecomposition crtDecomposition;
|
||||
WopPBSParameter wopPBS;
|
||||
};
|
||||
|
||||
struct V0Parameter {
|
||||
size_t glweDimension;
|
||||
size_t logPolynomialSize;
|
||||
size_t nSmall;
|
||||
size_t brLevel;
|
||||
size_t brLogBase;
|
||||
size_t ksLevel;
|
||||
size_t ksLogBase;
|
||||
|
||||
llvm::Optional<LargeIntegerParameter> largeInteger;
|
||||
|
||||
// TODO remove the shift when we have true polynomial size
|
||||
size_t getPolynomialSize() { return 1 << logPolynomialSize; }
|
||||
|
||||
size_t getNBigLweDimension() { return glweDimension * getPolynomialSize(); }
|
||||
};
|
||||
|
||||
struct V0FHEContext {
|
||||
V0FHEContext() = delete;
|
||||
V0FHEContext(const V0FHEConstraint &constraint, const V0Parameter ¶meter)
|
||||
: constraint(constraint), parameter(parameter) {}
|
||||
|
||||
V0FHEConstraint constraint;
|
||||
V0Parameter parameter;
|
||||
};
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,26 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_LEGALITY_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_LEGALITY_H_
|
||||
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
template <typename Op>
|
||||
void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
target.addDynamicallyLegalOp<Op>([&](Op op) {
|
||||
return typeConverter.isLegal(op->getOperandTypes()) &&
|
||||
typeConverter.isLegal(op->getResultTypes());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,45 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
/// RegionOpTypeConverterPattern is a rewrite pattern that applies
|
||||
/// `TypeConverter` to an instance of `OpWithRegion`, converting the
|
||||
/// type of all operands, results and arguments of regions according
|
||||
/// to the type converter.
|
||||
template <typename OpWithRegion, typename TypeConverter>
|
||||
struct RegionOpTypeConverterPattern
|
||||
: public mlir::OpRewritePattern<OpWithRegion> {
|
||||
RegionOpTypeConverterPattern(mlir::MLIRContext *context,
|
||||
TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpRewritePattern<OpWithRegion>(context, benefit),
|
||||
converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(OpWithRegion op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto doConvertType = [&](mlir::Value v) {
|
||||
mlir::Type type = converter.convertType(v.getType());
|
||||
|
||||
if (type != mlir::Type())
|
||||
v.setType(type);
|
||||
};
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
llvm::for_each(op->getOperands(), doConvertType);
|
||||
llvm::for_each(op->getResults(), doConvertType);
|
||||
llvm::for_each(op->getRegions(), [&](mlir::Region ®ion) {
|
||||
llvm::for_each(region.front().getArguments(), doConvertType);
|
||||
});
|
||||
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
TypeConverter &converter;
|
||||
};
|
||||
@@ -0,0 +1,216 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_REINSTANTIATINGOPTYPECONVERSION_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_REINSTANTIATINGOPTYPECONVERSION_H_
|
||||
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// Set of types defining how attributes should be handled when
|
||||
// invocating the build method of an operation upon reinstantiation
|
||||
struct ReinstantiationAttributeHandling {
|
||||
// Copy attributes
|
||||
struct copy {};
|
||||
|
||||
// Completely dismiss attributes by not passing a set of arguments
|
||||
// to the builder at all
|
||||
struct dismiss {};
|
||||
|
||||
// Dismiss attributes by passing an empty set of arguments to the
|
||||
// builder
|
||||
struct pass_empty_vector {};
|
||||
};
|
||||
|
||||
// Template defining how attributes should be dismissed when invoking
|
||||
// the build method of an operation upon reinstantiation. In the
|
||||
// default case, the argument for attributes is simply dismissed.
|
||||
template <typename T> struct ReinstantiationAttributeDismissalStrategy {
|
||||
typedef ReinstantiationAttributeHandling::dismiss strategy;
|
||||
};
|
||||
|
||||
// Template defining how attributes should be copied when invoking the
|
||||
// build method of an operation upon reinstantiation. In the default
|
||||
// case, the argument for attributes is forwarded to the build method.
|
||||
template <typename T> struct ReinstantiationAttributeCopyStrategy {
|
||||
typedef ReinstantiationAttributeHandling::copy strategy;
|
||||
};
|
||||
|
||||
namespace {
|
||||
// Class template that defines the attribute handling strategy for
|
||||
// either dismissal of attributes (if `copyAttrsSwitch` is `false`) or copying
|
||||
// attributes (if `copyAttrsSwitch` is `true`).
|
||||
template <typename T, bool copyAttrsSwitch> struct AttributeHandlingSwitch {};
|
||||
|
||||
template <typename T> struct AttributeHandlingSwitch<T, true> {
|
||||
typedef typename ReinstantiationAttributeCopyStrategy<T>::strategy strategy;
|
||||
};
|
||||
|
||||
template <typename T> struct AttributeHandlingSwitch<T, false> {
|
||||
typedef
|
||||
typename ReinstantiationAttributeDismissalStrategy<T>::strategy strategy;
|
||||
};
|
||||
|
||||
// Simple functor-like template invoking a rewriter with a variable
|
||||
// set of arguments and an op's attributes as the last argument.
|
||||
template <typename NewOpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpCopyAttrs {
|
||||
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::Operation *op, mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange operands) {
|
||||
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands,
|
||||
op->getAttrs());
|
||||
}
|
||||
};
|
||||
|
||||
// Simple functor-like template invoking a rewriter with a variable
|
||||
// set of arguments dismissing the attributes passed as the last
|
||||
// argument.
|
||||
template <typename NewOpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpDismissAttrs {
|
||||
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::Operation *op, mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange operands) {
|
||||
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands);
|
||||
}
|
||||
};
|
||||
|
||||
// Simple functor-like template invoking a rewriter with a variable
|
||||
// set of arguments dismissing the attributes by passing an empty
|
||||
// set of arguments to the builder.
|
||||
template <typename NewOpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpEmptyAttrs {
|
||||
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::Operation *op, mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange operands) {
|
||||
llvm::SmallVector<mlir::NamedAttribute> attrs{};
|
||||
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands,
|
||||
attrs);
|
||||
}
|
||||
};
|
||||
|
||||
// Functor-like template that either forwards to
|
||||
// `ReplaceOpWithNewOpCopyAttrs` or `ReplaceOpWithNewOpDismissAttrs`
|
||||
// depending on the value of `copyAttrs`.
|
||||
template <typename copyAttrsSwitch, typename OpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpAttrSwitch {};
|
||||
|
||||
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does copy
|
||||
// attributes.
|
||||
template <typename OpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpAttrSwitch<ReinstantiationAttributeHandling::copy,
|
||||
OpTy, Args...> {
|
||||
typedef ReplaceOpWithNewOpCopyAttrs<OpTy, Args...> instantiator;
|
||||
};
|
||||
|
||||
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does NOT copy
|
||||
// attributes by not passing attributes to the builder at all.
|
||||
template <typename OpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpAttrSwitch<ReinstantiationAttributeHandling::dismiss,
|
||||
OpTy, Args...> {
|
||||
typedef ReplaceOpWithNewOpDismissAttrs<OpTy, Args...> instantiator;
|
||||
};
|
||||
|
||||
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does NOT copy
|
||||
// attributes by passing an empty set of attributes to the builder.
|
||||
template <typename OpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpAttrSwitch<
|
||||
ReinstantiationAttributeHandling::pass_empty_vector, OpTy, Args...> {
|
||||
typedef ReplaceOpWithNewOpEmptyAttrs<OpTy, Args...> instantiator;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename OldOp, typename NewOp, bool copyAttrs = false>
|
||||
struct GenericOneToOneOpConversionPatternBase
|
||||
: public mlir::OpConversionPattern<OldOp> {
|
||||
GenericOneToOneOpConversionPatternBase(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpConversionPattern<OldOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::SmallVector<mlir::Type> convertResultTypes(OldOp oldOp) const {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
// Convert result types
|
||||
mlir::SmallVector<mlir::Type> resultTypes(oldOp->getNumResults());
|
||||
|
||||
for (unsigned i = 0; i < oldOp->getNumResults(); i++) {
|
||||
auto result = oldOp->getResult(i);
|
||||
resultTypes[i] = converter->convertType(result.getType());
|
||||
}
|
||||
|
||||
return resultTypes;
|
||||
}
|
||||
|
||||
mlir::Type convertResultType(OldOp oldOp) const {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
return converter->convertType(oldOp->getResult(0).getType());
|
||||
}
|
||||
};
|
||||
|
||||
// Conversion pattern that replaces an instance of an operation of the type
|
||||
// `OldOp` with an instance of the type `NewOp`, taking into account operands,
|
||||
// return types and possible copying attributes (iff copyAttrs is `true`).
|
||||
template <typename OldOp, typename NewOp, bool copyAttrs = false>
|
||||
struct GenericOneToOneOpConversionPattern
|
||||
: public GenericOneToOneOpConversionPatternBase<OldOp, NewOp, copyAttrs> {
|
||||
GenericOneToOneOpConversionPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: GenericOneToOneOpConversionPatternBase<OldOp, NewOp, copyAttrs>(
|
||||
context, converter, benefit) {}
|
||||
|
||||
virtual mlir::LogicalResult
|
||||
matchAndRewrite(OldOp oldOp,
|
||||
typename mlir::OpConversionPattern<OldOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::SmallVector<mlir::Type> resultTypes = this->convertResultTypes(oldOp);
|
||||
|
||||
ReplaceOpWithNewOpAttrSwitch<
|
||||
typename AttributeHandlingSwitch<NewOp, copyAttrs>::strategy,
|
||||
NewOp>::instantiator::replace(rewriter, oldOp,
|
||||
mlir::TypeRange{resultTypes},
|
||||
mlir::ValueRange{adaptor.getOperands()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// Conversion pattern that retrieves the converted operands of an
|
||||
// operation of the type `Op`, converts the types of the results of
|
||||
// the operation and re-instantiates the operation type with the
|
||||
// converted operands and result types.
|
||||
template <typename Op, bool copyAttrs = false>
|
||||
struct TypeConvertingReinstantiationPattern
|
||||
: public GenericOneToOneOpConversionPatternBase<Op, Op, copyAttrs> {
|
||||
TypeConvertingReinstantiationPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: GenericOneToOneOpConversionPatternBase<Op, Op, copyAttrs>(
|
||||
context, converter, benefit) {}
|
||||
// Simple forward that makes the method specializable out of class
|
||||
// directly for this class rather than for its base
|
||||
virtual mlir::LogicalResult
|
||||
matchAndRewrite(Op op,
|
||||
typename mlir::OpConversionPattern<Op>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::SmallVector<mlir::Type> resultTypes = this->convertResultTypes(op);
|
||||
|
||||
ReplaceOpWithNewOpAttrSwitch<
|
||||
typename AttributeHandlingSwitch<Op, copyAttrs>::strategy,
|
||||
Op>::instantiator::replace(rewriter, op, mlir::TypeRange{resultTypes},
|
||||
mlir::ValueRange{adaptor.getOperands()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,66 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
|
||||
#define CONCRETELANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Utils/Dialects/Tensor.h"
|
||||
#include "concretelang/Conversion/Utils/Legality.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
inline void
|
||||
populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
// ExtractOp
|
||||
patterns.add<TypeConvertingReinstantiationPattern<mlir::tensor::ExtractOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::ExtractOp>(target, typeConverter);
|
||||
|
||||
// ExtractSliceOp
|
||||
patterns.add<
|
||||
TypeConvertingReinstantiationPattern<mlir::tensor::ExtractSliceOp, true>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(target,
|
||||
typeConverter);
|
||||
|
||||
// InsertOp
|
||||
patterns.add<TypeConvertingReinstantiationPattern<mlir::tensor::InsertOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::InsertOp>(target, typeConverter);
|
||||
// InsertSliceOp
|
||||
patterns.add<
|
||||
TypeConvertingReinstantiationPattern<mlir::tensor::InsertSliceOp, true>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::InsertSliceOp>(target, typeConverter);
|
||||
|
||||
// FromElementsOp
|
||||
patterns
|
||||
.add<TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(target,
|
||||
typeConverter);
|
||||
// TensorCollapseShapeOp
|
||||
patterns
|
||||
.add<TypeConvertingReinstantiationPattern<mlir::tensor::CollapseShapeOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::CollapseShapeOp>(target,
|
||||
typeConverter);
|
||||
// TensorExpandShapeOp
|
||||
patterns
|
||||
.add<TypeConvertingReinstantiationPattern<mlir::tensor::ExpandShapeOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::ExpandShapeOp>(target, typeConverter);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,7 @@
|
||||
add_subdirectory(FHE)
|
||||
add_subdirectory(FHELinalg)
|
||||
add_subdirectory(TFHE)
|
||||
add_subdirectory(Concrete)
|
||||
add_subdirectory(RT)
|
||||
add_subdirectory(SDFG)
|
||||
add_subdirectory(Tracing)
|
||||
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
@@ -0,0 +1,13 @@
|
||||
set(LLVM_TARGET_DEFINITIONS ConcreteOps.td)
|
||||
mlir_tablegen(ConcreteOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(ConcreteOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(ConcreteOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=Concrete)
|
||||
mlir_tablegen(ConcreteOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=Concrete)
|
||||
mlir_tablegen(ConcreteOpsDialect.h.inc -gen-dialect-decls -dialect=Concrete)
|
||||
mlir_tablegen(ConcreteOpsDialect.cpp.inc -gen-dialect-defs -dialect=Concrete)
|
||||
add_public_tablegen_target(MLIRConcreteOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRConcreteOpsIncGen)
|
||||
|
||||
add_concretelang_doc(ConcreteOps ConcreteDialect concretelang/ -gen-dialect-doc -dialect=Concrete)
|
||||
add_concretelang_doc(ConcreteOps ConcreteOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(ConcreteTypes ConcreteTypes concretelang/ -gen-typedef-doc)
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_ConcreteDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_ConcreteDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,16 @@
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_DIALECT
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Concrete_Dialect : Dialect {
|
||||
let name = "Concrete";
|
||||
let summary = "Low Level Fully Homorphic Encryption dialect";
|
||||
let description = [{
|
||||
A dialect for representation of low level operation on fully homomorphic ciphertext.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::Concrete";
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,21 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_Concrete_OPS_H
|
||||
#define CONCRETELANG_DIALECT_Concrete_Concrete_OPS_H
|
||||
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Interfaces/BatchableInterface.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,392 @@
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_OPS
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td"
|
||||
include "concretelang/Interfaces/BatchableInterface.td"
|
||||
include "concretelang/Dialect/RT/IR/RTDialect.td"
|
||||
include "concretelang/Dialect/RT/IR/RTTypes.td"
|
||||
|
||||
def Concrete_LweTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_LutTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_CrtLutsTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_LweCRTTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_BatchLweTensor : 2DTensorOf<[I64]>;
|
||||
|
||||
def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_CrtLutsBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
|
||||
|
||||
class Concrete_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Concrete_Dialect, mnemonic, traits>;
|
||||
|
||||
|
||||
def Concrete_AddLweTensorOp : Concrete_Op<"add_lwe_tensor", [NoSideEffect]> {
|
||||
let summary = "Returns the sum of 2 lwe ciphertexts";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweTensor:$lhs,
|
||||
Concrete_LweTensor:$rhs
|
||||
);
|
||||
let results = (outs Concrete_LweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_AddLweBufferOp : Concrete_Op<"add_lwe_buffer"> {
|
||||
let summary = "Returns the sum of 2 lwe ciphertexts";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweBuffer:$result,
|
||||
Concrete_LweBuffer:$lhs,
|
||||
Concrete_LweBuffer:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [NoSideEffect]> {
|
||||
let summary = "Returns the sum of a clear integer and an lwe ciphertext";
|
||||
|
||||
let arguments = (ins Concrete_LweTensor:$lhs, I64:$rhs);
|
||||
let results = (outs Concrete_LweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_AddPlaintextLweBufferOp : Concrete_Op<"add_plaintext_lwe_buffer"> {
|
||||
let summary = "Returns the sum of a clear integer and an lwe ciphertext";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweBuffer:$result,
|
||||
Concrete_LweBuffer:$lhs,
|
||||
I64:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [NoSideEffect]> {
|
||||
let summary = "Returns the product of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins Concrete_LweTensor:$lhs, I64:$rhs);
|
||||
let results = (outs Concrete_LweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_MulCleartextLweBufferOp : Concrete_Op<"mul_cleartext_lwe_buffer"> {
|
||||
let summary = "Returns the product of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweBuffer:$result,
|
||||
Concrete_LweBuffer:$lhs,
|
||||
I64:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [NoSideEffect]> {
|
||||
let summary = "Negates a lwe ciphertext";
|
||||
|
||||
let arguments = (ins Concrete_LweTensor:$ciphertext);
|
||||
let results = (outs Concrete_LweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_NegateLweBufferOp : Concrete_Op<"negate_lwe_buffer"> {
|
||||
let summary = "Negates a lwe ciphertext";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweBuffer:$result,
|
||||
Concrete_LweBuffer:$ciphertext
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [NoSideEffect]> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a bootstrap";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LutTensor : $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs Concrete_LutTensor : $result);
|
||||
}
|
||||
|
||||
def Concrete_EncodeExpandLutForBootstrapBufferOp : Concrete_Op<"encode_expand_lut_for_bootstrap_buffer"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a bootstrap";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LutBuffer: $result,
|
||||
Concrete_LutBuffer: $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits,
|
||||
BoolAttr : $isSigned
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [NoSideEffect]> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LutTensor : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs Concrete_CrtLutsTensor : $result);
|
||||
}
|
||||
|
||||
def Concrete_EncodeLutForCrtWopPBSBufferOp : Concrete_Op<"encode_lut_for_crt_woppbs_buffer"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a crt wop pbs";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_CrtLutsBuffer : $result,
|
||||
Concrete_LutBuffer : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_EncodePlaintextWithCrtTensorOp : Concrete_Op<"encode_plaintext_with_crt_tensor", [NoSideEffect]> {
|
||||
let summary =
|
||||
"Encodes a plaintext by decomposing it on a crt basis";
|
||||
|
||||
let arguments = (ins
|
||||
I64 : $input,
|
||||
I64ArrayAttr: $mods,
|
||||
I64Attr: $modsProd
|
||||
);
|
||||
|
||||
let results = (outs Concrete_CrtPlaintextTensor : $result);
|
||||
}
|
||||
|
||||
def Concrete_EncodePlaintextWithCrtBufferOp : Concrete_Op<"encode_plaintext_with_crt_buffer"> {
|
||||
let summary =
|
||||
"Encodes a plaintext by decomposing it on a crt basis";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_CrtPlaintextBuffer: $result,
|
||||
I64 : $input,
|
||||
I64ArrayAttr: $mods,
|
||||
I64Attr: $modsProd
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [NoSideEffect]> {
|
||||
let summary = "Bootstraps an LWE ciphertext with a GLWE trivial encryption of the lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweTensor:$input_ciphertext,
|
||||
Concrete_LweTensor:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
let results = (outs Concrete_LweTensor:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::OpOperand& getBatchableOperand() {
|
||||
return getOperation()->getOpOperand(0);
|
||||
}
|
||||
|
||||
::mlir::OperandRange getNonBatchableOperands() {
|
||||
return getOperation()->getOperands().drop_front();
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::Value batchedOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
return builder.create<BatchedBootstrapLweTensorOp>(
|
||||
mlir::TypeRange{resType},
|
||||
mlir::ValueRange{batchedOperands, lookup_table()},
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweBufferOp : Concrete_Op<"bootstrap_lwe_buffer"> {
|
||||
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweBuffer:$result,
|
||||
Concrete_LweBuffer:$input_ciphertext,
|
||||
Concrete_LutBuffer:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_tensor", [NoSideEffect]> {
|
||||
let summary = "Batched version of BootstrapLweOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweTensor:$input_ciphertext,
|
||||
Concrete_LutTensor:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_buffer"> {
|
||||
let summary = "Batched version of BootstrapLweOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$input_ciphertext,
|
||||
Concrete_LutBuffer:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [NoSideEffect]> {
|
||||
let summary = "Keyswitches an LWE ciphertext";
|
||||
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
Concrete_LweTensor:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
);
|
||||
let results = (outs Concrete_LweTensor:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::OpOperand& getBatchableOperand() {
|
||||
return getOperation()->getOpOperand(0);
|
||||
}
|
||||
|
||||
::mlir::OperandRange getNonBatchableOperands() {
|
||||
return getOperation()->getOperands().drop_front();
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::Value batchedOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
return builder.create<BatchedKeySwitchLweTensorOp>(
|
||||
mlir::TypeRange{resType},
|
||||
mlir::ValueRange{batchedOperands},
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Concrete_KeySwitchLweBufferOp : Concrete_Op<"keyswitch_lwe_buffer"> {
|
||||
let summary = "Keyswitches an LWE ciphertext";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweBuffer:$result,
|
||||
Concrete_LweBuffer:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedKeySwitchLweTensorOp : Concrete_Op<"batched_keyswitch_lwe_tensor", [NoSideEffect]> {
|
||||
let summary = "Batched version of KeySwitchLweOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
Concrete_BatchLweTensor:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_buffer"> {
|
||||
let summary = "Batched version of KeySwitchLweOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
Concrete_LweCRTTensor:$ciphertext,
|
||||
Concrete_CrtLutsTensor:$lookupTable,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
// Keyswitch parameters
|
||||
I32Attr : $keyswitchLevel,
|
||||
I32Attr : $keyswitchBaseLog,
|
||||
// Packing keyswitch key parameters
|
||||
I32Attr : $packingKeySwitchInputLweDimension,
|
||||
I32Attr : $packingKeySwitchoutputPolynomialSize,
|
||||
I32Attr : $packingKeySwitchLevel,
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog
|
||||
);
|
||||
let results = (outs Concrete_LweCRTTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_WopPBSCRTLweBufferOp : Concrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
Concrete_LweCRTBuffer:$result,
|
||||
Concrete_LweCRTBuffer:$ciphertext,
|
||||
Concrete_CrtLutsBuffer:$lookup_table,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
// Keyswitch parameters
|
||||
I32Attr : $keyswitchLevel,
|
||||
I32Attr : $keyswitchBaseLog,
|
||||
// Packing keyswitch key parameters
|
||||
I32Attr : $packingKeySwitchInputLweDimension,
|
||||
I32Attr : $packingKeySwitchoutputPolynomialSize,
|
||||
I32Attr : $packingKeySwitchLevel,
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,17 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_ConcreteTYPES_H
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_ConcreteTYPES_H
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOpsTypes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,20 @@
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_TYPES
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_TYPES
|
||||
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
|
||||
|
||||
class Concrete_Type<string name, list<Trait> traits = []> : TypeDef<Concrete_Dialect, name, traits> { }
|
||||
|
||||
def Concrete_Context : Concrete_Type<"Context"> {
|
||||
let mnemonic = "context";
|
||||
|
||||
let summary = "A runtime context";
|
||||
|
||||
let description = [{
|
||||
An abstract runtime context to pass contextual value, like public keys, ...
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_CONCRETE_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
#define CONCRETELANG_DIALECT_CONCRETE_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace concretelang {
|
||||
namespace Concrete {
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
} // namespace Concrete
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Concrete)
|
||||
add_public_tablegen_target(ConcreteTransformsIncGen)
|
||||
@@ -0,0 +1,20 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_CONCRETE_TRANSFORMS_PASSES_H_
|
||||
#define CONCRETELANG_DIALECT_CONCRETE_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "concretelang/Dialect/Concrete/Transforms/Passes.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAddRuntimeContext();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // CONCRETELANG_DIALECT_CONCRETE_TRANSFORMS_PASSES_H_
|
||||
@@ -0,0 +1,19 @@
|
||||
//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||
#define MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def AddRuntimeContext : Pass<"add-runtime-context", "mlir::ModuleOp"> {
|
||||
let summary = "Add the runtime context argument";
|
||||
let constructor = "mlir::concretelang::createAddRuntimeContext()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||
@@ -0,0 +1,13 @@
|
||||
set(LLVM_TARGET_DEFINITIONS MANP.td)
|
||||
mlir_tablegen(MANP.h.inc -gen-pass-decls -name Analysis)
|
||||
mlir_tablegen(MANP.capi.h.inc -gen-pass-capi-header --prefix Analysis)
|
||||
mlir_tablegen(MANP.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis)
|
||||
add_public_tablegen_target(MANPPassIncGen)
|
||||
add_dependencies(mlir-headers MANPPassIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ConcreteOptimizer.td)
|
||||
mlir_tablegen(ConcreteOptimizer.h.inc -gen-pass-decls -name Analysis)
|
||||
mlir_tablegen(ConcreteOptimizer.capi.h.inc -gen-pass-capi-header --prefix Analysis)
|
||||
mlir_tablegen(ConcreteOptimizer.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis)
|
||||
add_public_tablegen_target(ConcreteOptimizerPassIncGen)
|
||||
add_dependencies(mlir-headers ConcreteOptimizerPassIncGen)
|
||||
@@ -0,0 +1,29 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER_H
|
||||
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER_H
|
||||
|
||||
#include <map>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#include "concrete-optimizer.hpp"
|
||||
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace optimizer {
|
||||
using FunctionsDag = std::map<std::string, llvm::Optional<Dag>>;
|
||||
|
||||
std::unique_ptr<mlir::Pass> createDagPass(optimizer::Config config,
|
||||
optimizer::FunctionsDag &dags);
|
||||
|
||||
} // namespace optimizer
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,15 @@
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER
|
||||
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ConcreteOptimizer : Pass<"ConcreteOptimizer", "::mlir::func::FuncOp"> {
|
||||
let summary = "Call concrete-optimizer";
|
||||
let description = [{
|
||||
The pass calls the concrete-optimizer to provide crypto parameter.
|
||||
It construct a simplified representation of the FHE circuit and send it to the concrete optimizer.
|
||||
It uses on the values from the MANP pass to indicate how noise is propagate and amplified in levelled operations.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,23 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP_H
|
||||
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP_H
|
||||
|
||||
#include <functional>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
bool isEncryptedValue(mlir::Value value);
|
||||
unsigned int getEintPrecision(mlir::Value value);
|
||||
std::unique_ptr<mlir::Pass> createMANPPass(bool debug = false);
|
||||
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createMaxMANPPass(std::function<void(uint64_t, unsigned)> setMax);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,110 @@
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP
|
||||
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def MANP : Pass<"MANP", "::mlir::func::FuncOp"> {
|
||||
let summary = "FHE Minimal Arithmetic Noise Padding Pass";
|
||||
let description = [{
|
||||
This pass calculates the Minimal Arithmetic Noise Padding
|
||||
(MANP) for each operation of a function and stores the result in an
|
||||
integer attribute named "MANP". This metric is identical to the
|
||||
ceiled 2-norm of the constant vector of an equivalent dot product
|
||||
between a vector of encrypted integers resulting directly from an
|
||||
encryption and a vector of plaintext constants.
|
||||
|
||||
The pass supports the following operations:
|
||||
|
||||
- FHELinalg.dot_eint_int
|
||||
- FHE.zero
|
||||
- FHE.add_eint_int
|
||||
- FHE.add_eint
|
||||
- FHE.sub_int_eint
|
||||
- FHE.neg_eint
|
||||
- FHE.mul_eint_int
|
||||
- FHE.apply_lookup_table
|
||||
|
||||
If any other operation is encountered, the pass conservatively
|
||||
fails. The pass further makes the optimistic assumption that all
|
||||
values passed to a function are either the direct result of an
|
||||
encryption of a noise-refreshing operation.
|
||||
|
||||
Conceptually, the pass is equivalent to the three steps below:
|
||||
|
||||
1. Replace all arithmetic operations with an equivalent dot
|
||||
operation
|
||||
|
||||
2. Merge resulting dot operations into a single, equivalent
|
||||
dot operation
|
||||
|
||||
3. Calculate the 2-norm of the vector of plaintext constants
|
||||
of the dot operation
|
||||
|
||||
with the following replacement rules:
|
||||
|
||||
- Function argument a -> FHELinalg.dot_eint_int([a], [1])
|
||||
- FHE.apply_lookup_table -> FHELinalg.dot_eint_int([LUT result], [1])
|
||||
- FHE.zero() -> FHELinalg.dot_eint_int([encrypted 0], [1])
|
||||
- FHE.add_eint_int(e, c) -> FHELinalg.dot_eint_int([e, 1], [1, c])
|
||||
with the encrypted 1 trivialy encrypted, i.e. without noise so 1xc is not take into account
|
||||
- FHE.add_eint(e0, e1) -> FHELinalg.dot_eint_int([e0, e1], [1, 1])
|
||||
- FHE.sub_int_eint(c, e) -> FHELinalg.dot_eint_int([e, c], [1, -1])
|
||||
- FHE.neg_eint(e) -> FHELinalg.dot_eint_int([e], [-1])
|
||||
- FHE.mul_eint_int(e, c) -> FHELinalg.dot_eint_int([e], [c])
|
||||
|
||||
Dependent dot operations, e.g.,
|
||||
|
||||
a = FHELinalg.dot_eint_int([a0, a1, ...], [c0, c1, ...])
|
||||
b = FHELinalg.dot_eint_int([b0, b1, ...], [d0, d1, ...])
|
||||
x = FHELinalg.dot_eint_int([a, b, ...], [f0, f1, ...])
|
||||
|
||||
are merged as follows:
|
||||
|
||||
x = FHELinalg.dot_eint_int([a0, a1, ..., b0, b1, ...],
|
||||
[f0*c0, f0*c1, ..., f1*d0, f1*d1, ...])
|
||||
|
||||
However, the implementation does not explicitly create the
|
||||
equivalent dot operations, but only accumulates the squared 2-norm
|
||||
of the constant vector of the equivalent dot operation along the
|
||||
edges of the data-flow graph composed by the operations in order to
|
||||
calculate the final 2-norm for the final single dot operation above.
|
||||
|
||||
For the example above, this means that the pass calculates the
|
||||
squared 2-norm of x, sqN(x) as:
|
||||
|
||||
sqN(a) = c0*c0 + c1*c1 + ...
|
||||
sqN(b) = d0*d0 + d1*d1 + ...
|
||||
sqN(x) = f0*f0*c0*c0 + f0*f0*c1*c1 + ... + f1*f1*d0*d0 + f1*f1*d1*d1 + ...
|
||||
= f0*f0*sqN(a) + f1*f1*sqN(b)
|
||||
|
||||
This leads to the following rules to calculate the squared 2-norm
|
||||
for the supported operations:
|
||||
|
||||
- Function argument -> 1
|
||||
- FHE.apply_lookup_table -> 1
|
||||
- FHE.zero() -> 1
|
||||
- FHELinalg.dot_eint_int([e0, e1, ...], [c0, c1, ...]) ->
|
||||
c0*c0*sqN(e0) + c1*c1*sqN(e1) + ...
|
||||
- FHE.add_eint_int(e, c) -> 1*1*sqN(e) = sqN(e)
|
||||
- FHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
|
||||
- FHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
|
||||
- FHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e)
|
||||
- FHE.mul_eint_int(e, c) -> c*c*sqN(e)
|
||||
|
||||
The final, non-squared 2-norm of an operation is the square root of the
|
||||
squared value rounded to the next highest integer.
|
||||
}];
|
||||
}
|
||||
|
||||
def MaxMANP : Pass<"MaxMANP", "::mlir::func::FuncOp"> {
|
||||
let summary = "Extract maximum FHE Minimal Arithmetic Noise Padding and "
|
||||
"maximum encrypted integer width";
|
||||
let description = [{
|
||||
This pass calculates the squared Minimal Arithmetic Noise Padding
|
||||
(MANP) for each operation using the MANP pass and extracts the
|
||||
maximum (non-squared) Minimal Arithmetic Noise Padding and the
|
||||
maximum ecrypted integer width from.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,24 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_UTILS_H
|
||||
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_UTILS_H
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace fhe {
|
||||
namespace utils {
|
||||
|
||||
bool isEncryptedValue(mlir::Value value);
|
||||
unsigned int getEintPrecision(mlir::Value value);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace fhe
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,3 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
@@ -0,0 +1,17 @@
|
||||
set(LLVM_TARGET_DEFINITIONS FHEInterfaces.td)
|
||||
mlir_tablegen(FHETypesInterfaces.h.inc -gen-type-interface-decls)
|
||||
mlir_tablegen(FHETypesInterfaces.cpp.inc -gen-type-interface-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS FHEOps.td)
|
||||
mlir_tablegen(FHEOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(FHEOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(FHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=FHE)
|
||||
mlir_tablegen(FHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=FHE)
|
||||
mlir_tablegen(FHEOpsDialect.h.inc -gen-dialect-decls -dialect=FHE)
|
||||
mlir_tablegen(FHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=FHE)
|
||||
add_public_tablegen_target(MLIRFHEOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRFHEOpsIncGen)
|
||||
|
||||
add_concretelang_doc(FHEOps FHEDialect concretelang/ -gen-dialect-doc -dialect=FHE)
|
||||
add_concretelang_doc(FHEOps FHEOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(FHETypes FHETypes concretelang/ -gen-typedef-doc)
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHEDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHEDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,24 @@
|
||||
//===- FHEDialect.td - FHE dialect ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_DIALECT
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHE_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def FHE_Dialect : Dialect {
|
||||
let name = "FHE";
|
||||
let summary = "High Level Fully Homorphic Encryption dialect";
|
||||
let description = [{
|
||||
A dialect for representation of high level operation on fully homomorphic ciphertext.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::FHE";
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> {
|
||||
let cppNamespace = "mlir::concretelang::FHE";
|
||||
|
||||
let description = [{
|
||||
Interface for encapsulating the common properties of encrypted integer types.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*description=*/"Get bit-width of the integer.",
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getWidth"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*description=*/"Get whether the integer is signed.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isSigned"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*description=*/"Get whether the integer is unsigned.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isUnsigned"
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
|
||||
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace FHE {
|
||||
|
||||
bool verifyEncryptedIntegerInputAndResultConsistency(
|
||||
Operation &op, FheIntegerInterface &input, FheIntegerInterface &result);
|
||||
|
||||
bool verifyEncryptedIntegerAndIntegerInputsConsistency(Operation &op,
|
||||
FheIntegerInterface &a,
|
||||
IntegerType &b);
|
||||
|
||||
/// Shared error message for all ApplyLookupTable variant Op (several Dialect)
|
||||
/// E.g. FHE.apply_lookup_table(input, lut)
|
||||
/// Message when the lut tensor has an invalid size,
|
||||
/// i.e. it cannot accomodate the input elements bitwidth
|
||||
template <class Op>
|
||||
void emitErrorBadLutSize(Op &op, std::string lutName, std::string inputName,
|
||||
int expectedSize, int bitWidth) {
|
||||
auto s = op.emitOpError();
|
||||
s << ": `" << lutName << "` (operand #2)"
|
||||
<< " inner dimension should have size " << expectedSize << "(=2^"
|
||||
<< bitWidth << ") to match "
|
||||
<< "`" << inputName << "` (operand #1)"
|
||||
<< " elements bitwidth (" << bitWidth << ")";
|
||||
}
|
||||
|
||||
} // namespace FHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,643 @@
|
||||
//===- FHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_OPS
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHE_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "concretelang/Dialect/FHE/IR/FHEDialect.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHETypes.td"
|
||||
|
||||
class FHE_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<FHE_Dialect, mnemonic, traits>;
|
||||
|
||||
def FHE_ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> {
|
||||
let summary = "Returns a trivial encrypted integer of 0";
|
||||
|
||||
let description = [{
|
||||
Returns a trivial encrypted integer of 0
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"FHE.zero"() : () -> !FHE.eint<2>
|
||||
"FHE.zero"() : () -> !FHE.esint<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs FHE_AnyEncryptedInteger:$out);
|
||||
}
|
||||
|
||||
def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> {
|
||||
let summary = "Creates a new tensor with all elements initialized to an encrypted zero.";
|
||||
|
||||
let description = [{
|
||||
Creates a new tensor with the shape specified in the result type and initializes its elements with an encrypted zero.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%tensor = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<4>>
|
||||
%tensor = "FHE.zero_tensor"() : () -> tensor<5x!FHE.esint<4>>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$tensor);
|
||||
}
|
||||
|
||||
def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> {
|
||||
let summary = "Adds an encrypted integer and a clear integer";
|
||||
|
||||
let description = [{
|
||||
Adds an encrypted integer and a clear integer.
|
||||
The clear integer must have at most one more bit than the encrypted integer
|
||||
and the result must have the same width and the same signedness as the encrypted integer.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
"FHE.add_eint_int"(%a, %i) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
|
||||
|
||||
// error
|
||||
"FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i4) -> !FHE.eint<2>
|
||||
"FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<3>
|
||||
"FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.esint<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> {
|
||||
let summary = "Adds two encrypted integers";
|
||||
|
||||
let description = [{
|
||||
Adds two encrypted integers
|
||||
The encrypted integers and the result must have the same width and the same signedness.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
"FHE.add_eint"(%a, %b): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>)
|
||||
|
||||
// error
|
||||
"FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
|
||||
"FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
|
||||
"FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
|
||||
"FHE.add_eint"(%a, %b): (!FHE.esint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> {
|
||||
let summary = "Subtract an encrypted integer from a clear integer";
|
||||
|
||||
let description = [{
|
||||
Subtract an encrypted integer from a clear integer.
|
||||
The clear integer must have one more bit than the encrypted integer
|
||||
and the result must have the same width and the same signedness as the encrypted integer.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
"FHE.sub_int_eint"(%i, %a) : (i3, !FHE.esint<2>) -> !FHE.esint<2>
|
||||
|
||||
// error
|
||||
"FHE.sub_int_eint"(%i, %a) : (i4, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
"FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.eint<3>
|
||||
"FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.esint<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyInteger:$a, FHE_AnyEncryptedInteger:$b);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, b.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> {
|
||||
let summary = "Subtract a clear integer from an encrypted integer";
|
||||
|
||||
let description = [{
|
||||
Subtract a clear integer from an encrypted integer.
|
||||
The clear integer must have one more bit than the encrypted integer
|
||||
and the result must have the same width and the same signedness as the encrypted integer.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
"FHE.sub_eint_int"(%i, %a) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
|
||||
|
||||
// error
|
||||
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i4) -> !FHE.eint<2>
|
||||
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<3>
|
||||
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.esint<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def FHE_SubEintOp : FHE_Op<"sub_eint", [NoSideEffect]> {
|
||||
let summary = "Subtract an encrypted integer from an encrypted integer";
|
||||
|
||||
let description = [{
|
||||
Subtract an encrypted integer from an encrypted integer.
|
||||
The encrypted integers and the result must have the same width and the same signedness.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>)
|
||||
|
||||
// error
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.esint<2>) -> (!FHE.esint<2>)
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> {
|
||||
|
||||
let summary = "Negates an encrypted integer";
|
||||
|
||||
let description = [{
|
||||
Negates an encrypted integer.
|
||||
The result must have the same width and the same signedness as the encrypted integer.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
"FHE.neg_eint"(%a): (!FHE.esint<2>) -> (!FHE.esint<2>)
|
||||
|
||||
// error
|
||||
"FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.eint<3>)
|
||||
"FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.esint<2>)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a), [{
|
||||
build($_builder, $_state, a.getType(), a);
|
||||
}]>
|
||||
];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
|
||||
let summary = "Multiply an encrypted integer with a clear integer";
|
||||
|
||||
let description = [{
|
||||
Multiply an encrypted integer with a clear integer.
|
||||
The clear integer must have one more bit than the encrypted integer
|
||||
and the result must have the same width and the same signedness as the encrypted integer.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
"FHE.mul_eint_int"(%a, %i) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
|
||||
|
||||
// error
|
||||
"FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i4) -> !FHE.eint<2>
|
||||
"FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<3>
|
||||
"FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.esint<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def FHE_MulEintOp : FHE_Op<"mul_eint", [NoSideEffect]> {
|
||||
let summary = "Multiplies two encrypted integers";
|
||||
|
||||
let description = [{
|
||||
Multiplies two encrypted integers.
|
||||
|
||||
The encrypted integers and the result must have the same width and
|
||||
signedness. Also, due to the current implementation, one supplementary
|
||||
bit of width must be provided, in addition to the number of bits needed
|
||||
to encode the largest output value.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<3>, !FHE.eint<3>) -> (!FHE.eint<3>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.esint<3>, !FHE.esint<3>) -> (!FHE.esint<3>)
|
||||
|
||||
// error
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
|
||||
"FHE.mul_eint"(%a, %b): (!FHE.esint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_MaxEintOp : FHE_Op<"max_eint", [NoSideEffect]> {
|
||||
let summary = "Get maximum of two encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Get maximum of two encrypted integers using the formula, 'max(x, y) == max(x - y, 0) + y'.
|
||||
Type of inputs and the output should be the same.
|
||||
|
||||
If `x - y`` inside the max overflows or underflows, the behavior is undefined.
|
||||
So to support the full range, you should increase the bit-width by 1 manually.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.esint<3>, !FHE.esint<3>) -> !FHE.esint<3>
|
||||
|
||||
// error
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<3>) -> !FHE.eint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.esint<2>
|
||||
"FHE.max_eint"(%x, %y) : (!FHE.esint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$x, FHE_AnyEncryptedInteger:$y);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$x, "Value":$y), [{
|
||||
build($_builder, $_state, x.getType(), x, y);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
|
||||
let summary = "Cast an unsigned integer to a signed one";
|
||||
|
||||
let description = [{
|
||||
Cast an unsigned integer to a signed one.
|
||||
The result must have the same width as the input.
|
||||
|
||||
The behavior is undefined on overflow/underflow.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.to_signed"(%x) : (!FHE.eint<2>) -> !FHE.esint<2>
|
||||
|
||||
// error
|
||||
"FHE.to_signed"(%x) : (!FHE.eint<2>) -> !FHE.esint<3>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedIntegerType:$input);
|
||||
let results = (outs FHE_EncryptedSignedIntegerType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [NoSideEffect]> {
|
||||
let summary = "Cast a signed integer to an unsigned one";
|
||||
|
||||
let description = [{
|
||||
Cast a signed integer to an unsigned one.
|
||||
The result must have the same width as the input.
|
||||
|
||||
The behavior is undefined on overflow/underflow.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.to_unsigned"(%x) : (!FHE.esint<2>) -> !FHE.eint<2>
|
||||
|
||||
// error
|
||||
"FHE.to_unsigned"(%x) : (!FHE.esint<2>) -> !FHE.eint<3>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedSignedIntegerType:$input);
|
||||
let results = (outs FHE_EncryptedIntegerType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> {
|
||||
|
||||
let summary = "Applies a clear lookup table to an encrypted integer";
|
||||
|
||||
let description = [{
|
||||
Applies a clear lookup table to an encrypted integer, the width of the result can be different than the width of the operand.
|
||||
The lookup table must be a tensor of size equals to `2^p` where `p` is the width of the encrypted integer.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<2>)
|
||||
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<3>, tensor<4xi64>) -> (!FHE.eint<2>)
|
||||
|
||||
// error
|
||||
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<8xi64>) -> (!FHE.eint<2>)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a,
|
||||
TensorOf<[AnyInteger]>:$lut);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_RoundEintOp: FHE_Op<"round", [NoSideEffect]> {
|
||||
|
||||
let summary = "Rounds a ciphertext to a smaller precision.";
|
||||
|
||||
let description = [{
|
||||
Assuming a ciphertext whose message is implemented over `p` bits, this
|
||||
operation rounds it to fit to `q` bits with `p>q`.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.round"(%a): (!FHE.eint<6>) -> (!FHE.eint<5>)
|
||||
"FHE.round"(%a): (!FHE.eint<5>) -> (!FHE.eint<3>)
|
||||
"FHE.round"(%a): (!FHE.eint<3>) -> (!FHE.eint<2>)
|
||||
"FHE.round"(%a): (!FHE.esint<3>) -> (!FHE.esint<2>)
|
||||
|
||||
// error
|
||||
"FHE.round"(%a): (!FHE.eint<6>) -> (!FHE.eint<6>)
|
||||
"FHE.round"(%a): (!FHE.eint<4>) -> (!FHE.eint<5>)
|
||||
"FHE.round"(%a): (!FHE.eint<4>) -> (!FHE.esint<5>)
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$input);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
// FHE Boolean Operations
|
||||
|
||||
def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> {
|
||||
|
||||
let summary = "Applies a truth table based on two boolean inputs";
|
||||
|
||||
let description = [{
|
||||
Applies a truth table based on two boolean inputs.
|
||||
|
||||
truth table must be a tensor of 4 boolean values.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> (!FHE.ebool)
|
||||
|
||||
// error
|
||||
"FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<7xi64>) -> (!FHE.ebool)
|
||||
```
|
||||
}];
|
||||
|
||||
// The reason the truth table is of AnyInteger and not I1 is that in lowering passes, the truth_table is meant to be passed
|
||||
// to an LUT operation which requires the table to be of type I64. Whenever lowering passes are no more restrictive, this
|
||||
// can be set to I1 to reflect the boolean logic.
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right, TensorOf<[AnyInteger]>:$truth_table);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_MuxOp : FHE_Op<"mux", [NoSideEffect]> {
|
||||
|
||||
let summary = "Multiplexer for two encrypted boolean inputs, based on an encrypted condition";
|
||||
|
||||
let description = [{
|
||||
Mutex between two encrypted boolean inputs, based on an encrypted condition.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"FHE.mux"(%cond, %c1, %c2): (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> (!FHE.ebool)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$cond, FHE_EncryptedBooleanType:$c1, FHE_EncryptedBooleanType:$c2);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolAndOp : FHE_Op<"and", [NoSideEffect]> {
|
||||
|
||||
let summary = "Applies an AND gate to two encrypted boolean values";
|
||||
|
||||
let description = [{
|
||||
Applies an AND gate to two encrypted boolean values.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"FHE.and"(%a, %b): (!FHE.ebool, !FHE.ebool) -> (!FHE.ebool)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolOrOp : FHE_Op<"or", [NoSideEffect]> {
|
||||
|
||||
let summary = "Applies an OR gate to two encrypted boolean values";
|
||||
|
||||
let description = [{
|
||||
Applies an OR gate to two encrypted boolean values.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"FHE.or"(%a, %b): (!FHE.ebool, !FHE.ebool) -> (!FHE.ebool)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolNandOp : FHE_Op<"nand", [NoSideEffect]> {
|
||||
|
||||
let summary = "Applies a NAND gate to two encrypted boolean values";
|
||||
|
||||
let description = [{
|
||||
Applies a NAND gate to two encrypted boolean values.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"FHE.nand"(%a, %b): (!FHE.ebool, !FHE.ebool) -> (!FHE.ebool)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolXorOp : FHE_Op<"xor", [NoSideEffect]> {
|
||||
|
||||
let summary = "Applies a XOR gate to two encrypted boolean values";
|
||||
|
||||
let description = [{
|
||||
Applies a XOR gate to two encrypted boolean values.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"FHE.xor"(%a, %b): (!FHE.ebool, !FHE.ebool) -> (!FHE.ebool)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_BoolNotOp : FHE_Op<"not", [NoSideEffect]> {
|
||||
|
||||
let summary = "Applies a NOT gate to an encrypted boolean value";
|
||||
|
||||
let description = [{
|
||||
Applies a NOT gate to an encrypted boolean value.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"FHE.not"(%a): (!FHE.ebool) -> (!FHE.ebool)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$value);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
}
|
||||
|
||||
def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> {
|
||||
let summary = "Cast an unsigned integer to a boolean";
|
||||
|
||||
let description = [{
|
||||
Cast an unsigned integer to a boolean.
|
||||
|
||||
The input must necessarily be of width 1 or 2. 2 being the current representation
|
||||
of an encrypted boolean, leaving one bit for the carry.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.to_bool"(%x) : (!FHE.eint<1>) -> !FHE.ebool
|
||||
"FHE.to_bool"(%x) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
|
||||
// error
|
||||
"FHE.to_bool"(%x) : (!FHE.eint<3>) -> !FHE.ebool
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedIntegerType:$input);
|
||||
let results = (outs FHE_EncryptedBooleanType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHE_FromBoolOp : FHE_Op<"from_bool", [NoSideEffect]> {
|
||||
let summary = "Cast a boolean to an unsigned integer";
|
||||
|
||||
let description = [{
|
||||
Cast a boolean to an unsigned integer.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
"FHE.from_bool"(%x) : (!FHE.ebool) -> !FHE.eint<1>
|
||||
"FHE.from_bool"(%x) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
"FHE.from_bool"(%x) : (!FHE.ebool) -> !FHE.eint<4>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedBooleanType:$input);
|
||||
let results = (outs FHE_EncryptedIntegerType);
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHETYPES_H
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHETYPES_H
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypesInterfaces.h.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsTypes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,88 @@
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
|
||||
|
||||
include "concretelang/Dialect/FHE/IR/FHEDialect.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHEInterfaces.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
class FHE_Type<string name, list<Trait> traits = []> :
|
||||
TypeDef<FHE_Dialect, name, traits> { }
|
||||
|
||||
def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger",
|
||||
[MemRefElementTypeInterface, FheIntegerInterface]> {
|
||||
let mnemonic = "eint";
|
||||
|
||||
let summary = "An encrypted integer";
|
||||
|
||||
let description = [{
|
||||
An encrypted integer with `width` bits to performs FHE Operations.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
!FHE.eint<7>
|
||||
!FHE.eint<6>
|
||||
```
|
||||
}];
|
||||
|
||||
let parameters = (ins "unsigned":$width);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let genVerifyDecl = true;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool isSigned() const { return false; }
|
||||
bool isUnsigned() const { return true; }
|
||||
}];
|
||||
}
|
||||
|
||||
def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger",
|
||||
[MemRefElementTypeInterface, FheIntegerInterface]> {
|
||||
let mnemonic = "esint";
|
||||
|
||||
let summary = "An encrypted signed integer";
|
||||
|
||||
let description = [{
|
||||
An encrypted signed integer with `width` bits to performs FHE Operations.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
!FHE.esint<7>
|
||||
!FHE.esint<6>
|
||||
```
|
||||
}];
|
||||
|
||||
let parameters = (ins "unsigned":$width);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let genVerifyDecl = true;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool isSigned() const { return true; }
|
||||
bool isUnsigned() const { return false; }
|
||||
}];
|
||||
}
|
||||
|
||||
def FHE_AnyEncryptedInteger : Type<Or<[
|
||||
FHE_EncryptedIntegerType.predicate,
|
||||
FHE_EncryptedSignedIntegerType.predicate
|
||||
]>>;
|
||||
|
||||
def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean",
|
||||
[MemRefElementTypeInterface]> {
|
||||
let mnemonic = "ebool";
|
||||
|
||||
let summary = "An encrypted boolean";
|
||||
|
||||
let description = [{
|
||||
An encrypted boolean.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the required number of bits to represent an encrypted boolean
|
||||
static size_t getWidth() { return 2; }
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,24 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_FHE_BIGINT_PASS_H
|
||||
#define CONCRETELANG_FHE_BIGINT_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>>
|
||||
createFHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#ifndef CONCRETELANG_FHE_BIGINT_PASS
|
||||
#define CONCRETELANG_FHE_BIGINT_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FHEBigIntTransform : Pass<"fhe-big-int-transform"> {
|
||||
let summary = "Transform FHE operations on big integer into operations on chunks of small integer";
|
||||
let constructor = "mlir::concretelang::createFHEBigIntTransformPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS BigInt.td)
|
||||
mlir_tablegen(BigInt.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangFHEBigIntPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangFHEBigIntPassIncGen)
|
||||
@@ -0,0 +1,23 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_FHE_BOOLEAN_PASS_H
|
||||
#define CONCRETELANG_FHE_BOOLEAN_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createFHEBooleanTransformPass();
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user