feat(compiler): add matmul eint eint op

This commit is contained in:
Andrei Stoian
2023-04-12 17:33:14 +02:00
committed by Andrei Stoian
parent a5c679f0dc
commit 817ee6b637
23 changed files with 1522 additions and 235 deletions

3
.gitignore vendored
View File

@@ -5,3 +5,6 @@
# Jetbrains tools
.idea/
# HPX library
compilers/concrete-compiler/compiler/hpx*

View File

@@ -50,6 +50,12 @@ include_directories(${PROJECT_BINARY_DIR}/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})
if(DEFINED LLVM_USE_LINKER AND (NOT ${LLVM_USE_LINKER} STREQUAL ""))
message(INFO " Using custom Linker: ${CMAKE_LINKER}")
else()
message(INFO " Using standard linker")
endif()
# Custom doc generation function
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
include(AddConcretelangDoc)

View File

@@ -1,7 +1,9 @@
LLVM_PROJECT_DIR=../llvm-project
BUILD_TYPE?=Release
BUILD_DIR?=./build
# Build dir is `build` for Release builds and `build_BUILD_TYPE` for
# other build types (e.g. Debug, ReleaseWithDebugInfo, etc.)
BUILD_DIR?=./$(shell ./get_build_dir.sh "$(BUILD_TYPE)")
Python3_EXECUTABLE?=$(shell which python3)
BINDINGS_PYTHON_ENABLED=ON
DATAFLOW_EXECUTION_ENABLED=OFF
@@ -86,6 +88,15 @@ else
CXX_COMPILER_OPTION=
endif
# If the build type is Debug, and the lld linked is installed
# then use it
CUSTOM_LINKER_OPTS=
ifneq ($(shell which lld),)
ifeq ($(BUILD_TYPE),Debug)
CUSTOM_LINKER_OPTS=-DLLVM_USE_LINKER=lld
endif
endif
# don't run parallel python tests if compiler doesn't support it
ifeq ($(DATAFLOW_EXECUTION_ENABLED),ON)
PYTHON_TESTS_MARKER=""
@@ -119,6 +130,7 @@ $(BUILD_DIR)/configured.stamp:
$(CMAKE_CCACHE_OPTIONS) \
$(CC_COMPILER_OPTION) \
$(CXX_COMPILER_OPTION) \
$(CUSTOM_LINKER_OPTS) \
-DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="host" \
@@ -138,6 +150,10 @@ $(BUILD_DIR)/configured.stamp:
build-initialized: $(BUILD_DIR)/configured.stamp
reconfigure:
rm $(BUILD_DIR)/configured.stamp
rm $(BUILD_DIR)/CMakeCache.txt
doc: build-initialized
cmake --build $(BUILD_DIR) --target mlir-doc
@@ -263,7 +279,14 @@ $(FIXTURE_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py FORCE
$(FIXTURE_CPU_DIR)/bug_report.yaml:
unzip -o $(FIXTURE_CPU_DIR)/bug_report.zip -d $(FIXTURE_CPU_DIR)
generate-cpu-tests: $(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml $(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(FIXTURE_CPU_DIR)/bug_report.yaml $(FIXTURE_CPU_DIR)/end_to_end_round.yaml $(FIXTURE_CPU_DIR)/end_to_end_multi_precision.yaml
generate-cpu-tests: \
$(FIXTURE_CPU_DIR)/end_to_end_leveled.yaml \
$(FIXTURE_CPU_DIR)/end_to_end_apply_lookup_table.yaml \
$(FIXTURE_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml \
$(FIXTURE_CPU_DIR)/bug_report.yaml \
$(FIXTURE_CPU_DIR)/end_to_end_round.yaml \
$(FIXTURE_CPU_DIR)/end_to_end_multi_precision.yaml \
$(FIXTURE_CPU_DIR)/end_to_end_linalg_enc_enc_matmul_dot.yaml
SECURITY_TO_TEST=128
OPTIMIZATION_STRATEGY_TO_TEST=dag-mono dag-multi
@@ -275,7 +298,8 @@ run-end-to-end-tests: $(GTEST_PARALLEL_PY) build-end-to-end-tests generate-cpu-t
$(GTEST_PARALLEL_CMD) $(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/$(TEST);)
$(foreach optimizer_strategy,$(OPTIMIZATION_STRATEGY_TO_TEST), $(foreach security,$(SECURITY_TO_TEST), \
$(GTEST_PARALLEL_CMD) $(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_test \
$(GTEST_PARALLEL_SEPARATOR) --backend=cpu --security-level=$(security) --optimizer-strategy=$(optimizer_strategy) --jit $(FIXTURE_CPU_DIR)/*.yaml;))
$(GTEST_PARALLEL_SEPARATOR) --backend=cpu --security-level=$(security) \
--optimizer-strategy=$(optimizer_strategy) --jit $(FIXTURE_CPU_DIR)/*.yaml;))
### end-to-end-tests GPU
@@ -542,4 +566,5 @@ FORCE:
opt \
mlir-opt \
mlir-cpu-runner \
mlir-translate
mlir-translate \
reconfigure

View File

@@ -70,6 +70,12 @@ Run the compiler
./build-Release/bin/concretecompiler
```
#### Debug build and custom linker
To build a debug version of the project, you can set `BUILD_TYPE=Debug` in the `Makefile`. In `Debug`
the build system will detect if the `lld` linker is installed on the system and use it. `lld` is much faster
than the default `ld` linker. Release builds with `lld` can also be enabled by modifying the `Makefile`.
### Installation from source
You can install libs, bins, and include files into a specific directory by running:

View File

@@ -0,0 +1,10 @@
#!/bin/env bash
BUILD_TYPE=$1
if [[ ${BUILD_TYPE,,} = "release" ]]; then
echo "build"
else
echo "build_${BUILD_TYPE}"
fi

View File

@@ -25,6 +25,11 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(Operation &op,
FheIntegerInterface &a,
IntegerType &b);
// Checks the consistency between two integer inputs of an operation
bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
FheIntegerInterface &a,
FheIntegerInterface &b);
/// Shared error message for all ApplyLookupTable variant Op (several Dialect)
/// E.g. FHE.apply_lookup_table(input, lut)
/// Message when the lut tensor has an invalid size,

View File

@@ -587,12 +587,37 @@ def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int"> {
let hasVerifier = 1;
}
def FHELinalg_DotEint : FHELinalg_Op<"dot_eint_eint"> {
let summary = "Returns the encrypted dot product between two vectors of encrypted integers.";
let description = [{
Performs a dot product between two vectors of encrypted integers.
Examples:
```mlir
// Returns the dot product of `%a0` with `%a1`
"FHELinalg.dot_eint_eint"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4x!FHE.eint<4>>) -> !FHE.eint<4>
```
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs);
let results = (outs FHE_AnyEncryptedInteger:$out);
let hasVerifier = 1;
}
def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a matrix of clear integers.";
let description = [{
Performs a matrix multiplication of a matrix of encrypted integers and a matrix of clear integers.
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
The width of the clear integers must be less than or equals to the width of encrypted integers.
The behavior depends on the arguments in the following way:
@@ -730,7 +755,7 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt
let description = [{
Performs a matrix multiplication of a matrix of clear integers and a matrix of encrypted integers.
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
The width of the clear integers must be less than or equals to the width of encrypted integers.
The behavior depends on the arguments in the following way:
@@ -863,6 +888,144 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt
let hasVerifier = 1;
}
def FHELinalg_MatMulEintEintOp : FHELinalg_Op<"matmul_eint_eint", [TensorBinaryEint]> {
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a second matrix of encrypted integers.";
let description = [{
Performs a matrix multiplication of a matrix of encrypted integers and a second matrix of encrypted integers.
The behavior depends on the arguments in the following way:
- If both arguments are 2-D,
they are multiplied like conventional matrices.
e.g.,
arg0: tensor<MxN> = [...]
arg1: tensor<NxP> = [...]
result: tensor<MxP> = [...]
- If the first argument is a vector (1-D),
it is treated as a matrix with a single row and standard matrix multiplication is performed.
After standard matrix multiplication,
the first dimension is removed from the result.
e.g.,
arg0: tensor<3> = [x, y, z]
arg1: tensor<3xM> = [
[_, _, ..., _, _],
[_, _, ..., _, _],
[_, _, ..., _, _],
]
is treated as
arg0: tensor<1x3> = [
[x, y, z]
]
arg1: tensor<3xM> = [
[_, _, ..., _, _],
[_, _, ..., _, _],
[_, _, ..., _, _],
]
and matrix multiplication is performed with the following form (1x3 @ 3xM -> 1xM)
result: tensor<1xM> = [[_, _, ..., _, _]]
finally, the first dimension is removed by definition so the result has the following form
result: tensor<M> = [_, _, ..., _, _]
- If the second argument is 1-D,
it is treated as a matrix with a single column and standard matrix multiplication is performed.
After standard matrix multiplication,
the last dimension is removed from the result.
e.g.,
arg0: tensor<Mx3> = [
[_, _, _],
[_, _, _],
...,
[_, _, _],
[_, _, _],
]
arg1: tensor<3> = [x, y, z]
is treated as
arg0: tensor<Mx3> = [
[_, _, _],
[_, _, _],
...,
[_, _, _],
[_, _, _],
]
arg1: tensor<3x1> = [
[x],
[y],
[z],
]
and matrix multiplication is performed with the following form (Mx3 @ 3x1 -> Mx1)
result: tensor<Mx1> = [
[_],
[_],
...,
[_],
[_],
]
finally, the last dimension is removed by definition so the result has the following form
result: tensor<M> = [_, _, _]
- If either argument is N-D where N > 2,
the operation is treated as a collection of matrices residing in the last two indices and broadcasted accordingly.
arg0: tensor<Kx1MxN> = [...]
arg1: tensor<LxNxP> = [...]
result: tensor<KxLxMxP> = [...]
```mlir
"FHELinalg.matmul_eint_eint(%a, %b) : (tensor<MxNx!FHE.eint<p>>, tensor<NxPx!FHE.eint<p>'>) -> tensor<MxPx!FHE.eint<p>>"
"FHELinalg.matmul_eint_eint(%a, %b) : (tensor<KxLxMxNx!FHE.eint<p>>, tensor<KxLxNxPx!FHE.eint<p>'>) -> tensor<KxLxMxPx!FHE.eint<p>>"
"FHELinalg.matmul_eint_eint(%a, %b) : (tensor<MxNx!FHE.eint<p>>, tensor<Nx!FHE.eint<p>'>) -> tensor<Mx!FHE.eint<p>>"
"FHELinalg.matmul_eint_eint(%a, %b) : (tensor<Nx!FHE.eint<p>>, tensor<NxPx!FHE.eint<p>'>) -> tensor<Px!FHE.eint<p>>"
```
Examples:
```mlir
// Returns the matrix multiplication of a 3x2 matrix of encrypted integers and a 2x3 matrix of integers.
// [ 1, 2, 3]
// [ 2, 3, 4]
// *
// [1,2] [ 5, 8,11]
// [3,4] = [11,18,25]
// [5,6] [17,28,39]
//
"FHELinalg.matmul_eint_eint"(%a, %b) : (tensor<3x2x!FHE.eint<6>>, tensor<2x3x!FHE.eint<6>>) -> tensor<3x3x!FHE.eint<12>>
```
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
}
def FHELinalg_SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
let summary = "Returns the sum of elements of a tensor of encrypted integers along specified axes.";

View File

@@ -45,11 +45,16 @@ inline void forwardOptimizerID(mlir::Operation *source,
destination->setAttr("TFHE.OId", optimizerIdAttr);
}
struct DotToLinalgGeneric
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::Dot> {
DotToLinalgGeneric(::mlir::MLIRContext *context)
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Dot>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
template <typename DotOp, typename FHEMulOp>
struct DotToLinalgGeneric : public ::mlir::OpRewritePattern<DotOp> {
DotToLinalgGeneric(
::mlir::MLIRContext *context,
std::function<FHEMulOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
mlir::Value, mlir::Value)>
createMulOp)
: ::mlir::OpRewritePattern<DotOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT),
createMulOp(createMulOp) {}
/// This rewrite pattern transforms any instance of
/// `FHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
@@ -87,7 +92,7 @@ struct DotToLinalgGeneric
/// %o = tensor.extract %1[%c0] : tensor<1x!FHE.eint<0>>
///
::mlir::LogicalResult
matchAndRewrite(::mlir::concretelang::FHELinalg::Dot dotOp,
matchAndRewrite(DotOp dotOp,
::mlir::PatternRewriter &rewriter) const override {
auto zeroTensorOp = rewriter.create<mlir::concretelang::FHE::ZeroTensorOp>(
@@ -111,9 +116,9 @@ struct DotToLinalgGeneric
auto regBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::concretelang::FHE::MulEintIntOp mul =
nestedBuilder.create<mlir::concretelang::FHE::MulEintIntOp>(
dotOp.getLoc(), blockArgs[0], blockArgs[1]);
auto mul = this->createMulOp(nestedBuilder, dotOp.getLoc(),
dotOp.getResult().getType(), blockArgs[0],
blockArgs[1]);
forwardOptimizerID(dotOp, mul);
mlir::concretelang::FHE::AddEintOp add =
nestedBuilder.create<mlir::concretelang::FHE::AddEintOp>(
@@ -140,6 +145,11 @@ struct DotToLinalgGeneric
return ::mlir::success();
};
private:
std::function<FHEMulOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
mlir::Value, mlir::Value)>
createMulOp;
};
mlir::AffineMap
@@ -826,13 +836,13 @@ struct FHELinalgNegEintToLinalgGeneric
/// linalg.yield %e : !FHE.eint<p>
/// }
///
template <typename FHELinalgMatmulOp>
template <typename FHELinalgMatmulOp, typename FHEMulOp>
struct FHELinalgMatmulToLinalgGeneric
: public mlir::OpRewritePattern<FHELinalgMatmulOp> {
FHELinalgMatmulToLinalgGeneric(
mlir::MLIRContext *context,
std::function<FHE::MulEintIntOp(mlir::OpBuilder &, mlir::Location,
mlir::Type, mlir::Value, mlir::Value)>
std::function<FHEMulOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
mlir::Value, mlir::Value)>
createMulOp,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
@@ -1091,8 +1101,8 @@ struct FHELinalgMatmulToLinalgGeneric
};
private:
std::function<FHE::MulEintIntOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
mlir::Value, mlir::Value)>
std::function<FHEMulOp(mlir::OpBuilder &, mlir::Location, mlir::Type,
mlir::Value, mlir::Value)>
createMulOp;
};
@@ -2195,7 +2205,21 @@ void FHETensorOpsToLinalg::runOnOperation() {
target.addLegalOp<bufferization::AllocTensorOp>();
mlir::RewritePatternSet patterns(&getContext());
patterns.insert<DotToLinalgGeneric>(&getContext());
patterns.insert<DotToLinalgGeneric<mlir::concretelang::FHELinalg::Dot,
mlir::concretelang::FHE::MulEintIntOp>>(
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(
loc, type, arg0, arg1);
});
patterns.insert<DotToLinalgGeneric<mlir::concretelang::FHELinalg::DotEint,
mlir::concretelang::FHE::MulEintOp>>(
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
return builder.create<mlir::concretelang::FHE::MulEintOp>(loc, type,
arg0, arg1);
});
patterns.insert<
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::AddEintOp,
mlir::concretelang::FHE::AddEintOp>>(
@@ -2227,19 +2251,29 @@ void FHETensorOpsToLinalg::runOnOperation() {
patterns.insert<FHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
patterns.insert<FHELinalgNegEintToLinalgGeneric>(&getContext());
patterns.insert<FHELinalgMatmulToLinalgGeneric<
mlir::concretelang::FHELinalg::MatMulEintIntOp>>(
mlir::concretelang::FHELinalg::MatMulEintIntOp,
mlir::concretelang::FHE::MulEintIntOp>>(
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(
loc, type, arg0, arg1);
});
patterns.insert<FHELinalgMatmulToLinalgGeneric<
mlir::concretelang::FHELinalg::MatMulIntEintOp>>(
mlir::concretelang::FHELinalg::MatMulIntEintOp,
mlir::concretelang::FHE::MulEintIntOp>>(
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(
loc, type, arg1, arg0);
});
patterns.insert<FHELinalgMatmulToLinalgGeneric<
mlir::concretelang::FHELinalg::MatMulEintEintOp,
mlir::concretelang::FHE::MulEintOp>>(
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
return builder.create<mlir::concretelang::FHE::MulEintOp>(loc, type,
arg1, arg0);
});
patterns.insert<FHELinalgApplyMultiLookupTableToLinalgGeneric>(&getContext());
patterns.insert<FHELinalgApplyMappedLookupTableToLinalgGeneric>(
&getContext());

View File

@@ -159,6 +159,10 @@ struct FunctionToDag {
DEBUG("Replace Dot by LevelledOp on " << op);
index = addLevelledOp(dag, op, encrypted_inputs);
}
} else if (auto dot = asDotEint(op)) {
index = addDotEint(dag, dot, encrypted_inputs, precision);
// The above function call sets the OIds, can return right away
return;
} else if (auto mul = asMul(op)) {
// special case as mul are rewritten in several optimizer nodes
addMul(dag, mul, encrypted_inputs, precision);
@@ -175,6 +179,10 @@ struct FunctionToDag {
// special case as max are rewritten in several optimizer nodes
addMaxpool2d(dag, maxpool2d, encrypted_inputs, precision);
return;
} else if (auto matmulEintEint = asMatmulEintEint(op)) {
index =
addEncMatMulTensor(dag, matmulEintEint, encrypted_inputs, precision);
return;
} else {
index = addLevelledOp(dag, op, encrypted_inputs);
}
@@ -352,7 +360,7 @@ struct FunctionToDag {
dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost,
tluSubManp, slice(resultShape), comment);
index[result] = resultNode;
// Set attribute on the MLIR node
mlir::Builder builder(mulOp.getContext());
mlir::SmallVector<int32_t, 5> operatorIndexes = {
(int32_t)addNode.index, (int32_t)lhsTluNode.index,
@@ -362,10 +370,193 @@ struct FunctionToDag {
// We push that at the end by convention
operatorIndexes.push_back(lhsCorrectionNode.value().index);
}
if (setOptimizerID)
mulOp->setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes));
}
template <typename InnerProductOp>
concrete_optimizer::dag::OperatorIndex
addTensorInnerProductEncEnc(optimizer::Dag &dag,
InnerProductOp &innerProductOp, Inputs &inputs,
int precision) {
mlir::Value result = innerProductOp.getResult();
const std::vector<uint64_t> resultShape = getShape(result);
// We assume a first tensorized matmul step
// is the construction of matrices of:
// - sums of all pairs
// - differences of all pairs
// where pairs are pairs of values that are to be multiplied together
// Compute the number of elements in each of the
// matrices of pairs
auto lhsType = ((mlir::Type)innerProductOp.getLhs().getType())
.cast<mlir::RankedTensorType>();
auto rhsType = ((mlir::Type)innerProductOp.getRhs().getType())
.cast<mlir::RankedTensorType>();
std::vector<int64_t> lhsShape = lhsType.getShape();
std::vector<int64_t> rhsShape = rhsType.getShape();
if (rhsShape.size() == 1)
rhsShape.push_back(1);
if (lhsShape.size() == 1)
lhsShape.emplace(lhsShape.begin(), 1);
int64_t rhsDims = (int64_t)rhsShape.size();
int64_t lhsDims = (int64_t)lhsShape.size();
// Suppose lhsDims is (5, 3, 2) -> 5 matrices of size 3x2 (2 is the
// reduction dimension) and rhsDims is (3, 5, 2, 3) -> 3x 5 matrices of
// size 2x3 the pair matrix would have size (3, 5, 3, 3, 2) this is the
// shape of the matrix onto which we apply the TLUs that compute the
// multiplication of all pairs of values
// The RHS can be a (N,) matrix, the outer dimension is supposed to be 1
int64_t rhsOuterDim = rhsShape[rhsDims - 1];
int64_t lhsOuterDim = lhsShape[lhsDims - 2];
std::vector<uint64_t> pairMatrixShape;
// Compute the output matrix dimension
// Corresponding dimensions that are considered "compatible" are "N, 1", "1,
// N", "N, N"
int64_t rhsDimIter = rhsDims - 3, lhsDimIter = lhsDims - 3;
if (rhsDimIter >= 0 && lhsDimIter >= 0) {
while (rhsDimIter >= 0 && lhsDimIter >= 0 &&
(lhsShape[lhsDimIter] == rhsShape[rhsDimIter] ||
lhsShape[lhsDimIter] == 1 || rhsShape[rhsDimIter] == 1)) {
pairMatrixShape.push_back(
std::max(rhsShape[rhsDimIter], lhsShape[lhsDimIter]));
--lhsDimIter;
--rhsDimIter;
}
}
assert((lhsDimIter < 0 || rhsDimIter < 0) &&
"Bad dimensions given to matmul or dot");
while (lhsDimIter >= 0) {
pairMatrixShape.push_back(lhsShape[lhsDimIter]);
--lhsDimIter;
}
while (rhsDimIter >= 0) {
pairMatrixShape.push_back(rhsShape[rhsDimIter]);
--rhsDimIter;
}
// Add the outer dimensions of the individual matrices
pairMatrixShape.push_back(lhsOuterDim);
pairMatrixShape.push_back(rhsOuterDim);
// Add the reduction dimension
// The number of elements in the dot product
// is the number of cells on the reduction axis (aka "destroyed dimension")
int64_t reductionDimSize = rhsShape[rhsDims - 2];
assert(lhsShape[lhsDims - 1] == reductionDimSize);
pairMatrixShape.push_back(reductionDimSize);
// Compute the manp of the various steps
// in the matmul of enc x enc:
// 1. (x + y) and (x - y) -> supposing broadcasting is used
// to tensorize this operation
Operation *xOp = innerProductOp.getLhs().getDefiningOp();
Operation *yOp = innerProductOp.getRhs().getDefiningOp();
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
llvm::APInt xSmanp = llvm::APInt{1, 1, false};
if (xOp != nullptr) {
const auto xSmanpAttr = xOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
assert(xSmanpAttr && "Missing SMANP value on a crypto operation");
xSmanp = xSmanpAttr.getValue();
}
llvm::APInt ySmanp = llvm::APInt{1, 1, false};
if (yOp != nullptr) {
const auto ySmanpAttr = yOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
assert(ySmanpAttr && "Missing SMANP value on a crypto operation");
ySmanp = ySmanpAttr.getValue();
}
auto loc = loc_to_string(innerProductOp.getLoc());
auto comment =
std::string(innerProductOp->getName().getStringRef()) + " " + loc;
// (x + y) and (x - y)
const double addSubManp =
sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble());
auto addNode =
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
addSubManp, slice(pairMatrixShape), comment);
auto subNode =
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
addSubManp, slice(pairMatrixShape), comment);
// tlu(x - y), tlu(x + y)
const std::vector<std::uint64_t> unknownFunction;
auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision);
auto rhsTluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
// 3. Sum(tlu(x + y) - tlu(x - y))
// Create a leveled op that simulates concatenation. It takes
// as inputs all the intermediary dot product results and produces
// the output tensor
// Default complexity is negligible
double fixed_cost = NEGLIGIBLE_COMPLEXITY;
double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY;
// For the output of the operation, take the MANP from the MANP pass
mlir::Operation *op = innerProductOp.getOperation();
mlir::IntegerAttr smanp_int = op->getAttrOfType<mlir::IntegerAttr>("SMANP");
assert(smanp_int && "Missing manp value on a crypto operation");
Inputs tluOpIndices{lhsTluNode, rhsTluNode};
// TODO: use APIFloat.sqrt when it's available
double manp = sqrt(smanp_int.getValue().roundToDouble());
index[result] =
dag->add_levelled_op(slice(tluOpIndices), lwe_dim_cost_factor,
fixed_cost, manp, slice(resultShape), comment);
mlir::Builder builder(innerProductOp.getContext());
mlir::SmallVector<int32_t, 5> operatorIndexes = {
(int32_t)addNode.index, (int32_t)lhsTluNode.index,
(int32_t)subNode.index, (int32_t)rhsTluNode.index,
(int32_t)index[result].index};
if (setOptimizerID)
innerProductOp->setAttr("TFHE.OId",
builder.getDenseI32ArrayAttr(operatorIndexes));
return index[result];
}
concrete_optimizer::dag::OperatorIndex
addEncMatMulTensor(optimizer::Dag &dag, FHELinalg::MatMulEintEintOp &matmulOp,
Inputs &inputs, int precision) {
return addTensorInnerProductEncEnc<FHELinalg::MatMulEintEintOp>(
dag, matmulOp, inputs, precision);
}
concrete_optimizer::dag::OperatorIndex addDotEint(optimizer::Dag &dag,
FHELinalg::DotEint &dotOp,
Inputs &inputs,
int precision) {
return addTensorInnerProductEncEnc<FHELinalg::DotEint>(dag, dotOp, inputs,
precision);
}
void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs,
int precision) {
mlir::Value result = maxOp.getResult();
@@ -547,6 +738,10 @@ struct FunctionToDag {
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op);
}
mlir::concretelang::FHELinalg::DotEint asDotEint(mlir::Operation &op) {
return llvm::dyn_cast<mlir::concretelang::FHELinalg::DotEint>(op);
}
mlir::concretelang::FHE::MulEintOp asMul(mlir::Operation &op) {
return llvm::dyn_cast<mlir::concretelang::FHE::MulEintOp>(op);
}
@@ -563,6 +758,11 @@ struct FunctionToDag {
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Maxpool2dOp>(op);
}
mlir::concretelang::FHELinalg::MatMulEintEintOp
asMatmulEintEint(mlir::Operation &op) {
return llvm::dyn_cast<mlir::concretelang::FHELinalg::MatMulEintEintOp>(op);
}
bool isReturn(mlir::Operation &op) {
return llvm::isa<mlir::func::ReturnOp>(op);
}

View File

@@ -341,6 +341,48 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Dot op,
}
}
/// Calculates the squared Minimal Arithmetic Noise Padding of an
/// `FHELinalg.dot_eint_eint` operation.
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::DotEint op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value();
auto rhsType =
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
int64_t rhsDims = (int64_t)rhsShape.size();
assert(rhsDims == 1 && "In MANP computation dot product RHS expected to have "
"a single dimension");
int64_t N = rhsShape[0];
// Compute output MANP:
// Tlu output MANP is 1
llvm::APInt tlu = {1, 1, false};
// The element-wise multiplication is given by the
// subtraction of two TLU outputs. The MANP of the multiplication is thus
// the sum of the TLU MANPs
llvm::APInt elemMulNorm = APIntWidthExtendUAdd(tlu, tlu);
llvm::APInt accNorm = llvm::APInt{1, 0, false};
// For the total Dot product MANP, take the manp of the sum of products
for (int64_t i = 0; i < N; i++) {
accNorm = APIntWidthExtendUAdd(elemMulNorm, accNorm);
}
return accNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of an
/// `FHE.add_eint_int` operation.
static llvm::APInt getSqMANP(mlir::concretelang::FHE::AddEintIntOp op,
@@ -522,12 +564,12 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHE::MulEintOp op,
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value();
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value();
const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y);
// The MANP of this operation is simply the MANP after the TLUs
// which is equal to the sum of outputs of 2 TLUs
const llvm::APInt tlu = {1, 1, false};
const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu);
// this is not optimal as it can increase the resulting noise unnecessarily
return APIntUMax(beforeTLUs, result);
return result;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
@@ -742,12 +784,12 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MulEintOp op,
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value();
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value();
const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y);
// The MANP of this operation is simply the MANP after the TLUs
// which is equal to the sum of outputs of 2 TLUs
const llvm::APInt tlu = {1, 1, false};
const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu);
// this is not optimal as it can increase the resulting noise unnecessarily
return APIntUMax(beforeTLUs, result);
return result;
}
static llvm::APInt computeVectorNorm(
@@ -755,7 +797,10 @@ static llvm::APInt computeVectorNorm(
mlir::DenseIntElementsAttr denseValues, llvm::APInt encryptedOperandNorm,
llvm::SmallVector<uint64_t, /*size-hint=*/4> &elementSelector) {
llvm::APInt accumulationNorm = llvm::APInt{1, 1, false};
// The accumulator is initialized with 0s in all bits (not the encrypted 0)
// the there is no initial noise in the accumulator, so its
// MANP is initialized to 0
llvm::APInt accumulationNorm = llvm::APInt{1, 0, false};
for (int64_t i = 0; i < shape[axis]; i++) {
elementSelector[axis] = i;
@@ -845,7 +890,7 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintIntOp op,
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt accNorm = llvm::APInt{1, 1, false};
llvm::APInt accNorm = llvm::APInt{1, 0, false};
mlir::arith::ConstantOp cstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
@@ -866,7 +911,10 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintIntOp op,
int64_t P = rhsShape[1];
for (int64_t m = 0; m < M; m++) {
for (int64_t p = 0; p < P; p++) {
llvm::APInt tmpNorm = llvm::APInt{1, 1, false};
// The accumulator is initialized with 0s in all bits (not the
// encrypted 0) the there is no initial noise in the accumulator, so
// its MANP is initialized to 0
llvm::APInt tmpNorm = llvm::APInt{1, 0, false};
for (int64_t n = 0; n < N; n++) {
llvm::APInt cst = denseValsAP[{(uint64_t)n, (uint64_t)p}];
llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst);
@@ -943,7 +991,7 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulIntEintOp op,
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value();
llvm::APInt accNorm = llvm::APInt{1, 1, false};
llvm::APInt accNorm = llvm::APInt{1, 0, false};
mlir::arith::ConstantOp cstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
@@ -1018,6 +1066,46 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulIntEintOp op,
return accNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a matmul
/// operation
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintEintOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
auto rhsType =
((mlir::Type)op.getRhs().getType()).cast<mlir::RankedTensorType>();
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
int64_t rhsDims = (int64_t)rhsShape.size();
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().has_value() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value();
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value();
int64_t N = rhsDims <= 2 ? rhsShape[0] : rhsShape[rhsDims - 2];
// Compute MANP of a single matrix cell x matrix cell multiplication
// This is used later to compute the MANP of an entire dot product
llvm::APInt tlu = {1, 1, false};
llvm::APInt elemMulNorm = APIntWidthExtendUAdd(tlu, tlu);
llvm::APInt accNorm = llvm::APInt{1, 0, false};
// For the total MatMul MANP, take the MANP of a single
// column-row dot-product
// All such dot-products produce the same MANP, there
// is no need to take the maximum over the dot-products
for (int64_t i = 0; i < N; i++) {
accNorm = APIntWidthExtendUAdd(elemMulNorm, accNorm);
}
return accNorm;
}
static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::TransposeOp op,
llvm::ArrayRef<const MANPLattice *> operandMANPs) {
@@ -1373,6 +1461,9 @@ public:
else if (auto dotOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op)) {
norm2SqEquiv = getSqMANP(dotOp, operands);
} else if (auto dotEintOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::DotEint>(op)) {
norm2SqEquiv = getSqMANP(dotEintOp, operands);
} else if (auto addEintIntOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::AddEintIntOp>(
op)) {
@@ -1419,6 +1510,9 @@ public:
} else if (auto matmulIntEintOp = llvm::dyn_cast<
mlir::concretelang::FHELinalg::MatMulIntEintOp>(op)) {
norm2SqEquiv = getSqMANP(matmulIntEintOp, operands);
} else if (auto matmulEintEintOp = llvm::dyn_cast<
mlir::concretelang::FHELinalg::MatMulEintEintOp>(op)) {
norm2SqEquiv = getSqMANP(matmulEintEintOp, operands);
} else if (llvm::isa<
mlir::concretelang::FHELinalg::ApplyLookupTableEintOp,
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp,

View File

@@ -401,35 +401,73 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
verifyLutsSize(*this, t, luts).succeeded());
}
::mlir::LogicalResult Dot::verify() {
if (::mlir::failed(mlir::verifyCompatibleShape(this->getLhs().getType(),
this->getRhs().getType()))) {
return this->emitOpError("arguments have incompatible shapes");
}
auto lhsEltType = this->getLhs()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.dyn_cast<FHE::FheIntegerInterface>();
auto rhsEltType = this->getRhs()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<mlir::IntegerType>();
auto resultType =
this->getResult().getType().dyn_cast<FHE::FheIntegerInterface>();
if (!mlir::concretelang::FHE::
verifyEncryptedIntegerAndIntegerInputsConsistency(
*this->getOperation(), lhsEltType, rhsEltType)) {
mlir::LogicalResult
verifyDotInputsOutputsConsistency(mlir::concretelang::FHELinalg::DotEint &op,
FHE::FheIntegerInterface &lhsEltType,
FHE::FheIntegerInterface &rhsEltType,
FHE::FheIntegerInterface &resultType) {
if (!mlir::concretelang::FHE::verifyEncryptedIntegerInputsConsistency(
*op.getOperation(), lhsEltType, rhsEltType)) {
return ::mlir::failure();
}
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(
*this->getOperation(), lhsEltType, resultType)) {
*op.getOperation(), lhsEltType, resultType)) {
return ::mlir::failure();
}
return ::mlir::success();
}
mlir::LogicalResult
verifyDotInputsOutputsConsistency(mlir::concretelang::FHELinalg::Dot &op,
FHE::FheIntegerInterface &lhsEltType,
mlir::IntegerType &rhsEltType,
FHE::FheIntegerInterface &resultType) {
if (!mlir::concretelang::FHE::
verifyEncryptedIntegerAndIntegerInputsConsistency(
*op.getOperation(), lhsEltType, rhsEltType)) {
return ::mlir::failure();
}
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(
*op.getOperation(), lhsEltType, resultType)) {
return ::mlir::failure();
}
return ::mlir::success();
}
// Verify a dot product operation:
// - check that the shapes are compatible
// - check that the widths of the inputs and result is the same
template <typename DotOp, typename RHSElementType>
mlir::LogicalResult verifyDot(DotOp &op) {
if (::mlir::failed(mlir::verifyCompatibleShape(op.getLhs().getType(),
op.getRhs().getType()))) {
return op.emitOpError("arguments have incompatible shapes");
}
auto lhsEltType = ((mlir::Type)op.getLhs().getType())
.cast<mlir::TensorType>()
.getElementType()
.dyn_cast<FHE::FheIntegerInterface>();
auto rhsEltType = ((mlir::Type)op.getRhs().getType())
.cast<mlir::TensorType>()
.getElementType()
.cast<RHSElementType>();
auto resultType = ((mlir::Type)op.getResult().getType())
.dyn_cast<FHE::FheIntegerInterface>();
return verifyDotInputsOutputsConsistency(op, lhsEltType, rhsEltType,
resultType);
}
::mlir::LogicalResult Dot::verify() {
return ::mlir::concretelang::FHELinalg::verifyDot<
mlir::concretelang::FHELinalg::Dot, mlir::IntegerType>(*this);
}
::mlir::LogicalResult DotEint::verify() {
return ::mlir::concretelang::FHELinalg::verifyDot<
mlir::concretelang::FHELinalg::DotEint, FHE::FheIntegerInterface>(*this);
}
llvm::SmallVector<int64_t, 3>
verifySumCalculateActualOutputShape(mlir::Type outputType) {
auto actualOutputShape = llvm::SmallVector<int64_t, 3>{};
@@ -784,6 +822,11 @@ mlir::LogicalResult MatMulIntEintOp::verify() {
mlir::concretelang::FHELinalg::MatMulIntEintOp>(*this);
}
mlir::LogicalResult MatMulEintEintOp::verify() {
return ::mlir::concretelang::FHELinalg::verifyMatmul<
mlir::concretelang::FHELinalg::MatMulEintEintOp>(*this);
}
mlir::SmallVector<int64_t, 4>
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 4> paddingInts;

View File

@@ -335,3 +335,24 @@ func.func @main(%x: tensor<3x4xi6>, %y: tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x
%0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<3x4xi6>, tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x!FHE.eint<5>>
return %0 : tensor<3x2x!FHE.eint<5>>
}
// -----
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-NEXT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<5>>, %[[a1:.*]]: tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x!FHE.eint<5>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<5>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m1]], #[[m2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[a0]], %[[a1]] : tensor<3x4x!FHE.eint<5>>, tensor<4x2x!FHE.eint<5>>) outs(%[[v0]] : tensor<3x2x!FHE.eint<5>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<5>, %[[aa1:.*]]: !FHE.eint<5>, %[[aa2:.*]]: !FHE.eint<5>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.mul_eint"(%[[aa1]], %[[aa0]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
// CHECK-NEXT: %[[vv1:.*]] = "FHE.add_eint"(%[[aa2]], %[[vv0]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
// CHECK-NEXT: linalg.yield %[[vv1]] : !FHE.eint<5>
// CHECK-NEXT: } -> tensor<3x2x!FHE.eint<5>>
// CHECK-NEXT: return %[[v1]] : tensor<3x2x!FHE.eint<5>>
// CHECK-NEXT: }
func.func @main(%x: tensor<3x4x!FHE.eint<5>>, %y: tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x!FHE.eint<5>> {
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<3x4x!FHE.eint<5>>, tensor<4x2x!FHE.eint<5>>) -> tensor<3x2x!FHE.eint<5>>
return %0 : tensor<3x2x!FHE.eint<5>>
}

View File

@@ -281,11 +281,12 @@ func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>)
func.func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2>
{
// sqrt((2^2-1)^2*1) = sqrt(9) = 3
// FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}}
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
// sqrt(4*(2^2-1)^2*9) = sqrt(324) = 18
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 42 : ui{{[[0-9]+}}}
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I:.*]]) {MANP = 42 : ui{{[0-9]+}}}
%1 = "FHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
return %1 : !FHE.eint<2>
@@ -299,26 +300,28 @@ func.func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>)
func.func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
// p = 0
// acc = manp(0) = 1
// acc = manp(0) = 0
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 9 + 1 = 10
// ceil(sqrt(65)) = 4
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 8 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
// manp(add_eint(mul, acc)) = 9
// ceil(sqrt(9)) = 3
// FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed
// CHECK: %[[V0:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 7 : ui{{[0-9]+}}}
%0 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %0 : tensor<3x2x!FHE.eint<2>>
}
// -----
func.func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
// p = 0
// acc = manp(0) = 1
// acc = manp(0) = 0
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 9 + 1 = 10
// manp(add_eint(mul, acc)) = 9
// p = 1
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 10 + 9 = 19
// ceil(sqrt(19)) = 5
// manp(add_eint(mul, acc)) = 9 + 9 = 18
// ceil(sqrt(18)) = 5
// FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 10 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
@@ -329,11 +332,11 @@ func.func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tenso
func.func @matmul_eint_int_cst_p_1(%arg0: tensor<3x1x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> {
%0 = arith.constant dense<[[3, 1]]> : tensor<1x2xi3>
// c(m,n) = a(m,p) * b(p,n) the max cst is used for n = 0
// acc = manp(0) = 1
// acc = manp(0) = 0
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
// manp(add_eint(mul, acc)) = 9 + 1 = 10
// ceil(sqrt(10)) = 4
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
// manp(add_eint(mul, acc)) = 9
// ceil(sqrt(10)) = 3
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -348,10 +351,10 @@ func.func @matmul_eint_int_cst_p_2_n_0(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
// manp(add_eint(mul, acc)) = 9 + 1 = 10
// p = 1
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 17
// manp(add_eint(mul, acc)) = 17 + 9 = 26
// ceil(sqrt(26)) = 6
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 6 : ui{{[0-9]+}}}
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16
// manp(add_eint(mul, acc)) = 16 + 9 = 25
// ceil(sqrt(25)) = 5
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -362,13 +365,13 @@ func.func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso
%0 = arith.constant dense<[[1, 4],[3, 1]]> : tensor<2x2xi3>
// c(m,n) = a(m,p) * b(p,n) the max csts [4,1] are used for n = 1
// p = 0
// acc = manp(0) = 1
// acc = manp(0) = 0
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16
// manp(add_eint(mul, acc)) = 16 + 1 = 17
// manp(add_eint(mul, acc)) = 16
// p = 1
// mul = manp(mul_eint_int(eint<2>, 1) = 1 * 1^2 = 1
// manp(add_eint(mul, acc)) = 1 + 17 = 18
// ceil(sqrt(18)) = 5
// manp(add_eint(mul, acc)) = 1 + 16 = 17
// ceil(sqrt(17)) = 5
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
@@ -382,7 +385,7 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
// ===============================
%1 = arith.constant dense<
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
[2, 1, 5]
> : tensor<3xi8>
@@ -392,8 +395,8 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
// ===============================
%3 = arith.constant dense<
// ceil(sqrt(2^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(39)) = 7
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(2^2 + 3^2 + 5^2)) = ceil(sqrt(39)) = 7
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
[
[2, 3],
[3, 2],
@@ -401,47 +404,47 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
]
> : tensor<3x2xi8>
// CHECK: MANP = 8 : ui{{[0-9]+}}
// CHECK: MANP = 7 : ui{{[0-9]+}}
%4 = "FHELinalg.matmul_eint_int"(%0, %3) : (tensor<4x3x!FHE.eint<7>>, tensor<3x2xi8>) -> tensor<4x2x!FHE.eint<7>>
// ===============================
%5 = arith.constant dense<
[
// ceil(sqrt(1^2 + 4^2 + 6^2 + 1)) = ceil(sqrt(54)) = 8
// ceil(sqrt(6^2 + 3^2 + 2^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(1^2 + 4^2 + 6^2)) = ceil(sqrt(53)) = 8
// ceil(sqrt(6^2 + 3^2 + 2^2)) = ceil(sqrt(49)) = 7
[
[1, 6],
[4, 3],
[6, 2]
],
// ceil(sqrt(5^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(60)) = 8
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(5^2 + 3^2 + 5^2)) = ceil(sqrt(59)) = 8
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
[
[5, 3],
[3, 2],
[5, 6]
],
// ceil(sqrt(5^2 + 5^2 + 3^2 + 1)) = ceil(sqrt(60)) = 8
// ceil(sqrt(3^2 + 6^2 + 3^2 + 1)) = ceil(sqrt(55)) = 8
// ceil(sqrt(5^2 + 5^2 + 3^2)) = ceil(sqrt(59)) = 8
// ceil(sqrt(3^2 + 6^2 + 3^2)) = ceil(sqrt(54)) = 8
[
[5, 3],
[5, 6],
[3, 3]
],
// ceil(sqrt(6^2 + 1^2 + 4^2 + 1)) = ceil(sqrt(54)) = 8
// ceil(sqrt(3^2 + 4^2 + 3^2 + 1)) = ceil(sqrt(35)) = 6
// ceil(sqrt(6^2 + 1^2 + 4^2)) = ceil(sqrt(53)) = 8
// ceil(sqrt(3^2 + 4^2 + 3^2)) = ceil(sqrt(34)) = 6
[
[6, 3],
[1, 4],
[4, 3]
],
// ceil(sqrt(1^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(74)) = 9
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(1^2 + 6^2 + 6^2)) = ceil(sqrt(73)) = 9
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
[
[1, 2],
[6, 1],
@@ -458,40 +461,40 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
%7 = arith.constant dense<
[
[
// ceil(sqrt(1^2 + 4^2 + 6^2 + 1)) = ceil(sqrt(54)) = 8
// ceil(sqrt(6^2 + 3^2 + 2^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(1^2 + 4^2 + 6^2)) = ceil(sqrt(53)) = 8
// ceil(sqrt(6^2 + 3^2 + 2^2)) = ceil(sqrt(49)) = 7
[
[1, 6],
[4, 3],
[6, 2]
],
// ceil(sqrt(5^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(60)) = 8
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(5^2 + 3^2 + 5^2)) = ceil(sqrt(59)) = 8
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(47)) = 7
[
[5, 3],
[3, 2],
[5, 6]
],
// ceil(sqrt(5^2 + 5^2 + 3^2 + 1)) = ceil(sqrt(60)) = 8
// ceil(sqrt(3^2 + 6^2 + 3^2 + 1)) = ceil(sqrt(55)) = 8
// ceil(sqrt(5^2 + 5^2 + 3^2)) = ceil(sqrt(59)) = 8
// ceil(sqrt(3^2 + 6^2 + 3^2)) = ceil(sqrt(54)) = 8
[
[5, 3],
[5, 6],
[3, 3]
],
// ceil(sqrt(6^2 + 1^2 + 4^2 + 1)) = ceil(sqrt(54)) = 8
// ceil(sqrt(3^2 + 4^2 + 3^2 + 1)) = ceil(sqrt(35)) = 6
// ceil(sqrt(6^2 + 1^2 + 4^2)) = ceil(sqrt(53)) = 8
// ceil(sqrt(3^2 + 4^2 + 3^2)) = ceil(sqrt(34)) = 6
[
[6, 3],
[1, 4],
[4, 3]
],
// ceil(sqrt(1^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(74)) = 9
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(1^2 + 6^2 + 6^2)) = ceil(sqrt(73)) = 9
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
[
[1, 2],
[6, 1],
@@ -499,40 +502,40 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
]
],
[
// ceil(sqrt(6^2 + 1^2 + 3^2 + 1)) = ceil(sqrt(47)) = 7
// ceil(sqrt(5^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(98)) = 10
// ceil(sqrt(6^2 + 1^2 + 3^2)) = ceil(sqrt(46)) = 7
// ceil(sqrt(5^2 + 6^2 + 6^2)) = ceil(sqrt(97)) = 10
[
[6, 5],
[1, 6],
[3, 6]
],
// ceil(sqrt(1^2 + 2^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(6^2 + 3^2 + 1^2 + 1)) = ceil(sqrt(47)) = 7
// ceil(sqrt(1^2 + 2^2 + 5^2)) = ceil(sqrt(30)) = 6
// ceil(sqrt(6^2 + 3^2 + 1^2)) = ceil(sqrt(46)) = 7
[
[1, 6],
[2, 3],
[5, 1]
],
// ceil(sqrt(4^2 + 3^2 + 6^2 + 1)) = ceil(sqrt(62)) = 8
// ceil(sqrt(1^2 + 5^2 + 2^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(4^2 + 3^2 + 6^2)) = ceil(sqrt(61)) = 8
// ceil(sqrt(1^2 + 5^2 + 2^2)) = ceil(sqrt(30)) = 6
[
[4, 1],
[3, 5],
[6, 2]
],
// ceil(sqrt(2^2 + 3^2 + 3^2 + 1)) = ceil(sqrt(23)) = 5
// ceil(sqrt(2^2 + 2^2 + 1^2 + 1)) = ceil(sqrt(10)) = 4
// ceil(sqrt(2^2 + 3^2 + 3^2)) = ceil(sqrt(22)) = 5
// ceil(sqrt(2^2 + 2^2 + 1^2)) = ceil(sqrt(9)) = 3
[
[2, 2],
[3, 2],
[3, 1]
],
// ceil(sqrt(6^2 + 2^2 + 3^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(2^2 + 4^2 + 2^2 + 1)) = ceil(sqrt(25)) = 5
// ceil(sqrt(6^2 + 2^2 + 3^2)) = ceil(sqrt(49)) = 7
// ceil(sqrt(2^2 + 4^2 + 2^2)) = ceil(sqrt(24)) = 5
[
[6, 2],
[2, 4],
@@ -563,7 +566,7 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<
// ===============================
%1 = arith.constant dense<
// ceil(sqrt(1 * (2^2 + 1^2 + 5^2) + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(1 * (2^2 + 1^2 + 5^2))) = ceil(sqrt(30)) = 6
[2, 1, 5]
> : tensor<3xi8>
@@ -583,11 +586,11 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<
func.func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> {
// p = 0
// acc = manp(0) = 1
// acc = manp(0) = 0
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 64 + 1 = 10
// ceil(sqrt(65)) = 4
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 8 : ui{{[0-9]+}}}
// manp(add_eint(mul, acc)) = 64
// ceil(sqrt(64)) = 8
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 7 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -613,9 +616,9 @@ func.func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!FHE
func.func @matmul_int_eint_cst_p_1(%arg0: tensor<1x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>> {
%0 = arith.constant dense<[[3], [1]]> : tensor<2x1xi3>
// c(m,n) = a(m,p) * b(p,n) the max cst is used for m = 0
// acc = manp(0) = 1
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
// manp(add_eint(mul, acc)) = 9 + 1 = 10
// acc = manp(0) = 0
// mul = manp(mul_eint_int(eint<2>, 3) = 1^2 + 3^2 = 10
// manp(add_eint(mul, acc)) = 10
// ceil(sqrt(10)) = 4
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x1xi3>, tensor<1x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>>
@@ -666,7 +669,7 @@ func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
// ===============================
%1 = arith.constant dense<
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
[2, 1, 5]
> : tensor<3xi8>
@@ -676,8 +679,8 @@ func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
// ===============================
%3 = arith.constant dense<
// ceil(sqrt(2^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(39)) = 7
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(2^2 + 3^2 + 5^2)) = ceil(sqrt(38)) = 7
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
[
[2, 3, 5],
[3, 2, 6]
@@ -691,36 +694,36 @@ func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
%5 = arith.constant dense<
[
// ceil(sqrt(1^2 + 4^2 + 6^2 + 1)) = ceil(sqrt(54)) = 8
// ceil(sqrt(6^2 + 3^2 + 2^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(1^2 + 4^2 + 6^2)) = ceil(sqrt(53)) = 8
// ceil(sqrt(6^2 + 3^2 + 2^2)) = ceil(sqrt(49)) = 7
[
[1, 4, 6],
[6, 3, 2]
],
// ceil(sqrt(5^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(60)) = 8
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(5^2 + 3^2 + 5^2)) = ceil(sqrt(59)) = 8
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
[
[5, 3, 5],
[3, 2, 6]
],
// ceil(sqrt(5^2 + 5^2 + 3^2 + 1)) = ceil(sqrt(60)) = 8
// ceil(sqrt(3^2 + 6^2 + 3^2 + 1)) = ceil(sqrt(55)) = 8
// ceil(sqrt(5^2 + 5^2 + 3^2)) = ceil(sqrt(59)) = 8
// ceil(sqrt(3^2 + 6^2 + 3^2)) = ceil(sqrt(54)) = 8
[
[5, 5, 3],
[3, 6, 3]
],
// ceil(sqrt(6^2 + 1^2 + 4^2 + 1)) = ceil(sqrt(54)) = 8
// ceil(sqrt(3^2 + 4^2 + 3^2 + 1)) = ceil(sqrt(35)) = 6
// ceil(sqrt(6^2 + 1^2 + 4^2)) = ceil(sqrt(53)) = 8
// ceil(sqrt(3^2 + 4^2 + 3^2)) = ceil(sqrt(34)) = 6
[
[6, 1, 4],
[3, 4, 3]
],
// ceil(sqrt(1^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(74)) = 9
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(1^2 + 6^2 + 6^2)) = ceil(sqrt(73)) = 9
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
[
[1, 6, 6],
[2, 1, 5]
@@ -736,72 +739,72 @@ func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
%7 = arith.constant dense<
[
[
// ceil(sqrt(1^2 + 4^2 + 6^2 + 1)) = ceil(sqrt(54)) = 8
// ceil(sqrt(6^2 + 3^2 + 2^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(1^2 + 4^2 + 6^2)) = ceil(sqrt(53)) = 8
// ceil(sqrt(6^2 + 3^2 + 2^2)) = ceil(sqrt(49)) = 7
[
[1, 4, 6],
[6, 3, 2]
],
// ceil(sqrt(5^2 + 3^2 + 5^2 + 1)) = ceil(sqrt(60)) = 8
// ceil(sqrt(3^2 + 2^2 + 6^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(5^2 + 3^2 + 5^2)) = ceil(sqrt(59)) = 8
// ceil(sqrt(3^2 + 2^2 + 6^2)) = ceil(sqrt(49)) = 7
[
[5, 3, 5],
[3, 2, 6]
],
// ceil(sqrt(5^2 + 5^2 + 3^2 + 1)) = ceil(sqrt(60)) = 8
// ceil(sqrt(3^2 + 6^2 + 3^2 + 1)) = ceil(sqrt(55)) = 8
// ceil(sqrt(5^2 + 5^2 + 3^2)) = ceil(sqrt(59)) = 8
// ceil(sqrt(3^2 + 6^2 + 3^2)) = ceil(sqrt(54)) = 8
[
[5, 5, 3],
[3, 6, 3]
],
// ceil(sqrt(6^2 + 1^2 + 4^2 + 1)) = ceil(sqrt(54)) = 8
// ceil(sqrt(3^2 + 4^2 + 3^2 + 1)) = ceil(sqrt(35)) = 6
// ceil(sqrt(6^2 + 1^2 + 4^2)) = ceil(sqrt(53)) = 8
// ceil(sqrt(3^2 + 4^2 + 3^2)) = ceil(sqrt(34)) = 6
[
[6, 1, 4],
[3, 4, 3]
],
// ceil(sqrt(1^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(74)) = 9
// ceil(sqrt(2^2 + 1^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(1^2 + 6^2 + 6^2)) = ceil(sqrt(73)) = 9
// ceil(sqrt(2^2 + 1^2 + 5^2)) = ceil(sqrt(30)) = 6
[
[1, 6, 6],
[2, 1, 5]
]
],
[
// ceil(sqrt(6^2 + 1^2 + 3^2 + 1)) = ceil(sqrt(47)) = 7
// ceil(sqrt(5^2 + 6^2 + 6^2 + 1)) = ceil(sqrt(98)) = 10
// ceil(sqrt(6^2 + 1^2 + 3^2)) = ceil(sqrt(46)) = 7
// ceil(sqrt(5^2 + 6^2 + 6^2)) = ceil(sqrt(97)) = 10
[
[6, 1, 3],
[5, 6, 6]
],
// ceil(sqrt(1^2 + 2^2 + 5^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(6^2 + 3^2 + 1^2 + 1)) = ceil(sqrt(47)) = 7
// ceil(sqrt(1^2 + 2^2 + 5^2)) = ceil(sqrt(30)) = 6
// ceil(sqrt(6^2 + 3^2 + 1^2)) = ceil(sqrt(46)) = 7
[
[1, 2, 5],
[6, 3, 1]
],
// ceil(sqrt(4^2 + 3^2 + 6^2 + 1)) = ceil(sqrt(62)) = 8
// ceil(sqrt(1^2 + 5^2 + 2^2 + 1)) = ceil(sqrt(31)) = 6
// ceil(sqrt(4^2 + 3^2 + 6^2)) = ceil(sqrt(61)) = 8
// ceil(sqrt(1^2 + 5^2 + 2^2)) = ceil(sqrt(30)) = 6
[
[4, 3, 6],
[1, 5, 2]
],
// ceil(sqrt(2^2 + 3^2 + 3^2 + 1)) = ceil(sqrt(23)) = 5
// ceil(sqrt(2^2 + 2^2 + 1^2 + 1)) = ceil(sqrt(10)) = 4
// ceil(sqrt(2^2 + 3^2 + 3^2)) = ceil(sqrt(22)) = 5
// ceil(sqrt(2^2 + 2^2 + 1^2)) = ceil(sqrt(9)) = 3
[
[2, 3, 3],
[2, 2, 1]
],
// ceil(sqrt(6^2 + 2^2 + 3^2 + 1)) = ceil(sqrt(50)) = 8
// ceil(sqrt(2^2 + 4^2 + 2^2 + 1)) = ceil(sqrt(25)) = 5
// ceil(sqrt(6^2 + 2^2 + 3^2)) = ceil(sqrt(49)) = 7
// ceil(sqrt(2^2 + 4^2 + 2^2)) = ceil(sqrt(24)) = 5
[
[6, 2, 3],
[2, 4, 2]

View File

@@ -8,6 +8,16 @@ func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<2x3xi3>) -> tensor<4x2x
return %0 : tensor<4x2x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<2x3x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size on dimension #1 of operand #0 and dimension #0 of operand #1}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<2x3x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>>
return %0 : tensor<4x2x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x3xi3>, %y: tensor<4x3x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
@@ -24,6 +34,16 @@ func.func @main(%x: tensor<2x4x3x5x!FHE.eint<2>>, %y: tensor<4x3x2xi3>) -> tenso
return %0 : tensor<2x4x3x2x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x5x!FHE.eint<2>>, %y: tensor<4x3x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size on dimension #3 of operand #0 and dimension #1 of operand #1}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x5x!FHE.eint<2>>, tensor<4x3x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>>
return %0 : tensor<2x4x3x2x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x5xi3>, %y: tensor<4x3x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>> {
@@ -40,6 +60,16 @@ func.func @main(%x: tensor<2x4x3x5x!FHE.eint<2>>, %y: tensor<10x5x2xi3>) -> tens
return %0 : tensor<2x4x3x2x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x5x!FHE.eint<2>>, %y: tensor<10x5x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size or size of 1 on dimension #1 of operand #0 and dimension #0 of operand #1}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x5x!FHE.eint<2>>, tensor<10x5x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>>
return %0 : tensor<2x4x3x2x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x5xi3>, %y: tensor<10x5x2x!FHE.eint<2>>) -> tensor<2x4x3x2x!FHE.eint<2>> {
@@ -56,6 +86,16 @@ func.func @main(%x: tensor<2x!FHE.eint<2>>, %y: tensor<5x3x4x2xi3>) -> tensor<5x
return %0 : tensor<5x3x2x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x!FHE.eint<2>>, %y: tensor<5x3x4x2x!FHE.eint<2>>) -> tensor<5x3x2x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size on dimension #0 of operand #0 and dimension #2 of operand #1}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x!FHE.eint<2>>, tensor<5x3x4x2x!FHE.eint<2>>) -> tensor<5x3x2x!FHE.eint<2>>
return %0 : tensor<5x3x2x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2xi3>, %y: tensor<5x3x4x2x!FHE.eint<2>>) -> tensor<5x3x2x!FHE.eint<2>> {
@@ -72,6 +112,16 @@ func.func @main(%x: tensor<5x3x4x2x!FHE.eint<2>>, %y: tensor<4xi3>) -> tensor<5x
return %0 : tensor<5x3x4x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x3x4x2x!FHE.eint<2>>, %y: tensor<4x!FHE.eint<2>>) -> tensor<5x3x4x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have the same size on dimension #3 of operand #0 and dimension #0 of operand #1}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x3x4x2x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<5x3x4x!FHE.eint<2>>
return %0 : tensor<5x3x4x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x3x4x2xi3>, %y: tensor<4x!FHE.eint<2>>) -> tensor<5x3x4x!FHE.eint<2>> {
@@ -88,6 +138,16 @@ func.func @main(%x: tensor<4x!FHE.eint<2>>, %y: tensor<4xi3>) -> tensor<1x!FHE.e
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x!FHE.eint<2>>, %y: tensor<4x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op should have at least one multi dimensional tensor as an operand}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4xi3>, %y: tensor<4x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -104,6 +164,16 @@ func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<3x2xi3>) -> tensor<1x!F
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x3xi3>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -120,6 +190,16 @@ func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<3x2xi3>) -> tensor<1x!FHE
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<3x!FHE.eint<2>>, tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<3xi3>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -136,6 +216,16 @@ func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<4x3x2xi3>) -> tensor<1x!F
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<4x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<3x!FHE.eint<2>>, tensor<4x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<3xi3>, %y: tensor<4x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -152,6 +242,16 @@ func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<4x5x3x2xi3>) -> tensor<1x
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<3x!FHE.eint<2>>, %y: tensor<4x5x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <4x5x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<3x!FHE.eint<2>>, tensor<4x5x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<3xi3>, %y: tensor<4x5x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -168,6 +268,16 @@ func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<3xi3>) -> tensor<1x!FHE
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <4>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x3xi3>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -184,6 +294,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<3xi3>) -> tensor<1x!F
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -200,6 +320,16 @@ func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<3xi3>) -> tensor<1x
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x2x4x3x!FHE.eint<2>>, tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x2x4x3xi3>, %y: tensor<3x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -216,6 +346,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tensor<
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -232,6 +372,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<1x3x2xi3>) -> tensor<
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<1x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<1x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<1x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -248,6 +398,16 @@ func.func @main(%x: tensor<1x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tensor<
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<1x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<1x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<1x4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -264,6 +424,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tensor<
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -280,6 +450,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<3x2xi3>) -> tensor<1x
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -296,6 +476,16 @@ func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2xi3>) -> ten
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x2x4x3x!FHE.eint<2>>, tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x2x4x3xi3>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -312,6 +502,16 @@ func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tenso
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x2x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x2x4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -328,6 +528,16 @@ func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<3x2xi3>) -> tensor<
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x2x4x3x!FHE.eint<2>>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x2x4x3x!FHE.eint<2>>, tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x2x4x3xi3>, %y: tensor<3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -344,6 +554,16 @@ func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2xi3>) -> tenso
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<2x4x3x!FHE.eint<2>>, tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<2x4x3xi3>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -360,6 +580,16 @@ func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2xi3>) -> tensor<
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x3x!FHE.eint<2>>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<4x3xi3>, %y: tensor<5x2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
@@ -376,10 +606,15 @@ func.func @main(%x: tensor<5x1x4x3x!FHE.eint<2>>, %y: tensor<2x3x2xi3>) -> tenso
return %0 : tensor<1x!FHE.eint<2>>
}
// -----
func.func @main(%x: tensor<5x1x4x3xi3>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_int_eint' op does not have the proper output shape of <5x2x4x2>}}
%0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<5x1x4x3xi3>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
func.func @main(%x: tensor<5x1x4x3x!FHE.eint<2>>, %y: tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.matmul_eint_eint' op does not have the proper output shape of <5x2x4x2>}}
%0 = "FHELinalg.matmul_eint_eint"(%x, %y): (tensor<5x1x4x3x!FHE.eint<2>>, tensor<2x3x2x!FHE.eint<2>>) -> tensor<1x!FHE.eint<2>>
return %0 : tensor<1x!FHE.eint<2>>
}
// -----

View File

@@ -484,6 +484,23 @@ func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>,
return %ret : !FHE.eint<2>
}
/////////////////////////////////////////////////
// FHELinalg.dot_eint_eint
/////////////////////////////////////////////////
// CHECK-LABEL: func.func @dot_eint_eint(%arg0: tensor<2x!FHE.eint<2>>, %arg1: tensor<2x!FHE.eint<2>>) -> !FHE.eint<2>
func.func @dot_eint_eint(%arg0: tensor<2x!FHE.eint<2>>,
%arg1: tensor<2x!FHE.eint<2>>) -> !FHE.eint<2>
{
// CHECK-NEXT: %[[RET:.*]] = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<2>>, tensor<2x!FHE.eint<2>>) -> !FHE.eint<2>
%ret = "FHELinalg.dot_eint_eint"(%arg0, %arg1) :
(tensor<2x!FHE.eint<2>>, tensor<2x!FHE.eint<2>>) -> !FHE.eint<2>
//CHECK-NEXT: return %[[RET]] : !FHE.eint<2>
return %ret : !FHE.eint<2>
}
/////////////////////////////////////////////////
// FHELinalg.conv2d
/////////////////////////////////////////////////

