mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Introduce the HLFHELinalg dialect and a first operator HLFHELinalg.add_eint_int
This commit is contained in:
committed by
Andi Drebes
parent
247cc489c5
commit
0d4e10169b
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(HLFHE)
|
||||
add_subdirectory(HLFHELinalg)
|
||||
add_subdirectory(MidLFHE)
|
||||
add_subdirectory(LowLFHE)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
add_subdirectory(IR)
|
||||
@@ -0,0 +1,9 @@
|
||||
set(LLVM_TARGET_DEFINITIONS HLFHELinalgOps.td)
|
||||
mlir_tablegen(HLFHELinalgOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(HLFHELinalgOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(HLFHELinalgOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsDialect.h.inc -gen-dialect-decls -dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsDialect.cpp.inc -gen-dialect-defs -dialect=HLFHELinalg)
|
||||
add_public_tablegen_target(MLIRHLFHELinalgOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRHLFHELinalgOpsIncGen)
|
||||
@@ -0,0 +1,10 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,15 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def HLFHELinalg_Dialect : Dialect {
|
||||
let name = "HLFHELinalg";
|
||||
let summary = "High Level Fully Homorphic Encryption Linalg dialect";
|
||||
let description = [{
|
||||
A dialect for representation of high level linalg operations on fully homomorphic ciphertexts.
|
||||
}];
|
||||
let cppNamespace = "::mlir::zamalang::HLFHELinalg";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,56 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h"
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
namespace impl {
|
||||
LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op);
|
||||
LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
/// TensorBroadcastingRules is a trait for operators that should respect the
|
||||
/// broadcasting rules. All of the operands should be a RankedTensorType, the
|
||||
/// result must be unique and be a RankedTensorType. The operands shape are
|
||||
/// considered compatible if we compare dimensions of shapes from the right to
|
||||
/// the left and if dimension are equals, or equals to one. If one of the shape
|
||||
/// are smaller than the others, the missing dimension are considered to be one.
|
||||
/// The result shape should have the size of the largest shape of operands and
|
||||
/// each dimension `i` should be equals to the maximum of dimensions `i` of
|
||||
/// each operands.
|
||||
template <typename ConcreteType>
|
||||
class TensorBroadcastingRules
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBroadcastingRules> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyTensorBroadcastingRules(op);
|
||||
}
|
||||
};
|
||||
|
||||
/// TensorBinaryEintInt verifies that the operation matches the following
|
||||
/// signature
|
||||
/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...xi$p'>) ->
|
||||
/// tensor<...x!HLFHE.eint<$p>>` where `$p <= $p+1`.
|
||||
template <typename ConcreteType>
|
||||
class TensorBinaryEintInt
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEintInt> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyTensorBinaryEintInt(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,70 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td"
|
||||
include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td"
|
||||
|
||||
class HLFHELinalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<HLFHELinalg_Dialect, mnemonic, traits>;
|
||||
|
||||
// TensorBroadcastingRules verify that the operands and result verify the broadcasting rules
|
||||
def TensorBroadcastingRules : NativeOpTrait<"TensorBroadcastingRules">;
|
||||
def TensorBinaryEintInt : NativeOpTrait<"TensorBinaryEintInt">;
|
||||
|
||||
def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
|
||||
let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers.";
|
||||
|
||||
let description = [{
|
||||
Performs an addition follwing the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers.
|
||||
The width of the clear integers should be less or equals than the witdh of encrypted integers.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
|
||||
|
||||
// Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched.
|
||||
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of integers.
|
||||
//
|
||||
// [1,2,3] [1] [2,3,4]
|
||||
// [4,5,6] + [2] = [6,7,8]
|
||||
// [7,8,9] [3] [10,11,12]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of integers.
|
||||
//
|
||||
// [1,2,3] [2,4,6]
|
||||
// [4,5,6] + [1,2,3] = [5,7,9]
|
||||
// [7,8,9] [8,10,12]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
|
||||
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
|
||||
"HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
|
||||
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{
|
||||
build($_builder, $_state, rhs.getType(), rhs, lhs);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,11 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,11 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES
|
||||
#define ZAMALANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES
|
||||
|
||||
include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "zamalang/Dialect/HLFHE/IR/HLFHETypes.td"
|
||||
|
||||
class HLFHELinalg_Type<string name, list<Trait> traits = []> :
|
||||
TypeDef<HLFHELinalg_Dialect, name, traits> { }
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user