diff --git a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h index 1005af379..87f0efe8b 100644 --- a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h @@ -19,6 +19,20 @@ convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context, return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()); } +mlir::Value createZeroGLWEOpFromHLFHE(mlir::PatternRewriter rewriter, + mlir::Location loc, + mlir::OpResult result) { + mlir::SmallVector args{}; + mlir::SmallVector attrs; + auto eint = + result.getType().cast(); + mlir::SmallVector resTypes{ + convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)}; + MidLFHE::ZeroGLWEOp op = + rewriter.create(loc, resTypes, args, attrs); + return op.getODSResults(0).front(); +} + template mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, diff --git a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.td b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.td index cf5594681..ab5ce47bc 100644 --- a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.td +++ b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.td @@ -4,6 +4,12 @@ include "zamalang/Dialect/HLFHE/IR/HLFHEOps.td" include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td" +def createZeroGLWEOp : NativeCodeCall<"mlir::zamalang::createZeroGLWEOpFromHLFHE($_builder, $_loc, $0)">; + +def ZeroEintPattern : Pat< + (ZeroEintOp:$result), + (createZeroGLWEOp $result)>; + def createAddGLWEIntOp : NativeCodeCall<"mlir::zamalang::createGLWEOpFromHLFHE($_builder, $_loc, $0, $1, $2)">; def AddEintIntPattern : Pat< diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h index 48d5b65db..a27c5df49 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -75,6 +75,19 @@ CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context, assert(false && "expect glwe or lwe"); } +mlir::Value createZeroLWEOpFromMidLFHE(mlir::PatternRewriter rewriter, + mlir::Location loc, + mlir::OpResult result) { + mlir::SmallVector args{}; + mlir::SmallVector attrs; + auto glwe = result.getType().cast(); + mlir::SmallVector resTypes{ + convertTypeToLWE(rewriter.getContext(), glwe)}; + LowLFHE::ZeroLWEOp op = + rewriter.create(loc, resTypes, args, attrs); + return op.getODSResults(0).front(); +} + template mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td index 6157dbce1..fafc648a2 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td @@ -5,6 +5,12 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td" include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td" +def createZeroLWEOp : NativeCodeCall<"mlir::zamalang::createZeroLWEOpFromMidLFHE($_builder, $_loc, $0)">; + +def ZeroGLWEPattern : Pat< + (ZeroGLWEOp:$result), + (createZeroLWEOp $result)>; + def createAddLWEOp : NativeCodeCall<"mlir::zamalang::createLowLFHEOpFromMidLFHE($_builder, $_loc, $0, $1, $2)">; def AddGLWEPattern : Pat< diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index 60e298bbb..51a67d080 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -19,7 +19,7 @@ class HLFHE_Op traits = []> : Op; // Generates an encrypted zero constant -def ZeroOp : HLFHE_Op<"zero"> { +def ZeroEintOp : HLFHE_Op<"zero"> { let arguments = (ins); let results = (outs EncryptedIntegerType:$out); } diff --git a/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td b/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td index 380e45924..30685f7fe 100644 --- a/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td +++ b/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td @@ -10,6 +10,11 @@ include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.td" class LowLFHE_Op traits = []> : Op; +def ZeroLWEOp : LowLFHE_Op<"zero"> { + let arguments = (ins); + let results = (outs LweCiphertextType:$out); +} + def AddLweCiphertextsOp : LowLFHE_Op<"add_lwe_ciphertexts"> { let arguments = (ins LweCiphertextType:$lhs, LweCiphertextType:$rhs); let results = (outs LweCiphertextType:$result); diff --git a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td index d42fa9ec6..faccf0b4a 100644 --- a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td +++ b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td @@ -18,6 +18,11 @@ include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td" class MidLFHE_Op traits = []> : Op; +def ZeroGLWEOp : MidLFHE_Op<"zero"> { + let arguments = (ins); + let results = (outs GLWECipherTextType:$out); +} + def AddGLWEIntOp : MidLFHE_Op<"add_glwe_int"> { let arguments = (ins GLWECipherTextType:$a, AnyInteger:$b); let results = (outs GLWECipherTextType); diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index fb333cc03..46013f17e 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -62,7 +62,7 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern { ::llvm::dyn_cast_or_null<::mlir::zamalang::HLFHE::Dot>(op0); // Zero value to initialize accumulator - mlir::Value zeroCst = rewriter.create( + mlir::Value zeroCst = rewriter.create( dotOp.getLoc(), dotOp.lhs().getType().cast().getElementType()); diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index b7887fdf1..9a0c6efe7 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -65,8 +65,7 @@ mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op, /// ``` /// to /// ``` -/// err = memref.alloc() : memref -/// out = _allocate_(err); +/// err = constant 0 : i64 /// call_op(err, out, arg0, arg1); /// ``` template @@ -136,6 +135,46 @@ private: std::string allocName; }; +struct LowLFHEZeroOpPattern + : public mlir::OpRewritePattern { + LowLFHEZeroOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, + benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::zamalang::LowLFHE::ZeroLWEOp op, + mlir::PatternRewriter &rewriter) const override { + auto allocName = "allocate_lwe_ciphertext_u64"; + auto errType = mlir::IndexType::get(rewriter.getContext()); + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), {errType, rewriter.getIndexType()}, + {op->getResultTypes().front()}); + if (insertForwardDeclaration(op, rewriter, allocName, funcType) + .failed()) { + return mlir::failure(); + } + } + // Replace the operation with a call to the `funcName` + { + mlir::Type resultType = op->getResultTypes().front(); + auto lweResultType = + resultType.cast(); + // Create the err value + auto errOp = rewriter.create(op.getLoc(), + rewriter.getIndexAttr(0)); + // Add the call to the allocation + auto lweSizeOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize())); + mlir::SmallVector allocOperands{errOp, lweSizeOp}; + auto alloc = rewriter.replaceOpWithNewOp( + op, allocName, op.getType(), allocOperands); + } + return mlir::success(); + }; +}; + struct LowLFHEEncodeIntOpPattern : public mlir::OpRewritePattern { LowLFHEEncodeIntOpPattern(mlir::MLIRContext *context, @@ -197,6 +236,7 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) { "allocate_lwe_ciphertext_u64"); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } namespace { diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index 28ae9d137..cad6e8c9f 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -162,6 +162,8 @@ void populateWithMidLFHEOpTypeConversionPatterns( mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, mlir::TypeConverter &typeConverter, mlir::zamalang::V0Parameter &v0Parameter) { + populateWithMidLFHEOpTypeConversionPattern< + mlir::zamalang::MidLFHE::ZeroGLWEOp>(patterns, target, typeConverter); populateWithMidLFHEOpTypeConversionPattern< mlir::zamalang::MidLFHE::AddGLWEIntOp>(patterns, target, typeConverter); populateWithMidLFHEOpTypeConversionPattern< diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index 1334ceab3..f46936daa 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -381,4 +381,33 @@ func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.ei uint64_t res; ASSERT_LLVM_ERROR(argument->getResult(0, res)); ASSERT_EQ(res, 76); +} + +TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%arg0: tensor<4x!HLFHE.eint<7>>, + %arg1: tensor<4xi8>) -> !HLFHE.eint<7> +{ + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7> + return %ret : !HLFHE.eint<7> +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set arg0, arg1, acc + const size_t in_size = 4; + uint8_t arg0[in_size] = {0, 1, 2, 3}; + ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size)); + uint8_t arg1[in_size] = {0, 1, 2, 3}; + ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, 14); } \ No newline at end of file