View File

@@ -0,0 +1,119 @@
import argparse
import numpy as np
PRECISIONS_TO_BENCH = [(6, 2), (16, 7)]
SHAPES = [((2, 3, 4), (2, 4, 2)), ((3, 4), (4, 2)), ((3,), (3,)), ((3,), (3, 2)), ((3,), (4, 3, 2)), ((3,4), (4,)), ((2,3,4), (4,)), ((2, 1, 3, 4), (5, 4, 2))]
P_ERROR = 1.0 / 1e6
def format_shape(shape):
shape_str = "x".join(map(str, shape))
if len(shape):
shape_str += "x"
else:
shape_str = "1x"
return shape_str
def flatten_and_to_str(data, is_tensor=True):
payload = ", ".join(map(str, data.reshape((-1,))))
if is_tensor:
return "[" + payload + "]"
return payload
def generate(op):
assert(op in {"matmul_eint_eint", "dot_eint_eint"})
print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
print("# /!\ THIS FILE HAS BEEN GENERATED")
for p, p_inputs in PRECISIONS_TO_BENCH:
for shapes in SHAPES:
for signed in [False, True]:
min_value = 0
max_value = (2 ** p_inputs) - 1
inp_0 = np.random.randint(min_value, max_value+1, size=shapes[0])
inp_1 = np.random.randint(min_value, max_value+1, size=shapes[1])
expected_result = inp_0 @ inp_1
assert(np.all(expected_result < 2**p))
assert(np.all(expected_result >= 0))
out_shape = expected_result.shape
if len(out_shape) < 2 and op == "matmul_eint_eint":
# Matmul only works for matmuls on operands
# that produce at least 2-dimensional outputs
continue
elif len(out_shape) >= 1 and op == "dot_eint_eint":
# Dot will only be tested when the output is
# a scalar
continue
shape_0_str = format_shape(shapes[0])
shape_1_str = format_shape(shapes[1])
out_shape_str = format_shape(out_shape)
dtype = "esint" if signed else "eint"
op_outputs_scalar = op == "dot_eint_eint" and len(out_shape) == 0
out_dtype_str = f"tensor<{out_shape_str}!FHE.{dtype}<{p}>>" if not op_outputs_scalar else f"!FHE.{dtype}<{p}>"
program = (f"description: {op}_{p}bits_{'s' if signed else 'u'}_{shape_0_str}_{shape_1_str}\n"
f"program: |\n"
f" func.func @main(%x: tensor<{shape_0_str}!FHE.{dtype}<{p}>>, "
f"%y: tensor<{shape_1_str}!FHE.{dtype}<{p}>>) -> {out_dtype_str} {{\n"
f" %0 = \"FHELinalg.{op}\"(%x, %y): (tensor<{shape_0_str}!FHE.{dtype}<{p}>>, "
f"tensor<{shape_1_str}!FHE.{dtype}<{p}>>) -> {out_dtype_str}\n"
f" return %0 : {out_dtype_str}\n"
f" }}\n"
)
inp_0_str = flatten_and_to_str(inp_0)
inp_1_str = flatten_and_to_str(inp_1)
expected_str = flatten_and_to_str(expected_result, is_tensor=not op_outputs_scalar)
shape_0_str_yaml = ",".join(map(str, shapes[0]))
shape_1_str_yaml = ",".join(map(str, shapes[1]))
expected_shape_yaml = ",".join(map(str, out_shape))
program += (f"p-error: {P_ERROR}\n"
"tests:\n"
" - inputs: \n"
f" - tensor: {inp_0_str}\n"
f" shape: [{shape_0_str_yaml}]\n"
f" - tensor: {inp_1_str}\n"
f" shape: [{shape_1_str_yaml}]\n"
f" outputs:\n"
)
if op_outputs_scalar:
program += (
f" - scalar: {expected_str}\n"
)
else:
program += (
f" - tensor: {expected_str}\n"
f" shape: [{expected_shape_yaml}]\n"
)
if signed:
program += (
f" signed: True\n"
)
program += f"---"
print(program)
if __name__ == "__main__":
CLI = argparse.ArgumentParser()
CLI.add_argument(
"--minimal",
help="Specify whether to generate minimal tests only",
type=bool,
default=False,
)
args = CLI.parse_args()
generate("matmul_eint_eint")
generate("dot_eint_eint")

