feat(compiler): Add support for scf.forall and associated ops in TFHE to Concrete pass

This commit is contained in:
Andi Drebes
2023-04-05 16:34:50 +02:00
parent 1aa6ac1ff5
commit 99f262cc8d

View File

@@ -522,7 +522,7 @@ struct InsertSliceOpPattern : public mlir::OpConversionPattern<OpTy> {
adaptor.getStaticStrides(), adaptor.getStrides(), rewriter);
strides.push_back(rewriter.getI64IntegerAttr(1));
// replace tensor.insert_slice with the new one
// replace insert slice-like operation with the new one
rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, adaptor.getSource(),
adaptor.getDest(), offsets, sizes,
strides);
@@ -779,6 +779,20 @@ void TFHEToConcretePass::runOnOperation() {
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
target.addDynamicallyLegalOp<mlir::scf::ForallOp>(
[&](mlir::scf::ForallOp op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()) &&
converter.isLegal(op.getOutputs().getTypes()));
});
target.addDynamicallyLegalOp<mlir::scf::InParallelOp>(
[&](mlir::scf::InParallelOp op) {
return converter.isLegal(&op.getBodyRegion());
});
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
@@ -858,6 +872,7 @@ void TFHEToConcretePass::runOnOperation() {
// types
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern,
InsertSliceOpPattern<mlir::tensor::InsertSliceOp>,
InsertSliceOpPattern<mlir::tensor::ParallelInsertSliceOp>,
InsertOpPattern, FromElementsOpPattern>(&getContext(),
converter);
// Add patterns to rewrite some of tensor ops that were introduced by the
@@ -874,19 +889,21 @@ void TFHEToConcretePass::runOnOperation() {
target.addDynamicallyLegalOp<
mlir::tensor::YieldOp, mlir::scf::YieldOp, mlir::tensor::GenerateOp,
mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp,
mlir::tensor::InsertSliceOp, mlir::tensor::ExpandShapeOp,
mlir::tensor::CollapseShapeOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::EmptyOp>([&](mlir::Operation *op) {
return converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getOperandTypes());
});
mlir::tensor::InsertSliceOp, mlir::tensor::ParallelInsertSliceOp,
mlir::tensor::ExpandShapeOp, mlir::tensor::CollapseShapeOp,
mlir::tensor::EmptyOp, mlir::bufferization::AllocTensorOp>(
[&](mlir::Operation *op) {
return converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getOperandTypes());
});
// rewrite scf for loops if working on illegal types
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::ForOp>>(&getContext(), converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
converter);
mlir::scf::ForOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::ForallOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::InParallelOp>>(&getContext(), converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
target, converter);