mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 08:01:20 -05:00
feat(compiler): support compilation of CRT in simulation
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank);
|
||||
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value);
|
||||
|
||||
mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::ArrayAttr arrAttr);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
#endif
|
||||
@@ -57,13 +57,34 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
|
||||
uint32_t level, uint32_t base_log,
|
||||
uint32_t glwe_dim);
|
||||
|
||||
/// simulate a WoP PBS
|
||||
void sim_wop_pbs_crt(
|
||||
// Output 1D memref
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride,
|
||||
// Input 1D memref
|
||||
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,
|
||||
uint64_t in_size, uint64_t in_stride,
|
||||
// clear text lut 2D memref
|
||||
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
|
||||
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
|
||||
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
|
||||
// CRT decomposition 1D memref
|
||||
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
|
||||
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
|
||||
uint64_t crt_decomp_stride,
|
||||
// Additional crypto parameters
|
||||
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
|
||||
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
|
||||
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
|
||||
uint32_t polynomial_size);
|
||||
|
||||
void sim_encode_expand_lut_for_boostrap(
|
||||
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,
|
||||
uint64_t in_size, uint64_t in_stride, uint64_t *out_allocated,
|
||||
uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint32_t poly_size, uint32_t output_bits,
|
||||
bool is_signed);
|
||||
}
|
||||
|
||||
void sim_encode_plaintext_with_crt(uint64_t *output_allocated,
|
||||
uint64_t *output_aligned,
|
||||
@@ -74,4 +95,25 @@ void sim_encode_plaintext_with_crt(uint64_t *output_allocated,
|
||||
uint64_t mods_size, uint64_t mods_stride,
|
||||
uint64_t mods_product);
|
||||
|
||||
void sim_encode_lut_for_crt_woppbs(
|
||||
// Output encoded/expanded lut
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size0,
|
||||
uint64_t output_lut_size1, uint64_t output_lut_stride0,
|
||||
uint64_t output_lut_stride1,
|
||||
// Input lut
|
||||
uint64_t *input_lut_allocated, uint64_t *input_lut_aligned,
|
||||
uint64_t input_lut_offset, uint64_t input_lut_size,
|
||||
uint64_t input_lut_stride,
|
||||
// Crt coprimes
|
||||
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
|
||||
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
|
||||
uint64_t crt_decomposition_stride,
|
||||
// Crt number of bits
|
||||
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
|
||||
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
|
||||
// Crypto parameters
|
||||
uint32_t modulus_product, bool is_signed);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -18,6 +18,7 @@ add_mlir_library(
|
||||
Tools.cpp
|
||||
Utils/Dialects/SCF.cpp
|
||||
Utils/Dialects/Tensor.cpp
|
||||
Utils/Utils.cpp
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Conversion/Utils/Utils.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
||||
@@ -66,38 +67,13 @@ char memref_encode_expand_lut_for_bootstrap[] =
|
||||
char memref_encode_lut_for_crt_woppbs[] = "memref_encode_lut_for_crt_woppbs";
|
||||
char memref_trace[] = "memref_trace";
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
|
||||
}
|
||||
return mlir::MemRefType::get(
|
||||
shape, rewriter.getI64Type(),
|
||||
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
|
||||
}
|
||||
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
|
||||
mlir::Type valueType = value.getType();
|
||||
|
||||
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
|
||||
return rewriter.create<mlir::memref::CastOp>(
|
||||
value.getLoc(),
|
||||
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
|
||||
value);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
|
||||
|
||||
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
|
||||
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
|
||||
auto memref1DType =
|
||||
mlir::concretelang::getDynamicMemrefWithUnknownOffset(rewriter, 1);
|
||||
auto memref2DType =
|
||||
mlir::concretelang::getDynamicMemrefWithUnknownOffset(rewriter, 2);
|
||||
auto futureType =
|
||||
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType());
|
||||
auto contextType =
|
||||
@@ -282,7 +258,8 @@ struct ConcreteToCAPICallPattern : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
if (!type.isa<mlir::MemRefType>()) {
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
operands.push_back(getCastedMemRef(rewriter, operand.get()));
|
||||
operands.push_back(
|
||||
mlir::concretelang::getCastedMemRef(rewriter, operand.get()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,7 +349,7 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
|
||||
|
||||
auto globalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*globalMemref).getType(), (*globalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, globalRef));
|
||||
operands.push_back(mlir::concretelang::getCastedMemRef(rewriter, globalRef));
|
||||
|
||||
// lwe_small_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
@@ -441,7 +418,8 @@ void encodePlaintextWithCrtAddOperands(
|
||||
auto modsGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*modsGlobalMemref).getType(),
|
||||
(*modsGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, modsGlobalRef));
|
||||
operands.push_back(
|
||||
mlir::concretelang::getCastedMemRef(rewriter, modsGlobalRef));
|
||||
|
||||
// mods_prod
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
@@ -484,7 +462,8 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
|
||||
auto crtDecompositionGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*crtDecompositionGlobalMemref).getType(),
|
||||
(*crtDecompositionGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, crtDecompositionGlobalRef));
|
||||
operands.push_back(
|
||||
mlir::concretelang::getCastedMemRef(rewriter, crtDecompositionGlobalRef));
|
||||
|
||||
// crt_bits
|
||||
mlir::Type crtBitsType = mlir::RankedTensorType::get(
|
||||
@@ -503,7 +482,8 @@ void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
|
||||
auto crtBitsGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*crtBitsGlobalMemref).getType(),
|
||||
(*crtBitsGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef));
|
||||
operands.push_back(
|
||||
mlir::concretelang::getCastedMemRef(rewriter, crtBitsGlobalRef));
|
||||
// modulus_product
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.getModulusProductAttr()));
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Conversion/Utils/Utils.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
@@ -177,6 +178,62 @@ struct EncodeExpandLutForBootstrapOpPattern
|
||||
}
|
||||
};
|
||||
|
||||
struct EncodeLutForCrtWopPBSOpPattern
|
||||
: public mlir::OpConversionPattern<TFHE::EncodeLutForCrtWopPBSOp> {
|
||||
|
||||
EncodeLutForCrtWopPBSOpPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &typeConverter)
|
||||
: mlir::OpConversionPattern<TFHE::EncodeLutForCrtWopPBSOp>(
|
||||
typeConverter, context,
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::EncodeLutForCrtWopPBSOp encodeOp,
|
||||
TFHE::EncodeLutForCrtWopPBSOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
const std::string funcName = "sim_encode_lut_for_crt_woppbs";
|
||||
|
||||
mlir::Value modulusProductCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
encodeOp.getLoc(), encodeOp.getModulusProduct(), 32);
|
||||
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
encodeOp.getLoc(), encodeOp.getIsSigned(), 1);
|
||||
|
||||
mlir::Value outputBuffer =
|
||||
rewriter.create<mlir::bufferization::AllocTensorOp>(
|
||||
encodeOp.getLoc(),
|
||||
encodeOp.getResult().getType().cast<mlir::RankedTensorType>(),
|
||||
mlir::ValueRange{});
|
||||
|
||||
auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
|
||||
rewriter, encodeOp.getLoc(), encodeOp.getCrtDecompositionAttr());
|
||||
auto crtBitsValue = mlir::concretelang::globalMemrefFromArrayAttr(
|
||||
rewriter, encodeOp.getLoc(), encodeOp.getCrtBitsAttr());
|
||||
|
||||
if (insertForwardDeclaration(
|
||||
encodeOp, rewriter, funcName,
|
||||
rewriter.getFunctionType(
|
||||
{encodeOp.getResult().getType(),
|
||||
encodeOp.getInputLookupTable().getType(),
|
||||
crtDecompValue.getType(), crtBitsValue.getType(),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
|
||||
{}))
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(
|
||||
encodeOp.getLoc(), funcName, mlir::TypeRange{},
|
||||
mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(),
|
||||
crtDecompValue, crtBitsValue, modulusProductCst,
|
||||
isSignedCst}));
|
||||
|
||||
rewriter.replaceOp(encodeOp, outputBuffer);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
struct EncodePlaintextWithCrtOpPattern
|
||||
: public mlir::OpConversionPattern<TFHE::EncodePlaintextWithCrtOp> {
|
||||
|
||||
@@ -202,20 +259,23 @@ struct EncodePlaintextWithCrtOpPattern
|
||||
epOp.getResult().getType().cast<mlir::RankedTensorType>(),
|
||||
mlir::ValueRange{});
|
||||
|
||||
// TODO: add mods
|
||||
auto ModsValue = mlir::concretelang::globalMemrefFromArrayAttr(
|
||||
rewriter, epOp.getLoc(), epOp.getModsAttr());
|
||||
|
||||
if (insertForwardDeclaration(
|
||||
epOp, rewriter, funcName,
|
||||
rewriter.getFunctionType({epOp.getResult().getType(),
|
||||
epOp.getInput().getType() /*, mods here*/,
|
||||
rewriter.getI64Type()},
|
||||
{}))
|
||||
rewriter.getFunctionType(
|
||||
{epOp.getResult().getType(), epOp.getInput().getType(),
|
||||
ModsValue.getType(), rewriter.getI64Type()},
|
||||
{}))
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(
|
||||
epOp.getLoc(), funcName, mlir::TypeRange{},
|
||||
mlir::ValueRange({outputBuffer, adaptor.getInput(), modsProductCst}));
|
||||
mlir::ValueRange(
|
||||
{outputBuffer, adaptor.getInput(), ModsValue, modsProductCst}));
|
||||
|
||||
rewriter.replaceOp(epOp, outputBuffer);
|
||||
|
||||
@@ -223,6 +283,88 @@ struct EncodePlaintextWithCrtOpPattern
|
||||
}
|
||||
};
|
||||
|
||||
struct WopPBSGLWEOpPattern
|
||||
: public mlir::OpConversionPattern<TFHE::WopPBSGLWEOp> {
|
||||
|
||||
WopPBSGLWEOpPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &typeConverter)
|
||||
: mlir::OpConversionPattern<TFHE::WopPBSGLWEOp>(
|
||||
typeConverter, context,
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::WopPBSGLWEOp wopPbs,
|
||||
TFHE::WopPBSGLWEOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
const std::string funcName = "sim_wop_pbs_crt";
|
||||
|
||||
auto resultType = wopPbs.getType().cast<mlir::RankedTensorType>();
|
||||
auto inputType =
|
||||
wopPbs.getCiphertexts().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Value outputBuffer =
|
||||
rewriter.create<mlir::bufferization::AllocTensorOp>(
|
||||
wopPbs.getLoc(),
|
||||
this->getTypeConverter()
|
||||
->convertType(resultType)
|
||||
.cast<mlir::RankedTensorType>(),
|
||||
mlir::ValueRange{});
|
||||
|
||||
auto lweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getPksk().getInputLweDim(), 32);
|
||||
auto cbsLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getCbsLevels(), 32);
|
||||
auto cbsBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getCbsBaseLog(), 32);
|
||||
auto kskLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getKsk().getLevels(), 32);
|
||||
auto kskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getKsk().getBaseLog(), 32);
|
||||
auto bskLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getBsk().getLevels(), 32);
|
||||
auto bskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getBsk().getBaseLog(), 32);
|
||||
auto fpkskLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getPksk().getLevels(), 32);
|
||||
auto fpkskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getPksk().getBaseLog(), 32);
|
||||
auto polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getPksk().getOutputPolySize(), 32);
|
||||
|
||||
auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
|
||||
rewriter, wopPbs.getLoc(), wopPbs.getCrtDecompositionAttr());
|
||||
|
||||
if (insertForwardDeclaration(
|
||||
wopPbs, rewriter, funcName,
|
||||
rewriter.getFunctionType(
|
||||
{this->getTypeConverter()->convertType(resultType),
|
||||
this->getTypeConverter()->convertType(inputType),
|
||||
wopPbs.getLookupTable().getType(), crtDecompValue.getType(),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32)},
|
||||
{}))
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(
|
||||
wopPbs.getLoc(), funcName, mlir::TypeRange{},
|
||||
mlir::ValueRange({outputBuffer, adaptor.getCiphertexts(),
|
||||
adaptor.getLookupTable(), crtDecompValue, lweDimCst,
|
||||
cbsLevelCountCst, cbsBaseLogCst, kskLevelCountCst,
|
||||
kskBaseLogCst, bskLevelCountCst, bskBaseLogCst,
|
||||
fpkskLevelCountCst, fpkskBaseLogCst, polySizeCst}));
|
||||
|
||||
rewriter.replaceOp(wopPbs, outputBuffer);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
struct BootstrapGLWEOpPattern
|
||||
: public mlir::OpConversionPattern<TFHE::BootstrapGLWEOp> {
|
||||
|
||||
@@ -399,8 +541,8 @@ void SimulateTFHEPass::runOnOperation() {
|
||||
SimulateTFHETypeConverter converter;
|
||||
|
||||
target.addLegalDialect<mlir::arith::ArithDialect>();
|
||||
target.addLegalOp<mlir::func::CallOp, mlir::bufferization::AllocTensorOp,
|
||||
mlir::tensor::CastOp>();
|
||||
target.addLegalOp<mlir::func::CallOp, mlir::memref::GetGlobalOp,
|
||||
mlir::bufferization::AllocTensorOp, mlir::tensor::CastOp>();
|
||||
// Make sure that no ops from `TFHE` remain after the lowering
|
||||
target.addIllegalDialect<TFHE::TFHEDialect>();
|
||||
|
||||
@@ -465,7 +607,9 @@ void SimulateTFHEPass::runOnOperation() {
|
||||
});
|
||||
|
||||
patterns.insert<ZeroOpPattern, ZeroTensorOpPattern, KeySwitchGLWEOpPattern,
|
||||
BootstrapGLWEOpPattern, EncodeExpandLutForBootstrapOpPattern,
|
||||
BootstrapGLWEOpPattern, WopPBSGLWEOpPattern,
|
||||
EncodeExpandLutForBootstrapOpPattern,
|
||||
EncodeLutForCrtWopPBSOpPattern,
|
||||
EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(),
|
||||
converter);
|
||||
patterns.insert<SubIntGLWEOpPattern>(&getContext());
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Conversion/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
|
||||
}
|
||||
return mlir::MemRefType::get(
|
||||
shape, rewriter.getI64Type(),
|
||||
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
|
||||
}
|
||||
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
|
||||
mlir::Type valueType = value.getType();
|
||||
|
||||
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
|
||||
return rewriter.create<mlir::memref::CastOp>(
|
||||
value.getLoc(),
|
||||
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
|
||||
value);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::ArrayAttr arrAttr) {
|
||||
mlir::Type type =
|
||||
mlir::RankedTensorType::get({(int)arrAttr.size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> values;
|
||||
for (auto a : arrAttr) {
|
||||
values.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
auto denseAttr = rewriter.getI64TensorAttr(values);
|
||||
auto cstOp = rewriter.create<mlir::arith::ConstantOp>(loc, denseAttr, type);
|
||||
auto globalMemref = mlir::bufferization::getGlobalFor(cstOp, 0);
|
||||
rewriter.eraseOp(cstOp);
|
||||
assert(!mlir::failed(globalMemref));
|
||||
auto globalRef = rewriter.create<mlir::memref::GetGlobalOp>(
|
||||
loc, (*globalMemref).getType(), (*globalMemref).getName());
|
||||
return mlir::concretelang::getCastedMemRef(rewriter, globalRef);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -93,6 +93,29 @@ uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t *tlu_allocated,
|
||||
return out + gaussian_noise(0, variance);
|
||||
}
|
||||
|
||||
void sim_wop_pbs_crt(
|
||||
// Output 1D memref
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride,
|
||||
// Input 1D memref
|
||||
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,
|
||||
uint64_t in_size, uint64_t in_stride,
|
||||
// clear text lut 2D memref
|
||||
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
|
||||
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
|
||||
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
|
||||
// CRT decomposition 1D memref
|
||||
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
|
||||
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
|
||||
uint64_t crt_decomp_stride,
|
||||
// Additional crypto parameters
|
||||
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
|
||||
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
|
||||
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
|
||||
uint32_t polynomial_size) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }
|
||||
|
||||
void sim_encode_expand_lut_for_boostrap(
|
||||
@@ -120,3 +143,34 @@ void sim_encode_plaintext_with_crt(uint64_t *output_allocated,
|
||||
output_stride, input, mods_allocated, mods_aligned, mods_offset,
|
||||
mods_size, mods_stride, mods_product);
|
||||
}
|
||||
|
||||
void sim_encode_lut_for_crt_woppbs(
|
||||
// Output encoded/expanded lut
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size0,
|
||||
uint64_t output_lut_size1, uint64_t output_lut_stride0,
|
||||
uint64_t output_lut_stride1,
|
||||
// Input lut
|
||||
uint64_t *input_lut_allocated, uint64_t *input_lut_aligned,
|
||||
uint64_t input_lut_offset, uint64_t input_lut_size,
|
||||
uint64_t input_lut_stride,
|
||||
// Crt coprimes
|
||||
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
|
||||
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
|
||||
uint64_t crt_decomposition_stride,
|
||||
// Crt number of bits
|
||||
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
|
||||
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
|
||||
// Crypto parameters
|
||||
uint32_t modulus_product, bool is_signed) {
|
||||
return memref_encode_lut_for_crt_woppbs(
|
||||
output_lut_allocated, output_lut_aligned, output_lut_offset,
|
||||
output_lut_size0, output_lut_size1, output_lut_stride0,
|
||||
output_lut_stride1, input_lut_allocated, input_lut_aligned,
|
||||
input_lut_offset, input_lut_size, input_lut_stride,
|
||||
crt_decomposition_allocated, crt_decomposition_aligned,
|
||||
crt_decomposition_offset, crt_decomposition_size,
|
||||
crt_decomposition_stride, crt_bits_allocated, crt_bits_aligned,
|
||||
crt_bits_offset, crt_bits_size, crt_bits_stride, modulus_product,
|
||||
is_signed);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user