mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-10 13:27:57 -05:00
feat(compiler): add matmul eint eint op
This commit is contained in:
committed by
Andrei Stoian
parent
a5c679f0dc
commit
817ee6b637
3
.gitignore
vendored
3
.gitignore
vendored
@@ -5,3 +5,6 @@
|
||||
|
||||
# Jetbrains tools
|
||||
.idea/
|
||||
|
||||
# HPX library
|
||||
compilers/concrete-compiler/compiler/hpx*
|
||||
|
||||
@@ -50,6 +50,12 @@ include_directories(${PROJECT_BINARY_DIR}/include)
|
||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
|
||||
if(DEFINED LLVM_USE_LINKER AND (NOT ${LLVM_USE_LINKER} STREQUAL ""))
|
||||
message(INFO " Using custom Linker: ${CMAKE_LINKER}")
|
||||
else()
|
||||
message(INFO " Using standard linker")
|
||||
endif()
|
||||
|
||||
# Custom doc generation function
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
|
||||
include(AddConcretelangDoc)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
LLVM_PROJECT_DIR=../llvm-project
|
||||
|
||||
BUILD_TYPE?=Release
|
||||
BUILD_DIR?=./build
|
||||
# Build dir is `build` for Release builds and `build_BUILD_TYPE` for
|
||||
# other build types (e.g. Debug, ReleaseWithDebugInfo, etc.)
|
||||
BUILD_DIR?=./$(shell ./get_build_dir.sh "$(BUILD_TYPE)")
|
||||
Python3_EXECUTABLE?=$(shell which python3)
|
||||
BINDINGS_PYTHON_ENABLED=ON
|
||||
DATAFLOW_EXECUTION_ENABLED=OFF
|
||||
@@ -86,6 +88,15 @@ else
|
||||
CXX_COMPILER_OPTION=
|
||||
endif
|
||||
|
||||
# If the build type is Debug, and the lld linked is installed
|
||||
# then use it
|
||||
CUSTOM_LINKER_OPTS=
|
||||
ifneq ($(shell which lld),)
|
||||
ifeq ($(BUILD_TYPE),Debug)
|
||||
CUSTOM_LINKER_OPTS=-DLLVM_USE_LINKER=lld
|
||||
endif
|
||||
endif
|
||||
|
||||
# don't run parallel python tests if compiler doesn't support it
|
||||
ifeq ($(DATAFLOW_EXECUTION_ENABLED),ON)
|
||||
PYTHON_TESTS_MARKER=""
|
||||
@@ -119,6 +130,7 @@ $(BUILD_DIR)/configured.stamp:
|
||||
$(CMAKE_CCACHE_OPTIONS) \
|
||||
$(CC_COMPILER_OPTION) \
|
||||
$(CXX_COMPILER_OPTION) \
|
||||
$(CUSTOM_LINKER_OPTS) \
|
||||
-DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \
|
||||
-DLLVM_BUILD_EXAMPLES=OFF \
|
||||
-DLLVM_TARGETS_TO_BUILD="host" \
|
||||
@@ -138,6 +150,10 @@ $(BUILD_DIR)/configured.stamp:
|
||||
|
||||
build-initialized: $(BUILD_DIR)/configured.stamp
|
||||
|
||||
reconfigure:
|
||||
rm $(BUILD_DIR)/configured.stamp
|
||||
rm $(BUILD_DIR)/CMakeCache.txt
|
||||
|
||||
doc: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target mlir-doc
|
||||
|
||||
@@ -263,7 +279,14 @@ $(FIXTURE_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py FORCE
|
||||
$(FIXTURE_CPU_DIR)/bug_report.yaml:
|
||||
unzip -o $(FIXTURE_CPU_DIR)/bug_report.zip -d $(FIXTURE_CPU_DIR)
|
||||
|
||||
generate-cpu-tests: $(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml $(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/bug_report.yaml $(FIXTURE_CPU_DIR)/end_to_end_round.yaml $(FIXTURE_CPU_DIR)/end_to_end_multi_precision.yaml
|
||||
generate-cpu-tests: \
|
||||
$(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml \
|
||||
$(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml \
|
||||
$(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml \
|
||||
$(FIXTURE_CPU_DIR)/bug_report.yaml \
|
||||
$(FIXTURE_CPU_DIR)/end_to_end_round.yaml \
|
||||
$(FIXTURE_CPU_DIR)/end_to_end_multi_precision.yaml \
|
||||
$(FIXTURE_CPU_DIR)/end_to_end_linalg_enc_enc_matmul_dot.yaml
|
||||
|
||||
SECURITY_TO_TEST=128
|
||||
OPTIMIZATION_STRATEGY_TO_TEST=dag-mono dag-multi
|
||||
@@ -275,7 +298,8 @@ run-end-to-end-tests: $(GTEST_PARALLEL_PY) build-end-to-end-tests generate-cpu-t
|
||||
$(GTEST_PARALLEL_CMD) $(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/$(TEST);)
|
||||
$(foreach optimizer_strategy,$(OPTIMIZATION_STRATEGY_TO_TEST), $(foreach security,$(SECURITY_TO_TEST), \
|
||||
$(GTEST_PARALLEL_CMD) $(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_test \
|
||||
$(GTEST_PARALLEL_SEPARATOR) --backend=cpu --security-level=$(security) --optimizer-strategy=$(optimizer_strategy) --jit $(FIXTURE_CPU_DIR)/*.yaml;))
|
||||
$(GTEST_PARALLEL_SEPARATOR) --backend=cpu --security-level=$(security) \
|
||||
--optimizer-strategy=$(optimizer_strategy) --jit $(FIXTURE_CPU_DIR)/*.yaml;))
|
||||
|
||||
### end-to-end-tests GPU
|
||||
|
||||
@@ -542,4 +566,5 @@ FORCE:
|
||||
opt \
|
||||
mlir-opt \
|
||||
mlir-cpu-runner \
|
||||
mlir-translate
|
||||
mlir-translate \
|
||||
reconfigure
|
||||
|
||||
@@ -70,6 +70,12 @@ Run the compiler
|
||||
./build-Release/bin/concretecompiler
|
||||
```
|
||||
|
||||
#### Debug build and custom linker
|
||||
|
||||
To build a debug version of the project, you can set `BUILD_TYPE=Debug` in the `Makefile`. In `Debug`
|
||||
the build system will detect if the `lld` linker is installed on the system and use it. `lld` is much faster
|
||||
than the default `ld` linker. Release builds with `lld` can also be enabled by modifying the `Makefile`.
|
||||
|
||||
### Installation from source
|
||||
|
||||
You can install libs, bins, and include files into a specific directory by running:
|
||||
|
||||
10
compilers/concrete-compiler/compiler/get_build_dir.sh
Executable file
10
compilers/concrete-compiler/compiler/get_build_dir.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/bin/env bash
|
||||
|
||||
BUILD_TYPE=$1
|
||||
|
||||
if [[ ${BUILD_TYPE,,} = "release" ]]; then
|
||||
echo "build"
|
||||
else
|
||||
echo "build_${BUILD_TYPE}"
|
||||
fi
|
||||
|
||||
@@ -25,6 +25,11 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(Operation &op,
|
||||
FheIntegerInterface &a,
|
||||
IntegerType &b);
|
||||
|
||||
// Checks the consistency between two integer inputs of an operation
|
||||
bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
|
||||
FheIntegerInterface &a,
|
||||
FheIntegerInterface &b);
|
||||
|
||||
/// Shared error message for all ApplyLookupTable variant Op (several Dialect)
|
||||
/// E.g. FHE.apply_lookup_table(input, lut)
|
||||
/// Message when the lut tensor has an invalid size,
|
||||
|
||||
@@ -587,12 +587,37 @@ def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int"> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
def FHELinalg_DotEint : FHELinalg_Op<"dot_eint_eint"> {
|
||||
let summary = "Returns the encrypted dot product between two vectors of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Performs a dot product between two vectors of encrypted integers.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the dot product of `%a0` with `%a1`
|
||||
"FHELinalg.dot_eint_eint"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4x!FHE.eint<4>>) -> !FHE.eint<4>
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$lhs,
|
||||
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs);
|
||||
|
||||
let results = (outs FHE_AnyEncryptedInteger:$out);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
|
||||
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a matrix of clear integers.";
|
||||
|
||||
let description = [{
|
||||
Performs a matrix multiplication of a matrix of encrypted integers and a matrix of clear integers.
|
||||
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
|
||||
The width of the clear integers must be less than or equals to the width of encrypted integers.
|
||||
|
||||
The behavior depends on the arguments in the following way:
|
||||
|
||||
@@ -730,7 +755,7 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt
|
||||
|
||||
let description = [{
|
||||
Performs a matrix multiplication of a matrix of clear integers and a matrix of encrypted integers.
|
||||
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
|
||||
The width of the clear integers must be less than or equals to the width of encrypted integers.
|
||||
|
||||
The behavior depends on the arguments in the following way:
|
||||
|
||||
@@ -863,6 +888,144 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
def FHELinalg_MatMulEintEintOp : FHELinalg_Op<"matmul_eint_eint", [TensorBinaryEint]> {
|
||||
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a second matrix of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Performs a matrix multiplication of a matrix of encrypted integers and a second matrix of encrypted integers.
|
||||
|
||||
The behavior depends on the arguments in the following way:
|
||||
|
||||
- If both arguments are 2-D,
|
||||
they are multiplied like conventional matrices.
|
||||
|
||||
e.g.,
|
||||
|
||||
arg0: tensor<MxN> = [...]
|
||||
arg1: tensor<NxP> = [...]
|
||||
|
||||
result: tensor<MxP> = [...]
|
||||
|
||||
- If the first argument is a vector (1-D),
|
||||
it is treated as a matrix with a single row and standard matrix multiplication is performed.
|
||||
|
||||
After standard matrix multiplication,
|
||||
the first dimension is removed from the result.
|
||||
|
||||
e.g.,
|
||||
|
||||
arg0: tensor<3> = [x, y, z]
|
||||
arg1: tensor<3xM> = [
|
||||
[_, _, ..., _, _],
|
||||
[_, _, ..., _, _],
|
||||
[_, _, ..., _, _],
|
||||
]
|
||||
|
||||
is treated as
|
||||
|
||||
arg0: tensor<1x3> = [
|
||||
[x, y, z]
|
||||
]
|
||||
arg1: tensor<3xM> = [
|
||||
[_, _, ..., _, _],
|
||||
[_, _, ..., _, _],
|
||||
[_, _, ..., _, _],
|
||||
]
|
||||
|
||||
and matrix multiplication is performed with the following form (1x3 @ 3xM -> 1xM)
|
||||
|
||||
result: tensor<1xM> = [[_, _, ..., _, _]]
|
||||
|
||||
finally, the first dimension is removed by definition so the result has the following form
|
||||
|
||||
result: tensor<M> = [_, _, ..., _, _]
|
||||
|
||||
- If the second argument is 1-D,
|
||||
it is treated as a matrix with a single column and standard matrix multiplication is performed.
|
||||
|
||||
After standard matrix multiplication,
|
||||
the last dimension is removed from the result.
|
||||
|
||||
e.g.,
|
||||
|
||||
arg0: tensor<Mx3> = [
|
||||
[_, _, _],
|
||||
[_, _, _],
|
||||
...,
|
||||
[_, _, _],
|
||||
[_, _, _],
|
||||
]
|
||||
arg1: tensor<3> = [x, y, z]
|
||||
|
||||
is treated as
|
||||
|
||||
arg0: tensor<Mx3> = [
|
||||
[_, _, _],
|
||||
[_, _, _],
|
||||
...,
|
||||
[_, _, _],
|
||||
[_, _, _],
|
||||
]
|
||||
arg1: tensor<3x1> = [
|
||||
[x],
|
||||
[y],
|
||||
[z],
|
||||
]
|
||||
|
||||
and matrix multiplication is performed with the following form (Mx3 @ 3x1 -> Mx1)
|
||||
|
||||
result: tensor<Mx1> = [
|
||||
[_],
|
||||
[_],
|
||||
...,
|
||||
[_],
|
||||
[_],
|
||||
]
|
||||
|
||||
finally, the last dimension is removed by definition so the result has the following form
|
||||
|
||||
result: tensor<M> = [_, _, _]
|
||||
|
||||
- If either argument is N-D where N > 2,
|
||||
the operation is treated as a collection of matrices residing in the last two indices and broadcasted accordingly.
|
||||
|
||||
arg0: tensor<Kx1MxN> = [...]
|
||||
arg1: tensor<LxNxP> = [...]
|
||||
|
||||
result: tensor<KxLxMxP> = [...]
|
||||
|
||||
```mlir
|
||||
"FHELinalg.matmul_eint_eint(%a, %b) : (tensor<MxNx!FHE.eint<p>>, tensor<NxPx!FHE.eint<p>'>) -> tensor<MxPx!FHE.eint<p>>"
|
||||
"FHELinalg.matmul_eint_eint(%a, %b) : (tensor<KxLxMxNx!FHE.eint<p>>, tensor<KxLxNxPx!FHE.eint<p>'>) -> tensor<KxLxMxPx!FHE.eint<p>>"
|
||||
"FHELinalg.matmul_eint_eint(%a, %b) : (tensor<MxNx!FHE.eint<p>>, tensor<Nx!FHE.eint<p>'>) -> tensor<Mx!FHE.eint<p>>"
|
||||
"FHELinalg.matmul_eint_eint(%a, %b) : (tensor<Nx!FHE.eint<p>>, tensor<NxPx!FHE.eint<p>'>) -> tensor<Px!FHE.eint<p>>"
|
||||
```
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the matrix multiplication of a 3x2 matrix of encrypted integers and a 2x3 matrix of integers.
|
||||
// [ 1, 2, 3]
|
||||
// [ 2, 3, 4]
|
||||
// *
|
||||
// [1,2] [ 5, 8,11]
|
||||
// [3,4] = [11,18,25]
|
||||
// [5,6] [17,28,39]
|
||||
//
|
||||
"FHELinalg.matmul_eint_eint"(%a, %b) : (tensor<3x2x!FHE.eint<6>>, tensor<2x3x!FHE.eint<6>>) -> tensor<3x3x!FHE.eint<12>>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$lhs,
|
||||
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
|
||||
let summary = "Returns the sum of elements of a tensor of encrypted integers along specified axes.";
|
||||
|
||||
|
||||
@@ -45,11 +45,16 @@ inline void forwardOptimizerID(mlir::Operation *source,
|
||||
destination->setAttr("TFHE.OId", optimizerIdAttr);
|
||||
}
|
||||
|
||||
struct DotToLinalgGeneric
|
||||
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::Dot> {
|
||||
DotToLinalgGeneric(::mlir::MLIRContext *context)
|
||||
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Dot>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
template <typename DotOp, typename FHEMulOp>
|
||||
struct DotToLinalgGeneric : public ::mlir::OpRewritePattern<DotOp> {
|
||||
DotToLinalgGeneric(
|
||||
::mlir::MLIRContext *context,
|
||||
std::function<FHEMulOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
|
||||
mlir::Value, mlir::Value)>
|
||||
createMulOp)
|
||||
: ::mlir::OpRewritePattern<DotOp>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT),
|
||||
createMulOp(createMulOp) {}
|
||||
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `FHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
|
||||
@@ -87,7 +92,7 @@ struct DotToLinalgGeneric
|
||||
/// %o = tensor.extract %1[%c0] : tensor<1x!FHE.eint<0>>
|
||||
///
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(::mlir::concretelang::FHELinalg::Dot dotOp,
|
||||
matchAndRewrite(DotOp dotOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
auto zeroTensorOp = rewriter.create<mlir::concretelang::FHE::ZeroTensorOp>(
|
||||
@@ -111,9 +116,9 @@ struct DotToLinalgGeneric
|
||||
auto regBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::concretelang::FHE::MulEintIntOp mul =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::MulEintIntOp>(
|
||||
dotOp.getLoc(), blockArgs[0], blockArgs[1]);
|
||||
auto mul = this->createMulOp(nestedBuilder, dotOp.getLoc(),
|
||||
dotOp.getResult().getType(), blockArgs[0],
|
||||
blockArgs[1]);
|
||||
forwardOptimizerID(dotOp, mul);
|
||||
mlir::concretelang::FHE::AddEintOp add =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::AddEintOp>(
|
||||
@@ -140,6 +145,11 @@ struct DotToLinalgGeneric
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
std::function<FHEMulOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
|
||||
mlir::Value, mlir::Value)>
|
||||
createMulOp;
|
||||
};
|
||||
|
||||
mlir::AffineMap
|
||||
@@ -826,13 +836,13 @@ struct FHELinalgNegEintToLinalgGeneric
|
||||
/// linalg.yield %e : !FHE.eint<p>
|
||||
/// }
|
||||
///
|
||||
template <typename FHELinalgMatmulOp>
|
||||
template <typename FHELinalgMatmulOp, typename FHEMulOp>
|
||||
struct FHELinalgMatmulToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<FHELinalgMatmulOp> {
|
||||
FHELinalgMatmulToLinalgGeneric(
|
||||
mlir::MLIRContext *context,
|
||||
std::function<FHE::MulEintIntOp(mlir::OpBuilder &, mlir::Location,
|
||||
mlir::Type, mlir::Value, mlir::Value)>
|
||||
std::function<FHEMulOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
|
||||
mlir::Value, mlir::Value)>
|
||||
createMulOp,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
@@ -1091,8 +1101,8 @@ struct FHELinalgMatmulToLinalgGeneric
|
||||
};
|
||||
|
||||
private:
|
||||
std::function<FHE::MulEintIntOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
|
||||
mlir::Value, mlir::Value)>
|
||||
std::function<FHEMulOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
|
||||
mlir::Value, mlir::Value)>
|
||||
createMulOp;
|
||||
};
|
||||
|
||||
@@ -2195,7 +2205,21 @@ void FHETensorOpsToLinalg::runOnOperation() {
|
||||
target.addLegalOp<bufferization::AllocTensorOp>();
|
||||
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.insert<DotToLinalgGeneric>(&getContext());
|
||||
|
||||
patterns.insert<DotToLinalgGeneric<mlir::concretelang::FHELinalg::Dot,
|
||||
mlir::concretelang::FHE::MulEintIntOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(
|
||||
loc, type, arg0, arg1);
|
||||
});
|
||||
patterns.insert<DotToLinalgGeneric<mlir::concretelang::FHELinalg::DotEint,
|
||||
mlir::concretelang::FHE::MulEintOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::concretelang::FHE::MulEintOp>(loc, type,
|
||||
arg0, arg1);
|
||||
});
|
||||
patterns.insert<
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::AddEintOp,
|
||||
mlir::concretelang::FHE::AddEintOp>>(
|
||||
@@ -2227,19 +2251,29 @@ void FHETensorOpsToLinalg::runOnOperation() {
|
||||
patterns.insert<FHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
|
||||
patterns.insert<FHELinalgNegEintToLinalgGeneric>(&getContext());
|
||||
patterns.insert<FHELinalgMatmulToLinalgGeneric<
|
||||
mlir::concretelang::FHELinalg::MatMulEintIntOp>>(
|
||||
mlir::concretelang::FHELinalg::MatMulEintIntOp,
|
||||
mlir::concretelang::FHE::MulEintIntOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(
|
||||
loc, type, arg0, arg1);
|
||||
});
|
||||
patterns.insert<FHELinalgMatmulToLinalgGeneric<
|
||||
mlir::concretelang::FHELinalg::MatMulIntEintOp>>(
|
||||
mlir::concretelang::FHELinalg::MatMulIntEintOp,
|
||||
mlir::concretelang::FHE::MulEintIntOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(
|
||||
loc, type, arg1, arg0);
|
||||
});
|
||||
patterns.insert<FHELinalgMatmulToLinalgGeneric<
|
||||
mlir::concretelang::FHELinalg::MatMulEintEintOp,
|
||||
mlir::concretelang::FHE::MulEintOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::concretelang::FHE::MulEintOp>(loc, type,
|
||||
arg1, arg0);
|
||||
});
|
||||
patterns.insert<FHELinalgApplyMultiLookupTableToLinalgGeneric>(&getContext());
|
||||
patterns.insert<FHELinalgApplyMappedLookupTableToLinalgGeneric>(
|
||||
&getContext());
|
||||
|
||||
@@ -159,6 +159,10 @@ struct FunctionToDag {
|
||||
DEBUG("Replace Dot by LevelledOp on " << op);
|
||||
index = addLevelledOp(dag, op, encrypted_inputs);
|
||||
}
|
||||
} else if (auto dot = asDotEint(op)) {
|
||||
index = addDotEint(dag, dot, encrypted_inputs, precision);
|
||||
// The above function call sets the OIds, can return right away
|
||||
return;
|
||||
} else if (auto mul = asMul(op)) {
|
||||
// special case as mul are rewritten in several optimizer nodes
|
||||
addMul(dag, mul, encrypted_inputs, precision);
|
||||
@@ -175,6 +179,10 @@ struct FunctionToDag {
|
||||
// special case as max are rewritten in several optimizer nodes
|
||||
addMaxpool2d(dag, maxpool2d, encrypted_inputs, precision);
|
||||
return;
|
||||
} else if (auto matmulEintEint = asMatmulEintEint(op)) {
|
||||
index =
|
||||
addEncMatMulTensor(dag, matmulEintEint, encrypted_inputs, precision);
|
||||
return;
|
||||
} else {
|
||||
index = addLevelledOp(dag, op, encrypted_inputs);
|
||||
}
|
||||
@@ -352,7 +360,7 @@ struct FunctionToDag {
|
||||
dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost,
|
||||
tluSubManp, slice(resultShape), comment);
|
||||
index[result] = resultNode;
|
||||
// Set attribute on the MLIR node
|
||||
|
||||
mlir::Builder builder(mulOp.getContext());
|
||||
mlir::SmallVector<int32_t, 5> operatorIndexes = {
|
||||
(int32_t)addNode.index, (int32_t)lhsTluNode.index,
|
||||
@@ -362,10 +370,193 @@ struct FunctionToDag {
|
||||
// We push that at the end by convention
|
||||
operatorIndexes.push_back(lhsCorrectionNode.value().index);
|
||||
}
|
||||
|
||||
if (setOptimizerID)
|
||||
mulOp->setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes));
|
||||
}
|
||||
|
||||
template <typename InnerProductOp>
|
||||
concrete_optimizer::dag::OperatorIndex
|
||||
addTensorInnerProductEncEnc(optimizer::Dag &dag,
|
||||
InnerProductOp &innerProductOp, Inputs &inputs,
|
||||
int precision) {
|
||||
mlir::Value result = innerProductOp.getResult();
|
||||
const std::vector<uint64_t> resultShape = getShape(result);
|
||||
|
||||
// We assume a first tensorized matmul step
|
||||
// is the construction of matrices of:
|
||||
// - sums of all pairs
|
||||
// - differences of all pairs
|
||||
// where pairs are pairs of values that are to be multiplied together
|
||||
|
||||
// Compute the number of elements in each of the
|
||||
// matrices of pairs
|
||||
|
||||
auto lhsType = ((mlir::Type)innerProductOp.getLhs().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
|
||||
auto rhsType = ((mlir::Type)innerProductOp.getRhs().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
|
||||
std::vector<int64_t> lhsShape = lhsType.getShape();
|
||||
std::vector<int64_t> rhsShape = rhsType.getShape();
|
||||
|
||||
if (rhsShape.size() == 1)
|
||||
rhsShape.push_back(1);
|
||||
|
||||
if (lhsShape.size() == 1)
|
||||
lhsShape.emplace(lhsShape.begin(), 1);
|
||||
|
||||
int64_t rhsDims = (int64_t)rhsShape.size();
|
||||
int64_t lhsDims = (int64_t)lhsShape.size();
|
||||
|
||||
// Suppose lhsDims is (5, 3, 2) -> 5 matrices of size 3x2 (2 is the
|
||||
// reduction dimension) and rhsDims is (3, 5, 2, 3) -> 3x 5 matrices of
|
||||
// size 2x3 the pair matrix would have size (3, 5, 3, 3, 2) this is the
|
||||
// shape of the matrix onto which we apply the TLUs that compute the
|
||||
// multiplication of all pairs of values
|
||||
|
||||
// The RHS can be a (N,) matrix, the outer dimension is supposed to be 1
|
||||
int64_t rhsOuterDim = rhsShape[rhsDims - 1];
|
||||
int64_t lhsOuterDim = lhsShape[lhsDims - 2];
|
||||
|
||||
std::vector<uint64_t> pairMatrixShape;
|
||||
|
||||
// Compute the output matrix dimension
|
||||
// Corresponding dimensions that are considered "compatible" are "N, 1", "1,
|
||||
// N", "N, N"
|
||||
int64_t rhsDimIter = rhsDims - 3, lhsDimIter = lhsDims - 3;
|
||||
if (rhsDimIter >= 0 && lhsDimIter >= 0) {
|
||||
while (rhsDimIter >= 0 && lhsDimIter >= 0 &&
|
||||
(lhsShape[lhsDimIter] == rhsShape[rhsDimIter] ||
|
||||
lhsShape[lhsDimIter] == 1 || rhsShape[rhsDimIter] == 1)) {
|
||||
pairMatrixShape.push_back(
|
||||
std::max(rhsShape[rhsDimIter], lhsShape[lhsDimIter]));
|
||||
--lhsDimIter;
|
||||
--rhsDimIter;
|
||||
}
|
||||
}
|
||||
assert((lhsDimIter < 0 || rhsDimIter < 0) &&
|
||||
"Bad dimensions given to matmul or dot");
|
||||
while (lhsDimIter >= 0) {
|
||||
pairMatrixShape.push_back(lhsShape[lhsDimIter]);
|
||||
--lhsDimIter;
|
||||
}
|
||||
while (rhsDimIter >= 0) {
|
||||
pairMatrixShape.push_back(rhsShape[rhsDimIter]);
|
||||
--rhsDimIter;
|
||||
}
|
||||
// Add the outer dimensions of the individual matrices
|
||||
pairMatrixShape.push_back(lhsOuterDim);
|
||||
pairMatrixShape.push_back(rhsOuterDim);
|
||||
|
||||
// Add the reduction dimension
|
||||
// The number of elements in the dot product
|
||||
// is the number of cells on the reduction axis (aka "destroyed dimension")
|
||||
int64_t reductionDimSize = rhsShape[rhsDims - 2];
|
||||
assert(lhsShape[lhsDims - 1] == reductionDimSize);
|
||||
|
||||
pairMatrixShape.push_back(reductionDimSize);
|
||||
|
||||
// Compute the manp of the various steps
|
||||
// in the matmul of enc x enc:
|
||||
|
||||
// 1. (x + y) and (x - y) -> supposing broadcasting is used
|
||||
// to tensorize this operation
|
||||
|
||||
Operation *xOp = innerProductOp.getLhs().getDefiningOp();
|
||||
Operation *yOp = innerProductOp.getRhs().getDefiningOp();
|
||||
|
||||
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
||||
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
||||
|
||||
llvm::APInt xSmanp = llvm::APInt{1, 1, false};
|
||||
if (xOp != nullptr) {
|
||||
const auto xSmanpAttr = xOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
||||
assert(xSmanpAttr && "Missing SMANP value on a crypto operation");
|
||||
xSmanp = xSmanpAttr.getValue();
|
||||
}
|
||||
|
||||
llvm::APInt ySmanp = llvm::APInt{1, 1, false};
|
||||
if (yOp != nullptr) {
|
||||
const auto ySmanpAttr = yOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
||||
assert(ySmanpAttr && "Missing SMANP value on a crypto operation");
|
||||
ySmanp = ySmanpAttr.getValue();
|
||||
}
|
||||
|
||||
auto loc = loc_to_string(innerProductOp.getLoc());
|
||||
auto comment =
|
||||
std::string(innerProductOp->getName().getStringRef()) + " " + loc;
|
||||
|
||||
// (x + y) and (x - y)
|
||||
const double addSubManp =
|
||||
sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble());
|
||||
|
||||
auto addNode =
|
||||
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
|
||||
addSubManp, slice(pairMatrixShape), comment);
|
||||
|
||||
auto subNode =
|
||||
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
|
||||
addSubManp, slice(pairMatrixShape), comment);
|
||||
|
||||
// tlu(x - y), tlu(x + y)
|
||||
const std::vector<std::uint64_t> unknownFunction;
|
||||
|
||||
auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision);
|
||||
|
||||
auto rhsTluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
|
||||
|
||||
// 3. Sum(tlu(x + y) - tlu(x - y))
|
||||
// Create a leveled op that simulates concatenation. It takes
|
||||
// as inputs all the intermediary dot product results and produces
|
||||
// the output tensor
|
||||
|
||||
// Default complexity is negligible
|
||||
double fixed_cost = NEGLIGIBLE_COMPLEXITY;
|
||||
double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY;
|
||||
|
||||
// For the output of the operation, take the MANP from the MANP pass
|
||||
mlir::Operation *op = innerProductOp.getOperation();
|
||||
mlir::IntegerAttr smanp_int = op->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
||||
assert(smanp_int && "Missing manp value on a crypto operation");
|
||||
|
||||
Inputs tluOpIndices{lhsTluNode, rhsTluNode};
|
||||
|
||||
// TODO: use APIFloat.sqrt when it's available
|
||||
double manp = sqrt(smanp_int.getValue().roundToDouble());
|
||||
index[result] =
|
||||
dag->add_levelled_op(slice(tluOpIndices), lwe_dim_cost_factor,
|
||||
fixed_cost, manp, slice(resultShape), comment);
|
||||
|
||||
mlir::Builder builder(innerProductOp.getContext());
|
||||
mlir::SmallVector<int32_t, 5> operatorIndexes = {
|
||||
(int32_t)addNode.index, (int32_t)lhsTluNode.index,
|
||||
(int32_t)subNode.index, (int32_t)rhsTluNode.index,
|
||||
(int32_t)index[result].index};
|
||||
|
||||
if (setOptimizerID)
|
||||
innerProductOp->setAttr("TFHE.OId",
|
||||
builder.getDenseI32ArrayAttr(operatorIndexes));
|
||||
|
||||
return index[result];
|
||||
}
|
||||
|
||||
concrete_optimizer::dag::OperatorIndex
|
||||
addEncMatMulTensor(optimizer::Dag &dag, FHELinalg::MatMulEintEintOp &matmulOp,
|
||||
Inputs &inputs, int precision) {
|
||||
return addTensorInnerProductEncEnc<FHELinalg::MatMulEintEintOp>(
|
||||
dag, matmulOp, inputs, precision);
|
||||
}
|
||||
|
||||
concrete_optimizer::dag::OperatorIndex addDotEint(optimizer::Dag &dag,
|
||||
FHELinalg::DotEint &dotOp,
|
||||
Inputs &inputs,
|
||||
int precision) {
|
||||
return addTensorInnerProductEncEnc<FHELinalg::DotEint>(dag, dotOp, inputs,
|
||||
precision);
|
||||
}
|
||||
|
||||
void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs,
|
||||
int precision) {
|
||||
mlir::Value result = maxOp.getResult();
|
||||
@@ -547,6 +738,10 @@ struct FunctionToDag {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op);
|
||||
}
|
||||
|
||||
mlir::concretelang::FHELinalg::DotEint asDotEint(mlir::Operation &op) {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHELinalg::DotEint>(op);
|
||||
}
|
||||
|
||||
mlir::concretelang::FHE::MulEintOp asMul(mlir::Operation &op) {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHE::MulEintOp>(op);
|
||||
}
|
||||
@@ -563,6 +758,11 @@ struct FunctionToDag {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Maxpool2dOp>(op);
|
||||
}
|
||||
|
||||
mlir::concretelang::FHELinalg::MatMulEintEintOp
|
||||
asMatmulEintEint(mlir::Operation &op) {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHELinalg::MatMulEintEintOp>(op);
|
||||
}
|
||||
|
||||
bool isReturn(mlir::Operation &op) {
|
||||
return llvm::isa<mlir::func::ReturnOp>(op);
|
||||
}
|
||||
|
||||
@@ -341,6 +341,48 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Dot op,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHELinalg.dot_eint_eint` operation.
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::DotEint op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
auto rhsType =
|
||||
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
||||
|
||||
int64_t rhsDims = (int64_t)rhsShape.size();
|
||||
|
||||
assert(rhsDims == 1 && "In MANP computation dot product RHS expected to have "
|
||||
"a single dimension");
|
||||
|
||||
int64_t N = rhsShape[0];
|
||||
|
||||
// Compute output MANP:
|
||||
// Tlu output MANP is 1
|
||||
llvm::APInt tlu = {1, 1, false};
|
||||
// The element-wise multiplication is given by the
|
||||
// subtraction of two TLU outputs. The MANP of the multiplication is thus
|
||||
// the sum of the TLU MANPs
|
||||
llvm::APInt elemMulNorm = APIntWidthExtendUAdd(tlu, tlu);
|
||||
|
||||
llvm::APInt accNorm = llvm::APInt{1, 0, false};
|
||||
|
||||
// For the total Dot product MANP, take the manp of the sum of products
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
accNorm = APIntWidthExtendUAdd(elemMulNorm, accNorm);
|
||||
}
|
||||
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
/// `FHE.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHE::AddEintIntOp op,
|
||||
@@ -522,12 +564,12 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHE::MulEintOp op,
|
||||
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value();
|
||||
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y);
|
||||
// The MANP of this operation is simply the MANP after the TLUs
|
||||
// which is equal to the sum of outputs of 2 TLUs
|
||||
const llvm::APInt tlu = {1, 1, false};
|
||||
const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu);
|
||||
|
||||
// this is not optimal as it can increase the resulting noise unnecessarily
|
||||
return APIntUMax(beforeTLUs, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
@@ -742,12 +784,12 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MulEintOp op,
|
||||
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value();
|
||||
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y);
|
||||
// The MANP of this operation is simply the MANP after the TLUs
|
||||
// which is equal to the sum of outputs of 2 TLUs
|
||||
const llvm::APInt tlu = {1, 1, false};
|
||||
const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu);
|
||||
|
||||
// this is not optimal as it can increase the resulting noise unnecessarily
|
||||
return APIntUMax(beforeTLUs, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
static llvm::APInt computeVectorNorm(
|
||||
@@ -755,7 +797,10 @@ static llvm::APInt computeVectorNorm(
|
||||
mlir::DenseIntElementsAttr denseValues, llvm::APInt encryptedOperandNorm,
|
||||
llvm::SmallVector<uint64_t, /*size-hint=*/4> &elementSelector) {
|
||||
|
||||
llvm::APInt accumulationNorm = llvm::APInt{1, 1, false};
|
||||
// The accumulator is initialized with 0s in all bits (not the encrypted 0)
|
||||
// the there is no initial noise in the accumulator, so its
|
||||
// MANP is initialized to 0
|
||||
llvm::APInt accumulationNorm = llvm::APInt{1, 0, false};
|
||||
for (int64_t i = 0; i < shape[axis]; i++) {
|
||||
elementSelector[axis] = i;
|
||||
|
||||
@@ -845,7 +890,7 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintIntOp op,
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt accNorm = llvm::APInt{1, 1, false};
|
||||
llvm::APInt accNorm = llvm::APInt{1, 0, false};
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
@@ -866,7 +911,10 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintIntOp op,
|
||||
int64_t P = rhsShape[1];
|
||||
for (int64_t m = 0; m < M; m++) {
|
||||
for (int64_t p = 0; p < P; p++) {
|
||||
llvm::APInt tmpNorm = llvm::APInt{1, 1, false};
|
||||
// The accumulator is initialized with 0s in all bits (not the
|
||||
// encrypted 0) the there is no initial noise in the accumulator, so
|
||||
// its MANP is initialized to 0
|
||||
llvm::APInt tmpNorm = llvm::APInt{1, 0, false};
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
llvm::APInt cst = denseValsAP[{(uint64_t)n, (uint64_t)p}];
|
||||
llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst);
|
||||
@@ -943,7 +991,7 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulIntEintOp op,
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value();
|
||||
llvm::APInt accNorm = llvm::APInt{1, 1, false};
|
||||
llvm::APInt accNorm = llvm::APInt{1, 0, false};
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
@@ -1018,6 +1066,46 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulIntEintOp op,
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a matmul
|
||||
/// operation
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintEintOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
auto rhsType =
|
||||
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
||||
|
||||
int64_t rhsDims = (int64_t)rhsShape.size();
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().has_value() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value();
|
||||
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value();
|
||||
|
||||
int64_t N = rhsDims <= 2 ? rhsShape[0] : rhsShape[rhsDims - 2];
|
||||
|
||||
// Compute MANP of a single matrix cell x matrix cell multiplication
|
||||
// This is used later to compute the MANP of an entire dot product
|
||||
|
||||
llvm::APInt tlu = {1, 1, false};
|
||||
llvm::APInt elemMulNorm = APIntWidthExtendUAdd(tlu, tlu);
|
||||
llvm::APInt accNorm = llvm::APInt{1, 0, false};
|
||||
|
||||
// For the total MatMul MANP, take the MANP of a single
|
||||
// column-row dot-product
|
||||
// All such dot-products produce the same MANP, there
|
||||
// is no need to take the maximum over the dot-products
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
accNorm = APIntWidthExtendUAdd(elemMulNorm, accNorm);
|
||||
}
|
||||
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::TransposeOp op,
|
||||
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
|
||||
|
||||
@@ -1373,6 +1461,9 @@ public:
|
||||
else if (auto dotOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op)) {
|
||||
norm2SqEquiv = getSqMANP(dotOp, operands);
|
||||
} else if (auto dotEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::DotEint>(op)) {
|
||||
norm2SqEquiv = getSqMANP(dotEintOp, operands);
|
||||
} else if (auto addEintIntOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::AddEintIntOp>(
|
||||
op)) {
|
||||
@@ -1419,6 +1510,9 @@ public:
|
||||
} else if (auto matmulIntEintOp = llvm::dyn_cast<
|
||||
mlir::concretelang::FHELinalg::MatMulIntEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(matmulIntEintOp, operands);
|
||||
} else if (auto matmulEintEintOp = llvm::dyn_cast<
|
||||
mlir::concretelang::FHELinalg::MatMulEintEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(matmulEintEintOp, operands);
|
||||
} else if (llvm::isa<
|
||||
mlir::concretelang::FHELinalg::ApplyLookupTableEintOp,
|
||||
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp,
|
||||
|
||||
@@ -401,35 +401,73 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
|
||||
verifyLutsSize(*this, t, luts).succeeded());
|
||||
}
|
||||
|
||||
::mlir::LogicalResult Dot::verify() {
|
||||
if (::mlir::failed(mlir::verifyCompatibleShape(this->getLhs().getType(),
|
||||
this->getRhs().getType()))) {
|
||||
return this->emitOpError("arguments have incompatible shapes");
|
||||
}
|
||||
auto lhsEltType = this->getLhs()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.dyn_cast<FHE::FheIntegerInterface>();
|
||||
auto rhsEltType = this->getRhs()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::IntegerType>();
|
||||
auto resultType =
|
||||
this->getResult().getType().dyn_cast<FHE::FheIntegerInterface>();
|
||||
if (!mlir::concretelang::FHE::
|
||||
verifyEncryptedIntegerAndIntegerInputsConsistency(
|
||||
*this->getOperation(), lhsEltType, rhsEltType)) {
|
||||
mlir::LogicalResult
|
||||
verifyDotInputsOutputsConsistency(mlir::concretelang::FHELinalg::DotEint &op,
|
||||
FHE::FheIntegerInterface &lhsEltType,
|
||||
FHE::FheIntegerInterface &rhsEltType,
|
||||
FHE::FheIntegerInterface &resultType) {
|
||||
if (!mlir::concretelang::FHE::verifyEncryptedIntegerInputsConsistency(
|
||||
*op.getOperation(), lhsEltType, rhsEltType)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(
|
||||
*this->getOperation(), lhsEltType, resultType)) {
|
||||
*op.getOperation(), lhsEltType, resultType)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
verifyDotInputsOutputsConsistency(mlir::concretelang::FHELinalg::Dot &op,
|
||||
FHE::FheIntegerInterface &lhsEltType,
|
||||
mlir::IntegerType &rhsEltType,
|
||||
FHE::FheIntegerInterface &resultType) {
|
||||
if (!mlir::concretelang::FHE::
|
||||
verifyEncryptedIntegerAndIntegerInputsConsistency(
|
||||
*op.getOperation(), lhsEltType, rhsEltType)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(
|
||||
*op.getOperation(), lhsEltType, resultType)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
// Verify a dot product operation:
|
||||
// - check that the shapes are compatible
|
||||
// - check that the widths of the inputs and result is the same
|
||||
template <typename DotOp, typename RHSElementType>
|
||||
mlir::LogicalResult verifyDot(DotOp &op) {
|
||||
if (::mlir::failed(mlir::verifyCompatibleShape(op.getLhs().getType(),
|
||||
op.getRhs().getType()))) {
|
||||
return op.emitOpError("arguments have incompatible shapes");
|
||||
}
|
||||
auto lhsEltType = ((mlir::Type)op.getLhs().getType())
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.dyn_cast<FHE::FheIntegerInterface>();
|
||||
auto rhsEltType = ((mlir::Type)op.getRhs().getType())
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.cast<RHSElementType>();
|
||||
auto resultType = ((mlir::Type)op.getResult().getType())
|
||||
.dyn_cast<FHE::FheIntegerInterface>();
|
||||
|
||||
return verifyDotInputsOutputsConsistency(op, lhsEltType, rhsEltType,
|
||||
resultType);
|
||||
}
|
||||
|
||||
::mlir::LogicalResult Dot::verify() {
|
||||
return ::mlir::concretelang::FHELinalg::verifyDot<
|
||||
mlir::concretelang::FHELinalg::Dot, mlir::IntegerType>(*this);
|
||||
}
|
||||
|
||||
::mlir::LogicalResult DotEint::verify() {
|
||||
return ::mlir::concretelang::FHELinalg::verifyDot<
|
||||
mlir::concretelang::FHELinalg::DotEint, FHE::FheIntegerInterface>(*this);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 3>
|
||||
verifySumCalculateActualOutputShape(mlir::Type outputType) {
|
||||
auto actualOutputShape = llvm::SmallVector<int64_t, 3>{};
|
||||
@@ -784,6 +822,11 @@ mlir::LogicalResult MatMulIntEintOp::verify() {
|
||||
mlir::concretelang::FHELinalg::MatMulIntEintOp>(*this);
|
||||
}
|
||||
|
||||
mlir::LogicalResult MatMulEintEintOp::verify() {
|
||||
return ::mlir::concretelang::FHELinalg::verifyMatmul<
|
||||
mlir::concretelang::FHELinalg::MatMulEintEintOp>(*this);
|
||||
}
|
||||
|
||||
mlir::SmallVector<int64_t, 4>
|
||||
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 4> paddingInts;
|
||||
|
||||
@@ -335,3 +335,24 @@ func.func @main(%x: tensor<3x4xi6>, %y: tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x
|
||||
%0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<3x4xi6>, tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x!FHE.eint<5>>
|
||||
return %0 : tensor<3x2x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-NEXT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<5>>, %[[a1:.*]]: tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x!FHE.eint<5>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<5>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m1]], #[[m2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[a0]], %[[a1]] : tensor<3x4x!FHE.eint<5>>, tensor<4x2x!FHE.eint<5>>) outs(%[[v0]] : tensor<3x2x!FHE.eint<5>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<5>, %[[aa1:.*]]: !FHE.eint<5>, %[[aa2:.*]]: !FHE.eint<5>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.mul_eint"(%[[aa1]], %[[aa0]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
// CHECK-NEXT: %[[vv1:.*]] = "FHE.add_eint"(%[[aa2]], %[[vv0]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
// CHECK-NEXT: linalg.yield %[[vv1]] : !FHE.eint<5>
|
||||
// CHECK-NEXT: } -> tensor<3x2x!FHE.eint<5>>
|
||||
// CHECK-NEXT: return %[[v1]] : tensor<3x2x!FHE.eint<5>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%x: tensor<3x4x!FHE.eint<5>>, %y: tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x!FHE.eint<5>> {
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<3x4x!FHE.eint<5>>, tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x!FHE.eint<5>>
|
||||
return %0 : tensor<3x2x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
@@ -281,11 +281,12 @@ func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>)
|
||||
func.func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2>
|
||||
{
|
||||
// sqrt((2^2-1)^2*1) = sqrt(9) = 3
|
||||
// FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}}
|
||||
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
|
||||
// sqrt(4*(2^2-1)^2*9) = sqrt(324) = 18
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 42 : ui{{[[0-9]+}}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I:.*]]) {MANP = 42 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
|
||||
|
||||
return %1 : !FHE.eint<2>
|
||||
@@ -299,26 +300,28 @@ func.func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>)
|
||||
|
||||
func.func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// acc = manp(0) = 0
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// ceil(sqrt(65)) = 4
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 8 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
// manp(add_eint(mul, acc)) = 9
|
||||
// ceil(sqrt(9)) = 3
|
||||
// FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 7 : ui{{[0-9]+}}}
|
||||
%0 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %0 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// acc = manp(0) = 0
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// manp(add_eint(mul, acc)) = 9
|
||||
// p = 1
|
||||
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 10 + 9 = 19
|
||||
// ceil(sqrt(19)) = 5
|
||||
// manp(add_eint(mul, acc)) = 9 + 9 = 18
|
||||
// ceil(sqrt(18)) = 5
|
||||
// FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 10 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
@@ -329,11 +332,11 @@ func.func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tenso
|
||||
func.func @matmul_eint_int_cst_p_1(%arg0: tensor<3x1x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
%0 = arith.constant dense<[[3, 1]]> : tensor<1x2xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max cst is used for n = 0
|
||||
// acc = manp(0) = 1
|
||||
// acc = manp(0) = 0
|
||||
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// ceil(sqrt(10)) = 4
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
// manp(add_eint(mul, acc)) = 9
|
||||
// ceil(sqrt(10)) = 3
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -348,10 +351,10 @@ func.func @matmul_eint_int_cst_p_2_n_0(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso
|
||||
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// p = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 17
|
||||
// manp(add_eint(mul, acc)) = 17 + 9 = 26
|
||||
// ceil(sqrt(26)) = 6
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 6 : ui{{[0-9]+}}}
|
||||
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16
|
||||
// manp(add_eint(mul, acc)) = 16 + 9 = 25
|
||||
// ceil(sqrt(25)) = 5
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -362,13 +365,13 @@ func.func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso
|
||||
%0 = arith.constant dense<[[1, 4],[3, 1]]> : tensor<2x2xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max csts [4,1] are used for n = 1
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// acc = manp(0) = 0
|
||||
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16
|
||||
// manp(add_eint(mul, acc)) = 16 + 1 = 17
|
||||
// manp(add_eint(mul, acc)) = 16
|
||||
// p = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 1) = 1 * 1^2 = 1
|
||||
// manp(add_eint(mul, acc)) = 1 + 17 = 18
|
||||
// ceil(sqrt(18)) = 5
|
||||
// manp(add_eint(mul, acc)) = 1 + 16 = 17
|
||||
// ceil(sqrt(17)) = 5
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
@@ -382,7 +385,7 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
|
||||
// ===============================
|
||||
|
||||
%1 = arith.constant dense<
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
|
||||
[2, 1, 5]
|
||||
> : tensor<3xi8>
|
||||
|
||||
@@ -392,8 +395,8 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
|
||||
// ===============================
|
||||
|
||||
%3 = arith.constant dense<
|
||||
// ceil(sqrt(2^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(39)) = 7
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(2^2 + 3^2 + 5^2)) = ceil(sqrt(39)) = 7
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[2, 3],
|
||||
[3, 2],
|
||||
@@ -401,47 +404,47 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
|
||||
]
|
||||
> : tensor<3x2xi8>
|
||||
|
||||
// CHECK: MANP = 8 : ui{{[0-9]+}}
|
||||
// CHECK: MANP = 7 : ui{{[0-9]+}}
|
||||
%4 = "FHELinalg.matmul_eint_int"(%0, %3) : (tensor<4x3x!FHE.eint<7>>, tensor<3x2xi8>) -> tensor<4x2x!FHE.eint<7>>
|
||||
|
||||
// ===============================
|
||||
|
||||
%5 = arith.constant dense<
|
||||
[
|
||||
// ceil(sqrt(1^2 + 4^2 + 6^2 + 1)) = ceil(sqrt(54)) = 8
|
||||
// ceil(sqrt(6^2 + 3^2 + 2^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(1^2 + 4^2 + 6^2)) = ceil(sqrt(53)) = 8
|
||||
// ceil(sqrt(6^2 + 3^2 + 2^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[1, 6],
|
||||
[4, 3],
|
||||
[6, 2]
|
||||
],
|
||||
|
||||
// ceil(sqrt(5^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(60)) = 8
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(5^2 + 3^2 + 5^2)) = ceil(sqrt(59)) = 8
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[5, 3],
|
||||
[3, 2],
|
||||
[5, 6]
|
||||
],
|
||||
|
||||
// ceil(sqrt(5^2 + 5^2 + 3^2 + 1)) = ceil(sqrt(60)) = 8
|
||||
// ceil(sqrt(3^2 + 6^2 + 3^2 + 1)) = ceil(sqrt(55)) = 8
|
||||
// ceil(sqrt(5^2 + 5^2 + 3^2)) = ceil(sqrt(59)) = 8
|
||||
// ceil(sqrt(3^2 + 6^2 + 3^2)) = ceil(sqrt(54)) = 8
|
||||
[
|
||||
[5, 3],
|
||||
[5, 6],
|
||||
[3, 3]
|
||||
],
|
||||
|
||||
// ceil(sqrt(6^2 + 1^2 + 4^2 + 1)) = ceil(sqrt(54)) = 8
|
||||
// ceil(sqrt(3^2 + 4^2 + 3^2 + 1)) = ceil(sqrt(35)) = 6
|
||||
// ceil(sqrt(6^2 + 1^2 + 4^2)) = ceil(sqrt(53)) = 8
|
||||
// ceil(sqrt(3^2 + 4^2 + 3^2)) = ceil(sqrt(34)) = 6
|
||||
[
|
||||
[6, 3],
|
||||
[1, 4],
|
||||
[4, 3]
|
||||
],
|
||||
|
||||
// ceil(sqrt(1^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(74)) = 9
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(1^2 + 6^2 + 6^2)) = ceil(sqrt(73)) = 9
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
|
||||
[
|
||||
[1, 2],
|
||||
[6, 1],
|
||||
@@ -458,40 +461,40 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
|
||||
%7 = arith.constant dense<
|
||||
[
|
||||
[
|
||||
// ceil(sqrt(1^2 + 4^2 + 6^2 + 1)) = ceil(sqrt(54)) = 8
|
||||
// ceil(sqrt(6^2 + 3^2 + 2^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(1^2 + 4^2 + 6^2)) = ceil(sqrt(53)) = 8
|
||||
// ceil(sqrt(6^2 + 3^2 + 2^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[1, 6],
|
||||
[4, 3],
|
||||
[6, 2]
|
||||
],
|
||||
|
||||
// ceil(sqrt(5^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(60)) = 8
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(5^2 + 3^2 + 5^2)) = ceil(sqrt(59)) = 8
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(47)) = 7
|
||||
[
|
||||
[5, 3],
|
||||
[3, 2],
|
||||
[5, 6]
|
||||
],
|
||||
|
||||
// ceil(sqrt(5^2 + 5^2 + 3^2 + 1)) = ceil(sqrt(60)) = 8
|
||||
// ceil(sqrt(3^2 + 6^2 + 3^2 + 1)) = ceil(sqrt(55)) = 8
|
||||
// ceil(sqrt(5^2 + 5^2 + 3^2)) = ceil(sqrt(59)) = 8
|
||||
// ceil(sqrt(3^2 + 6^2 + 3^2)) = ceil(sqrt(54)) = 8
|
||||
[
|
||||
[5, 3],
|
||||
[5, 6],
|
||||
[3, 3]
|
||||
],
|
||||
|
||||
// ceil(sqrt(6^2 + 1^2 + 4^2 + 1)) = ceil(sqrt(54)) = 8
|
||||
// ceil(sqrt(3^2 + 4^2 + 3^2 + 1)) = ceil(sqrt(35)) = 6
|
||||
// ceil(sqrt(6^2 + 1^2 + 4^2)) = ceil(sqrt(53)) = 8
|
||||
// ceil(sqrt(3^2 + 4^2 + 3^2)) = ceil(sqrt(34)) = 6
|
||||
[
|
||||
[6, 3],
|
||||
[1, 4],
|
||||
[4, 3]
|
||||
],
|
||||
|
||||
// ceil(sqrt(1^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(74)) = 9
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(1^2 + 6^2 + 6^2)) = ceil(sqrt(73)) = 9
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
|
||||
[
|
||||
[1, 2],
|
||||
[6, 1],
|
||||
@@ -499,40 +502,40 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
|
||||
]
|
||||
],
|
||||
[
|
||||
// ceil(sqrt(6^2 + 1^2 + 3^2 + 1)) = ceil(sqrt(47)) = 7
|
||||
// ceil(sqrt(5^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(98)) = 10
|
||||
// ceil(sqrt(6^2 + 1^2 + 3^2)) = ceil(sqrt(46)) = 7
|
||||
// ceil(sqrt(5^2 + 6^2 + 6^2)) = ceil(sqrt(97)) = 10
|
||||
[
|
||||
[6, 5],
|
||||
[1, 6],
|
||||
[3, 6]
|
||||
],
|
||||
|
||||
// ceil(sqrt(1^2 + 2^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(6^2 + 3^2 + 1^2 + 1)) = ceil(sqrt(47)) = 7
|
||||
// ceil(sqrt(1^2 + 2^2 + 5^2)) = ceil(sqrt(30)) = 6
|
||||
// ceil(sqrt(6^2 + 3^2 + 1^2)) = ceil(sqrt(46)) = 7
|
||||
[
|
||||
[1, 6],
|
||||
[2, 3],
|
||||
[5, 1]
|
||||
],
|
||||
|
||||
// ceil(sqrt(4^2 + 3^2 + 6^2 + 1)) = ceil(sqrt(62)) = 8
|
||||
// ceil(sqrt(1^2 + 5^2 + 2^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(4^2 + 3^2 + 6^2)) = ceil(sqrt(61)) = 8
|
||||
// ceil(sqrt(1^2 + 5^2 + 2^2)) = ceil(sqrt(30)) = 6
|
||||
[
|
||||
[4, 1],
|
||||
[3, 5],
|
||||
[6, 2]
|
||||
],
|
||||
|
||||
// ceil(sqrt(2^2 + 3^2 + 3^2 + 1)) = ceil(sqrt(23)) = 5
|
||||
// ceil(sqrt(2^2 + 2^2 + 1^2 + 1)) = ceil(sqrt(10)) = 4
|
||||
// ceil(sqrt(2^2 + 3^2 + 3^2)) = ceil(sqrt(22)) = 5
|
||||
// ceil(sqrt(2^2 + 2^2 + 1^2)) = ceil(sqrt(9)) = 3
|
||||
[
|
||||
[2, 2],
|
||||
[3, 2],
|
||||
[3, 1]
|
||||
],
|
||||
|
||||
// ceil(sqrt(6^2 + 2^2 + 3^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(2^2 + 4^2 + 2^2 + 1)) = ceil(sqrt(25)) = 5
|
||||
// ceil(sqrt(6^2 + 2^2 + 3^2)) = ceil(sqrt(49)) = 7
|
||||
// ceil(sqrt(2^2 + 4^2 + 2^2)) = ceil(sqrt(24)) = 5
|
||||
[
|
||||
[6, 2],
|
||||
[2, 4],
|
||||
@@ -563,7 +566,7 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<
|
||||
// ===============================
|
||||
|
||||
%1 = arith.constant dense<
|
||||
// ceil(sqrt(1 * (2^2 + 1^2 + 5^2) + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(1 * (2^2 + 1^2 + 5^2))) = ceil(sqrt(30)) = 6
|
||||
[2, 1, 5]
|
||||
> : tensor<3xi8>
|
||||
|
||||
@@ -583,11 +586,11 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<
|
||||
|
||||
func.func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// acc = manp(0) = 0
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 10
|
||||
// ceil(sqrt(65)) = 4
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 8 : ui{{[0-9]+}}}
|
||||
// manp(add_eint(mul, acc)) = 64
|
||||
// ceil(sqrt(64)) = 8
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 7 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -613,9 +616,9 @@ func.func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!FHE
|
||||
func.func @matmul_int_eint_cst_p_1(%arg0: tensor<1x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>> {
|
||||
%0 = arith.constant dense<[[3], [1]]> : tensor<2x1xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max cst is used for m = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// acc = manp(0) = 0
|
||||
// mul = manp(mul_eint_int(eint<2>, 3) = 1^2 + 3^2 = 10
|
||||
// manp(add_eint(mul, acc)) = 10
|
||||
// ceil(sqrt(10)) = 4
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x1xi3>, tensor<1x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>>
|
||||
@@ -666,7 +669,7 @@ func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
|
||||
// ===============================
|
||||
|
||||
%1 = arith.constant dense<
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
|
||||
[2, 1, 5]
|
||||
> : tensor<3xi8>
|
||||
|
||||
@@ -676,8 +679,8 @@ func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
|
||||
// ===============================
|
||||
|
||||
%3 = arith.constant dense<
|
||||
// ceil(sqrt(2^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(39)) = 7
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(2^2 + 3^2 + 5^2)) = ceil(sqrt(38)) = 7
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[2, 3, 5],
|
||||
[3, 2, 6]
|
||||
@@ -691,36 +694,36 @@ func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
|
||||
|
||||
%5 = arith.constant dense<
|
||||
[
|
||||
// ceil(sqrt(1^2 + 4^2 + 6^2 + 1)) = ceil(sqrt(54)) = 8
|
||||
// ceil(sqrt(6^2 + 3^2 + 2^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(1^2 + 4^2 + 6^2)) = ceil(sqrt(53)) = 8
|
||||
// ceil(sqrt(6^2 + 3^2 + 2^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[1, 4, 6],
|
||||
[6, 3, 2]
|
||||
],
|
||||
|
||||
// ceil(sqrt(5^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(60)) = 8
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(5^2 + 3^2 + 5^2)) = ceil(sqrt(59)) = 8
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[5, 3, 5],
|
||||
[3, 2, 6]
|
||||
],
|
||||
|
||||
// ceil(sqrt(5^2 + 5^2 + 3^2 + 1)) = ceil(sqrt(60)) = 8
|
||||
// ceil(sqrt(3^2 + 6^2 + 3^2 + 1)) = ceil(sqrt(55)) = 8
|
||||
// ceil(sqrt(5^2 + 5^2 + 3^2)) = ceil(sqrt(59)) = 8
|
||||
// ceil(sqrt(3^2 + 6^2 + 3^2)) = ceil(sqrt(54)) = 8
|
||||
[
|
||||
[5, 5, 3],
|
||||
[3, 6, 3]
|
||||
],
|
||||
|
||||
// ceil(sqrt(6^2 + 1^2 + 4^2 + 1)) = ceil(sqrt(54)) = 8
|
||||
// ceil(sqrt(3^2 + 4^2 + 3^2 + 1)) = ceil(sqrt(35)) = 6
|
||||
// ceil(sqrt(6^2 + 1^2 + 4^2)) = ceil(sqrt(53)) = 8
|
||||
// ceil(sqrt(3^2 + 4^2 + 3^2)) = ceil(sqrt(34)) = 6
|
||||
[
|
||||
[6, 1, 4],
|
||||
[3, 4, 3]
|
||||
],
|
||||
|
||||
// ceil(sqrt(1^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(74)) = 9
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(1^2 + 6^2 + 6^2)) = ceil(sqrt(73)) = 9
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
|
||||
[
|
||||
[1, 6, 6],
|
||||
[2, 1, 5]
|
||||
@@ -736,72 +739,72 @@ func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
|
||||
%7 = arith.constant dense<
|
||||
[
|
||||
[
|
||||
// ceil(sqrt(1^2 + 4^2 + 6^2 + 1)) = ceil(sqrt(54)) = 8
|
||||
// ceil(sqrt(6^2 + 3^2 + 2^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(1^2 + 4^2 + 6^2)) = ceil(sqrt(53)) = 8
|
||||
// ceil(sqrt(6^2 + 3^2 + 2^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[1, 4, 6],
|
||||
[6, 3, 2]
|
||||
],
|
||||
|
||||
// ceil(sqrt(5^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(60)) = 8
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(5^2 + 3^2 + 5^2)) = ceil(sqrt(59)) = 8
|
||||
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
|
||||
[
|
||||
[5, 3, 5],
|
||||
[3, 2, 6]
|
||||
],
|
||||
|
||||
// ceil(sqrt(5^2 + 5^2 + 3^2 + 1)) = ceil(sqrt(60)) = 8
|
||||
// ceil(sqrt(3^2 + 6^2 + 3^2 + 1)) = ceil(sqrt(55)) = 8
|
||||
// ceil(sqrt(5^2 + 5^2 + 3^2)) = ceil(sqrt(59)) = 8
|
||||
// ceil(sqrt(3^2 + 6^2 + 3^2)) = ceil(sqrt(54)) = 8
|
||||
[
|
||||
[5, 5, 3],
|
||||
[3, 6, 3]
|
||||
],
|
||||
|
||||
// ceil(sqrt(6^2 + 1^2 + 4^2 + 1)) = ceil(sqrt(54)) = 8
|
||||
// ceil(sqrt(3^2 + 4^2 + 3^2 + 1)) = ceil(sqrt(35)) = 6
|
||||
// ceil(sqrt(6^2 + 1^2 + 4^2)) = ceil(sqrt(53)) = 8
|
||||
// ceil(sqrt(3^2 + 4^2 + 3^2)) = ceil(sqrt(34)) = 6
|
||||
[
|
||||
[6, 1, 4],
|
||||
[3, 4, 3]
|
||||
],
|
||||
|
||||
// ceil(sqrt(1^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(74)) = 9
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(1^2 + 6^2 + 6^2)) = ceil(sqrt(73)) = 9
|
||||
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
|
||||
[
|
||||
[1, 6, 6],
|
||||
[2, 1, 5]
|
||||
]
|
||||
],
|
||||
[
|
||||
// ceil(sqrt(6^2 + 1^2 + 3^2 + 1)) = ceil(sqrt(47)) = 7
|
||||
// ceil(sqrt(5^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(98)) = 10
|
||||
// ceil(sqrt(6^2 + 1^2 + 3^2)) = ceil(sqrt(46)) = 7
|
||||
// ceil(sqrt(5^2 + 6^2 + 6^2)) = ceil(sqrt(97)) = 10
|
||||
[
|
||||
[6, 1, 3],
|
||||
[5, 6, 6]
|
||||
],
|
||||
|
||||
// ceil(sqrt(1^2 + 2^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(6^2 + 3^2 + 1^2 + 1)) = ceil(sqrt(47)) = 7
|
||||
// ceil(sqrt(1^2 + 2^2 + 5^2)) = ceil(sqrt(30)) = 6
|
||||
// ceil(sqrt(6^2 + 3^2 + 1^2)) = ceil(sqrt(46)) = 7
|
||||
[
|
||||
[1, 2, 5],
|
||||
[6, 3, 1]
|
||||
],
|
||||
|
||||
// ceil(sqrt(4^2 + 3^2 + 6^2 + 1)) = ceil(sqrt(62)) = 8
|
||||
// ceil(sqrt(1^2 + 5^2 + 2^2 + 1)) = ceil(sqrt(31)) = 6
|
||||
// ceil(sqrt(4^2 + 3^2 + 6^2)) = ceil(sqrt(61)) = 8
|
||||
// ceil(sqrt(1^2 + 5^2 + 2^2)) = ceil(sqrt(30)) = 6
|
||||
[
|
||||
[4, 3, 6],
|
||||
[1, 5, 2]
|
||||
],
|
||||
|
||||
// ceil(sqrt(2^2 + 3^2 + 3^2 + 1)) = ceil(sqrt(23)) = 5
|
||||
// ceil(sqrt(2^2 + 2^2 + 1^2 + 1)) = ceil(sqrt(10)) = 4
|
||||
// ceil(sqrt(2^2 + 3^2 + 3^2)) = ceil(sqrt(22)) = 5
|
||||
// ceil(sqrt(2^2 + 2^2 + 1^2)) = ceil(sqrt(9)) = 3
|
||||
[
|
||||
[2, 3, 3],
|
||||
[2, 2, 1]
|
||||
],
|
||||
|
||||
// ceil(sqrt(6^2 + 2^2 + 3^2 + 1)) = ceil(sqrt(50)) = 8
|
||||
// ceil(sqrt(2^2 + 4^2 + 2^2 + 1)) = ceil(sqrt(25)) = 5
|
||||
// ceil(sqrt(6^2 + 2^2 + 3^2)) = ceil(sqrt(49)) = 7
|
||||
// ceil(sqrt(2^2 + 4^2 + 2^2)) = ceil(sqrt(24)) = 5
|
||||
[
|
||||
[6, 2, 3],
|
||||
[2, 4, 2]
|
||||
|
||||
@@ -8,6 +8,16 @@ func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<2x3xi3>) -> tensor<4x2x
|
||||
return %0 : tensor<4x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<2x3x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size on dimension #1 of operand #0 and dimension #0 of operand #1}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<2x3x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>>
|
||||
return %0 : tensor<4x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2x3xi3>, %y: tensor<4x3x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
|
||||
@@ -24,6 +34,16 @@ func.func @main(%x: tensor<2x4x3x5x!FHE.eint<2>>, %y: tensor<4x3x2xi3>) -> tenso
|
||||
return %0 : tensor<2x4x3x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x5x!FHE.eint<2>>, %y: tensor<4x3x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size on dimension #3 of operand #0 and dimension #1 of operand #1}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x5x!FHE.eint<2>>, tensor<4x3x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>>
|
||||
return %0 : tensor<2x4x3x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x5xi3>, %y: tensor<4x3x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>> {
|
||||
@@ -40,6 +60,16 @@ func.func @main(%x: tensor<2x4x3x5x!FHE.eint<2>>, %y: tensor<10x5x2xi3>) -> tens
|
||||
return %0 : tensor<2x4x3x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x5x!FHE.eint<2>>, %y: tensor<10x5x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size or size of 1 on dimension #1 of operand #0 and dimension #0 of operand #1}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x5x!FHE.eint<2>>, tensor<10x5x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>>
|
||||
return %0 : tensor<2x4x3x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x5xi3>, %y: tensor<10x5x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>> {
|
||||
@@ -56,6 +86,16 @@ func.func @main(%x: tensor<2x!FHE.eint<2>>, %y: tensor<5x3x4x2xi3>) -> tensor<5x
|
||||
return %0 : tensor<5x3x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x!FHE.eint<2>>, %y: tensor<5x3x4x2x!FHE.eint<2>>) -> tensor<5x3x2x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size on dimension #0 of operand #0 and dimension #2 of operand #1}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x!FHE.eint<2>>, tensor<5x3x4x2x!FHE.eint<2>>) -> tensor<5x3x2x!FHE.eint<2>>
|
||||
return %0 : tensor<5x3x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2xi3>, %y: tensor<5x3x4x2x!FHE.eint<2>>) -> tensor<5x3x2x!FHE.eint<2>> {
|
||||
@@ -72,6 +112,16 @@ func.func @main(%x: tensor<5x3x4x2x!FHE.eint<2>>, %y: tensor<4xi3>) -> tensor<5x
|
||||
return %0 : tensor<5x3x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<5x3x4x2x!FHE.eint<2>>, %y: tensor<4x!FHE.eint<2>>) -> tensor<5x3x4x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size on dimension #3 of operand #0 and dimension #0 of operand #1}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x3x4x2x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<5x3x4x!FHE.eint<2>>
|
||||
return %0 : tensor<5x3x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<5x3x4x2xi3>, %y: tensor<4x!FHE.eint<2>>) -> tensor<5x3x4x!FHE.eint<2>> {
|
||||
@@ -88,6 +138,16 @@ func.func @main(%x: tensor<4x!FHE.eint<2>>, %y: tensor<4xi3>) -> tensor<1x!FHE.e
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<4x!FHE.eint<2>>, %y: tensor<4x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have at least one multi dimensional tensor as an operand}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<4xi3>, %y: tensor<4x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -104,6 +164,16 @@ func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<3x2xi3>) -> tensor<1x!F
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<4x3xi3>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -120,6 +190,16 @@ func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<3x2xi3>) -> tensor<1x!FHE
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<3x!FHE.eint<2>>, tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<3xi3>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -136,6 +216,16 @@ func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<4x3x2xi3>) -> tensor<1x!F
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<4x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<3x!FHE.eint<2>>, tensor<4x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<3xi3>, %y: tensor<4x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -152,6 +242,16 @@ func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<4x5x3x2xi3>) -> tensor<1x
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<4x5x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <4x5x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<3x!FHE.eint<2>>, tensor<4x5x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<3xi3>, %y: tensor<4x5x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -168,6 +268,16 @@ func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<3xi3>) -> tensor<1x!FHE
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <4>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<4x3xi3>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -184,6 +294,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<3xi3>) -> tensor<1x!F
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -200,6 +320,16 @@ func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<3xi3>) -> tensor<1x
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x2x4x3x!FHE.eint<2>>, tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<5x2x4x3xi3>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -216,6 +346,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tensor<
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -232,6 +372,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<1x3x2xi3>) -> tensor<
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<1x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<1x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<1x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -248,6 +398,16 @@ func.func @main(%x: tensor<1x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tensor<
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<1x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<1x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<1x4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -264,6 +424,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tensor<
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -280,6 +450,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<3x2xi3>) -> tensor<1x
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -296,6 +476,16 @@ func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2xi3>) -> ten
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x2x4x3x!FHE.eint<2>>, tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<5x2x4x3xi3>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -312,6 +502,16 @@ func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tenso
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x2x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<5x2x4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -328,6 +528,16 @@ func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<3x2xi3>) -> tensor<
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x2x4x3x!FHE.eint<2>>, tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<5x2x4x3xi3>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -344,6 +554,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2xi3>) -> tenso
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -360,6 +580,16 @@ func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2xi3>) -> tensor<
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<4x3xi3>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
@@ -376,10 +606,15 @@ func.func @main(%x: tensor<5x1x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tenso
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @main(%x: tensor<5x1x4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_int_eint' op does not have the proper output shape of <5x2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<5x1x4x3xi3>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
|
||||
func.func @main(%x: tensor<5x1x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x1x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
|
||||
return %0 : tensor<1x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@@ -484,6 +484,23 @@ func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>,
|
||||
return %ret : !FHE.eint<2>
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.dot_eint_eint
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// CHECK-LABEL: func.func @dot_eint_eint(%arg0: tensor<2x!FHE.eint<2>>, %arg1: tensor<2x!FHE.eint<2>>) -> !FHE.eint<2>
|
||||
func.func @dot_eint_eint(%arg0: tensor<2x!FHE.eint<2>>,
|
||||
%arg1: tensor<2x!FHE.eint<2>>) -> !FHE.eint<2>
|
||||
{
|
||||
// CHECK-NEXT: %[[RET:.*]] = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<2>>, tensor<2x!FHE.eint<2>>) -> !FHE.eint<2>
|
||||
%ret = "FHELinalg.dot_eint_eint"(%arg0, %arg1) :
|
||||
(tensor<2x!FHE.eint<2>>, tensor<2x!FHE.eint<2>>) -> !FHE.eint<2>
|
||||
|
||||
//CHECK-NEXT: return %[[RET]] : !FHE.eint<2>
|
||||
return %ret : !FHE.eint<2>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.conv2d
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
PRECISIONS_TO_BENCH = [(6, 2), (16, 7)]
|
||||
SHAPES = [((2, 3, 4), (2, 4, 2)), ((3, 4), (4, 2)), ((3,), (3,)), ((3,), (3, 2)), ((3,), (4, 3, 2)), ((3,4), (4,)), ((2,3,4), (4,)), ((2, 1, 3, 4), (5, 4, 2))]
|
||||
P_ERROR = 1.0 / 1e6
|
||||
|
||||
def format_shape(shape):
|
||||
shape_str = "x".join(map(str, shape))
|
||||
if len(shape):
|
||||
shape_str += "x"
|
||||
else:
|
||||
shape_str = "1x"
|
||||
return shape_str
|
||||
|
||||
def flatten_and_to_str(data, is_tensor=True):
|
||||
payload = ", ".join(map(str, data.reshape((-1,))))
|
||||
if is_tensor:
|
||||
return "[" + payload + "]"
|
||||
return payload
|
||||
|
||||
def generate(op):
|
||||
assert(op in {"matmul_eint_eint", "dot_eint_eint"})
|
||||
print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
|
||||
print("# /!\ THIS FILE HAS BEEN GENERATED")
|
||||
for p, p_inputs in PRECISIONS_TO_BENCH:
|
||||
for shapes in SHAPES:
|
||||
for signed in [False, True]:
|
||||
min_value = 0
|
||||
max_value = (2 ** p_inputs) - 1
|
||||
|
||||
inp_0 = np.random.randint(min_value, max_value+1, size=shapes[0])
|
||||
inp_1 = np.random.randint(min_value, max_value+1, size=shapes[1])
|
||||
|
||||
expected_result = inp_0 @ inp_1
|
||||
|
||||
assert(np.all(expected_result < 2**p))
|
||||
assert(np.all(expected_result >= 0))
|
||||
|
||||
out_shape = expected_result.shape
|
||||
|
||||
if len(out_shape) < 2 and op == "matmul_eint_eint":
|
||||
# Matmul only works for matmuls on operands
|
||||
# that produce at least 2-dimensional outputs
|
||||
continue
|
||||
elif len(out_shape) >= 1 and op == "dot_eint_eint":
|
||||
# Dot will only be tested when the output is
|
||||
# a scalar
|
||||
continue
|
||||
|
||||
shape_0_str = format_shape(shapes[0])
|
||||
shape_1_str = format_shape(shapes[1])
|
||||
out_shape_str = format_shape(out_shape)
|
||||
|
||||
dtype = "esint" if signed else "eint"
|
||||
|
||||
op_outputs_scalar = op == "dot_eint_eint" and len(out_shape) == 0
|
||||
out_dtype_str = f"tensor<{out_shape_str}!FHE.{dtype}<{p}>>" if not op_outputs_scalar else f"!FHE.{dtype}<{p}>"
|
||||
|
||||
program = (f"description: {op}_{p}bits_{'s' if signed else 'u'}_{shape_0_str}_{shape_1_str}\n"
|
||||
f"program: |\n"
|
||||
f" func.func @main(%x: tensor<{shape_0_str}!FHE.{dtype}<{p}>>, "
|
||||
f"%y: tensor<{shape_1_str}!FHE.{dtype}<{p}>>) -> {out_dtype_str} {{\n"
|
||||
f" %0 = \"FHELinalg.{op}\"(%x, %y): (tensor<{shape_0_str}!FHE.{dtype}<{p}>>, "
|
||||
f"tensor<{shape_1_str}!FHE.{dtype}<{p}>>) -> {out_dtype_str}\n"
|
||||
f" return %0 : {out_dtype_str}\n"
|
||||
f" }}\n"
|
||||
)
|
||||
|
||||
inp_0_str = flatten_and_to_str(inp_0)
|
||||
inp_1_str = flatten_and_to_str(inp_1)
|
||||
expected_str = flatten_and_to_str(expected_result, is_tensor=not op_outputs_scalar)
|
||||
|
||||
shape_0_str_yaml = ",".join(map(str, shapes[0]))
|
||||
shape_1_str_yaml = ",".join(map(str, shapes[1]))
|
||||
expected_shape_yaml = ",".join(map(str, out_shape))
|
||||
|
||||
program += (f"p-error: {P_ERROR}\n"
|
||||
"tests:\n"
|
||||
" - inputs: \n"
|
||||
f" - tensor: {inp_0_str}\n"
|
||||
f" shape: [{shape_0_str_yaml}]\n"
|
||||
f" - tensor: {inp_1_str}\n"
|
||||
f" shape: [{shape_1_str_yaml}]\n"
|
||||
f" outputs:\n"
|
||||
)
|
||||
|
||||
if op_outputs_scalar:
|
||||
program += (
|
||||
f" - scalar: {expected_str}\n"
|
||||
)
|
||||
else:
|
||||
program += (
|
||||
f" - tensor: {expected_str}\n"
|
||||
f" shape: [{expected_shape_yaml}]\n"
|
||||
)
|
||||
|
||||
if signed:
|
||||
program += (
|
||||
f" signed: True\n"
|
||||
)
|
||||
|
||||
program += f"---"
|
||||
|
||||
print(program)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI = argparse.ArgumentParser()
|
||||
CLI.add_argument(
|
||||
"--minimal",
|
||||
help="Specify whether to generate minimal tests only",
|
||||
type=bool,
|
||||
default=False,
|
||||
)
|
||||
args = CLI.parse_args()
|
||||
generate("matmul_eint_eint")
|
||||
generate("dot_eint_eint")
|
||||
@@ -1,10 +1,21 @@
|
||||
# Python Frontend
|
||||
|
||||
## Installation for end-users
|
||||
|
||||
End-users should install `concrete-python` using `pip`:
|
||||
|
||||
```shell
|
||||
pip install concrete-python
|
||||
```
|
||||
|
||||
## Setup for development
|
||||
|
||||
Developers that want to contribute to the Concrete-Python project can use the following
|
||||
approach to setup their environment.
|
||||
|
||||
```shell
|
||||
# clone the repository
|
||||
git clone https://github.com/zama-ai/concrete.git
|
||||
git clone https://github.com/zama-ai/concrete.git --recursive
|
||||
cd concrete
|
||||
|
||||
# create virtual environment
|
||||
@@ -19,6 +30,8 @@ cd ../../compilers/concrete-compiler/compiler
|
||||
make python-bindings
|
||||
|
||||
# set bindings build directory as an environment variable
|
||||
# *** NOTE ***: You must use the Release build of the compiler!
|
||||
# For now, the Debug build is not compatible with concrete-python
|
||||
export COMPILER_BUILD_DIRECTORY=$(pwd)/build
|
||||
echo "export COMPILER_BUILD_DIRECTORY=$(pwd)/build" >> ~/.bashrc
|
||||
|
||||
@@ -26,3 +39,17 @@ echo "export COMPILER_BUILD_DIRECTORY=$(pwd)/build" >> ~/.bashrc
|
||||
cd ../../../frontends/concrete-python
|
||||
make pytest
|
||||
```
|
||||
|
||||
### VSCode setup
|
||||
|
||||
Alternatively you can use VSCode to develop Concrete-Python:
|
||||
|
||||
Suppose the compiler bindings were built in `/home/zama/concrete/compilers/concrete-compiler/compiler/build`:
|
||||
|
||||
- Create a `.env` file in the concrete-python root directory
|
||||
- Determine the absolute path of the local compiler repository, e.g. `/home/zama/concrete`. Replace this with your
|
||||
path in the following two lines
|
||||
- Add to it `PYTHONPATH=$(PYTHON_PATH):/home/zama/concrete/compilers/concrete-compiler/compiler/build/tools/concretelang/python_packages/concretelang_core/`
|
||||
- Add to it `LD_PRELOAD=/home/zama/concrete/compilers/concrete-compiler/compiler/build/lib/libConcretelangRuntime.so`
|
||||
|
||||
You can now configure `pytest` in VScode and run the tests using the graphical interface.
|
||||
|
||||
@@ -764,28 +764,40 @@ class Context:
|
||||
}
|
||||
self.error(highlights)
|
||||
|
||||
if x.is_encrypted and y.is_encrypted:
|
||||
highlights = {
|
||||
x.origin: "lhs is encrypted",
|
||||
y.origin: (
|
||||
"rhs is encrypted" if x.origin is not y.origin else "operand is encrypted"
|
||||
),
|
||||
self.converting: "but encrypted-encrypted dot products are not supported",
|
||||
}
|
||||
self.error(highlights)
|
||||
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
|
||||
if x.is_scalar or y.is_scalar:
|
||||
return self.mul(resulting_type, x, y)
|
||||
|
||||
x = self.to_signedness(x, of=resulting_type)
|
||||
y = self.to_signedness(y, of=resulting_type)
|
||||
operation = fhelinalg.DotEint if x.is_encrypted and y.is_encrypted else fhelinalg.Dot
|
||||
|
||||
if x.is_clear:
|
||||
x, y = y, x
|
||||
|
||||
return self.operation(fhelinalg.Dot, resulting_type, x.result, y.result)
|
||||
if (x.is_signed or y.is_signed) and resulting_type.is_unsigned:
|
||||
x = self.to_signed(x)
|
||||
y = self.to_signed(y)
|
||||
|
||||
signed_resulting_type = self.typeof(
|
||||
Value(
|
||||
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
|
||||
shape=resulting_type.shape,
|
||||
is_encrypted=resulting_type.is_encrypted,
|
||||
)
|
||||
)
|
||||
intermediate_result = self.operation(
|
||||
operation,
|
||||
signed_resulting_type,
|
||||
x.result,
|
||||
y.result,
|
||||
)
|
||||
|
||||
return self.to_unsigned(intermediate_result)
|
||||
|
||||
x = self.to_signedness(x, of=resulting_type)
|
||||
y = self.to_signedness(y, of=resulting_type)
|
||||
|
||||
return self.operation(operation, resulting_type, x.result, y.result)
|
||||
|
||||
def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
|
||||
assert self.is_bit_width_compatible(resulting_type, x)
|
||||
@@ -1293,29 +1305,41 @@ class Context:
|
||||
}
|
||||
self.error(highlights)
|
||||
|
||||
if x.is_encrypted and y.is_encrypted:
|
||||
highlights = {
|
||||
x.origin: "lhs is encrypted",
|
||||
y.origin: (
|
||||
"rhs is encrypted" if x.origin is not y.origin else "operand is encrypted"
|
||||
),
|
||||
self.converting: "but encrypted-encrypted matrix multiplications are not supported",
|
||||
}
|
||||
self.error(highlights)
|
||||
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
|
||||
x = self.to_signedness(x, of=resulting_type)
|
||||
y = self.to_signedness(y, of=resulting_type)
|
||||
|
||||
if resulting_type.shape == ():
|
||||
if x.is_clear:
|
||||
x, y = y, x
|
||||
|
||||
operation = fhelinalg.Dot
|
||||
operation = fhelinalg.DotEint if x.is_encrypted and y.is_encrypted else fhelinalg.Dot
|
||||
elif x.is_encrypted and y.is_encrypted:
|
||||
operation = fhelinalg.MatMulEintEintOp
|
||||
else:
|
||||
operation = fhelinalg.MatMulEintIntOp if x.is_encrypted else fhelinalg.MatMulIntEintOp
|
||||
|
||||
if (x.is_signed or y.is_signed) and resulting_type.is_unsigned:
|
||||
x = self.to_signed(x)
|
||||
y = self.to_signed(y)
|
||||
|
||||
signed_resulting_type = self.typeof(
|
||||
Value(
|
||||
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
|
||||
shape=resulting_type.shape,
|
||||
is_encrypted=resulting_type.is_encrypted,
|
||||
)
|
||||
)
|
||||
intermediate_result = self.operation(
|
||||
operation,
|
||||
signed_resulting_type,
|
||||
x.result,
|
||||
y.result,
|
||||
)
|
||||
|
||||
return self.to_unsigned(intermediate_result)
|
||||
|
||||
x = self.to_signedness(x, of=resulting_type)
|
||||
y = self.to_signedness(y, of=resulting_type)
|
||||
|
||||
return self.operation(operation, resulting_type, x.result, y.result)
|
||||
|
||||
def maxpool2d(
|
||||
|
||||
@@ -65,8 +65,8 @@ def assign_precisions_1_node(node: Node, output_p: int, inputs_p: int):
|
||||
CHUNKED_COMPARISON = {"greater", "greater_equal", "less", "less_equal"}
|
||||
CHUNKED_COMPARISON_MIN_BITWIDTH = 4
|
||||
MAX_POOLS = {"maxpool1d", "maxpool2d", "maxpool3d"}
|
||||
MULTIPLY = {"multiply"}
|
||||
ROUNDING = {"round_bit_pattern"}
|
||||
MULTIPLY = {"multiply", "matmul"}
|
||||
|
||||
|
||||
def max_encrypted_bitwidth_node(node: Node):
|
||||
@@ -88,13 +88,38 @@ def max_encrypted_bitwidth_node(node: Node):
|
||||
return normal_p + 1
|
||||
|
||||
if name in MULTIPLY and all(value.is_encrypted for value in node.inputs):
|
||||
return normal_p + 1
|
||||
# For operations that use multiply, an additional bit
|
||||
# needs to be added to the bitwidths of the inputs.
|
||||
# For single precision circuits the max of the input / output
|
||||
# precisions will be taken in required_encrypted_bitwidth. For
|
||||
# multi-precision, the circuit partitions will handle the
|
||||
# input and output precisions separately.
|
||||
all_inp_bitwidths = []
|
||||
# Need a loop here to allow typechecking and make mypy happy
|
||||
for inp in node.inputs:
|
||||
dtype_inp = inp.dtype
|
||||
assert isinstance(dtype_inp, Integer)
|
||||
all_inp_bitwidths.append(dtype_inp.bit_width)
|
||||
|
||||
normal_p = max(all_inp_bitwidths)
|
||||
|
||||
# FIXME: This probably does not work well with multi-precision!
|
||||
return max(normal_p + 1, node.output.dtype.bit_width)
|
||||
|
||||
return normal_p
|
||||
|
||||
|
||||
def required_encrypted_bitwidth(nodes: Iterable[Node]) -> int:
|
||||
"""Give the minimal precision to implement all the nodes."""
|
||||
"""Give the minimal precision to implement all the nodes.
|
||||
|
||||
This function is called for both single-precision (for the whole circuit)
|
||||
and for multi-precision circuits (for circuit partitions).
|
||||
|
||||
Ops for which the compiler introduces TLUs need to be handled explicitly
|
||||
in `max_encrypted_bitwidth_node`. The maximum
|
||||
of all precisions of the various operations is returned.
|
||||
"""
|
||||
|
||||
bitwidths = map(max_encrypted_bitwidth_node, nodes)
|
||||
return max(bitwidths, default=-1)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def test_dot(size, helpers):
|
||||
cst = np.random.randint(0, bound, size=(size,))
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
def left_function(x):
|
||||
def dot_enc_enc_function(x):
|
||||
return np.dot(x, cst)
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
@@ -36,12 +36,55 @@ def test_dot(size, helpers):
|
||||
|
||||
inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)]
|
||||
|
||||
left_function_circuit = left_function.compile(inputset, configuration)
|
||||
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
|
||||
right_function_circuit = right_function.compile(inputset, configuration)
|
||||
method_circuit = method.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, bound, size=(size,))
|
||||
|
||||
helpers.check_execution(left_function_circuit, left_function, sample)
|
||||
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample)
|
||||
helpers.check_execution(right_function_circuit, right_function, sample)
|
||||
helpers.check_execution(method_circuit, method, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"size",
|
||||
[1, 10],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"bitwidth",
|
||||
[2, 6],
|
||||
)
|
||||
@pytest.mark.parametrize("signed", [True, False])
|
||||
@pytest.mark.parametrize("negative_only", [True, False])
|
||||
def test_dot_enc_enc(size, bitwidth, negative_only, signed, helpers):
|
||||
"""
|
||||
Test dot.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
minv = 0 if not signed else -(2 ** (bitwidth - 1))
|
||||
|
||||
# +1 since randint max is not inclusive
|
||||
maxv = 2**bitwidth if not signed else 2 ** (bitwidth - 1)
|
||||
if negative_only:
|
||||
maxv = 1
|
||||
|
||||
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def dot_enc_enc_function(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
inputset = [
|
||||
(np.random.randint(minv, maxv, size=(size,)), np.random.randint(minv, maxv, size=(size,)))
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
|
||||
|
||||
sample = [
|
||||
np.random.randint(minv, maxv, size=(size,)),
|
||||
np.random.randint(minv, maxv, size=(size,)),
|
||||
]
|
||||
|
||||
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample)
|
||||
|
||||
@@ -16,6 +16,11 @@ from concrete import fhe
|
||||
(2, 3),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(3, 2),
|
||||
(2, 3),
|
||||
(0, 127),
|
||||
),
|
||||
pytest.param(
|
||||
(1, 2),
|
||||
(2, 1),
|
||||
@@ -46,6 +51,11 @@ from concrete import fhe
|
||||
(5, 5),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(5, 5),
|
||||
(-127, 127),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(5, 3),
|
||||
@@ -59,7 +69,7 @@ from concrete import fhe
|
||||
pytest.param(
|
||||
(5,),
|
||||
(4, 5, 3),
|
||||
(0, 5),
|
||||
(-5, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(4, 5, 3),
|
||||
@@ -74,7 +84,7 @@ from concrete import fhe
|
||||
pytest.param(
|
||||
(2, 4, 5, 3),
|
||||
(3,),
|
||||
(0, 5),
|
||||
(-1, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(5, 4, 3),
|
||||
@@ -156,3 +166,200 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
|
||||
helpers.check_execution(rhs_operator_circuit, rhs_operator, rhs_sample)
|
||||
helpers.check_execution(lhs_function_circuit, lhs_function, lhs_sample)
|
||||
helpers.check_execution(rhs_function_circuit, rhs_function, rhs_sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lhs_shape,rhs_shape,bounds",
|
||||
[
|
||||
pytest.param(
|
||||
(3, 2),
|
||||
(2, 3),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(3, 2),
|
||||
(2, 3),
|
||||
(0, 127),
|
||||
),
|
||||
pytest.param(
|
||||
(1, 2),
|
||||
(2, 1),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(3, 3),
|
||||
(3, 3),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(2, 1),
|
||||
(1, 2),
|
||||
(0, 7),
|
||||
),
|
||||
pytest.param(
|
||||
(2,),
|
||||
(2,),
|
||||
(0, 7),
|
||||
),
|
||||
pytest.param(
|
||||
(5, 5),
|
||||
(5,),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(5, 5),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(5, 5),
|
||||
(-63, 63),
|
||||
),
|
||||
pytest.param(
|
||||
(2,),
|
||||
(2, 7),
|
||||
(-63, 0),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(5, 3),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(5, 3),
|
||||
(3,),
|
||||
(0, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(4, 5, 3),
|
||||
(-5, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(4, 5, 3),
|
||||
(3,),
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(2, 4, 5, 3),
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(2, 4, 5, 3),
|
||||
(3,),
|
||||
(-1, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(5, 4, 3),
|
||||
(3, 2),
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(4, 3),
|
||||
(5, 3, 2),
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(2, 5, 4, 3),
|
||||
(3, 2),
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(5, 4, 3),
|
||||
(1, 3, 2),
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(1, 4, 3),
|
||||
(5, 3, 2),
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(5, 4, 3),
|
||||
(2, 1, 3, 2),
|
||||
(0, 5),
|
||||
),
|
||||
pytest.param(
|
||||
(2, 1, 4, 3),
|
||||
(5, 3, 2),
|
||||
(0, 5),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_matmul_enc_enc_and_clear(lhs_shape, rhs_shape, bounds, helpers):
|
||||
"""
|
||||
Test matmul.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
minimum, maximum = bounds
|
||||
|
||||
# Matmul of clear values and encrypted matrices
|
||||
@fhe.compiler({"x": "encrypted", "y": "clear"})
|
||||
def lhs_operator_clear(x, y):
|
||||
return x @ y
|
||||
|
||||
# Matmul of two encrypted matrices
|
||||
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def enc_function_xy(x, y):
|
||||
return np.matmul(x, y)
|
||||
|
||||
# Put all the dual operand functions in a list
|
||||
|
||||
# FIXME: add lhs_operator_clear to this list to
|
||||
# re-enable the testing with clear values
|
||||
dual_operand_functions = [enc_function_xy]
|
||||
|
||||
# Compile each dual operand function and test it on random data
|
||||
for func in dual_operand_functions:
|
||||
dual_op_inputset = [
|
||||
(
|
||||
np.random.randint(minimum, maximum, size=lhs_shape),
|
||||
np.random.randint(minimum, maximum, size=rhs_shape),
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
dual_op_circuit = func.compile(dual_op_inputset, configuration)
|
||||
|
||||
lhs_sample, rhs_sample = np.random.randint(
|
||||
minimum, maximum, size=lhs_shape
|
||||
), np.random.randint(minimum, maximum, size=rhs_shape)
|
||||
|
||||
helpers.check_execution(dual_op_circuit, func, [lhs_sample, rhs_sample])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bitwidth", [4, 10])
|
||||
@pytest.mark.parametrize("signed", [True, False])
|
||||
def test_matmul_zero(bitwidth, signed, helpers):
|
||||
"""
|
||||
Test matmul.
|
||||
"""
|
||||
|
||||
lhs_shape = (2, 1)
|
||||
rhs_shape = (1, 2)
|
||||
range_lhs = (-(2 ** (bitwidth - 1)), 2 ** (bitwidth - 1) - 1) if signed else (0, 2**bitwidth)
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
# Matmul of two encrypted matrices
|
||||
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def enc_function_xy(x, y):
|
||||
return x * y
|
||||
|
||||
dual_op_inputset = [
|
||||
(
|
||||
np.random.randint(range_lhs[0], range_lhs[1], size=lhs_shape),
|
||||
np.zeros(rhs_shape, dtype=np.int64),
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
dual_op_circuit = enc_function_xy.compile(dual_op_inputset, configuration)
|
||||
|
||||
lhs_sample, rhs_sample = np.random.randint(
|
||||
range_lhs[0], range_lhs[1], size=lhs_shape
|
||||
), np.zeros(rhs_shape, dtype=np.int64)
|
||||
|
||||
helpers.check_execution(dual_op_circuit, enc_function_xy, [lhs_sample, rhs_sample])
|
||||
|
||||
@@ -152,6 +152,26 @@ def test_constant_mul(function, parameters, helpers):
|
||||
"x": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-10, 10], "status": "encrypted", "shape": (3, 1)},
|
||||
"y": {"range": [0, 0], "status": "encrypted", "shape": (3, 1)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [10, 20], "status": "encrypted", "shape": (3, 1)},
|
||||
"y": {"range": [0, 0], "status": "encrypted", "shape": (1, 3)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [2**12, 2**13 - 1], "status": "encrypted", "shape": (3, 1)},
|
||||
"y": {"range": [0, 0], "status": "encrypted", "shape": (1, 3)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [2**12, 2**13 - 1], "status": "encrypted", "shape": (3, 1)},
|
||||
"y": {"range": [0, 2 * 3 - 1], "status": "encrypted", "shape": (1, 3)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-(2**7), 2**7 - 1], "status": "encrypted", "shape": (3, 1)},
|
||||
"y": {"range": [-(2**7), 2**7 - 1], "status": "encrypted", "shape": (1, 3)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_mul(function, parameters, helpers):
|
||||
|
||||
@@ -359,30 +359,6 @@ return %3
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.dot(x, y),
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[
|
||||
(
|
||||
np.ones(shape=(3,), dtype=np.int64),
|
||||
np.ones(shape=(3,), dtype=np.int64),
|
||||
)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
|
||||
%1 = y # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint2> ∈ [3, 3]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted dot products are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{"x": "clear", "y": "clear"},
|
||||
@@ -398,25 +374,6 @@ Function you are trying to compile cannot be compiled
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
|
||||
%2 = matmul(%0, %1) # ClearTensor<uint5, shape=(2, 2)> ∈ [5, 20]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear matrix multiplications are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[([[1, 2], [3, 4]], [[4, 3], [2, 1]])],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be compiled
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
|
||||
%1 = y # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint5, shape=(2, 2)> ∈ [5, 20]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted matrix multiplications are not supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user