mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(compiler): add SDFG op generation for batched operations.
This commit is contained in:
@@ -89,10 +89,22 @@ def ProcessKindMulEintInt : I32EnumAttrCase<"mul_eint_int", 2>;
|
||||
def ProcessKindNegEint : I32EnumAttrCase<"neg_eint", 3>;
|
||||
def ProcessKindKeyswitch : I32EnumAttrCase<"keyswitch", 4>;
|
||||
def ProcessKindBootstrap : I32EnumAttrCase<"bootstrap", 5>;
|
||||
def ProcessKindBatchAddEint : I32EnumAttrCase<"batched_add_eint", 6>;
|
||||
def ProcessKindBatchAddEintInt : I32EnumAttrCase<"batched_add_eint_int", 7>;
|
||||
def ProcessKindBatchAddEintIntCst : I32EnumAttrCase<"batched_add_eint_int_cst", 8>;
|
||||
def ProcessKindBatchMulEintInt : I32EnumAttrCase<"batched_mul_eint_int", 9>;
|
||||
def ProcessKindBatchMulEintIntCst : I32EnumAttrCase<"batched_mul_eint_int_cst", 10>;
|
||||
def ProcessKindBatchNegEint : I32EnumAttrCase<"batched_neg_eint", 11>;
|
||||
def ProcessKindBatchKeyswitch : I32EnumAttrCase<"batched_keyswitch", 12>;
|
||||
def ProcessKindBatchBootstrap : I32EnumAttrCase<"batched_bootstrap", 13>;
|
||||
|
||||
def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind",
|
||||
[ProcessKindAddEint, ProcessKindAddEintInt, ProcessKindMulEintInt,
|
||||
ProcessKindNegEint, ProcessKindKeyswitch, ProcessKindBootstrap]> {
|
||||
ProcessKindNegEint, ProcessKindKeyswitch, ProcessKindBootstrap,
|
||||
ProcessKindBatchAddEint, ProcessKindBatchAddEintInt,
|
||||
ProcessKindBatchAddEintIntCst, ProcessKindBatchMulEintInt,
|
||||
ProcessKindBatchMulEintIntCst, ProcessKindBatchNegEint,
|
||||
ProcessKindBatchKeyswitch, ProcessKindBatchBootstrap]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::concretelang::SDFG";
|
||||
}
|
||||
|
||||
@@ -42,6 +42,27 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t output_size, uint32_t bsk_index, void *context);
|
||||
|
||||
void stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout);
|
||||
void stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout);
|
||||
void stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout);
|
||||
void stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout);
|
||||
void stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout);
|
||||
void stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process(
|
||||
void *dfg, void *sin1, void *sout);
|
||||
void stream_emulator_make_memref_batched_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
void *context);
|
||||
void stream_emulator_make_memref_batched_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t output_size, void *context);
|
||||
|
||||
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype);
|
||||
void stream_emulator_put_uint64(void *stream, uint64_t e);
|
||||
uint64_t stream_emulator_get_uint64(void *stream);
|
||||
@@ -53,6 +74,18 @@ void stream_emulator_put_memref(void *stream, uint64_t *allocated,
|
||||
void stream_emulator_get_memref(void *stream, uint64_t *out_allocated,
|
||||
uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride);
|
||||
|
||||
void *stream_emulator_make_memref_batch_stream(const char *name,
|
||||
stream_type stype);
|
||||
void stream_emulator_put_memref_batch(void *stream, uint64_t *allocated,
|
||||
uint64_t *aligned, uint64_t offset,
|
||||
uint64_t size0, uint64_t size1,
|
||||
uint64_t stride0, uint64_t stride1);
|
||||
void stream_emulator_get_memref_batch(void *stream, uint64_t *out_allocated,
|
||||
uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size0,
|
||||
uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -38,6 +38,31 @@ char stream_emulator_make_memref_keyswitch_lwe_u64_process[] =
|
||||
char stream_emulator_make_memref_bootstrap_lwe_u64_process[] =
|
||||
"stream_emulator_make_memref_bootstrap_lwe_u64_process";
|
||||
|
||||
char stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process[] =
|
||||
"stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process";
|
||||
char
|
||||
stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process
|
||||
[] = "stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_"
|
||||
"u64_process";
|
||||
char
|
||||
stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process
|
||||
[] = "stream_emulator_make_memref_batched_add_plaintext_cst_lwe_"
|
||||
"ciphertext_u64_process";
|
||||
char
|
||||
stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process
|
||||
[] = "stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_"
|
||||
"u64_process";
|
||||
char
|
||||
stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process
|
||||
[] = "stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_"
|
||||
"ciphertext_u64_process";
|
||||
char stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process[] =
|
||||
"stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process";
|
||||
char stream_emulator_make_memref_batched_keyswitch_lwe_u64_process[] =
|
||||
"stream_emulator_make_memref_batched_keyswitch_lwe_u64_process";
|
||||
char stream_emulator_make_memref_batched_bootstrap_lwe_u64_process[] =
|
||||
"stream_emulator_make_memref_batched_bootstrap_lwe_u64_process";
|
||||
|
||||
char stream_emulator_make_memref_stream[] =
|
||||
"stream_emulator_make_memref_stream";
|
||||
char stream_emulator_put_memref[] = "stream_emulator_put_memref";
|
||||
@@ -46,6 +71,10 @@ char stream_emulator_make_uint64_stream[] =
|
||||
char stream_emulator_put_uint64[] = "stream_emulator_put_uint64";
|
||||
char stream_emulator_get_uint64[] = "stream_emulator_get_uint64";
|
||||
|
||||
char stream_emulator_make_memref_batch_stream[] =
|
||||
"stream_emulator_make_memref_batch_stream";
|
||||
char stream_emulator_put_memref_batch[] = "stream_emulator_put_memref_batch";
|
||||
|
||||
mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) {
|
||||
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
|
||||
return mlir::RankedTensorType::get(shape, rewriter.getI64Type());
|
||||
@@ -164,7 +193,7 @@ struct LowerSDFGMakeProcess
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::SDFG::MakeProcess mpOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
const char *funcName;
|
||||
const char *funcName = nullptr;
|
||||
mlir::SmallVector<mlir::Value> operands(mpOp->getOperands());
|
||||
switch (mpOp.getType()) {
|
||||
case SDFG::ProcessKind::add_eint:
|
||||
@@ -181,8 +210,12 @@ struct LowerSDFGMakeProcess
|
||||
case SDFG::ProcessKind::neg_eint:
|
||||
funcName = stream_emulator_make_memref_negate_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::batched_keyswitch:
|
||||
funcName = stream_emulator_make_memref_batched_keyswitch_lwe_u64_process;
|
||||
[[fallthrough]];
|
||||
case SDFG::ProcessKind::keyswitch:
|
||||
funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process;
|
||||
if (funcName == nullptr)
|
||||
funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process;
|
||||
// level
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("level")));
|
||||
@@ -206,8 +239,12 @@ struct LowerSDFGMakeProcess
|
||||
// context
|
||||
operands.push_back(getContextArgument(mpOp));
|
||||
break;
|
||||
case SDFG::ProcessKind::batched_bootstrap:
|
||||
funcName = stream_emulator_make_memref_batched_bootstrap_lwe_u64_process;
|
||||
[[fallthrough]];
|
||||
case SDFG::ProcessKind::bootstrap:
|
||||
funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process;
|
||||
if (funcName == nullptr)
|
||||
funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process;
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
@@ -235,6 +272,30 @@ struct LowerSDFGMakeProcess
|
||||
// context
|
||||
operands.push_back(getContextArgument(mpOp));
|
||||
break;
|
||||
case SDFG::ProcessKind::batched_add_eint:
|
||||
funcName =
|
||||
stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::batched_add_eint_int:
|
||||
funcName =
|
||||
stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::batched_add_eint_int_cst:
|
||||
funcName =
|
||||
stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::batched_mul_eint_int:
|
||||
funcName =
|
||||
stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::batched_mul_eint_int_cst:
|
||||
funcName =
|
||||
stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::batched_neg_eint:
|
||||
funcName =
|
||||
stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
}
|
||||
if (insertGenericForwardDeclaration(mpOp, rewriter, funcName,
|
||||
mlir::ValueRange{operands}.getTypes(),
|
||||
@@ -274,7 +335,13 @@ struct LowerSDFGMakeStream
|
||||
assert(sType && "SDFG MakeStream operation should return a stream type");
|
||||
|
||||
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
|
||||
funcName = stream_emulator_make_memref_stream;
|
||||
if (sType.getElementType().dyn_cast<mlir::TensorType>().getRank() == 1)
|
||||
funcName = stream_emulator_make_memref_stream;
|
||||
else if (sType.getElementType().dyn_cast<mlir::TensorType>().getRank() ==
|
||||
2)
|
||||
funcName = stream_emulator_make_memref_batch_stream;
|
||||
else
|
||||
return ::mlir::failure();
|
||||
} else {
|
||||
assert(sType.getElementType().isa<mlir::IntegerType>() &&
|
||||
"SDFG streams only support memrefs and integers.");
|
||||
@@ -314,7 +381,13 @@ struct LowerSDFGPut
|
||||
assert(sType &&
|
||||
"SDFG Put operation must take a stream type as first parameter.");
|
||||
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
|
||||
funcName = stream_emulator_put_memref;
|
||||
if (sType.getElementType().dyn_cast<mlir::TensorType>().getRank() == 1)
|
||||
funcName = stream_emulator_put_memref;
|
||||
else if (sType.getElementType().dyn_cast<mlir::TensorType>().getRank() ==
|
||||
2)
|
||||
funcName = stream_emulator_put_memref_batch;
|
||||
else
|
||||
return ::mlir::failure();
|
||||
} else {
|
||||
assert(sType.getElementType().isa<mlir::IntegerType>() &&
|
||||
"SDFG streams only support memrefs and integers.");
|
||||
|
||||
@@ -82,6 +82,22 @@ mlir::LogicalResult MakeProcess::verify() {
|
||||
return checkStreams(1, 1);
|
||||
case ProcessKind::bootstrap:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::batched_add_eint:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::batched_add_eint_int:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::batched_add_eint_int_cst:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::batched_mul_eint_int:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::batched_mul_eint_int_cst:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::batched_neg_eint:
|
||||
return checkStreams(1, 1);
|
||||
case ProcessKind::batched_keyswitch:
|
||||
return checkStreams(1, 1);
|
||||
case ProcessKind::batched_bootstrap:
|
||||
return checkStreams(2, 1);
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
|
||||
@@ -69,11 +69,12 @@ mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
|
||||
}
|
||||
|
||||
char stream_emulator_get_memref[] = "stream_emulator_get_memref";
|
||||
char stream_emulator_get_memref_batch[] = "stream_emulator_get_memref_batch";
|
||||
|
||||
template <typename Op, char const *funcName>
|
||||
template <typename Op, char const *funcName, char const *funcName_batch>
|
||||
struct BufferizableWithCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableWithCallOpInterface<Op, funcName>, Op> {
|
||||
BufferizableWithCallOpInterface<Op, funcName, funcName_batch>, Op> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
@@ -103,6 +104,13 @@ struct BufferizableWithCallOpInterface
|
||||
// able to avoid the copy depending on the stream semantics.
|
||||
auto resTensorType =
|
||||
op->getResultTypes()[0].template cast<mlir::TensorType>();
|
||||
char const *fname;
|
||||
if (resTensorType.getRank() == 1)
|
||||
fname = funcName;
|
||||
else if (resTensorType.getRank() == 2)
|
||||
fname = funcName_batch;
|
||||
else
|
||||
return mlir::failure();
|
||||
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
|
||||
resTensorType.getElementType());
|
||||
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
|
||||
@@ -117,9 +125,9 @@ struct BufferizableWithCallOpInterface
|
||||
mlir::FunctionType funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), mlir::ValueRange{operands}.getTypes(),
|
||||
mlir::TypeRange());
|
||||
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed())
|
||||
if (insertForwardDeclaration(op, rewriter, fname, funcType).failed())
|
||||
return ::mlir::failure();
|
||||
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
|
||||
rewriter.create<mlir::func::CallOp>(loc, fname, mlir::TypeRange{},
|
||||
operands);
|
||||
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
|
||||
|
||||
@@ -133,7 +141,8 @@ void mlir::concretelang::SDFG::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, SDFG::SDFGDialect *dialect) {
|
||||
SDFG::Get::attachInterface<
|
||||
BufferizableWithCallOpInterface<SDFG::Get, stream_emulator_get_memref>>(
|
||||
BufferizableWithCallOpInterface<SDFG::Get, stream_emulator_get_memref,
|
||||
stream_emulator_get_memref_batch>>(
|
||||
*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -20,6 +20,15 @@ char mul_eint_int[] = "mul_eint_int";
|
||||
char neg_eint[] = "neg_eint";
|
||||
char keyswitch[] = "keyswitch";
|
||||
char bootstrap[] = "bootstrap";
|
||||
|
||||
char batched_add_eint[] = "batched_add_eint";
|
||||
char batched_add_eint_int[] = "batched_add_eint_int";
|
||||
char batched_add_eint_int_cst[] = "batched_add_eint_int_cst";
|
||||
char batched_mul_eint_int[] = "batched_mul_eint_int";
|
||||
char batched_mul_eint_int_cst[] = "batched_mul_eint_int_cst";
|
||||
char batched_neg_eint[] = "batched_neg_eint";
|
||||
char batched_keyswitch[] = "batched_keyswitch";
|
||||
char batched_bootstrap[] = "batched_bootstrap";
|
||||
} // namespace
|
||||
|
||||
template <typename Op, char const *processName, bool copyAttributes = false>
|
||||
@@ -87,6 +96,39 @@ void registerSDFGConvertibleOpInterfaceExternalModels(
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BootstrapLweTensorOp, bootstrap,
|
||||
true>>(*ctx);
|
||||
|
||||
mlir::concretelang::Concrete::BatchedAddLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedAddLweTensorOp,
|
||||
batched_add_eint>>(*ctx);
|
||||
mlir::concretelang::Concrete::BatchedAddPlaintextLweTensorOp::
|
||||
attachInterface<ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedAddPlaintextLweTensorOp,
|
||||
batched_add_eint_int>>(*ctx);
|
||||
mlir::concretelang::Concrete::BatchedAddPlaintextCstLweTensorOp::
|
||||
attachInterface<ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedAddPlaintextCstLweTensorOp,
|
||||
batched_add_eint_int_cst>>(*ctx);
|
||||
mlir::concretelang::Concrete::BatchedMulCleartextLweTensorOp::
|
||||
attachInterface<ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedMulCleartextLweTensorOp,
|
||||
batched_mul_eint_int>>(*ctx);
|
||||
mlir::concretelang::Concrete::BatchedMulCleartextCstLweTensorOp::
|
||||
attachInterface<ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedMulCleartextCstLweTensorOp,
|
||||
batched_mul_eint_int_cst>>(*ctx);
|
||||
mlir::concretelang::Concrete::BatchedNegateLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedNegateLweTensorOp,
|
||||
batched_neg_eint>>(*ctx);
|
||||
mlir::concretelang::Concrete::BatchedKeySwitchLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedKeySwitchLweTensorOp,
|
||||
batched_keyswitch, true>>(*ctx);
|
||||
mlir::concretelang::Concrete::BatchedBootstrapLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::Concrete::BatchedBootstrapLweTensorOp,
|
||||
batched_bootstrap, true>>(*ctx);
|
||||
});
|
||||
}
|
||||
} // namespace SDFG
|
||||
|
||||
Reference in New Issue
Block a user