From 4f2b44c9d847bd7ef0c3b5ffef9d8e1878e15d2c Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 13 Jul 2023 10:43:00 +0100 Subject: [PATCH] feat(compiler): support compilation of CRT in simulation --- .../concretelang/Conversion/Utils/Utils.h | 27 +++ .../include/concretelang/Runtime/simulation.h | 44 ++++- .../compiler/lib/Conversion/CMakeLists.txt | 1 + .../ConcreteToCAPI/ConcreteToCAPI.cpp | 48 ++---- .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 162 +++++++++++++++++- .../compiler/lib/Conversion/Utils/Utils.cpp | 60 +++++++ .../compiler/lib/Runtime/simulation.cpp | 54 ++++++ 7 files changed, 352 insertions(+), 44 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h create mode 100644 compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h new file mode 100644 index 000000000..9494efe09 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/Utils/Utils.h @@ -0,0 +1,27 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONVERSION_UTILS_H_ +#define CONCRETELANG_CONVERSION_UTILS_H_ + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace concretelang { + +mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, + size_t rank); + +// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` +mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value); + +mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter, + mlir::Location loc, + mlir::ArrayAttr arrAttr); + +} // namespace concretelang +} // namespace mlir +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index cc9544eb0..a19298add 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -57,13 +57,34 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, uint32_t level, uint32_t base_log, uint32_t glwe_dim); +/// simulate a WoP PBS +void sim_wop_pbs_crt( + // Output 1D memref + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, + // Input 1D memref + uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset, + uint64_t in_size, uint64_t in_stride, + // clear text lut 2D memref + uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned, + uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1, + uint64_t lut_ct_stride0, uint64_t lut_ct_stride1, + // CRT decomposition 1D memref + uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned, + uint64_t crt_decomp_offset, uint64_t crt_decomp_size, + uint64_t crt_decomp_stride, + // Additional crypto parameters + uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log, + uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count, + uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log, + uint32_t polynomial_size); + void sim_encode_expand_lut_for_boostrap( uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset, uint64_t in_size, uint64_t in_stride, uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint32_t poly_size, uint32_t output_bits, bool is_signed); -} void sim_encode_plaintext_with_crt(uint64_t *output_allocated, uint64_t *output_aligned, @@ -74,4 +95,25 @@ void sim_encode_plaintext_with_crt(uint64_t *output_allocated, uint64_t mods_size, uint64_t mods_stride, uint64_t mods_product); +void sim_encode_lut_for_crt_woppbs( + // Output encoded/expanded lut + uint64_t *output_lut_allocated, uint64_t *output_lut_aligned, + uint64_t output_lut_offset, uint64_t output_lut_size0, + uint64_t output_lut_size1, uint64_t output_lut_stride0, + uint64_t output_lut_stride1, + // Input lut + uint64_t *input_lut_allocated, uint64_t *input_lut_aligned, + uint64_t input_lut_offset, uint64_t input_lut_size, + uint64_t input_lut_stride, + // Crt coprimes + uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned, + uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size, + uint64_t crt_decomposition_stride, + // Crt number of bits + uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned, + uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride, + // Crypto parameters + uint32_t modulus_product, bool is_signed); +} + #endif diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Conversion/CMakeLists.txt index 5ca515084..45db37741 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Conversion/CMakeLists.txt @@ -18,6 +18,7 @@ add_mlir_library( Tools.cpp Utils/Dialects/SCF.cpp Utils/Dialects/Tensor.cpp + Utils/Utils.cpp LINK_LIBS PUBLIC MLIRIR) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp index 5d02485f2..34bf93d28 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp @@ -8,6 +8,7 @@ #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Tools.h" +#include "concretelang/Conversion/Utils/Utils.h" #include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" @@ -66,38 +67,13 @@ char memref_encode_expand_lut_for_bootstrap[] = char memref_encode_lut_for_crt_woppbs[] = "memref_encode_lut_for_crt_woppbs"; char memref_trace[] = "memref_trace"; -mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, - size_t rank) { - std::vector shape(rank, mlir::ShapedType::kDynamic); - mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0); - for (size_t i = 0; i < rank; i++) { - expr = expr + - (rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1)); - } - return mlir::MemRefType::get( - shape, rewriter.getI64Type(), - mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext())); -} - -// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` -mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) { - mlir::Type valueType = value.getType(); - - if (auto memrefTy = valueType.dyn_cast_or_null()) { - return rewriter.create( - value.getLoc(), - getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), - value); - } else { - return value; - } -} - mlir::LogicalResult insertForwardDeclarationOfTheCAPI( mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) { - auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); - auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2); + auto memref1DType = + mlir::concretelang::getDynamicMemrefWithUnknownOffset(rewriter, 1); + auto memref2DType = + mlir::concretelang::getDynamicMemrefWithUnknownOffset(rewriter, 2); auto futureType = mlir::concretelang::RT::FutureType::get(rewriter.getIndexType()); auto contextType = @@ -282,7 +258,8 @@ struct ConcreteToCAPICallPattern : public mlir::OpRewritePattern { if (!type.isa()) { operands.push_back(operand.get()); } else { - operands.push_back(getCastedMemRef(rewriter, operand.get())); + operands.push_back( + mlir::concretelang::getCastedMemRef(rewriter, operand.get())); } } @@ -372,7 +349,7 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op, auto globalRef = rewriter.create( op.getLoc(), (*globalMemref).getType(), (*globalMemref).getName()); - operands.push_back(getCastedMemRef(rewriter, globalRef)); + operands.push_back(mlir::concretelang::getCastedMemRef(rewriter, globalRef)); // lwe_small_size operands.push_back(rewriter.create( @@ -441,7 +418,8 @@ void encodePlaintextWithCrtAddOperands( auto modsGlobalRef = rewriter.create( op.getLoc(), (*modsGlobalMemref).getType(), (*modsGlobalMemref).getName()); - operands.push_back(getCastedMemRef(rewriter, modsGlobalRef)); + operands.push_back( + mlir::concretelang::getCastedMemRef(rewriter, modsGlobalRef)); // mods_prod operands.push_back(rewriter.create( @@ -484,7 +462,8 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op, auto crtDecompositionGlobalRef = rewriter.create( op.getLoc(), (*crtDecompositionGlobalMemref).getType(), (*crtDecompositionGlobalMemref).getName()); - operands.push_back(getCastedMemRef(rewriter, crtDecompositionGlobalRef)); + operands.push_back( + mlir::concretelang::getCastedMemRef(rewriter, crtDecompositionGlobalRef)); // crt_bits mlir::Type crtBitsType = mlir::RankedTensorType::get( @@ -503,7 +482,8 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op, auto crtBitsGlobalRef = rewriter.create( op.getLoc(), (*crtBitsGlobalMemref).getType(), (*crtBitsGlobalMemref).getName()); - operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef)); + operands.push_back( + mlir::concretelang::getCastedMemRef(rewriter, crtBitsGlobalRef)); // modulus_product operands.push_back(rewriter.create( op.getLoc(), op.getModulusProductAttr())); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index bddf222c1..dc98cec3f 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -14,6 +14,7 @@ #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" +#include "concretelang/Conversion/Utils/Utils.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Support/Constants.h" @@ -177,6 +178,62 @@ struct EncodeExpandLutForBootstrapOpPattern } }; +struct EncodeLutForCrtWopPBSOpPattern + : public mlir::OpConversionPattern { + + EncodeLutForCrtWopPBSOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(TFHE::EncodeLutForCrtWopPBSOp encodeOp, + TFHE::EncodeLutForCrtWopPBSOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + const std::string funcName = "sim_encode_lut_for_crt_woppbs"; + + mlir::Value modulusProductCst = rewriter.create( + encodeOp.getLoc(), encodeOp.getModulusProduct(), 32); + mlir::Value isSignedCst = rewriter.create( + encodeOp.getLoc(), encodeOp.getIsSigned(), 1); + + mlir::Value outputBuffer = + rewriter.create( + encodeOp.getLoc(), + encodeOp.getResult().getType().cast(), + mlir::ValueRange{}); + + auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr( + rewriter, encodeOp.getLoc(), encodeOp.getCrtDecompositionAttr()); + auto crtBitsValue = mlir::concretelang::globalMemrefFromArrayAttr( + rewriter, encodeOp.getLoc(), encodeOp.getCrtBitsAttr()); + + if (insertForwardDeclaration( + encodeOp, rewriter, funcName, + rewriter.getFunctionType( + {encodeOp.getResult().getType(), + encodeOp.getInputLookupTable().getType(), + crtDecompValue.getType(), crtBitsValue.getType(), + rewriter.getIntegerType(32), rewriter.getIntegerType(1)}, + {})) + .failed()) { + return mlir::failure(); + } + + rewriter.create( + encodeOp.getLoc(), funcName, mlir::TypeRange{}, + mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(), + crtDecompValue, crtBitsValue, modulusProductCst, + isSignedCst})); + + rewriter.replaceOp(encodeOp, outputBuffer); + + return mlir::success(); + } +}; + struct EncodePlaintextWithCrtOpPattern : public mlir::OpConversionPattern { @@ -202,20 +259,23 @@ struct EncodePlaintextWithCrtOpPattern epOp.getResult().getType().cast(), mlir::ValueRange{}); - // TODO: add mods + auto ModsValue = mlir::concretelang::globalMemrefFromArrayAttr( + rewriter, epOp.getLoc(), epOp.getModsAttr()); + if (insertForwardDeclaration( epOp, rewriter, funcName, - rewriter.getFunctionType({epOp.getResult().getType(), - epOp.getInput().getType() /*, mods here*/, - rewriter.getI64Type()}, - {})) + rewriter.getFunctionType( + {epOp.getResult().getType(), epOp.getInput().getType(), + ModsValue.getType(), rewriter.getI64Type()}, + {})) .failed()) { return mlir::failure(); } rewriter.create( epOp.getLoc(), funcName, mlir::TypeRange{}, - mlir::ValueRange({outputBuffer, adaptor.getInput(), modsProductCst})); + mlir::ValueRange( + {outputBuffer, adaptor.getInput(), ModsValue, modsProductCst})); rewriter.replaceOp(epOp, outputBuffer); @@ -223,6 +283,88 @@ struct EncodePlaintextWithCrtOpPattern } }; +struct WopPBSGLWEOpPattern + : public mlir::OpConversionPattern { + + WopPBSGLWEOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(TFHE::WopPBSGLWEOp wopPbs, + TFHE::WopPBSGLWEOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + const std::string funcName = "sim_wop_pbs_crt"; + + auto resultType = wopPbs.getType().cast(); + auto inputType = + wopPbs.getCiphertexts().getType().cast(); + + mlir::Value outputBuffer = + rewriter.create( + wopPbs.getLoc(), + this->getTypeConverter() + ->convertType(resultType) + .cast(), + mlir::ValueRange{}); + + auto lweDimCst = rewriter.create( + wopPbs.getLoc(), adaptor.getPksk().getInputLweDim(), 32); + auto cbsLevelCountCst = rewriter.create( + wopPbs.getLoc(), adaptor.getCbsLevels(), 32); + auto cbsBaseLogCst = rewriter.create( + wopPbs.getLoc(), adaptor.getCbsBaseLog(), 32); + auto kskLevelCountCst = rewriter.create( + wopPbs.getLoc(), adaptor.getKsk().getLevels(), 32); + auto kskBaseLogCst = rewriter.create( + wopPbs.getLoc(), adaptor.getKsk().getBaseLog(), 32); + auto bskLevelCountCst = rewriter.create( + wopPbs.getLoc(), adaptor.getBsk().getLevels(), 32); + auto bskBaseLogCst = rewriter.create( + wopPbs.getLoc(), adaptor.getBsk().getBaseLog(), 32); + auto fpkskLevelCountCst = rewriter.create( + wopPbs.getLoc(), adaptor.getPksk().getLevels(), 32); + auto fpkskBaseLogCst = rewriter.create( + wopPbs.getLoc(), adaptor.getPksk().getBaseLog(), 32); + auto polySizeCst = rewriter.create( + wopPbs.getLoc(), adaptor.getPksk().getOutputPolySize(), 32); + + auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr( + rewriter, wopPbs.getLoc(), wopPbs.getCrtDecompositionAttr()); + + if (insertForwardDeclaration( + wopPbs, rewriter, funcName, + rewriter.getFunctionType( + {this->getTypeConverter()->convertType(resultType), + this->getTypeConverter()->convertType(inputType), + wopPbs.getLookupTable().getType(), crtDecompValue.getType(), + rewriter.getIntegerType(32), rewriter.getIntegerType(32), + rewriter.getIntegerType(32), rewriter.getIntegerType(32), + rewriter.getIntegerType(32), rewriter.getIntegerType(32), + rewriter.getIntegerType(32), rewriter.getIntegerType(32), + rewriter.getIntegerType(32), rewriter.getIntegerType(32)}, + {})) + .failed()) { + return mlir::failure(); + } + + rewriter.create( + wopPbs.getLoc(), funcName, mlir::TypeRange{}, + mlir::ValueRange({outputBuffer, adaptor.getCiphertexts(), + adaptor.getLookupTable(), crtDecompValue, lweDimCst, + cbsLevelCountCst, cbsBaseLogCst, kskLevelCountCst, + kskBaseLogCst, bskLevelCountCst, bskBaseLogCst, + fpkskLevelCountCst, fpkskBaseLogCst, polySizeCst})); + + rewriter.replaceOp(wopPbs, outputBuffer); + + return mlir::success(); + } +}; + struct BootstrapGLWEOpPattern : public mlir::OpConversionPattern { @@ -399,8 +541,8 @@ void SimulateTFHEPass::runOnOperation() { SimulateTFHETypeConverter converter; target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); @@ -465,7 +607,9 @@ void SimulateTFHEPass::runOnOperation() { }); patterns.insert(&getContext(), converter); patterns.insert(&getContext()); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp new file mode 100644 index 000000000..141bde198 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Utils.cpp @@ -0,0 +1,60 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Conversion/Utils/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" + +namespace mlir { +namespace concretelang { + +mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, + size_t rank) { + std::vector shape(rank, mlir::ShapedType::kDynamic); + mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0); + for (size_t i = 0; i < rank; i++) { + expr = expr + + (rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1)); + } + return mlir::MemRefType::get( + shape, rewriter.getI64Type(), + mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext())); +} + +// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` +mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) { + mlir::Type valueType = value.getType(); + + if (auto memrefTy = valueType.dyn_cast_or_null()) { + return rewriter.create( + value.getLoc(), + getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), + value); + } else { + return value; + } +} + +mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter, + mlir::Location loc, + mlir::ArrayAttr arrAttr) { + mlir::Type type = + mlir::RankedTensorType::get({(int)arrAttr.size()}, rewriter.getI64Type()); + std::vector values; + for (auto a : arrAttr) { + values.push_back(a.cast().getValue().getZExtValue()); + } + auto denseAttr = rewriter.getI64TensorAttr(values); + auto cstOp = rewriter.create(loc, denseAttr, type); + auto globalMemref = mlir::bufferization::getGlobalFor(cstOp, 0); + rewriter.eraseOp(cstOp); + assert(!mlir::failed(globalMemref)); + auto globalRef = rewriter.create( + loc, (*globalMemref).getType(), (*globalMemref).getName()); + return mlir::concretelang::getCastedMemRef(rewriter, globalRef); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index 3813fc256..08cfbf6d9 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -93,6 +93,29 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated, return out + gaussian_noise(0, variance); } +void sim_wop_pbs_crt( + // Output 1D memref + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, + // Input 1D memref + uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset, + uint64_t in_size, uint64_t in_stride, + // clear text lut 2D memref + uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned, + uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1, + uint64_t lut_ct_stride0, uint64_t lut_ct_stride1, + // CRT decomposition 1D memref + uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned, + uint64_t crt_decomp_offset, uint64_t crt_decomp_size, + uint64_t crt_decomp_stride, + // Additional crypto parameters + uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log, + uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count, + uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log, + uint32_t polynomial_size) { + // TODO +} + uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } void sim_encode_expand_lut_for_boostrap( @@ -120,3 +143,34 @@ void sim_encode_plaintext_with_crt(uint64_t *output_allocated, output_stride, input, mods_allocated, mods_aligned, mods_offset, mods_size, mods_stride, mods_product); } + +void sim_encode_lut_for_crt_woppbs( + // Output encoded/expanded lut + uint64_t *output_lut_allocated, uint64_t *output_lut_aligned, + uint64_t output_lut_offset, uint64_t output_lut_size0, + uint64_t output_lut_size1, uint64_t output_lut_stride0, + uint64_t output_lut_stride1, + // Input lut + uint64_t *input_lut_allocated, uint64_t *input_lut_aligned, + uint64_t input_lut_offset, uint64_t input_lut_size, + uint64_t input_lut_stride, + // Crt coprimes + uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned, + uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size, + uint64_t crt_decomposition_stride, + // Crt number of bits + uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned, + uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride, + // Crypto parameters + uint32_t modulus_product, bool is_signed) { + return memref_encode_lut_for_crt_woppbs( + output_lut_allocated, output_lut_aligned, output_lut_offset, + output_lut_size0, output_lut_size1, output_lut_stride0, + output_lut_stride1, input_lut_allocated, input_lut_aligned, + input_lut_offset, input_lut_size, input_lut_stride, + crt_decomposition_allocated, crt_decomposition_aligned, + crt_decomposition_offset, crt_decomposition_size, + crt_decomposition_stride, crt_bits_allocated, crt_bits_aligned, + crt_bits_offset, crt_bits_size, crt_bits_stride, modulus_product, + is_signed); +}