feat(compiler): add SDFG op generation for batched operations.

This commit is contained in:
Antoniu Pop
2023-03-24 10:44:56 +00:00
committed by Antoniu Pop
parent 3f230957cb
commit 60412f7f61
6 changed files with 196 additions and 11 deletions

View File

@@ -89,10 +89,22 @@ def ProcessKindMulEintInt : I32EnumAttrCase<"mul_eint_int", 2>;
def ProcessKindNegEint : I32EnumAttrCase<"neg_eint", 3>;
def ProcessKindKeyswitch : I32EnumAttrCase<"keyswitch", 4>;
def ProcessKindBootstrap : I32EnumAttrCase<"bootstrap", 5>;
def ProcessKindBatchAddEint : I32EnumAttrCase<"batched_add_eint", 6>;
def ProcessKindBatchAddEintInt : I32EnumAttrCase<"batched_add_eint_int", 7>;
def ProcessKindBatchAddEintIntCst : I32EnumAttrCase<"batched_add_eint_int_cst", 8>;
def ProcessKindBatchMulEintInt : I32EnumAttrCase<"batched_mul_eint_int", 9>;
def ProcessKindBatchMulEintIntCst : I32EnumAttrCase<"batched_mul_eint_int_cst", 10>;
def ProcessKindBatchNegEint : I32EnumAttrCase<"batched_neg_eint", 11>;
def ProcessKindBatchKeyswitch : I32EnumAttrCase<"batched_keyswitch", 12>;
def ProcessKindBatchBootstrap : I32EnumAttrCase<"batched_bootstrap", 13>;
def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind",
[ProcessKindAddEint, ProcessKindAddEintInt, ProcessKindMulEintInt,
ProcessKindNegEint, ProcessKindKeyswitch, ProcessKindBootstrap]> {
ProcessKindNegEint, ProcessKindKeyswitch, ProcessKindBootstrap,
ProcessKindBatchAddEint, ProcessKindBatchAddEintInt,
ProcessKindBatchAddEintIntCst, ProcessKindBatchMulEintInt,
ProcessKindBatchMulEintIntCst, ProcessKindBatchNegEint,
ProcessKindBatchKeyswitch, ProcessKindBatchBootstrap]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::concretelang::SDFG";
}

View File

@@ -42,6 +42,27 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process(
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t output_size, uint32_t bsk_index, void *context);
void stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process(
void *dfg, void *sin1, void *sin2, void *sout);
void stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process(
void *dfg, void *sin1, void *sin2, void *sout);
void stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process(
void *dfg, void *sin1, void *sin2, void *sout);
void stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process(
void *dfg, void *sin1, void *sin2, void *sout);
void stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process(
void *dfg, void *sin1, void *sin2, void *sout);
void stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process(
void *dfg, void *sin1, void *sout);
void stream_emulator_make_memref_batched_keyswitch_lwe_u64_process(
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
void *context);
void stream_emulator_make_memref_batched_bootstrap_lwe_u64_process(
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t output_size, void *context);
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype);
void stream_emulator_put_uint64(void *stream, uint64_t e);
uint64_t stream_emulator_get_uint64(void *stream);
@@ -53,6 +74,18 @@ void stream_emulator_put_memref(void *stream, uint64_t *allocated,
void stream_emulator_get_memref(void *stream, uint64_t *out_allocated,
uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride);
void *stream_emulator_make_memref_batch_stream(const char *name,
stream_type stype);
void stream_emulator_put_memref_batch(void *stream, uint64_t *allocated,
uint64_t *aligned, uint64_t offset,
uint64_t size0, uint64_t size1,
uint64_t stride0, uint64_t stride1);
void stream_emulator_get_memref_batch(void *stream, uint64_t *out_allocated,
uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size0,
uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1);
}
#endif

View File

