mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler): Full lowering of hlfhe.zero, make dot_eint_int fully lowerable
This commit is contained in:
@@ -19,6 +19,20 @@ convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context,
|
||||
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
|
||||
}
|
||||
|
||||
mlir::Value createZeroGLWEOpFromHLFHE(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value> args{};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
auto eint =
|
||||
result.getType().cast<mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
|
||||
MidLFHE::ZeroGLWEOp op =
|
||||
rewriter.create<MidLFHE::ZeroGLWEOp>(loc, resTypes, args, attrs);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
|
||||
@@ -4,6 +4,12 @@
|
||||
include "zamalang/Dialect/HLFHE/IR/HLFHEOps.td"
|
||||
include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td"
|
||||
|
||||
def createZeroGLWEOp : NativeCodeCall<"mlir::zamalang::createZeroGLWEOpFromHLFHE($_builder, $_loc, $0)">;
|
||||
|
||||
def ZeroEintPattern : Pat<
|
||||
(ZeroEintOp:$result),
|
||||
(createZeroGLWEOp $result)>;
|
||||
|
||||
def createAddGLWEIntOp : NativeCodeCall<"mlir::zamalang::createGLWEOpFromHLFHE<mlir::zamalang::MidLFHE::AddGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddEintIntPattern : Pat<
|
||||
|
||||
@@ -75,6 +75,19 @@ CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context,
|
||||
assert(false && "expect glwe or lwe");
|
||||
}
|
||||
|
||||
mlir::Value createZeroLWEOpFromMidLFHE(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value> args{};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
auto glwe = result.getType().cast<GLWECipherTextType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeToLWE(rewriter.getContext(), glwe)};
|
||||
LowLFHE::ZeroLWEOp op =
|
||||
rewriter.create<LowLFHE::ZeroLWEOp>(loc, resTypes, args, attrs);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
|
||||
@@ -5,6 +5,12 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td"
|
||||
include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td"
|
||||
|
||||
def createZeroLWEOp : NativeCodeCall<"mlir::zamalang::createZeroLWEOpFromMidLFHE($_builder, $_loc, $0)">;
|
||||
|
||||
def ZeroGLWEPattern : Pat<
|
||||
(ZeroGLWEOp:$result),
|
||||
(createZeroLWEOp $result)>;
|
||||
|
||||
def createAddLWEOp : NativeCodeCall<"mlir::zamalang::createLowLFHEOpFromMidLFHE<mlir::zamalang::LowLFHE::AddLweCiphertextsOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddGLWEPattern : Pat<
|
||||
|
||||
@@ -19,7 +19,7 @@ class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<HLFHE_Dialect, mnemonic, traits>;
|
||||
|
||||
// Generates an encrypted zero constant
|
||||
def ZeroOp : HLFHE_Op<"zero"> {
|
||||
def ZeroEintOp : HLFHE_Op<"zero"> {
|
||||
let arguments = (ins);
|
||||
let results = (outs EncryptedIntegerType:$out);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,11 @@ include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.td"
|
||||
class LowLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<LowLFHE_Dialect, mnemonic, traits>;
|
||||
|
||||
def ZeroLWEOp : LowLFHE_Op<"zero"> {
|
||||
let arguments = (ins);
|
||||
let results = (outs LweCiphertextType:$out);
|
||||
}
|
||||
|
||||
def AddLweCiphertextsOp : LowLFHE_Op<"add_lwe_ciphertexts"> {
|
||||
let arguments = (ins LweCiphertextType:$lhs, LweCiphertextType:$rhs);
|
||||
let results = (outs LweCiphertextType:$result);
|
||||
|
||||
@@ -18,6 +18,11 @@ include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td"
|
||||
class MidLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<MidLFHE_Dialect, mnemonic, traits>;
|
||||
|
||||
def ZeroGLWEOp : MidLFHE_Op<"zero"> {
|
||||
let arguments = (ins);
|
||||
let results = (outs GLWECipherTextType:$out);
|
||||
}
|
||||
|
||||
def AddGLWEIntOp : MidLFHE_Op<"add_glwe_int"> {
|
||||
let arguments = (ins GLWECipherTextType:$a, AnyInteger:$b);
|
||||
let results = (outs GLWECipherTextType);
|
||||
|
||||
@@ -62,7 +62,7 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern {
|
||||
::llvm::dyn_cast_or_null<::mlir::zamalang::HLFHE::Dot>(op0);
|
||||
|
||||
// Zero value to initialize accumulator
|
||||
mlir::Value zeroCst = rewriter.create<mlir::zamalang::HLFHE::ZeroOp>(
|
||||
mlir::Value zeroCst = rewriter.create<mlir::zamalang::HLFHE::ZeroEintOp>(
|
||||
dotOp.getLoc(),
|
||||
dotOp.lhs().getType().cast<mlir::ShapedType>().getElementType());
|
||||
|
||||
|
||||
@@ -65,8 +65,7 @@ mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
/// ```
|
||||
/// to
|
||||
/// ```
|
||||
/// err = memref.alloc() : memref<index>
|
||||
/// out = _allocate_(err);
|
||||
/// err = constant 0 : i64
|
||||
/// call_op(err, out, arg0, arg1);
|
||||
/// ```
|
||||
template <typename Op>
|
||||
@@ -136,6 +135,46 @@ private:
|
||||
std::string allocName;
|
||||
};
|
||||
|
||||
struct LowLFHEZeroOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp> {
|
||||
LowLFHEZeroOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::LowLFHE::ZeroLWEOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto allocName = "allocate_lwe_ciphertext_u64";
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {errType, rewriter.getIndexType()},
|
||||
{op->getResultTypes().front()});
|
||||
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Replace the operation with a call to the `funcName`
|
||||
{
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
|
||||
// Create the err value
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// Add the call to the allocation
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize()));
|
||||
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
|
||||
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, allocName, op.getType(), allocOperands);
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowLFHEEncodeIntOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::EncodeIntOp> {
|
||||
LowLFHEEncodeIntOpPattern(mlir::MLIRContext *context,
|
||||
@@ -197,6 +236,7 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<LowLFHEEncodeIntOpPattern>(patterns.getContext());
|
||||
patterns.add<LowLFHEIntToCleartextOpPattern>(patterns.getContext());
|
||||
patterns.add<LowLFHEZeroOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -162,6 +162,8 @@ void populateWithMidLFHEOpTypeConversionPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter,
|
||||
mlir::zamalang::V0Parameter &v0Parameter) {
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::ZeroGLWEOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::AddGLWEIntOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
|
||||
@@ -381,4 +381,33 @@ func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.ei
|
||||
uint64_t res;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, 76);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
|
||||
%arg1: tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
{
|
||||
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set arg0, arg1, acc
|
||||
const size_t in_size = 4;
|
||||
uint8_t arg0[in_size] = {0, 1, 2, 3};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
|
||||
uint8_t arg1[in_size] = {0, 1, 2, 3};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, 14);
|
||||
}
|
||||
Reference in New Issue
Block a user