refactor: lower our conv2d to custom linalg named op

we can now generate linalg named op with custom operation for add/mul to
handle our types
This commit is contained in:
youben11
2022-06-23 10:06:48 +01:00
committed by Ayoub Benaissa
parent 83f2095af5
commit f1f1db923d

View File

@@ -1602,11 +1602,17 @@ struct FHELinalgConv2dToLinalgConv2d
auto stridesAttr = rewriter.getI64VectorAttr(stridesInts);
auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts);
rewriter.replaceOpWithNewOp<
mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>(
auto addOpAttr = rewriter.getNamedAttr(
"add", rewriter.getStringAttr(
mlir::concretelang::FHE::AddEintOp::getOperationName()));
auto mulOpAttr = rewriter.getNamedAttr(
"mul", rewriter.getStringAttr(
mlir::concretelang::FHE::MulEintIntOp::getOperationName()));
rewriter.replaceOpWithNewOp<mlir::linalg::Conv2DNchwFchwOp>(
conv2dOp, biasInitTensor.getType(),
mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr,
dilationsAttr);
dilationsAttr,
llvm::ArrayRef<mlir::NamedAttribute>({addOpAttr, mulOpAttr}));
return mlir::success();
};
};
@@ -1630,9 +1636,6 @@ void FHETensorOpsToLinalg::runOnOperation() {
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();
// TODO: this should be removed when we no longer need a custom generated op
// for conv that works on tensors of custom types
target.addLegalOp<mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>();
target.addLegalOp<bufferization::AllocTensorOp>();