feat(compiler): support compilation of CRT in simulation

This commit is contained in:
youben11
2023-07-13 10:43:00 +01:00
committed by Ayoub Benaissa
parent a70d2f0b83
commit 4f2b44c9d8
7 changed files with 352 additions and 44 deletions

View File

@@ -0,0 +1,27 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_UTILS_H_
#define CONCRETELANG_CONVERSION_UTILS_H_
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace concretelang {
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank);
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value);
mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter,
mlir::Location loc,
mlir::ArrayAttr arrAttr);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -57,13 +57,34 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
uint32_t level, uint32_t base_log,
uint32_t glwe_dim);
/// simulate a WoP PBS
void sim_wop_pbs_crt(
// Output 1D memref
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride,
// Input 1D memref
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,
uint64_t in_size, uint64_t in_stride,
// clear text lut 2D memref
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
// CRT decomposition 1D memref
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
uint64_t crt_decomp_stride,
// Additional crypto parameters
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size);
void sim_encode_expand_lut_for_boostrap(
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,
uint64_t in_size, uint64_t in_stride, uint64_t *out_allocated,
uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint32_t poly_size, uint32_t output_bits,
bool is_signed);
}
void sim_encode_plaintext_with_crt(uint64_t *output_allocated,
uint64_t *output_aligned,
@@ -74,4 +95,25 @@ void sim_encode_plaintext_with_crt(uint64_t *output_allocated,
uint64_t mods_size, uint64_t mods_stride,
uint64_t mods_product);
void sim_encode_lut_for_crt_woppbs(
// Output encoded/expanded lut
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size0,
uint64_t output_lut_size1, uint64_t output_lut_stride0,
uint64_t output_lut_stride1,
// Input lut
uint64_t *input_lut_allocated, uint64_t *input_lut_aligned,
uint64_t input_lut_offset, uint64_t input_lut_size,
uint64_t input_lut_stride,
// Crt coprimes
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
uint64_t crt_decomposition_stride,
// Crt number of bits
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
// Crypto parameters
uint32_t modulus_product, bool is_signed);
}
#endif

View File

@@ -18,6 +18,7 @@ add_mlir_library(
Tools.cpp
Utils/Dialects/SCF.cpp
Utils/Dialects/Tensor.cpp
Utils/Utils.cpp
LINK_LIBS
PUBLIC
MLIRIR)

View File

@@ -8,6 +8,7 @@
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Conversion/Utils/Utils.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
@@ -66,38 +67,13 @@ char memref_encode_expand_lut_for_bootstrap[] =
char memref_encode_lut_for_crt_woppbs[] = "memref_encode_lut_for_crt_woppbs";
char memref_trace[] = "memref_trace";
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
}
return mlir::MemRefType::get(
shape, rewriter.getI64Type(),
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
}
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
mlir::Type valueType = value.getType();
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
return rewriter.create<mlir::memref::CastOp>(
value.getLoc(),
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
value);
} else {
return value;
}
}
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
auto memref1DType =
mlir::concretelang::getDynamicMemrefWithUnknownOffset(rewriter, 1);
auto memref2DType =
mlir::concretelang::getDynamicMemrefWithUnknownOffset(rewriter, 2);
auto futureType =
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType());
auto contextType =
@@ -282,7 +258,8 @@ struct ConcreteToCAPICallPattern : public mlir::OpRewritePattern<ConcreteOp> {
if (!type.isa<mlir::MemRefType>()) {
operands.push_back(operand.get());
} else {
operands.push_back(getCastedMemRef(rewriter, operand.get()));
operands.push_back(
mlir::concretelang::getCastedMemRef(rewriter, operand.get()));
}
}
@@ -372,7 +349,7 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
auto globalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*globalMemref).getType(), (*globalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, globalRef));
operands.push_back(mlir::concretelang::getCastedMemRef(rewriter, globalRef));
// lwe_small_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
@@ -441,7 +418,8 @@ void encodePlaintextWithCrtAddOperands(
auto modsGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*modsGlobalMemref).getType(),
(*modsGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, modsGlobalRef));
operands.push_back(
mlir::concretelang::getCastedMemRef(rewriter, modsGlobalRef));
// mods_prod
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
@@ -484,7 +462,8 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
auto crtDecompositionGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*crtDecompositionGlobalMemref).getType(),
(*crtDecompositionGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, crtDecompositionGlobalRef));
operands.push_back(
mlir::concretelang::getCastedMemRef(rewriter, crtDecompositionGlobalRef));
// crt_bits
mlir::Type crtBitsType = mlir::RankedTensorType::get(
@@ -503,7 +482,8 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
auto crtBitsGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*crtBitsGlobalMemref).getType(),
(*crtBitsGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef));
operands.push_back(
mlir::concretelang::getCastedMemRef(rewriter, crtBitsGlobalRef));
// modulus_product
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getModulusProductAttr()));

