diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index bb462a7f1..db1a936fa 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -4,6 +4,7 @@ project(zamacompiler LANGUAGES CXX) set(CMAKE_CXX_STANDARD 14) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti") find_package(MLIR REQUIRED CONFIG) message(STATUS "Using MLIR cmake file from: ${MLIR_DIR}") diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h index bb4c40fa1..3a60ba09d 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h @@ -3,6 +3,7 @@ #include #include +#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #define GET_OP_CLASSES #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h.inc" diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index 1db0bef3e..095471d51 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -10,8 +10,78 @@ #define ZAMALANG_DIALECT_HLFHE_IR_HLFHE_OPS include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.td" +include "zamalang/Dialect/HLFHE/IR/HLFHETypes.td" class HLFHE_Op traits = []> : Op; +def AddEintIntOp : HLFHE_Op<"add_eint_int"> { + let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; +} + +def AddEintOp : HLFHE_Op<"add_eint"> { + let arguments = (ins EncryptedIntegerType:$a, EncryptedIntegerType:$b); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; +} + +def NegEintOp : HLFHE_Op<"neg_eint"> { + let arguments = (ins EncryptedIntegerType:$a); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a), [{ + build($_builder, $_state, a.getType(), a); + }]> + ]; +} + + +def MulEintIntOp : HLFHE_Op<"mul_eint_int"> { + let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; +} + +def MulEintOp : HLFHE_Op<"mul_eint"> { + let arguments = (ins EncryptedIntegerType:$a, EncryptedIntegerType:$b); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; +} + +def ApplyUnivariateOp : HLFHE_Op<"apply_univariate"> { + // TODO: express a functionLike? + let arguments = (ins EncryptedIntegerType:$a); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a), [{ + build($_builder, $_state, a.getType(), a); + }]> + ]; +} + + #endif diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.h b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.h new file mode 100644 index 000000000..51b52e1ff --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.h @@ -0,0 +1,12 @@ +#ifndef ZAMALANG_DIALECT_HLFHE_HLFHE_TYPES_H +#define ZAMALANG_DIALECT_HLFHE_HLFHE_TYPES_H + +#include "llvm/ADT/TypeSwitch.h" +#include +#include +#include + +#define GET_TYPEDEF_CLASSES +#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsTypes.h.inc" + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.td new file mode 100644 index 000000000..a1dff49d4 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.td @@ -0,0 +1,37 @@ +#ifndef ZAMALANG_DIALECT_HLFHE_IR_HLFHE_TYPES +#define ZAMALANG_DIALECT_HLFHE_IR_HLFHE_TYPES + +include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.td" + +class HLFHE_Type : TypeDef { } + +def EncryptedIntegerType : HLFHE_Type<"EncryptedInteger"> { + let mnemonic = "eint"; + + let summary = "An encrypted integer"; + + let description = [{ + An encrypted integer with clear precision of width. + }]; + + let parameters = (ins "unsigned":$width); + + // We define the printer inline. + let printer = [{ + $_printer << "eint<" << getImpl()->width << ">"; + }]; + + // The parser is defined here also. + let parser = [{ + if ($_parser.parseLess()) + return Type(); + int width; + if ($_parser.parseInteger(width)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + return get($_ctxt, width); + }]; +} + +#endif diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp index a5a8b998e..d02b15743 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp @@ -1,5 +1,9 @@ #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" + +#define GET_TYPEDEF_CLASSES +#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsTypes.cpp.inc" using namespace mlir::zamalang::HLFHE; @@ -8,4 +12,29 @@ void HLFHEDialect::initialize() { #define GET_OP_LIST #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.cpp.inc" >(); + + addTypes< + #define GET_TYPEDEF_LIST + #include "zamalang/Dialect/HLFHE/IR/HLFHEOpsTypes.cpp.inc" + >(); } + +::mlir::Type HLFHEDialect::parseType(::mlir::DialectAsmParser &parser) const +{ + if(parser.parseKeyword("eint").failed()) + return ::mlir::Type(); + + return EncryptedIntegerType::parse(this->getContext(), parser); +} + +void HLFHEDialect::printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &printer) const +{ + mlir::zamalang::HLFHE::EncryptedIntegerType eint = type.dyn_cast_or_null(); + if (eint != nullptr) { + eint.print(printer); + return; + } + // TODO - What should be done here? + printer << "unknwontype"; +} \ No newline at end of file diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index 93863a07d..ea8115355 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -1,4 +1,5 @@ #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #define GET_OP_CLASSES #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.cpp.inc" diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 46a94162c..66121817b 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -1,24 +1,54 @@ -#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" -#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" - -#include "llvm/Support/SourceMgr.h" - #include #include +#include "llvm/Support/SourceMgr.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" + +mlir::FuncOp buildFunction(mlir::OpBuilder &builder) { + mlir::FunctionType func_type = builder.getFunctionType({ mlir::zamalang::HLFHE::EncryptedIntegerType::get(builder.getContext(), 32) }, llvm::None); + + mlir::FuncOp funcOp = + mlir::FuncOp::create(builder.getUnknownLoc(), "hlfhe", func_type); + + mlir::FuncOp function(funcOp); + mlir::Block &entryBlock = *function.addEntryBlock(); + builder.setInsertionPointToStart(&entryBlock); + + mlir::Value v1 = builder.create( + builder.getUnknownLoc(), + llvm::APFloat(llvm::APFloat::IEEEsingle(), "1.0"), + builder.getF32Type()); + + // TODO: create v2 as EncryptedInteger and add it with v1 + + // mlir::Value v2 = + // builder.create( + // builder.getUnknownLoc()); + + mlir::Value c1 = builder.create( + builder.getUnknownLoc(), v1, v1); + + builder.create(builder.getUnknownLoc()); + + return funcOp; +} int main(int argc, char **argv) { - mlir::MLIRContext context; + mlir::MLIRContext context; - // Load our Dialect in this MLIR Context. - context.getOrLoadDialect(); - context.getOrLoadDialect(); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + context.getOrLoadDialect(); - mlir::OpBuilder builder(&context); + mlir::OpBuilder builder(&context); - mlir::ModuleOp module = mlir::ModuleOp::create(builder.getUnknownLoc()); + mlir::ModuleOp module = mlir::ModuleOp::create(builder.getUnknownLoc()); - module.dump(); + module.push_back(buildFunction(builder)); - return 0; + module.dump(); + + return 0; }