View File

@@ -1,10 +1,21 @@
# Python Frontend
## Installation for end-users
End-users should install `concrete-python` using `pip`:
```shell
pip install concrete-python
```
## Setup for development
Developers that want to contribute to the Concrete-Python project can use the following
approach to setup their environment.
```shell
# clone the repository
git clone https://github.com/zama-ai/concrete.git
git clone https://github.com/zama-ai/concrete.git --recursive
cd concrete
# create virtual environment
@@ -19,6 +30,8 @@ cd ../../compilers/concrete-compiler/compiler
make python-bindings
# set bindings build directory as an environment variable
# *** NOTE ***: You must use the Release build of the compiler!
# For now, the Debug build is not compatible with concrete-python
export COMPILER_BUILD_DIRECTORY=$(pwd)/build
echo "export COMPILER_BUILD_DIRECTORY=$(pwd)/build" >> ~/.bashrc
@@ -26,3 +39,17 @@ echo "export COMPILER_BUILD_DIRECTORY=$(pwd)/build" >> ~/.bashrc
cd ../../../frontends/concrete-python
make pytest
```
### VSCode setup
Alternatively you can use VSCode to develop Concrete-Python:
Suppose the compiler bindings were built in `/home/zama/concrete/compilers/concrete-compiler/compiler/build`:
- Create a `.env` file in the concrete-python root directory
- Determine the absolute path of the local compiler repository, e.g. `/home/zama/concrete`. Replace this with your
path in the following two lines
- Add to it `PYTHONPATH=$(PYTHON_PATH):/home/zama/concrete/compilers/concrete-compiler/compiler/build/tools/concretelang/python_packages/concretelang_core/`
- Add to it `LD_PRELOAD=/home/zama/concrete/compilers/concrete-compiler/compiler/build/lib/libConcretelangRuntime.so`
You can now configure `pytest` in VScode and run the tests using the graphical interface.

