enhance(compiler): Add a region to the HLFHE.apply_univariate for the expression of the univariate closure

This commit is contained in:
youben11
2021-05-25 12:45:25 +01:00
committed by Quentin Bourgerie
parent 183fe7f6fa
commit 4e663df5af
4 changed files with 71 additions and 5 deletions

View File

@@ -3,6 +3,9 @@
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#define GET_OP_CLASSES

View File

@@ -9,12 +9,21 @@
#ifndef ZAMALANG_DIALECT_HLFHE_IR_HLFHE_OPS
#define ZAMALANG_DIALECT_HLFHE_IR_HLFHE_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.td"
include "zamalang/Dialect/HLFHE/IR/HLFHETypes.td"
class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<HLFHE_Dialect, mnemonic, traits>;
def ApplyUnivariateRegion : Region<
CPred<"::mlir::zamalang::predApplyUnivariateRegion($_self)">,
"apply_univariate region needs one block with one any integer argument">;
def AddEintIntOp : HLFHE_Op<"add_eint_int"> {
let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b);
let results = (outs EncryptedIntegerType);
@@ -72,16 +81,27 @@ def MulEintOp : HLFHE_Op<"mul_eint"> {
}
def ApplyUnivariateOp : HLFHE_Op<"apply_univariate"> {
// TODO: express a functionLike?
let arguments = (ins EncryptedIntegerType:$a);
let results = (outs EncryptedIntegerType);
let arguments = (ins EncryptedIntegerType:$x);
let results = (outs EncryptedIntegerType:$result);
let builders = [
OpBuilder<(ins "Value":$a), [{
build($_builder, $_state, a.getType(), a);
OpBuilder<(ins "Value": $x), [{
build($_builder, $_state, x.getType(), x);
}]>
];
let regions = (region ApplyUnivariateRegion:$body);
// let assemblyFormat = "$x `:` type($x) $body attr-dict `:` type($result)";
}
def ReturnOp : HLFHE_Op<"apply_univariate_return", [NoSideEffect, ReturnLike, Terminator]> {
let summary = "terminator of apply_univariate block";
let arguments = (ins AnyInteger);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
}
#endif

View File

@@ -1,5 +1,28 @@
#include "mlir/IR/Region.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
namespace mlir{
namespace zamalang{
bool predApplyUnivariateRegion(::mlir::Region &region){
if (region.getBlocks().size() != 1) {
return false;
}
auto args = region.getBlocks().front().getArguments();
if (args.size() != 1) {
return false;
}
if (! args.front().getType().isa<mlir::IntegerType>()){
return false;
}
//TODO: need to handle when there is no terminator
auto terminator = region.getBlocks().front().getTerminator();
return terminator->getName().getStringRef().equals("HLFHE.apply_univariate_return");
}
}
}
#define GET_OP_CLASSES
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.cpp.inc"

View File

@@ -48,3 +48,23 @@ func @neg_eint(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
%1 = "HLFHE.neg_eint"(%arg0): (!HLFHE.eint<0>) -> (!HLFHE.eint<0>)
return %1: !HLFHE.eint<0>
}
// CHECK-LABEL: func @apply_univariate(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0>
func @apply_univariate(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.apply_univariate"(%arg0) ( {
// CHECK-NEXT: ^bb0(%[[V2:.*]]: i32):
// CHECK-NEXT: %[[CST:.*]] = constant 5 : i32
// CHECK-NEXT: %[[V3:.*]] = muli %[[V2]], %[[CST]] : i32
// CHECK-NEXT: "HLFHE.apply_univariate_return"(%[[V3]]) : (i32) -> ()
// CHECK-NEXT: }) : (!HLFHE.eint<0>) -> !HLFHE.eint<0>
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<0>
%0 = "HLFHE.apply_univariate"(%arg0)({
^bb0(%a: i32):
%cst = constant 5: i32
%res = std.muli %a, %cst : i32
"HLFHE.apply_univariate_return"(%res): (i32) -> ()
}) : (!HLFHE.eint<0>) -> !HLFHE.eint<0>
return %0: !HLFHE.eint<0>
}