mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
enhance(compiler): Add a region to the HLFHE.apply_univariate for the expression of the univariate closure
This commit is contained in:
committed by
Quentin Bourgerie
parent
183fe7f6fa
commit
4e663df5af
@@ -3,6 +3,9 @@
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
||||
@@ -9,12 +9,21 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHE_IR_HLFHE_OPS
|
||||
#define ZAMALANG_DIALECT_HLFHE_IR_HLFHE_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.td"
|
||||
include "zamalang/Dialect/HLFHE/IR/HLFHETypes.td"
|
||||
|
||||
class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<HLFHE_Dialect, mnemonic, traits>;
|
||||
|
||||
|
||||
def ApplyUnivariateRegion : Region<
|
||||
CPred<"::mlir::zamalang::predApplyUnivariateRegion($_self)">,
|
||||
"apply_univariate region needs one block with one any integer argument">;
|
||||
|
||||
|
||||
def AddEintIntOp : HLFHE_Op<"add_eint_int"> {
|
||||
let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b);
|
||||
let results = (outs EncryptedIntegerType);
|
||||
@@ -72,16 +81,27 @@ def MulEintOp : HLFHE_Op<"mul_eint"> {
|
||||
}
|
||||
|
||||
def ApplyUnivariateOp : HLFHE_Op<"apply_univariate"> {
|
||||
// TODO: express a functionLike?
|
||||
let arguments = (ins EncryptedIntegerType:$a);
|
||||
let results = (outs EncryptedIntegerType);
|
||||
let arguments = (ins EncryptedIntegerType:$x);
|
||||
let results = (outs EncryptedIntegerType:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a), [{
|
||||
build($_builder, $_state, a.getType(), a);
|
||||
OpBuilder<(ins "Value": $x), [{
|
||||
build($_builder, $_state, x.getType(), x);
|
||||
}]>
|
||||
];
|
||||
|
||||
|
||||
let regions = (region ApplyUnivariateRegion:$body);
|
||||
// let assemblyFormat = "$x `:` type($x) $body attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
def ReturnOp : HLFHE_Op<"apply_univariate_return", [NoSideEffect, ReturnLike, Terminator]> {
|
||||
let summary = "terminator of apply_univariate block";
|
||||
let arguments = (ins AnyInteger);
|
||||
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,28 @@
|
||||
#include "mlir/IR/Region.h"
|
||||
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
|
||||
|
||||
namespace mlir{
|
||||
namespace zamalang{
|
||||
bool predApplyUnivariateRegion(::mlir::Region ®ion){
|
||||
if (region.getBlocks().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
auto args = region.getBlocks().front().getArguments();
|
||||
if (args.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
if (! args.front().getType().isa<mlir::IntegerType>()){
|
||||
return false;
|
||||
}
|
||||
//TODO: need to handle when there is no terminator
|
||||
auto terminator = region.getBlocks().front().getTerminator();
|
||||
return terminator->getName().getStringRef().equals("HLFHE.apply_univariate_return");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.cpp.inc"
|
||||
|
||||
@@ -48,3 +48,23 @@ func @neg_eint(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
|
||||
%1 = "HLFHE.neg_eint"(%arg0): (!HLFHE.eint<0>) -> (!HLFHE.eint<0>)
|
||||
return %1: !HLFHE.eint<0>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @apply_univariate(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
func @apply_univariate(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.apply_univariate"(%arg0) ( {
|
||||
// CHECK-NEXT: ^bb0(%[[V2:.*]]: i32):
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant 5 : i32
|
||||
// CHECK-NEXT: %[[V3:.*]] = muli %[[V2]], %[[CST]] : i32
|
||||
// CHECK-NEXT: "HLFHE.apply_univariate_return"(%[[V3]]) : (i32) -> ()
|
||||
// CHECK-NEXT: }) : (!HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<0>
|
||||
|
||||
%0 = "HLFHE.apply_univariate"(%arg0)({
|
||||
^bb0(%a: i32):
|
||||
%cst = constant 5: i32
|
||||
%res = std.muli %a, %cst : i32
|
||||
"HLFHE.apply_univariate_return"(%res): (i32) -> ()
|
||||
}) : (!HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
|
||||
return %0: !HLFHE.eint<0>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user