View File

@@ -764,28 +764,40 @@ class Context:
}
self.error(highlights)
if x.is_encrypted and y.is_encrypted:
highlights = {
x.origin: "lhs is encrypted",
y.origin: (
"rhs is encrypted" if x.origin is not y.origin else "operand is encrypted"
),
self.converting: "but encrypted-encrypted dot products are not supported",
}
self.error(highlights)
assert self.is_bit_width_compatible(resulting_type, x, y)
if x.is_scalar or y.is_scalar:
return self.mul(resulting_type, x, y)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
operation = fhelinalg.DotEint if x.is_encrypted and y.is_encrypted else fhelinalg.Dot
if x.is_clear:
x, y = y, x
return self.operation(fhelinalg.Dot, resulting_type, x.result, y.result)
if (x.is_signed or y.is_signed) and resulting_type.is_unsigned:
x = self.to_signed(x)
y = self.to_signed(y)
signed_resulting_type = self.typeof(
Value(
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
shape=resulting_type.shape,
is_encrypted=resulting_type.is_encrypted,
)
)
intermediate_result = self.operation(
operation,
signed_resulting_type,
x.result,
y.result,
)
return self.to_unsigned(intermediate_result)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
return self.operation(operation, resulting_type, x.result, y.result)
def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
assert self.is_bit_width_compatible(resulting_type, x)
@@ -1293,29 +1305,41 @@ class Context:
}
self.error(highlights)
if x.is_encrypted and y.is_encrypted:
highlights = {
x.origin: "lhs is encrypted",
y.origin: (
"rhs is encrypted" if x.origin is not y.origin else "operand is encrypted"
),
self.converting: "but encrypted-encrypted matrix multiplications are not supported",
}
self.error(highlights)
assert self.is_bit_width_compatible(resulting_type, x, y)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
if resulting_type.shape == ():
if x.is_clear:
x, y = y, x
operation = fhelinalg.Dot
operation = fhelinalg.DotEint if x.is_encrypted and y.is_encrypted else fhelinalg.Dot
elif x.is_encrypted and y.is_encrypted:
operation = fhelinalg.MatMulEintEintOp
else:
operation = fhelinalg.MatMulEintIntOp if x.is_encrypted else fhelinalg.MatMulIntEintOp
if (x.is_signed or y.is_signed) and resulting_type.is_unsigned:
x = self.to_signed(x)
y = self.to_signed(y)
signed_resulting_type = self.typeof(
Value(
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
shape=resulting_type.shape,
is_encrypted=resulting_type.is_encrypted,
)
)
intermediate_result = self.operation(
operation,
signed_resulting_type,
x.result,
y.result,
)
return self.to_unsigned(intermediate_result)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
return self.operation(operation, resulting_type, x.result, y.result)
def maxpool2d(

View File

@@ -65,8 +65,8 @@ def assign_precisions_1_node(node: Node, output_p: int, inputs_p: int):
CHUNKED_COMPARISON = {"greater", "greater_equal", "less", "less_equal"}
CHUNKED_COMPARISON_MIN_BITWIDTH = 4
MAX_POOLS = {"maxpool1d", "maxpool2d", "maxpool3d"}
MULTIPLY = {"multiply"}
ROUNDING = {"round_bit_pattern"}
MULTIPLY = {"multiply", "matmul"}
def max_encrypted_bitwidth_node(node: Node):
@@ -88,13 +88,38 @@ def max_encrypted_bitwidth_node(node: Node):
return normal_p + 1
if name in MULTIPLY and all(value.is_encrypted for value in node.inputs):
return normal_p + 1
# For operations that use multiply, an additional bit
# needs to be added to the bitwidths of the inputs.
# For single precision circuits the max of the input / output
# precisions will be taken in required_encrypted_bitwidth. For
# multi-precision, the circuit partitions will handle the
# input and output precisions separately.
all_inp_bitwidths = []
# Need a loop here to allow typechecking and make mypy happy
for inp in node.inputs:
dtype_inp = inp.dtype
assert isinstance(dtype_inp, Integer)
all_inp_bitwidths.append(dtype_inp.bit_width)
normal_p = max(all_inp_bitwidths)
# FIXME: This probably does not work well with multi-precision!
return max(normal_p + 1, node.output.dtype.bit_width)
return normal_p
def required_encrypted_bitwidth(nodes: Iterable[Node]) -> int:
"""Give the minimal precision to implement all the nodes."""
"""Give the minimal precision to implement all the nodes.
This function is called for both single-precision (for the whole circuit)
and for multi-precision circuits (for circuit partitions).
Ops for which the compiler introduces TLUs need to be handled explicitly
in `max_encrypted_bitwidth_node`. The maximum
of all precisions of the various operations is returned.
"""
bitwidths = map(max_encrypted_bitwidth_node, nodes)
return max(bitwidths, default=-1)

View File

@@ -23,7 +23,7 @@ def test_dot(size, helpers):
cst = np.random.randint(0, bound, size=(size,))
@fhe.compiler({"x": "encrypted"})
def left_function(x):
def dot_enc_enc_function(x):
return np.dot(x, cst)
@fhe.compiler({"x": "encrypted"})
@@ -36,12 +36,55 @@ def test_dot(size, helpers):
inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)]
left_function_circuit = left_function.compile(inputset, configuration)
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
right_function_circuit = right_function.compile(inputset, configuration)
method_circuit = method.compile(inputset, configuration)
sample = np.random.randint(0, bound, size=(size,))
helpers.check_execution(left_function_circuit, left_function, sample)
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample)
helpers.check_execution(right_function_circuit, right_function, sample)
helpers.check_execution(method_circuit, method, sample)
@pytest.mark.parametrize(
"size",
[1, 10],
)
@pytest.mark.parametrize(
"bitwidth",
[2, 6],
)
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("negative_only", [True, False])
def test_dot_enc_enc(size, bitwidth, negative_only, signed, helpers):
"""
Test dot.
"""
configuration = helpers.configuration()
minv = 0 if not signed else -(2 ** (bitwidth - 1))
# +1 since randint max is not inclusive
maxv = 2**bitwidth if not signed else 2 ** (bitwidth - 1)
if negative_only:
maxv = 1
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def dot_enc_enc_function(x, y):
return np.dot(x, y)
inputset = [
(np.random.randint(minv, maxv, size=(size,)), np.random.randint(minv, maxv, size=(size,)))
for i in range(100)
]
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
sample = [
np.random.randint(minv, maxv, size=(size,)),
np.random.randint(minv, maxv, size=(size,)),
]
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample)

