mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Introduce the HLFHELinalg dialect and a first operator HLFHELinalg.add_eint_int
This commit is contained in:
committed by
Andi Drebes
parent
247cc489c5
commit
0d4e10169b
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(HLFHE)
|
||||
add_subdirectory(HLFHELinalg)
|
||||
add_subdirectory(MidLFHE)
|
||||
add_subdirectory(LowLFHE)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
add_subdirectory(IR)
|
||||
@@ -0,0 +1,9 @@
|
||||
set(LLVM_TARGET_DEFINITIONS HLFHELinalgOps.td)
|
||||
mlir_tablegen(HLFHELinalgOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(HLFHELinalgOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(HLFHELinalgOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsDialect.h.inc -gen-dialect-decls -dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsDialect.cpp.inc -gen-dialect-defs -dialect=HLFHELinalg)
|
||||
add_public_tablegen_target(MLIRHLFHELinalgOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRHLFHELinalgOpsIncGen)
|
||||
@@ -0,0 +1,10 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,15 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def HLFHELinalg_Dialect : Dialect {
|
||||
let name = "HLFHELinalg";
|
||||
let summary = "High Level Fully Homorphic Encryption Linalg dialect";
|
||||
let description = [{
|
||||
A dialect for representation of high level linalg operations on fully homomorphic ciphertexts.
|
||||
}];
|
||||
let cppNamespace = "::mlir::zamalang::HLFHELinalg";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,56 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h"
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
namespace impl {
|
||||
LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op);
|
||||
LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
/// TensorBroadcastingRules is a trait for operators that should respect the
|
||||
/// broadcasting rules. All of the operands should be a RankedTensorType, the
|
||||
/// result must be unique and be a RankedTensorType. The operands shape are
|
||||
/// considered compatible if we compare dimensions of shapes from the right to
|
||||
/// the left and if dimension are equals, or equals to one. If one of the shape
|
||||
/// are smaller than the others, the missing dimension are considered to be one.
|
||||
/// The result shape should have the size of the largest shape of operands and
|
||||
/// each dimension `i` should be equals to the maximum of dimensions `i` of
|
||||
/// each operands.
|
||||
template <typename ConcreteType>
|
||||
class TensorBroadcastingRules
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBroadcastingRules> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyTensorBroadcastingRules(op);
|
||||
}
|
||||
};
|
||||
|
||||
/// TensorBinaryEintInt verifies that the operation matches the following
|
||||
/// signature
|
||||
/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...xi$p'>) ->
|
||||
/// tensor<...x!HLFHE.eint<$p>>` where `$p <= $p+1`.
|
||||
template <typename ConcreteType>
|
||||
class TensorBinaryEintInt
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEintInt> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyTensorBinaryEintInt(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,70 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td"
|
||||
include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td"
|
||||
|
||||
class HLFHELinalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<HLFHELinalg_Dialect, mnemonic, traits>;
|
||||
|
||||
// TensorBroadcastingRules verify that the operands and result verify the broadcasting rules
|
||||
def TensorBroadcastingRules : NativeOpTrait<"TensorBroadcastingRules">;
|
||||
def TensorBinaryEintInt : NativeOpTrait<"TensorBinaryEintInt">;
|
||||
|
||||
def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
|
||||
let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers.";
|
||||
|
||||
let description = [{
|
||||
Performs an addition follwing the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers.
|
||||
The width of the clear integers should be less or equals than the witdh of encrypted integers.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
|
||||
|
||||
// Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched.
|
||||
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of integers.
|
||||
//
|
||||
// [1,2,3] [1] [2,3,4]
|
||||
// [4,5,6] + [2] = [6,7,8]
|
||||
// [7,8,9] [3] [10,11,12]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of integers.
|
||||
//
|
||||
// [1,2,3] [2,4,6]
|
||||
// [4,5,6] + [1,2,3] = [5,7,9]
|
||||
// [7,8,9] [8,10,12]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
|
||||
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
|
||||
"HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
|
||||
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{
|
||||
build($_builder, $_state, rhs.getType(), rhs, lhs);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,11 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,11 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES
|
||||
|
||||
include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "zamalang/Dialect/HLFHE/IR/HLFHETypes.td"
|
||||
|
||||
class HLFHELinalg_Type<string name, list<Trait> traits = []> :
|
||||
TypeDef<HLFHELinalg_Dialect, name, traits> { }
|
||||
|
||||
#endif
|
||||
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(HLFHELinalg)
|
||||
add_subdirectory(HLFHE)
|
||||
add_subdirectory(MidLFHE)
|
||||
add_subdirectory(LowLFHE)
|
||||
|
||||
1
compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt
Normal file
1
compiler/lib/Dialect/HLFHELinalg/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(IR)
|
||||
14
compiler/lib/Dialect/HLFHELinalg/IR/CMakeLists.txt
Normal file
14
compiler/lib/Dialect/HLFHELinalg/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
add_mlir_dialect_library(HLFHELinalgDialect
|
||||
HLFHELinalgDialect.cpp
|
||||
HLFHELinalgOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHELinalg
|
||||
|
||||
DEPENDS
|
||||
MLIRHLFHELinalgOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR)
|
||||
|
||||
target_link_libraries(HLFHELinalgDialect PUBLIC MLIRIR)
|
||||
22
compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.cpp
Normal file
22
compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.cpp.inc"
|
||||
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsDialect.cpp.inc"
|
||||
|
||||
using namespace mlir::zamalang::HLFHELinalg;
|
||||
|
||||
void HLFHELinalgDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp.inc"
|
||||
>();
|
||||
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.cpp.inc"
|
||||
>();
|
||||
}
|
||||
136
compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp
Normal file
136
compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp
Normal file
@@ -0,0 +1,136 @@
|
||||
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
namespace impl {
|
||||
|
||||
LogicalResult verifyTensorBroadcastingRules(
|
||||
mlir::Operation *op, llvm::SmallVector<mlir::RankedTensorType> operands,
|
||||
mlir::RankedTensorType result) {
|
||||
llvm::SmallVector<llvm::ArrayRef<int64_t>> operandsShapes;
|
||||
size_t maxOperandsDim = 0;
|
||||
auto resultShape = result.getShape();
|
||||
for (size_t i = 0; i < operands.size(); i++) {
|
||||
auto shape = operands[i].getShape();
|
||||
operandsShapes.push_back(shape);
|
||||
maxOperandsDim = std::max(shape.size(), maxOperandsDim);
|
||||
}
|
||||
// Check the result has the same number of dimension than the highest
|
||||
// dimension of operands
|
||||
if (resultShape.size() != maxOperandsDim) {
|
||||
op->emitOpError()
|
||||
<< "should have the number of dimensions of the result equal to the "
|
||||
"highest number of dimensions of operands"
|
||||
<< ", got " << result.getShape().size() << " expect " << maxOperandsDim;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// For all dimension
|
||||
for (size_t i = 0; i < maxOperandsDim; i++) {
|
||||
int64_t expectedResultDim = 1;
|
||||
|
||||
// Check the dimension of operands shape are compatible, i.e. equals or 1
|
||||
for (size_t j = 0; j < operandsShapes.size(); j++) {
|
||||
if (i < maxOperandsDim - operandsShapes[j].size()) {
|
||||
continue;
|
||||
}
|
||||
auto k = i - (maxOperandsDim - operandsShapes[j].size());
|
||||
auto operandDim = operandsShapes[j][k];
|
||||
if (expectedResultDim != 1 && operandDim != 1 &&
|
||||
operandDim != expectedResultDim) {
|
||||
op->emitOpError() << "has the dimension #"
|
||||
<< (operandsShapes[j].size() - k)
|
||||
<< " of the operand #" << j
|
||||
<< " incompatible with other operands"
|
||||
<< ", got " << operandDim << " expect 1 or "
|
||||
<< expectedResultDim;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
expectedResultDim = std::max(operandDim, expectedResultDim);
|
||||
}
|
||||
|
||||
// Check the dimension of the result is compatible with dimesion of the
|
||||
// operands
|
||||
if (resultShape[i] != expectedResultDim) {
|
||||
op->emitOpError() << "has the dimension #" << (maxOperandsDim - i)
|
||||
<< " of the result incompatible with operands dimension"
|
||||
<< ", got " << resultShape[i] << " expect "
|
||||
<< expectedResultDim;
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op) {
|
||||
// Check operands type are ranked tensor
|
||||
llvm::SmallVector<mlir::RankedTensorType> tensorOperands;
|
||||
unsigned i = 0;
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
auto tensorType = opType.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensorType == nullptr) {
|
||||
op->emitOpError() << " should have a ranked tensor as operand #" << i;
|
||||
return mlir::failure();
|
||||
}
|
||||
tensorOperands.push_back(tensorType);
|
||||
i++;
|
||||
}
|
||||
// Check number of result is 1
|
||||
if (op->getNumResults() != 1) {
|
||||
op->emitOpError() << "should have exactly 1 result, got "
|
||||
<< op->getNumResults();
|
||||
}
|
||||
auto tensorResult =
|
||||
op->getResult(0).getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensorResult == nullptr) {
|
||||
op->emitOpError(llvm::Twine("should have a ranked tensor as result"));
|
||||
return mlir::failure();
|
||||
}
|
||||
return verifyTensorBroadcastingRules(op, tensorOperands, tensorResult);
|
||||
}
|
||||
|
||||
LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) {
|
||||
if (op->getNumOperands() != 2) {
|
||||
op->emitOpError() << "should have exactly 2 operands";
|
||||
return mlir::failure();
|
||||
}
|
||||
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
|
||||
auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null<mlir::TensorType>();
|
||||
if (op0Ty == nullptr || op1Ty == nullptr) {
|
||||
op->emitOpError() << "should have both operands as tensor";
|
||||
return mlir::failure();
|
||||
}
|
||||
auto el0Ty =
|
||||
op0Ty.getElementType()
|
||||
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
if (el0Ty == nullptr) {
|
||||
op->emitOpError() << "should have a !HLFHE.eint as the element type of the "
|
||||
"tensor of operand #0";
|
||||
return mlir::failure();
|
||||
}
|
||||
auto el1Ty = op1Ty.getElementType().dyn_cast_or_null<mlir::IntegerType>();
|
||||
if (el1Ty == nullptr) {
|
||||
op->emitOpError() << "should have an integer as the element type of the "
|
||||
"tensor of operand #1";
|
||||
return mlir::failure();
|
||||
}
|
||||
// llvm::errs() << width << "";
|
||||
if (el1Ty.getWidth() > el0Ty.getWidth() + 1) {
|
||||
op->emitOpError()
|
||||
<< "should have the width of integer values less or equals "
|
||||
"than the width of encrypted values + 1";
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp.inc"
|
||||
@@ -16,6 +16,7 @@ add_mlir_library(ZamalangSupport
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
HLFHELinalgDialect
|
||||
HLFHETensorOpsToLinalg
|
||||
HLFHEToMidLFHE
|
||||
LowLFHEUnparametrize
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
@@ -485,6 +486,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
}
|
||||
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHELinalg::HLFHELinalgDialect>();
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
|
||||
39
compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir
Normal file
39
compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir
Normal file
@@ -0,0 +1,39 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.add_eint_int
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// Incompatible dimension of operands
|
||||
func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x2x3x4x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.add_eint_int' op has the dimension #2 of the operand #1 incompatible with other operands, got 2 expect 1 or 3}}
|
||||
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x2x3x4x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Incompatible dimension of result
|
||||
func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x10x3x4x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.add_eint_int' op has the dimension #3 of the result incompatible with operands dimension, got 10 expect 2}}
|
||||
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x10x3x4x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x10x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Incompatible number of dimension between operands and result
|
||||
func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x3x4x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.add_eint_int' op should have the number of dimensions of the result equal to the highest number of dimensions of operands, got 3 expect 4}}
|
||||
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x3x4x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Incompatible width between clear and encrypted witdh
|
||||
func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tensor<2x3x4x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.add_eint_int' op should have the width of integer values less or equals than the width of encrypted values + 1}}
|
||||
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<2x3x4xi4>) -> tensor<2x3x4x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
55
compiler/tests/Dialect/HLFHELinalg/ops.mlir
Normal file
55
compiler/tests/Dialect/HLFHELinalg/ops.mlir
Normal file
@@ -0,0 +1,55 @@
|
||||
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.add_eint_int
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// 1D tensor
|
||||
// CHECK: func @add_eint_int_1D(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @add_eint_int_1D(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>>
|
||||
return %1: tensor<4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// 2D tensor
|
||||
// CHECK: func @add_eint_int_2D(%[[a0:.*]]: tensor<2x4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<2x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @add_eint_int_2D(%a0: tensor<2x4x!HLFHE.eint<2>>, %a1: tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>>
|
||||
return %1: tensor<2x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// 10D tensor
|
||||
// CHECK: func @add_eint_int_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @add_eint_int_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>
|
||||
return %1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// Broadcasting with tensor with dimensions equals to one
|
||||
// CHECK: func @add_eint_int_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @add_eint_int_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>>
|
||||
return %1: tensor<3x4x5x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// Broadcasting with a tensor less dimensions of another
|
||||
// CHECK: func @add_eint_int_broadcast_2(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint_int"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<3x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @add_eint_int_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> {
|
||||
%1 ="HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>>
|
||||
return %1: tensor<3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
Reference in New Issue
Block a user