From f1f1db923ddb72d0d7466128445ff73e7a9e27c9 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 23 Jun 2022 10:06:48 +0100 Subject: [PATCH] refactor: lower our conv2d to custom linalg named op we can now generate linalg named op with custom operation for add/mul to handle our types --- .../FHETensorOpsToLinalg/TensorOpsToLinalg.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index ff5303730..04f984f5b 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1602,11 +1602,17 @@ struct FHELinalgConv2dToLinalgConv2d auto stridesAttr = rewriter.getI64VectorAttr(stridesInts); auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts); - rewriter.replaceOpWithNewOp< - mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>( + auto addOpAttr = rewriter.getNamedAttr( + "add", rewriter.getStringAttr( + mlir::concretelang::FHE::AddEintOp::getOperationName())); + auto mulOpAttr = rewriter.getNamedAttr( + "mul", rewriter.getStringAttr( + mlir::concretelang::FHE::MulEintIntOp::getOperationName())); + rewriter.replaceOpWithNewOp( conv2dOp, biasInitTensor.getType(), mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr, - dilationsAttr); + dilationsAttr, + llvm::ArrayRef({addOpAttr, mulOpAttr})); return mlir::success(); }; }; @@ -1630,9 +1636,6 @@ void FHETensorOpsToLinalg::runOnOperation() { target.addLegalDialect(); target.addIllegalOp(); target.addIllegalDialect(); - // TODO: this should be removed when we no longer need a custom generated op - // for conv that works on tensors of custom types - target.addLegalOp(); target.addLegalOp();