View File

@@ -16,6 +16,11 @@ from concrete import fhe
(2, 3),
(0, 3),
),
pytest.param(
(3, 2),
(2, 3),
(0, 127),
),
pytest.param(
(1, 2),
(2, 1),
@@ -46,6 +51,11 @@ from concrete import fhe
(5, 5),
(0, 3),
),
pytest.param(
(5,),
(5, 5),
(-127, 127),
),
pytest.param(
(5,),
(5, 3),
@@ -59,7 +69,7 @@ from concrete import fhe
pytest.param(
(5,),
(4, 5, 3),
(0, 5),
(-5, 5),
),
pytest.param(
(4, 5, 3),
@@ -74,7 +84,7 @@ from concrete import fhe
pytest.param(
(2, 4, 5, 3),
(3,),
(0, 5),
(-1, 5),
),
pytest.param(
(5, 4, 3),
@@ -156,3 +166,200 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
helpers.check_execution(rhs_operator_circuit, rhs_operator, rhs_sample)
helpers.check_execution(lhs_function_circuit, lhs_function, lhs_sample)
helpers.check_execution(rhs_function_circuit, rhs_function, rhs_sample)
@pytest.mark.parametrize(
"lhs_shape,rhs_shape,bounds",
[
pytest.param(
(3, 2),
(2, 3),
(0, 3),
),
pytest.param(
(3, 2),
(2, 3),
(0, 127),
),
pytest.param(
(1, 2),
(2, 1),
(0, 3),
),
pytest.param(
(3, 3),
(3, 3),
(0, 3),
),
pytest.param(
(2, 1),
(1, 2),
(0, 7),
),
pytest.param(
(2,),
(2,),
(0, 7),
),
pytest.param(
(5, 5),
(5,),
(0, 3),
),
pytest.param(
(5,),
(5, 5),
(0, 3),
),
pytest.param(
(5,),
(5, 5),
(-63, 63),
),
pytest.param(
(2,),
(2, 7),
(-63, 0),
),
pytest.param(
(5,),
(5, 3),
(0, 3),
),
pytest.param(
(5, 3),
(3,),
(0, 3),
),
pytest.param(
(5,),
(4, 5, 3),
(-5, 5),
),
pytest.param(
(4, 5, 3),
(3,),
(0, 5),
),
pytest.param(
(5,),
(2, 4, 5, 3),
(0, 5),
),
pytest.param(
(2, 4, 5, 3),
(3,),
(-1, 5),
),
pytest.param(
(5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(2, 5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(1, 3, 2),
(0, 5),
),
pytest.param(
(1, 4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(2, 1, 3, 2),
(0, 5),
),
pytest.param(
(2, 1, 4, 3),
(5, 3, 2),
(0, 5),
),
],
)
def test_matmul_enc_enc_and_clear(lhs_shape, rhs_shape, bounds, helpers):
"""
Test matmul.
"""
configuration = helpers.configuration()
minimum, maximum = bounds
# Matmul of clear values and encrypted matrices
@fhe.compiler({"x": "encrypted", "y": "clear"})
def lhs_operator_clear(x, y):
return x @ y
# Matmul of two encrypted matrices
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def enc_function_xy(x, y):
return np.matmul(x, y)
# Put all the dual operand functions in a list
# FIXME: add lhs_operator_clear to this list to
# re-enable the testing with clear values
dual_operand_functions = [enc_function_xy]
# Compile each dual operand function and test it on random data
for func in dual_operand_functions:
dual_op_inputset = [
(
np.random.randint(minimum, maximum, size=lhs_shape),
np.random.randint(minimum, maximum, size=rhs_shape),
)
for i in range(100)
]
dual_op_circuit = func.compile(dual_op_inputset, configuration)
lhs_sample, rhs_sample = np.random.randint(
minimum, maximum, size=lhs_shape
), np.random.randint(minimum, maximum, size=rhs_shape)
helpers.check_execution(dual_op_circuit, func, [lhs_sample, rhs_sample])
@pytest.mark.parametrize("bitwidth", [4, 10])
@pytest.mark.parametrize("signed", [True, False])
def test_matmul_zero(bitwidth, signed, helpers):
"""
Test matmul.
"""
lhs_shape = (2, 1)
rhs_shape = (1, 2)
range_lhs = (-(2 ** (bitwidth - 1)), 2 ** (bitwidth - 1) - 1) if signed else (0, 2**bitwidth)
configuration = helpers.configuration()
# Matmul of two encrypted matrices
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def enc_function_xy(x, y):
return x * y
dual_op_inputset = [
(
np.random.randint(range_lhs[0], range_lhs[1], size=lhs_shape),
np.zeros(rhs_shape, dtype=np.int64),
)
for i in range(100)
]
dual_op_circuit = enc_function_xy.compile(dual_op_inputset, configuration)
lhs_sample, rhs_sample = np.random.randint(
range_lhs[0], range_lhs[1], size=lhs_shape
), np.zeros(rhs_shape, dtype=np.int64)
helpers.check_execution(dual_op_circuit, enc_function_xy, [lhs_sample, rhs_sample])

View File

@@ -152,6 +152,26 @@ def test_constant_mul(function, parameters, helpers):
"x": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [-10, 10], "status": "encrypted", "shape": (3, 2)},
},
{
"x": {"range": [-10, 10], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [0, 0], "status": "encrypted", "shape": (3, 1)},
},
{
"x": {"range": [10, 20], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [0, 0], "status": "encrypted", "shape": (1, 3)},
},
{
"x": {"range": [2**12, 2**13 - 1], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [0, 0], "status": "encrypted", "shape": (1, 3)},
},
{
"x": {"range": [2**12, 2**13 - 1], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [0, 2 * 3 - 1], "status": "encrypted", "shape": (1, 3)},
},
{
"x": {"range": [-(2**7), 2**7 - 1], "status": "encrypted", "shape": (3, 1)},
"y": {"range": [-(2**7), 2**7 - 1], "status": "encrypted", "shape": (1, 3)},
},
],
)
def test_mul(function, parameters, helpers):

View File

@@ -359,30 +359,6 @@ return %3
""", # noqa: E501
),
pytest.param(
lambda x, y: np.dot(x, y),
{"x": "encrypted", "y": "encrypted"},
[
(
np.ones(shape=(3,), dtype=np.int64),
np.ones(shape=(3,), dtype=np.int64),
)
],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
%1 = y # EncryptedTensor<uint1, shape=(3,)> ∈ [1, 1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
%2 = dot(%0, %1) # EncryptedScalar<uint2> ∈ [3, 3]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted dot products are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{"x": "clear", "y": "clear"},
@@ -398,25 +374,6 @@ Function you are trying to compile cannot be compiled
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is clear
%2 = matmul(%0, %1) # ClearTensor<uint5, shape=(2, 2)> ∈ [5, 20]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear-clear matrix multiplications are not supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x @ y,
{"x": "encrypted", "y": "encrypted"},
[([[1, 2], [3, 4]], [[4, 3], [2, 1]])],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ lhs is encrypted
%1 = y # EncryptedTensor<uint3, shape=(2, 2)> ∈ [1, 4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ rhs is encrypted
%2 = matmul(%0, %1) # EncryptedTensor<uint5, shape=(2, 2)> ∈ [5, 20]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but encrypted-encrypted matrix multiplications are not supported
return %2
""", # noqa: E501