feat(compiler): Introduce HLFHEDialect

This commit is contained in:
Quentin Bourgerie
2021-05-18 16:26:50 +02:00
parent f7c11a0c4e
commit 32340d3ec0
8 changed files with 194 additions and 13 deletions

View File

@@ -4,6 +4,7 @@ project(zamacompiler LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti")
find_package(MLIR REQUIRED CONFIG)
message(STATUS "Using MLIR cmake file from: ${MLIR_DIR}")

View File

@@ -3,6 +3,7 @@
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#define GET_OP_CLASSES
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h.inc"

View File

@@ -10,8 +10,78 @@
#define ZAMALANG_DIALECT_HLFHE_IR_HLFHE_OPS
include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.td"
include "zamalang/Dialect/HLFHE/IR/HLFHETypes.td"
class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<HLFHE_Dialect, mnemonic, traits>;
def AddEintIntOp : HLFHE_Op<"add_eint_int"> {
let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b);
let results = (outs EncryptedIntegerType);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
build($_builder, $_state, a.getType(), a, b);
}]>
];
}
def AddEintOp : HLFHE_Op<"add_eint"> {
let arguments = (ins EncryptedIntegerType:$a, EncryptedIntegerType:$b);
let results = (outs EncryptedIntegerType);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
build($_builder, $_state, a.getType(), a, b);
}]>
];
}
def NegEintOp : HLFHE_Op<"neg_eint"> {
let arguments = (ins EncryptedIntegerType:$a);
let results = (outs EncryptedIntegerType);
let builders = [
OpBuilder<(ins "Value":$a), [{
build($_builder, $_state, a.getType(), a);
}]>
];
}
def MulEintIntOp : HLFHE_Op<"mul_eint_int"> {
let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b);
let results = (outs EncryptedIntegerType);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
build($_builder, $_state, a.getType(), a, b);
}]>
];
}
def MulEintOp : HLFHE_Op<"mul_eint"> {
let arguments = (ins EncryptedIntegerType:$a, EncryptedIntegerType:$b);
let results = (outs EncryptedIntegerType);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
build($_builder, $_state, a.getType(), a, b);
}]>
];
}
def ApplyUnivariateOp : HLFHE_Op<"apply_univariate"> {
// TODO: express a functionLike?
let arguments = (ins EncryptedIntegerType:$a);
let results = (outs EncryptedIntegerType);
let builders = [
OpBuilder<(ins "Value":$a), [{
build($_builder, $_state, a.getType(), a);
}]>
];
}
#endif

View File

@@ -0,0 +1,12 @@
#ifndef ZAMALANG_DIALECT_HLFHE_HLFHE_TYPES_H
#define ZAMALANG_DIALECT_HLFHE_HLFHE_TYPES_H
#include "llvm/ADT/TypeSwitch.h"
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#define GET_TYPEDEF_CLASSES
#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsTypes.h.inc"
#endif

View File

@@ -0,0 +1,37 @@
#ifndef ZAMALANG_DIALECT_HLFHE_IR_HLFHE_TYPES
#define ZAMALANG_DIALECT_HLFHE_IR_HLFHE_TYPES
include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.td"
class HLFHE_Type<string name> : TypeDef<HLFHE_Dialect, name> { }
def EncryptedIntegerType : HLFHE_Type<"EncryptedInteger"> {
let mnemonic = "eint";
let summary = "An encrypted integer";
let description = [{
An encrypted integer with clear precision of width.
}];
let parameters = (ins "unsigned":$width);
// We define the printer inline.
let printer = [{
$_printer << "eint<" << getImpl()->width << ">";
}];
// The parser is defined here also.
let parser = [{
if ($_parser.parseLess())
return Type();
int width;
if ($_parser.parseInteger(width))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, width);
}];
}
#endif

View File

@@ -1,5 +1,9 @@
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#define GET_TYPEDEF_CLASSES
#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsTypes.cpp.inc"
using namespace mlir::zamalang::HLFHE;
@@ -8,4 +12,29 @@ void HLFHEDialect::initialize() {
#define GET_OP_LIST
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsTypes.cpp.inc"
>();
}
::mlir::Type HLFHEDialect::parseType(::mlir::DialectAsmParser &parser) const
{
if(parser.parseKeyword("eint").failed())
return ::mlir::Type();
return EncryptedIntegerType::parse(this->getContext(), parser);
}
void HLFHEDialect::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const
{
mlir::zamalang::HLFHE::EncryptedIntegerType eint = type.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
if (eint != nullptr) {
eint.print(printer);
return;
}
// TODO - What should be done here?
printer << "unknwontype";
}

View File

@@ -1,4 +1,5 @@
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#define GET_OP_CLASSES
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.cpp.inc"

View File

@@ -1,24 +1,54 @@
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "llvm/Support/SourceMgr.h"
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Builders.h>
#include "llvm/Support/SourceMgr.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
mlir::FuncOp buildFunction(mlir::OpBuilder &builder) {
mlir::FunctionType func_type = builder.getFunctionType({ mlir::zamalang::HLFHE::EncryptedIntegerType::get(builder.getContext(), 32) }, llvm::None);
mlir::FuncOp funcOp =
mlir::FuncOp::create(builder.getUnknownLoc(), "hlfhe", func_type);
mlir::FuncOp function(funcOp);
mlir::Block &entryBlock = *function.addEntryBlock();
builder.setInsertionPointToStart(&entryBlock);
mlir::Value v1 = builder.create<mlir::ConstantFloatOp>(
builder.getUnknownLoc(),
llvm::APFloat(llvm::APFloat::IEEEsingle(), "1.0"),
builder.getF32Type());
// TODO: create v2 as EncryptedInteger and add it with v1
// mlir::Value v2 =
// builder.create<mlir::zamalang::HLFHE::EncryptedIntegerType>(
// builder.getUnknownLoc());
mlir::Value c1 = builder.create<mlir::zamalang::HLFHE::AddEintIntOp>(
builder.getUnknownLoc(), v1, v1);
builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
return funcOp;
}
int main(int argc, char **argv) {
mlir::MLIRContext context;
mlir::MLIRContext context;
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
mlir::OpBuilder builder(&context);
mlir::OpBuilder builder(&context);
mlir::ModuleOp module = mlir::ModuleOp::create(builder.getUnknownLoc());
mlir::ModuleOp module = mlir::ModuleOp::create(builder.getUnknownLoc());
module.dump();
module.push_back(buildFunction(builder));
return 0;
module.dump();
return 0;
}