enhance(compiler): Full lowering of hlfhe.zero, make dot_eint_int fully lowerable

This commit is contained in:
Quentin Bourgerie
2021-08-25 17:52:22 +02:00
parent de7129fe8e
commit 1077c9167c
11 changed files with 124 additions and 4 deletions

View File

@@ -19,6 +19,20 @@ convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context,
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
}
mlir::Value createZeroGLWEOpFromHLFHE(mlir::PatternRewriter rewriter,
mlir::Location loc,
mlir::OpResult result) {
mlir::SmallVector<mlir::Value> args{};
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
auto eint =
result.getType().cast<mlir::zamalang::HLFHE::EncryptedIntegerType>();
mlir::SmallVector<mlir::Type, 1> resTypes{
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
MidLFHE::ZeroGLWEOp op =
rewriter.create<MidLFHE::ZeroGLWEOp>(loc, resTypes, args, attrs);
return op.getODSResults(0).front();
}
template <class Operator>
mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,

View File

@@ -4,6 +4,12 @@
include "zamalang/Dialect/HLFHE/IR/HLFHEOps.td"
include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td"
def createZeroGLWEOp : NativeCodeCall<"mlir::zamalang::createZeroGLWEOpFromHLFHE($_builder, $_loc, $0)">;
def ZeroEintPattern : Pat<
(ZeroEintOp:$result),
(createZeroGLWEOp $result)>;
def createAddGLWEIntOp : NativeCodeCall<"mlir::zamalang::createGLWEOpFromHLFHE<mlir::zamalang::MidLFHE::AddGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
def AddEintIntPattern : Pat<

View File

@@ -75,6 +75,19 @@ CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context,
assert(false && "expect glwe or lwe");
}
mlir::Value createZeroLWEOpFromMidLFHE(mlir::PatternRewriter rewriter,
mlir::Location loc,
mlir::OpResult result) {
mlir::SmallVector<mlir::Value> args{};
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
auto glwe = result.getType().cast<GLWECipherTextType>();
mlir::SmallVector<mlir::Type, 1> resTypes{
convertTypeToLWE(rewriter.getContext(), glwe)};
LowLFHE::ZeroLWEOp op =
rewriter.create<LowLFHE::ZeroLWEOp>(loc, resTypes, args, attrs);
return op.getODSResults(0).front();
}
template <class Operator>
mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,

View File

@@ -5,6 +5,12 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td"
include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td"
def createZeroLWEOp : NativeCodeCall<"mlir::zamalang::createZeroLWEOpFromMidLFHE($_builder, $_loc, $0)">;
def ZeroGLWEPattern : Pat<
(ZeroGLWEOp:$result),
(createZeroLWEOp $result)>;
def createAddLWEOp : NativeCodeCall<"mlir::zamalang::createLowLFHEOpFromMidLFHE<mlir::zamalang::LowLFHE::AddLweCiphertextsOp>($_builder, $_loc, $0, $1, $2)">;
def AddGLWEPattern : Pat<

View File

@@ -19,7 +19,7 @@ class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<HLFHE_Dialect, mnemonic, traits>;
// Generates an encrypted zero constant
def ZeroOp : HLFHE_Op<"zero"> {
def ZeroEintOp : HLFHE_Op<"zero"> {
let arguments = (ins);
let results = (outs EncryptedIntegerType:$out);
}

View File

@@ -10,6 +10,11 @@ include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.td"
class LowLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<LowLFHE_Dialect, mnemonic, traits>;
def ZeroLWEOp : LowLFHE_Op<"zero"> {
let arguments = (ins);
let results = (outs LweCiphertextType:$out);
}
def AddLweCiphertextsOp : LowLFHE_Op<"add_lwe_ciphertexts"> {
let arguments = (ins LweCiphertextType:$lhs, LweCiphertextType:$rhs);
let results = (outs LweCiphertextType:$result);

View File

@@ -18,6 +18,11 @@ include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td"
class MidLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<MidLFHE_Dialect, mnemonic, traits>;
def ZeroGLWEOp : MidLFHE_Op<"zero"> {
let arguments = (ins);
let results = (outs GLWECipherTextType:$out);
}
def AddGLWEIntOp : MidLFHE_Op<"add_glwe_int"> {
let arguments = (ins GLWECipherTextType:$a, AnyInteger:$b);
let results = (outs GLWECipherTextType);

View File

@@ -62,7 +62,7 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern {
::llvm::dyn_cast_or_null<::mlir::zamalang::HLFHE::Dot>(op0);
// Zero value to initialize accumulator
mlir::Value zeroCst = rewriter.create<mlir::zamalang::HLFHE::ZeroOp>(
mlir::Value zeroCst = rewriter.create<mlir::zamalang::HLFHE::ZeroEintOp>(
dotOp.getLoc(),
dotOp.lhs().getType().cast<mlir::ShapedType>().getElementType());

View File

@@ -65,8 +65,7 @@ mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
/// ```
/// to
/// ```
/// err = memref.alloc() : memref<index>
/// out = _allocate_(err);
/// err = constant 0 : i64
/// call_op(err, out, arg0, arg1);
/// ```
template <typename Op>
@@ -136,6 +135,46 @@ private:
std::string allocName;
};
struct LowLFHEZeroOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp> {
LowLFHEZeroOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::ZeroLWEOp op,
mlir::PatternRewriter &rewriter) const override {
auto allocName = "allocate_lwe_ciphertext_u64";
auto errType = mlir::IndexType::get(rewriter.getContext());
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {errType, rewriter.getIndexType()},
{op->getResultTypes().front()});
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
.failed()) {
return mlir::failure();
}
}
// Replace the operation with a call to the `funcName`
{
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
// Create the err value
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// Add the call to the allocation
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize()));
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, allocName, op.getType(), allocOperands);
}
return mlir::success();
};
};
struct LowLFHEEncodeIntOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::EncodeIntOp> {
LowLFHEEncodeIntOpPattern(mlir::MLIRContext *context,
@@ -197,6 +236,7 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEEncodeIntOpPattern>(patterns.getContext());
patterns.add<LowLFHEIntToCleartextOpPattern>(patterns.getContext());
patterns.add<LowLFHEZeroOpPattern>(patterns.getContext());
}
namespace {

View File

@@ -162,6 +162,8 @@ void populateWithMidLFHEOpTypeConversionPatterns(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter,
mlir::zamalang::V0Parameter &v0Parameter) {
populateWithMidLFHEOpTypeConversionPattern<
mlir::zamalang::MidLFHE::ZeroGLWEOp>(patterns, target, typeConverter);
populateWithMidLFHEOpTypeConversionPattern<
mlir::zamalang::MidLFHE::AddGLWEIntOp>(patterns, target, typeConverter);
populateWithMidLFHEOpTypeConversionPattern<

View File

@@ -381,4 +381,33 @@ func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.ei
uint64_t res;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, 76);
}
TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
%arg1: tensor<4xi8>) -> !HLFHE.eint<7>
{
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
return %ret : !HLFHE.eint<7>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set arg0, arg1, acc
const size_t in_size = 4;
uint8_t arg0[in_size] = {0, 1, 2, 3};
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
uint8_t arg1[in_size] = {0, 1, 2, 3};
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, 14);
}