View File

@@ -14,6 +14,7 @@
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Conversion/Utils/Utils.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Support/Constants.h"
@@ -177,6 +178,62 @@ struct EncodeExpandLutForBootstrapOpPattern
}
};
struct EncodeLutForCrtWopPBSOpPattern
: public mlir::OpConversionPattern<TFHE::EncodeLutForCrtWopPBSOp> {
EncodeLutForCrtWopPBSOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::EncodeLutForCrtWopPBSOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::EncodeLutForCrtWopPBSOp encodeOp,
TFHE::EncodeLutForCrtWopPBSOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const std::string funcName = "sim_encode_lut_for_crt_woppbs";
mlir::Value modulusProductCst = rewriter.create<mlir::arith::ConstantIntOp>(
encodeOp.getLoc(), encodeOp.getModulusProduct(), 32);
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
encodeOp.getLoc(), encodeOp.getIsSigned(), 1);
mlir::Value outputBuffer =
rewriter.create<mlir::bufferization::AllocTensorOp>(
encodeOp.getLoc(),
encodeOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, encodeOp.getLoc(), encodeOp.getCrtDecompositionAttr());
auto crtBitsValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, encodeOp.getLoc(), encodeOp.getCrtBitsAttr());
if (insertForwardDeclaration(
encodeOp, rewriter, funcName,
rewriter.getFunctionType(
{encodeOp.getResult().getType(),
encodeOp.getInputLookupTable().getType(),
crtDecompValue.getType(), crtBitsValue.getType(),
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
{}))
.failed()) {
return mlir::failure();
}
rewriter.create<mlir::func::CallOp>(
encodeOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(),
crtDecompValue, crtBitsValue, modulusProductCst,
isSignedCst}));
rewriter.replaceOp(encodeOp, outputBuffer);
return mlir::success();
}
};
struct EncodePlaintextWithCrtOpPattern
: public mlir::OpConversionPattern<TFHE::EncodePlaintextWithCrtOp> {
@@ -202,20 +259,23 @@ struct EncodePlaintextWithCrtOpPattern
epOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
// TODO: add mods
auto ModsValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, epOp.getLoc(), epOp.getModsAttr());
if (insertForwardDeclaration(
epOp, rewriter, funcName,
rewriter.getFunctionType({epOp.getResult().getType(),
epOp.getInput().getType() /*, mods here*/,
rewriter.getI64Type()},
{}))
rewriter.getFunctionType(
{epOp.getResult().getType(), epOp.getInput().getType(),
ModsValue.getType(), rewriter.getI64Type()},
{}))
.failed()) {
return mlir::failure();
}
rewriter.create<mlir::func::CallOp>(
epOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getInput(), modsProductCst}));
mlir::ValueRange(
{outputBuffer, adaptor.getInput(), ModsValue, modsProductCst}));
rewriter.replaceOp(epOp, outputBuffer);
@@ -223,6 +283,88 @@ struct EncodePlaintextWithCrtOpPattern
}
};
struct WopPBSGLWEOpPattern
: public mlir::OpConversionPattern<TFHE::WopPBSGLWEOp> {
WopPBSGLWEOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::WopPBSGLWEOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::WopPBSGLWEOp wopPbs,
TFHE::WopPBSGLWEOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const std::string funcName = "sim_wop_pbs_crt";
auto resultType = wopPbs.getType().cast<mlir::RankedTensorType>();
auto inputType =
wopPbs.getCiphertexts().getType().cast<mlir::RankedTensorType>();
mlir::Value outputBuffer =
rewriter.create<mlir::bufferization::AllocTensorOp>(
wopPbs.getLoc(),
this->getTypeConverter()
->convertType(resultType)
.cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
auto lweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getInputLweDim(), 32);
auto cbsLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getCbsLevels(), 32);
auto cbsBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getCbsBaseLog(), 32);
auto kskLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getKsk().getLevels(), 32);
auto kskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getKsk().getBaseLog(), 32);
auto bskLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getBsk().getLevels(), 32);
auto bskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getBsk().getBaseLog(), 32);
auto fpkskLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getLevels(), 32);
auto fpkskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getBaseLog(), 32);
auto polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getOutputPolySize(), 32);
auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, wopPbs.getLoc(), wopPbs.getCrtDecompositionAttr());
if (insertForwardDeclaration(
wopPbs, rewriter, funcName,
rewriter.getFunctionType(
{this->getTypeConverter()->convertType(resultType),
this->getTypeConverter()->convertType(inputType),
wopPbs.getLookupTable().getType(), crtDecompValue.getType(),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32)},
{}))
.failed()) {
return mlir::failure();
}
rewriter.create<mlir::func::CallOp>(
wopPbs.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getCiphertexts(),
adaptor.getLookupTable(), crtDecompValue, lweDimCst,
cbsLevelCountCst, cbsBaseLogCst, kskLevelCountCst,
kskBaseLogCst, bskLevelCountCst, bskBaseLogCst,
fpkskLevelCountCst, fpkskBaseLogCst, polySizeCst}));
rewriter.replaceOp(wopPbs, outputBuffer);
return mlir::success();
}
};
struct BootstrapGLWEOpPattern
: public mlir::OpConversionPattern<TFHE::BootstrapGLWEOp> {
@@ -399,8 +541,8 @@ void SimulateTFHEPass::runOnOperation() {
SimulateTFHETypeConverter converter;
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::CallOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::CastOp>();
target.addLegalOp<mlir::func::CallOp, mlir::memref::GetGlobalOp,
mlir::bufferization::AllocTensorOp, mlir::tensor::CastOp>();
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();
@@ -465,7 +607,9 @@ void SimulateTFHEPass::runOnOperation() {
});
patterns.insert<ZeroOpPattern, ZeroTensorOpPattern, KeySwitchGLWEOpPattern,
BootstrapGLWEOpPattern, EncodeExpandLutForBootstrapOpPattern,
BootstrapGLWEOpPattern, WopPBSGLWEOpPattern,
EncodeExpandLutForBootstrapOpPattern,
EncodeLutForCrtWopPBSOpPattern,
EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(),
converter);
patterns.insert<SubIntGLWEOpPattern>(&getContext());

