From 4e663df5afc5f830545fce0cad8734723d9feae2 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 25 May 2021 12:45:25 +0100 Subject: [PATCH] enhance(compiler): Add a region to the HLFHE.apply_univariate for the expression of the univariate closure --- .../zamalang/Dialect/HLFHE/IR/HLFHEOps.h | 3 ++ .../zamalang/Dialect/HLFHE/IR/HLFHEOps.td | 30 +++++++++++++++---- compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp | 23 ++++++++++++++ compiler/tests/Dialect/HLFHE/ops.mlir | 20 +++++++++++++ 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h index 3a60ba09d..f5f6f9d50 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h @@ -3,6 +3,9 @@ #include #include +#include +#include + #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #define GET_OP_CLASSES diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index 095471d51..2797f6d1f 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -9,12 +9,21 @@ #ifndef ZAMALANG_DIALECT_HLFHE_IR_HLFHE_OPS #define ZAMALANG_DIALECT_HLFHE_IR_HLFHE_OPS +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" + include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.td" include "zamalang/Dialect/HLFHE/IR/HLFHETypes.td" class HLFHE_Op traits = []> : Op; + +def ApplyUnivariateRegion : Region< + CPred<"::mlir::zamalang::predApplyUnivariateRegion($_self)">, + "apply_univariate region needs one block with one any integer argument">; + + def AddEintIntOp : HLFHE_Op<"add_eint_int"> { let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b); let results = (outs EncryptedIntegerType); @@ -72,16 +81,27 @@ def MulEintOp : HLFHE_Op<"mul_eint"> { } def ApplyUnivariateOp : HLFHE_Op<"apply_univariate"> { - // TODO: express a functionLike? - let arguments = (ins EncryptedIntegerType:$a); - let results = (outs EncryptedIntegerType); + let arguments = (ins EncryptedIntegerType:$x); + let results = (outs EncryptedIntegerType:$result); let builders = [ - OpBuilder<(ins "Value":$a), [{ - build($_builder, $_state, a.getType(), a); + OpBuilder<(ins "Value": $x), [{ + build($_builder, $_state, x.getType(), x); }]> ]; + + + let regions = (region ApplyUnivariateRegion:$body); + // let assemblyFormat = "$x `:` type($x) $body attr-dict `:` type($result)"; } +def ReturnOp : HLFHE_Op<"apply_univariate_return", [NoSideEffect, ReturnLike, Terminator]> { + let summary = "terminator of apply_univariate block"; + let arguments = (ins AnyInteger); + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; +} + + + #endif diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index ea8115355..13c487a60 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -1,5 +1,28 @@ +#include "mlir/IR/Region.h" + #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" + +namespace mlir{ +namespace zamalang{ + bool predApplyUnivariateRegion(::mlir::Region ®ion){ + if (region.getBlocks().size() != 1) { + return false; + } + auto args = region.getBlocks().front().getArguments(); + if (args.size() != 1) { + return false; + } + if (! args.front().getType().isa()){ + return false; + } + //TODO: need to handle when there is no terminator + auto terminator = region.getBlocks().front().getTerminator(); + return terminator->getName().getStringRef().equals("HLFHE.apply_univariate_return"); + } +} +} + #define GET_OP_CLASSES #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.cpp.inc" diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index 1f604d3e0..036d73b48 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -48,3 +48,23 @@ func @neg_eint(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> { %1 = "HLFHE.neg_eint"(%arg0): (!HLFHE.eint<0>) -> (!HLFHE.eint<0>) return %1: !HLFHE.eint<0> } + +// CHECK-LABEL: func @apply_univariate(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> +func @apply_univariate(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> { + // CHECK-NEXT: %[[V1:.*]] = "HLFHE.apply_univariate"(%arg0) ( { + // CHECK-NEXT: ^bb0(%[[V2:.*]]: i32): + // CHECK-NEXT: %[[CST:.*]] = constant 5 : i32 + // CHECK-NEXT: %[[V3:.*]] = muli %[[V2]], %[[CST]] : i32 + // CHECK-NEXT: "HLFHE.apply_univariate_return"(%[[V3]]) : (i32) -> () + // CHECK-NEXT: }) : (!HLFHE.eint<0>) -> !HLFHE.eint<0> + // CHECK-NEXT: return %[[V1]] : !HLFHE.eint<0> + + %0 = "HLFHE.apply_univariate"(%arg0)({ + ^bb0(%a: i32): + %cst = constant 5: i32 + %res = std.muli %a, %cst : i32 + "HLFHE.apply_univariate_return"(%res): (i32) -> () + }) : (!HLFHE.eint<0>) -> !HLFHE.eint<0> + + return %0: !HLFHE.eint<0> +}