@@ -38,6 +38,31 @@ char stream_emulator_make_memref_keyswitch_lwe_u64_process[] =
char stream_emulator_make_memref_bootstrap_lwe_u64_process[] =
"stream_emulator_make_memref_bootstrap_lwe_u64_process";
char stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process[] =
"stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process";
char
stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process
[] = "stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_"
"u64_process";
char
stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process
[] = "stream_emulator_make_memref_batched_add_plaintext_cst_lwe_"
"ciphertext_u64_process";
char
stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process
[] = "stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_"
"u64_process";
char
stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process
[] = "stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_"
"ciphertext_u64_process";
char stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process[] =
"stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process";
char stream_emulator_make_memref_batched_keyswitch_lwe_u64_process[] =
"stream_emulator_make_memref_batched_keyswitch_lwe_u64_process";
char stream_emulator_make_memref_batched_bootstrap_lwe_u64_process[] =
"stream_emulator_make_memref_batched_bootstrap_lwe_u64_process";
char stream_emulator_make_memref_stream[] =
"stream_emulator_make_memref_stream";
char stream_emulator_put_memref[] = "stream_emulator_put_memref";
@@ -46,6 +71,10 @@ char stream_emulator_make_uint64_stream[] =
char stream_emulator_put_uint64[] = "stream_emulator_put_uint64";
char stream_emulator_get_uint64[] = "stream_emulator_get_uint64";
char stream_emulator_make_memref_batch_stream[] =
"stream_emulator_make_memref_batch_stream";
char stream_emulator_put_memref_batch[] = "stream_emulator_put_memref_batch";
mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) {
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
return mlir::RankedTensorType::get(shape, rewriter.getI64Type());
@@ -164,7 +193,7 @@ struct LowerSDFGMakeProcess
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::SDFG::MakeProcess mpOp,
::mlir::PatternRewriter &rewriter) const override {
const char *funcName;
const char *funcName = nullptr;
mlir::SmallVector<mlir::Value> operands(mpOp->getOperands());
switch (mpOp.getType()) {
case SDFG::ProcessKind::add_eint:
@@ -181,8 +210,12 @@ struct LowerSDFGMakeProcess
case SDFG::ProcessKind::neg_eint:
funcName = stream_emulator_make_memref_negate_lwe_ciphertext_u64_process;
break;
case SDFG::ProcessKind::batched_keyswitch:
funcName = stream_emulator_make_memref_batched_keyswitch_lwe_u64_process;
[[fallthrough]];
case SDFG::ProcessKind::keyswitch:
funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process;
if (funcName == nullptr)
funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process;
// level
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("level")));
@@ -206,8 +239,12 @@ struct LowerSDFGMakeProcess
// context
operands.push_back(getContextArgument(mpOp));
break;
case SDFG::ProcessKind::batched_bootstrap:
funcName = stream_emulator_make_memref_batched_bootstrap_lwe_u64_process;
[[fallthrough]];
case SDFG::ProcessKind::bootstrap:
funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process;
if (funcName == nullptr)
funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process;
// input_lwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
@@ -235,6 +272,30 @@ struct LowerSDFGMakeProcess
// context
operands.push_back(getContextArgument(mpOp));
break;
case SDFG::ProcessKind::batched_add_eint:
funcName =
stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process;
break;
case SDFG::ProcessKind::batched_add_eint_int:
funcName =
stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process;
break;
case SDFG::ProcessKind::batched_add_eint_int_cst:
funcName =
stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process;
break;
case SDFG::ProcessKind::batched_mul_eint_int:
funcName =
stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process;
break;
case SDFG::ProcessKind::batched_mul_eint_int_cst:
funcName =
stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process;
break;
case SDFG::ProcessKind::batched_neg_eint:
funcName =
stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process;
break;
}
if (insertGenericForwardDeclaration(mpOp, rewriter, funcName,
mlir::ValueRange{operands}.getTypes(),
@@ -274,7 +335,13 @@ struct LowerSDFGMakeStream
assert(sType && "SDFG MakeStream operation should return a stream type");
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
funcName = stream_emulator_make_memref_stream;
if (sType.getElementType().dyn_cast<mlir::TensorType>().getRank() == 1)
funcName = stream_emulator_make_memref_stream;
else if (sType.getElementType().dyn_cast<mlir::TensorType>().getRank() ==
2)
funcName = stream_emulator_make_memref_batch_stream;
else
return ::mlir::failure();
} else {
assert(sType.getElementType().isa<mlir::IntegerType>() &&
"SDFG streams only support memrefs and integers.");
@@ -314,7 +381,13 @@ struct LowerSDFGPut
assert(sType &&
"SDFG Put operation must take a stream type as first parameter.");
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
funcName = stream_emulator_put_memref;
if (sType.getElementType().dyn_cast<mlir::TensorType>().getRank() == 1)
funcName = stream_emulator_put_memref;
else if (sType.getElementType().dyn_cast<mlir::TensorType>().getRank() ==
2)
funcName = stream_emulator_put_memref_batch;
else
return ::mlir::failure();
} else {
assert(sType.getElementType().isa<mlir::IntegerType>() &&
"SDFG streams only support memrefs and integers.");

View File

@@ -82,6 +82,22 @@ mlir::LogicalResult MakeProcess::verify() {
return checkStreams(1, 1);
case ProcessKind::bootstrap:
return checkStreams(2, 1);
case ProcessKind::batched_add_eint:
return checkStreams(2, 1);
case ProcessKind::batched_add_eint_int:
return checkStreams(2, 1);
case ProcessKind::batched_add_eint_int_cst:
return checkStreams(2, 1);
case ProcessKind::batched_mul_eint_int:
return checkStreams(2, 1);
case ProcessKind::batched_mul_eint_int_cst:
return checkStreams(2, 1);
case ProcessKind::batched_neg_eint:
return checkStreams(1, 1);
case ProcessKind::batched_keyswitch:
return checkStreams(1, 1);
case ProcessKind::batched_bootstrap:
return checkStreams(2, 1);
}
return mlir::failure();

View File

@@ -69,11 +69,12 @@ mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
}
char stream_emulator_get_memref[] = "stream_emulator_get_memref";
char stream_emulator_get_memref_batch[] = "stream_emulator_get_memref_batch";
template <typename Op, char const *funcName>
template <typename Op, char const *funcName, char const *funcName_batch>
struct BufferizableWithCallOpInterface
: public BufferizableOpInterface::ExternalModel<
BufferizableWithCallOpInterface<Op, funcName>, Op> {
BufferizableWithCallOpInterface<Op, funcName, funcName_batch>, Op> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
@@ -103,6 +104,13 @@ struct BufferizableWithCallOpInterface
// able to avoid the copy depending on the stream semantics.
auto resTensorType =
op->getResultTypes()[0].template cast<mlir::TensorType>();
char const *fname;
if (resTensorType.getRank() == 1)
fname = funcName;
else if (resTensorType.getRank() == 2)
fname = funcName_batch;
else
return mlir::failure();
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
resTensorType.getElementType());
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
@@ -117,9 +125,9 @@ struct BufferizableWithCallOpInterface
mlir::FunctionType funcType = mlir::FunctionType::get(
rewriter.getContext(), mlir::ValueRange{operands}.getTypes(),
mlir::TypeRange());
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed())
if (insertForwardDeclaration(op, rewriter, fname, funcType).failed())
return ::mlir::failure();
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
rewriter.create<mlir::func::CallOp>(loc, fname, mlir::TypeRange{},
operands);
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
@@ -133,7 +141,8 @@ void mlir::concretelang::SDFG::registerBufferizableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, SDFG::SDFGDialect *dialect) {
SDFG::Get::attachInterface<
BufferizableWithCallOpInterface<SDFG::Get, stream_emulator_get_memref>>(
BufferizableWithCallOpInterface<SDFG::Get, stream_emulator_get_memref,
stream_emulator_get_memref_batch>>(
*ctx);
});
}

