mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): MidToLowLFHE lowering
This commit is contained in:
@@ -3,4 +3,5 @@ mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
|
||||
add_public_tablegen_target(MLIRConversionPassIncGen)
|
||||
|
||||
|
||||
add_subdirectory(HLFHEToMidLFHE)
|
||||
add_subdirectory(HLFHEToMidLFHE)
|
||||
add_subdirectory(MidLFHEToLowLFHE)
|
||||
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Patterns.td)
|
||||
mlir_tablegen(Patterns.h.inc -gen-rewriters -name MidLFHE)
|
||||
add_public_tablegen_target(MidLFHEToLowLFHEPatternsIncGen)
|
||||
14
compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Pass.h
Normal file
14
compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Pass.h
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
/// Create a pass to convert `MidLFHE` dialect to `LowLFHE` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertMidLFHEToLowLFHEPass();
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
116
compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h
Normal file
116
compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h
Normal file
@@ -0,0 +1,116 @@
|
||||
#ifndef ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS_H_
|
||||
#define ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
using LowLFHE::CleartextType;
|
||||
using LowLFHE::LweCiphertextType;
|
||||
using LowLFHE::PlaintextType;
|
||||
using MidLFHE::GLWECipherTextType;
|
||||
|
||||
LweCiphertextType convertTypeGLWEToLWE(mlir::MLIRContext *context,
|
||||
GLWECipherTextType &glwe) {
|
||||
return LweCiphertextType::get(context);
|
||||
}
|
||||
|
||||
PlaintextType convertIntToPlaintextType(mlir::MLIRContext *context,
|
||||
IntegerType &type) {
|
||||
return PlaintextType::get(context, type.getWidth());
|
||||
}
|
||||
|
||||
CleartextType convertIntToCleartextType(mlir::MLIRContext *context,
|
||||
IntegerType &type) {
|
||||
return CleartextType::get(context, type.getWidth());
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value, 2> args{arg0, arg1};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
auto glwe = result.getType().cast<GLWECipherTextType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeGLWEToLWE(rewriter.getContext(), glwe)};
|
||||
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
auto integer_type = arg1.getType().cast<IntegerType>();
|
||||
PlaintextType encoded_type =
|
||||
convertIntToPlaintextType(rewriter.getContext(), integer_type);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::EncodeIntOp>(loc, encoded_type, arg1)
|
||||
.plaintext();
|
||||
// convert result type
|
||||
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
|
||||
LweCiphertextType lwe_type =
|
||||
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
|
||||
// replace op using the encoded plaintext instead of int
|
||||
auto op =
|
||||
rewriter.create<mlir::zamalang::LowLFHE::AddPlaintextLweCiphertextOp>(
|
||||
loc, lwe_type, arg0, encoded);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result) {
|
||||
auto arg1_type = arg1.getType().cast<GLWECipherTextType>();
|
||||
auto negated_arg1 =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::NegateLweCiphertextOp>(
|
||||
loc, convertTypeGLWEToLWE(rewriter.getContext(), arg1_type), arg1)
|
||||
.result();
|
||||
return createAddPlainLweCiphertext(rewriter, loc, negated_arg1, arg0, result);
|
||||
}
|
||||
|
||||
mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
auto integer_type = arg1.getType().cast<IntegerType>();
|
||||
CleartextType encoded_type =
|
||||
convertIntToCleartextType(rewriter.getContext(), integer_type);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded = rewriter
|
||||
.create<mlir::zamalang::LowLFHE::IntToCleartextOp>(
|
||||
loc, encoded_type, arg1)
|
||||
.cleartext();
|
||||
// convert result type
|
||||
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
|
||||
LweCiphertextType lwe_type =
|
||||
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
|
||||
// replace op using the encoded plaintext instead of int
|
||||
auto op =
|
||||
rewriter.create<mlir::zamalang::LowLFHE::MulCleartextLweCiphertextOp>(
|
||||
loc, lwe_type, arg0, encoded);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
#include "zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h.inc"
|
||||
}
|
||||
|
||||
void populateWithGeneratedMidLFHEToLowLFHE(mlir::RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS
|
||||
#define ZAMALANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS
|
||||
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td"
|
||||
include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td"
|
||||
|
||||
def createAddLWEOp : NativeCodeCall<"mlir::zamalang::createLowLFHEOpFromMidLFHE<mlir::zamalang::LowLFHE::AddLweCiphertextsOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddGLWEPattern : Pat<
|
||||
(AddGLWEOp:$result $arg0, $arg1),
|
||||
(createAddLWEOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createAddPlainLweOp : NativeCodeCall<"mlir::zamalang::createAddPlainLweCiphertext($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddGLWEIntPattern : Pat<
|
||||
(AddGLWEIntOp:$result $arg0, $arg1),
|
||||
(createAddPlainLweOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createMulClearLweOp : NativeCodeCall<"mlir::zamalang::createMulClearLweCiphertext($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def MulGLWEIntPattern : Pat<
|
||||
(MulGLWEIntOp:$result $arg0, $arg1),
|
||||
(createMulClearLweOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createSubIntLweOp : NativeCodeCall<"mlir::zamalang::createSubIntLweCiphertext($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def SubIntGLWEPattern : Pat<
|
||||
(SubIntGLWEOp:$result $arg0, $arg1),
|
||||
(createSubIntLweOp $arg0, $arg1, $result)>;
|
||||
|
||||
#endif
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
|
||||
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
|
||||
@@ -17,6 +17,14 @@ def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the MidLFHE dialect to LowLFHE";
|
||||
let description = [{ Lowers operations from the MidLFHE dialect to LowLFHE }];
|
||||
let constructor = "mlir::zamalang::createConvertMidLFHEToLowLFHEPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(HLFHEToMidLFHE)
|
||||
add_subdirectory(MidLFHEToLowLFHE)
|
||||
add_subdirectory(HLFHETensorOpsToLinalg)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
|
||||
17
compiler/lib/Conversion/MidLFHEToLowLFHE/CMakeLists.txt
Normal file
17
compiler/lib/Conversion/MidLFHEToLowLFHE/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
add_mlir_dialect_library(MidLFHEToLowLFHE
|
||||
MidLFHEToLowLFHE.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/MidLFHE
|
||||
|
||||
DEPENDS
|
||||
MidLFHEDialect
|
||||
LowLFHEDialect
|
||||
MidLFHEToLowLFHEPatternsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
MLIRMath)
|
||||
|
||||
target_link_libraries(MidLFHEToLowLFHE PUBLIC MLIRIR)
|
||||
@@ -0,0 +1,99 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
|
||||
namespace {
|
||||
struct MidLFHEToLowLFHEPass
|
||||
: public MidLFHEToLowLFHEBase<MidLFHEToLowLFHEPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
using mlir::zamalang::LowLFHE::LweCiphertextType;
|
||||
using mlir::zamalang::MidLFHE::GLWECipherTextType;
|
||||
|
||||
/// MidLFHEToLowLFHETypeConverter is a TypeConverter that transform
|
||||
/// `MidLFHE.gwle<{_,_,_}{p}>` to LowLFHE.lwe_ciphertext
|
||||
class MidLFHEToLowLFHETypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
MidLFHEToLowLFHETypeConverter() {
|
||||
addConversion([&](GLWECipherTextType type) {
|
||||
return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type);
|
||||
});
|
||||
addConversion([&](mlir::MemRefType type) {
|
||||
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
|
||||
if (glwe == nullptr) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
mlir::Type r = mlir::MemRefType::get(
|
||||
type.getShape(),
|
||||
mlir::zamalang::convertTypeGLWEToLWE(glwe.getContext(), glwe),
|
||||
type.getAffineMaps(), type.getMemorySpace());
|
||||
return r;
|
||||
});
|
||||
// [workaround] need these converters to consider those types legal
|
||||
addConversion([&](mlir::IntegerType type) { return type; });
|
||||
addConversion(
|
||||
[&](mlir::zamalang::LowLFHE::LweCiphertextType type) { return type; });
|
||||
}
|
||||
};
|
||||
|
||||
void MidLFHEToLowLFHEPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
MidLFHEToLowLFHETypeConverter converter;
|
||||
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
|
||||
// Make sure that no ops from `MidLFHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
|
||||
[&](mlir::linalg::GenericOp op) {
|
||||
return (converter.isLegal(op.getOperandTypes()) &&
|
||||
converter.isLegal(op.getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([](mlir::FuncOp funcOp) {
|
||||
MidLFHEToLowLFHETypeConverter converter;
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
// Add all patterns required to lower all ops from `MidLFHE` to
|
||||
// `LowLFHE`
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
populateWithGeneratedMidLFHEToLowLFHE(patterns);
|
||||
patterns
|
||||
.add<LinalgGenericTypeConverterPattern<MidLFHEToLowLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertMidLFHEToLowLFHEPass() {
|
||||
return std::make_unique<MidLFHEToLowLFHEPass>();
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -37,6 +37,8 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
enablePass);
|
||||
|
||||
10
compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir
Normal file
10
compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir
Normal file
@@ -0,0 +1,10 @@
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext, %arg1: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext
|
||||
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.add_lwe_ciphertexts"(%arg0, %arg1) : (!LowLFHE.lwe_ciphertext, !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext
|
||||
|
||||
%0 = "MidLFHE.add_glwe"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{7}>, !MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,12,64}{7}>)
|
||||
return %0: !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
22
compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir
Normal file
22
compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir
Normal file
@@ -0,0 +1,22 @@
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext
|
||||
func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.encode_int"(%[[V1]]) : (i8) -> !LowLFHE.plaintext<8>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.add_plaintext_lwe_ciphertext"(%arg0, %[[V2]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.plaintext<8>) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext
|
||||
%0 = constant 1 : i8
|
||||
%1 = "MidLFHE.add_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,1,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_int(%arg0: !LowLFHE.lwe_ciphertext, %arg1: i5) -> !LowLFHE.lwe_ciphertext
|
||||
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.encode_int"(%arg1) : (i5) -> !LowLFHE.plaintext<5>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.add_plaintext_lwe_ciphertext"(%arg0, %[[V1]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.plaintext<5>) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext
|
||||
%1 = "MidLFHE.add_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
22
compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir
Normal file
22
compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir
Normal file
@@ -0,0 +1,22 @@
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext
|
||||
func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.int_to_cleartext"(%[[V1]]) : (i8) -> !LowLFHE.cleartext<8>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.mul_cleartext_lwe_ciphertext"(%arg0, %[[V2]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.cleartext<8>) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext
|
||||
%0 = constant 1 : i8
|
||||
%1 = "MidLFHE.mul_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,1,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @mul_glwe_int(%arg0: !LowLFHE.lwe_ciphertext, %arg1: i5) -> !LowLFHE.lwe_ciphertext
|
||||
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.int_to_cleartext"(%arg1) : (i5) -> !LowLFHE.cleartext<5>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.cleartext<5>) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext
|
||||
%1 = "MidLFHE.mul_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
23
compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir
Normal file
23
compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir
Normal file
@@ -0,0 +1,23 @@
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext
|
||||
func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i8
|
||||
// CHECK-NEXT: %[[NEG:.*]] = "LowLFHE.negate_lwe_ciphertext"(%arg0) : (!LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.encode_int"(%[[V1]]) : (i8) -> !LowLFHE.plaintext<8>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V2]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.plaintext<8>) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext
|
||||
%0 = constant 1 : i8
|
||||
%1 = "MidLFHE.sub_int_glwe"(%0, %arg0): (i8, !MidLFHE.glwe<{1024,1,64}{7}>) -> (!MidLFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sub_int_glwe(%arg0: !LowLFHE.lwe_ciphertext, %arg1: i5) -> !LowLFHE.lwe_ciphertext
|
||||
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[NEG:.*]] = "LowLFHE.negate_lwe_ciphertext"(%arg0) : (!LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.encode_int"(%arg1) : (i5) -> !LowLFHE.plaintext<5>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V1]]) : (!LowLFHE.lwe_ciphertext, !LowLFHE.plaintext<5>) -> !LowLFHE.lwe_ciphertext
|
||||
// CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext
|
||||
%1 = "MidLFHE.sub_int_glwe"(%arg1, %arg0): (i5, !MidLFHE.glwe<{1024,1,64}{4}>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
Reference in New Issue
Block a user