From 0d4e10169bd5d9676f55c4463a15c3399dc7bf19 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Wed, 20 Oct 2021 13:56:47 +0200 Subject: [PATCH] feat(compiler): Introduce the HLFHELinalg dialect and a first operator HLFHELinalg.add_eint_int --- .../include/zamalang/Dialect/CMakeLists.txt | 1 + .../Dialect/HLFHELinalg/CMakeLists.txt | 1 + .../Dialect/HLFHELinalg/IR/CMakeLists.txt | 9 ++ .../HLFHELinalg/IR/HLFHELinalgDialect.h | 10 ++ .../HLFHELinalg/IR/HLFHELinalgDialect.td | 15 ++ .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.h | 56 ++++++++ .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.td | 70 +++++++++ .../Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h | 11 ++ .../HLFHELinalg/IR/HLFHELinalgTypes.td | 11 ++ compiler/lib/Dialect/CMakeLists.txt | 1 + .../lib/Dialect/HLFHELinalg/CMakeLists.txt | 1 + .../lib/Dialect/HLFHELinalg/IR/CMakeLists.txt | 14 ++ .../HLFHELinalg/IR/HLFHELinalgDialect.cpp | 22 +++ .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp | 136 ++++++++++++++++++ compiler/lib/Support/CMakeLists.txt | 1 + compiler/src/main.cpp | 2 + .../Dialect/HLFHELinalg/ops.invalid.mlir | 39 +++++ compiler/tests/Dialect/HLFHELinalg/ops.mlir | 55 +++++++ 18 files changed, 455 insertions(+) create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/CMakeLists.txt create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/IR/CMakeLists.txt create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td create mode 100644 compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt create mode 100644 compiler/lib/Dialect/HLFHELinalg/IR/CMakeLists.txt create mode 100644 compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.cpp create mode 100644 compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp create mode 100644 compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir create mode 100644 compiler/tests/Dialect/HLFHELinalg/ops.mlir diff --git a/compiler/include/zamalang/Dialect/CMakeLists.txt b/compiler/include/zamalang/Dialect/CMakeLists.txt index a69c38e3d..249b735ae 100644 --- a/compiler/include/zamalang/Dialect/CMakeLists.txt +++ b/compiler/include/zamalang/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(HLFHE) +add_subdirectory(HLFHELinalg) add_subdirectory(MidLFHE) add_subdirectory(LowLFHE) diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/CMakeLists.txt b/compiler/include/zamalang/Dialect/HLFHELinalg/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/CMakeLists.txt b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/CMakeLists.txt new file mode 100644 index 000000000..7ee5bbed6 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS HLFHELinalgOps.td) +mlir_tablegen(HLFHELinalgOps.h.inc -gen-op-decls) +mlir_tablegen(HLFHELinalgOps.cpp.inc -gen-op-defs) +mlir_tablegen(HLFHELinalgOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=HLFHELinalg) +mlir_tablegen(HLFHELinalgOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=HLFHELinalg) +mlir_tablegen(HLFHELinalgOpsDialect.h.inc -gen-dialect-decls -dialect=HLFHELinalg) +mlir_tablegen(HLFHELinalgOpsDialect.cpp.inc -gen-dialect-defs -dialect=HLFHELinalg) +add_public_tablegen_target(MLIRHLFHELinalgOpsIncGen) +add_dependencies(mlir-headers MLIRHLFHELinalgOpsIncGen) \ No newline at end of file diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h new file mode 100644 index 000000000..7f913b3a9 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h @@ -0,0 +1,10 @@ +#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H +#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" + +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsDialect.h.inc" + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td new file mode 100644 index 000000000..e88d461bf --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td @@ -0,0 +1,15 @@ +#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT +#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT + +include "mlir/IR/OpBase.td" + +def HLFHELinalg_Dialect : Dialect { + let name = "HLFHELinalg"; + let summary = "High Level Fully Homorphic Encryption Linalg dialect"; + let description = [{ + A dialect for representation of high level linalg operations on fully homomorphic ciphertexts. + }]; + let cppNamespace = "::mlir::zamalang::HLFHELinalg"; +} + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h new file mode 100644 index 000000000..692a7c2de --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h @@ -0,0 +1,56 @@ +#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H +#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h" +#include +#include + +namespace mlir { +namespace OpTrait { + +namespace impl { +LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op); +LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op); +} // namespace impl + +/// TensorBroadcastingRules is a trait for operators that should respect the +/// broadcasting rules. All of the operands should be a RankedTensorType, the +/// result must be unique and be a RankedTensorType. The operands shape are +/// considered compatible if we compare dimensions of shapes from the right to +/// the left and if dimension are equals, or equals to one. If one of the shape +/// are smaller than the others, the missing dimension are considered to be one. +/// The result shape should have the size of the largest shape of operands and +/// each dimension `i` should be equals to the maximum of dimensions `i` of +/// each operands. +template +class TensorBroadcastingRules + : public mlir::OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorBroadcastingRules(op); + } +}; + +/// TensorBinaryEintInt verifies that the operation matches the following +/// signature +/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...xi$p'>) -> +/// tensor<...x!HLFHE.eint<$p>>` where `$p <= $p+1`. +template +class TensorBinaryEintInt + : public mlir::OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorBinaryEintInt(op); + } +}; + +} // namespace OpTrait +} // namespace mlir + +#define GET_OP_CLASSES +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h.inc" + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td new file mode 100644 index 000000000..e4fd26256 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -0,0 +1,70 @@ +#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS +#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" + +include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td" +include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td" + +class HLFHELinalg_Op traits = []> : + Op; + +// TensorBroadcastingRules verify that the operands and result verify the broadcasting rules +def TensorBroadcastingRules : NativeOpTrait<"TensorBroadcastingRules">; +def TensorBinaryEintInt : NativeOpTrait<"TensorBinaryEintInt">; + +def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> { + let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers."; + + let description = [{ + Performs an addition follwing the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers. + The width of the clear integers should be less or equals than the witdh of encrypted integers. + + Examples: + ```mlir + // Returns the term to term addition of `%a0` with `%a1` + "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> + + // Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched. + "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> + + // Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of integers. + // + // [1,2,3] [1] [2,3,4] + // [4,5,6] + [2] = [6,7,8] + // [7,8,9] [3] [10,11,12] + // + // The dimension #1 of operand #2 is stretched as it is equals to 1. + "HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>> + + // Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of integers. + // + // [1,2,3] [2,4,6] + // [4,5,6] + [1,2,3] = [5,7,9] + // [7,8,9] [8,10,12] + // + // The dimension #2 of operand #2 is stretched as it is equals to 1. + "HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>> + + // Same behavior than the previous one, but as the dimension #2 is missing of operand #2. + "HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> + + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let builders = [ + OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{ + build($_builder, $_state, rhs.getType(), rhs, lhs); + }]> + ]; +} + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h new file mode 100644 index 000000000..e26a43b85 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h @@ -0,0 +1,11 @@ +#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H +#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H + +#include +#include +#include + +#define GET_TYPEDEF_CLASSES +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.h.inc" + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td new file mode 100644 index 000000000..27c307aed --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td @@ -0,0 +1,11 @@ +#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES +#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES + +include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td" +include "mlir/IR/BuiltinTypes.td" +include "zamalang/Dialect/HLFHE/IR/HLFHETypes.td" + +class HLFHELinalg_Type traits = []> : + TypeDef { } + +#endif diff --git a/compiler/lib/Dialect/CMakeLists.txt b/compiler/lib/Dialect/CMakeLists.txt index a69c38e3d..ef6fdf7df 100644 --- a/compiler/lib/Dialect/CMakeLists.txt +++ b/compiler/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(HLFHELinalg) add_subdirectory(HLFHE) add_subdirectory(MidLFHE) add_subdirectory(LowLFHE) diff --git a/compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt b/compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/compiler/lib/Dialect/HLFHELinalg/IR/CMakeLists.txt b/compiler/lib/Dialect/HLFHELinalg/IR/CMakeLists.txt new file mode 100644 index 000000000..e86a8d305 --- /dev/null +++ b/compiler/lib/Dialect/HLFHELinalg/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(HLFHELinalgDialect + HLFHELinalgDialect.cpp + HLFHELinalgOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHELinalg + + DEPENDS + MLIRHLFHELinalgOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR) + +target_link_libraries(HLFHELinalgDialect PUBLIC MLIRIR) diff --git a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.cpp b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.cpp new file mode 100644 index 000000000..04cc09861 --- /dev/null +++ b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.cpp @@ -0,0 +1,22 @@ +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h" + +#define GET_TYPEDEF_CLASSES +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.cpp.inc" + +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsDialect.cpp.inc" + +using namespace mlir::zamalang::HLFHELinalg; + +void HLFHELinalgDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp.inc" + >(); + + addTypes< +#define GET_TYPEDEF_LIST +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.cpp.inc" + >(); +} diff --git a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp new file mode 100644 index 000000000..ba38f1a5c --- /dev/null +++ b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp @@ -0,0 +1,136 @@ + +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h" + +namespace mlir { +namespace OpTrait { +namespace impl { + +LogicalResult verifyTensorBroadcastingRules( + mlir::Operation *op, llvm::SmallVector operands, + mlir::RankedTensorType result) { + llvm::SmallVector> operandsShapes; + size_t maxOperandsDim = 0; + auto resultShape = result.getShape(); + for (size_t i = 0; i < operands.size(); i++) { + auto shape = operands[i].getShape(); + operandsShapes.push_back(shape); + maxOperandsDim = std::max(shape.size(), maxOperandsDim); + } + // Check the result has the same number of dimension than the highest + // dimension of operands + if (resultShape.size() != maxOperandsDim) { + op->emitOpError() + << "should have the number of dimensions of the result equal to the " + "highest number of dimensions of operands" + << ", got " << result.getShape().size() << " expect " << maxOperandsDim; + return mlir::failure(); + } + + // For all dimension + for (size_t i = 0; i < maxOperandsDim; i++) { + int64_t expectedResultDim = 1; + + // Check the dimension of operands shape are compatible, i.e. equals or 1 + for (size_t j = 0; j < operandsShapes.size(); j++) { + if (i < maxOperandsDim - operandsShapes[j].size()) { + continue; + } + auto k = i - (maxOperandsDim - operandsShapes[j].size()); + auto operandDim = operandsShapes[j][k]; + if (expectedResultDim != 1 && operandDim != 1 && + operandDim != expectedResultDim) { + op->emitOpError() << "has the dimension #" + << (operandsShapes[j].size() - k) + << " of the operand #" << j + << " incompatible with other operands" + << ", got " << operandDim << " expect 1 or " + << expectedResultDim; + return mlir::failure(); + } + + expectedResultDim = std::max(operandDim, expectedResultDim); + } + + // Check the dimension of the result is compatible with dimesion of the + // operands + if (resultShape[i] != expectedResultDim) { + op->emitOpError() << "has the dimension #" << (maxOperandsDim - i) + << " of the result incompatible with operands dimension" + << ", got " << resultShape[i] << " expect " + << expectedResultDim; + return mlir::failure(); + } + } + + return mlir::success(); +} + +LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op) { + // Check operands type are ranked tensor + llvm::SmallVector tensorOperands; + unsigned i = 0; + for (auto opType : op->getOperandTypes()) { + auto tensorType = opType.dyn_cast_or_null(); + if (tensorType == nullptr) { + op->emitOpError() << " should have a ranked tensor as operand #" << i; + return mlir::failure(); + } + tensorOperands.push_back(tensorType); + i++; + } + // Check number of result is 1 + if (op->getNumResults() != 1) { + op->emitOpError() << "should have exactly 1 result, got " + << op->getNumResults(); + } + auto tensorResult = + op->getResult(0).getType().dyn_cast_or_null(); + if (tensorResult == nullptr) { + op->emitOpError(llvm::Twine("should have a ranked tensor as result")); + return mlir::failure(); + } + return verifyTensorBroadcastingRules(op, tensorOperands, tensorResult); +} + +LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) { + if (op->getNumOperands() != 2) { + op->emitOpError() << "should have exactly 2 operands"; + return mlir::failure(); + } + auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); + auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); + if (op0Ty == nullptr || op1Ty == nullptr) { + op->emitOpError() << "should have both operands as tensor"; + return mlir::failure(); + } + auto el0Ty = + op0Ty.getElementType() + .dyn_cast_or_null(); + if (el0Ty == nullptr) { + op->emitOpError() << "should have a !HLFHE.eint as the element type of the " + "tensor of operand #0"; + return mlir::failure(); + } + auto el1Ty = op1Ty.getElementType().dyn_cast_or_null(); + if (el1Ty == nullptr) { + op->emitOpError() << "should have an integer as the element type of the " + "tensor of operand #1"; + return mlir::failure(); + } + // llvm::errs() << width << ""; + if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { + op->emitOpError() + << "should have the width of integer values less or equals " + "than the width of encrypted values + 1"; + return mlir::failure(); + } + return mlir::success(); +} +} // namespace impl + +} // namespace OpTrait +} // namespace mlir + +#define GET_OP_CLASSES +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp.inc" diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 82c3077ca..5e1b2a6ac 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_library(ZamalangSupport MLIRConversionPassIncGen LINK_LIBS PUBLIC + HLFHELinalgDialect HLFHETensorOpsToLinalg HLFHEToMidLFHE LowLFHEUnparametrize diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index afc0abf6a..6fb8f2833 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -18,6 +18,7 @@ #include "zamalang/Conversion/Utils/GlobalFHEContext.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" @@ -485,6 +486,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { } // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir new file mode 100644 index 000000000..17ce71261 --- /dev/null +++ b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir @@ -0,0 +1,39 @@ +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s + +///////////////////////////////////////////////// +// HLFHELinalg.add_eint_int +///////////////////////////////////////////////// + +// Incompatible dimension of operands +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint_int' op has the dimension #2 of the operand #1 incompatible with other operands, got 2 expect 1 or 3}} + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x2x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible dimension of result +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x10x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint_int' op has the dimension #3 of the result incompatible with operands dimension, got 10 expect 2}} + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x10x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x10x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible number of dimension between operands and result +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint_int' op should have the number of dimensions of the result equal to the highest number of dimensions of operands, got 3 expect 4}} + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible width between clear and encrypted witdh +func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint_int' op should have the width of integer values less or equals than the width of encrypted values + 1}} + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<2x3x4xi4>) -> tensor<2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x3x4x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.mlir new file mode 100644 index 000000000..ec1a765e9 --- /dev/null +++ b/compiler/tests/Dialect/HLFHELinalg/ops.mlir @@ -0,0 +1,55 @@ +// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s + +///////////////////////////////////////////////// +// HLFHELinalg.add_eint_int +///////////////////////////////////////////////// + +// 1D tensor +// CHECK: func @add_eint_int_1D(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_int_1D(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> + return %1: tensor<4x!HLFHE.eint<2>> +} + +// 2D tensor +// CHECK: func @add_eint_int_2D(%[[a0:.*]]: tensor<2x4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<2x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_int_2D(%a0: tensor<2x4x!HLFHE.eint<2>>, %a1: tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> + return %1: tensor<2x4x!HLFHE.eint<2>> +} + +// 10D tensor +// CHECK: func @add_eint_int_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_int_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> + return %1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +} + +// Broadcasting with tensor with dimensions equals to one +// CHECK: func @add_eint_int_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_int_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> + return %1: tensor<3x4x5x!HLFHE.eint<2>> +} + +// Broadcasting with a tensor less dimensions of another +// CHECK: func @add_eint_int_broadcast_2(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_int_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> { + %1 ="HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> + return %1: tensor<3x4x!HLFHE.eint<2>> +} \ No newline at end of file