View File

@@ -20,6 +20,15 @@ char mul_eint_int[] = "mul_eint_int";
char neg_eint[] = "neg_eint";
char keyswitch[] = "keyswitch";
char bootstrap[] = "bootstrap";
char batched_add_eint[] = "batched_add_eint";
char batched_add_eint_int[] = "batched_add_eint_int";
char batched_add_eint_int_cst[] = "batched_add_eint_int_cst";
char batched_mul_eint_int[] = "batched_mul_eint_int";
char batched_mul_eint_int_cst[] = "batched_mul_eint_int_cst";
char batched_neg_eint[] = "batched_neg_eint";
char batched_keyswitch[] = "batched_keyswitch";
char batched_bootstrap[] = "batched_bootstrap";
} // namespace
template <typename Op, char const *processName, bool copyAttributes = false>
@@ -87,6 +96,39 @@ void registerSDFGConvertibleOpInterfaceExternalModels(
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BootstrapLweTensorOp, bootstrap,
true>>(*ctx);
mlir::concretelang::Concrete::BatchedAddLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BatchedAddLweTensorOp,
batched_add_eint>>(*ctx);
mlir::concretelang::Concrete::BatchedAddPlaintextLweTensorOp::
attachInterface<ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BatchedAddPlaintextLweTensorOp,
batched_add_eint_int>>(*ctx);
mlir::concretelang::Concrete::BatchedAddPlaintextCstLweTensorOp::
attachInterface<ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BatchedAddPlaintextCstLweTensorOp,
batched_add_eint_int_cst>>(*ctx);
mlir::concretelang::Concrete::BatchedMulCleartextLweTensorOp::
attachInterface<ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BatchedMulCleartextLweTensorOp,
batched_mul_eint_int>>(*ctx);
mlir::concretelang::Concrete::BatchedMulCleartextCstLweTensorOp::
attachInterface<ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BatchedMulCleartextCstLweTensorOp,
batched_mul_eint_int_cst>>(*ctx);
mlir::concretelang::Concrete::BatchedNegateLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BatchedNegateLweTensorOp,
batched_neg_eint>>(*ctx);
mlir::concretelang::Concrete::BatchedKeySwitchLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BatchedKeySwitchLweTensorOp,
batched_keyswitch, true>>(*ctx);
mlir::concretelang::Concrete::BatchedBootstrapLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::Concrete::BatchedBootstrapLweTensorOp,
batched_bootstrap, true>>(*ctx);
});
}
} // namespace SDFG