View File

@@ -0,0 +1,60 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Conversion/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
namespace mlir {
namespace concretelang {
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
}
return mlir::MemRefType::get(
shape, rewriter.getI64Type(),
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
}
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
mlir::Type valueType = value.getType();
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
return rewriter.create<mlir::memref::CastOp>(
value.getLoc(),
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
value);
} else {
return value;
}
}
mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter,
mlir::Location loc,
mlir::ArrayAttr arrAttr) {
mlir::Type type =
mlir::RankedTensorType::get({(int)arrAttr.size()}, rewriter.getI64Type());
std::vector<int64_t> values;
for (auto a : arrAttr) {
values.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
auto denseAttr = rewriter.getI64TensorAttr(values);
auto cstOp = rewriter.create<mlir::arith::ConstantOp>(loc, denseAttr, type);
auto globalMemref = mlir::bufferization::getGlobalFor(cstOp, 0);
rewriter.eraseOp(cstOp);
assert(!mlir::failed(globalMemref));
auto globalRef = rewriter.create<mlir::memref::GetGlobalOp>(
loc, (*globalMemref).getType(), (*globalMemref).getName());
return mlir::concretelang::getCastedMemRef(rewriter, globalRef);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -93,6 +93,29 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
return out + gaussian_noise(0, variance);
}
void sim_wop_pbs_crt(
// Output 1D memref
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride,
// Input 1D memref
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,
uint64_t in_size, uint64_t in_stride,
// clear text lut 2D memref
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
// CRT decomposition 1D memref
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
uint64_t crt_decomp_stride,
// Additional crypto parameters
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size) {
// TODO
}
uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }
void sim_encode_expand_lut_for_boostrap(
@@ -120,3 +143,34 @@ void sim_encode_plaintext_with_crt(uint64_t *output_allocated,
output_stride, input, mods_allocated, mods_aligned, mods_offset,
mods_size, mods_stride, mods_product);
}
void sim_encode_lut_for_crt_woppbs(
// Output encoded/expanded lut
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size0,
uint64_t output_lut_size1, uint64_t output_lut_stride0,
uint64_t output_lut_stride1,
// Input lut
uint64_t *input_lut_allocated, uint64_t *input_lut_aligned,
uint64_t input_lut_offset, uint64_t input_lut_size,
uint64_t input_lut_stride,
// Crt coprimes
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
uint64_t crt_decomposition_stride,
// Crt number of bits
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
// Crypto parameters
uint32_t modulus_product, bool is_signed) {
return memref_encode_lut_for_crt_woppbs(
output_lut_allocated, output_lut_aligned, output_lut_offset,
output_lut_size0, output_lut_size1, output_lut_stride0,
output_lut_stride1, input_lut_allocated, input_lut_aligned,
input_lut_offset, input_lut_size, input_lut_stride,
crt_decomposition_allocated, crt_decomposition_aligned,
crt_decomposition_offset, crt_decomposition_size,
crt_decomposition_stride, crt_bits_allocated, crt_bits_aligned,
crt_bits_offset, crt_bits_size, crt_bits_stride, modulus_product,
is_signed);
}