diff --git a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp index f5525df33..1920b4b7e 100644 --- a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp +++ b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp @@ -59,12 +59,14 @@ void HLFHEToMidLFHEPass::runOnOperation() { target.addIllegalDialect(); // Make sure that no ops `linalg.generic` that have illegal types - target.addDynamicallyLegalOp( - [&](mlir::linalg::GenericOp op) { - return (converter.isLegal(op.getOperandTypes()) && - converter.isLegal(op.getResultTypes()) && + target + .addDynamicallyLegalOp( + [&](mlir::Operation *op) { + return ( + converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); - }); + }); // Make sure that func has legal signature target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { @@ -79,6 +81,9 @@ void HLFHEToMidLFHEPass::runOnOperation() { patterns.add>( &getContext(), converter); + patterns.add>( + &getContext(), converter); mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); diff --git a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp index f1099a544..4fb6638f0 100644 --- a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp +++ b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp @@ -82,15 +82,20 @@ void LowLFHEUnparametrizePass::runOnOperation() { LowLFHEUnparametrizeTypeConverter converter; // Conversion of linalg.generic operation - target.addDynamicallyLegalOp( - [&](mlir::linalg::GenericOp op) { - return (converter.isLegal(op.getOperandTypes()) && - converter.isLegal(op.getResultTypes()) && + target + .addDynamicallyLegalOp( + [&](mlir::Operation *op) { + return ( + converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); - }); + }); patterns.add>( &getContext(), converter); + patterns.add>( + &getContext(), converter); // Conversion of function signature and arguments target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index 35ee5eb08..9a2b277ff 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -288,6 +288,9 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() { patterns.add>( &getContext(), converter); + patterns.add>( + &getContext(), converter); mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp index 2ed9929f5..4ad0a3375 100644 --- a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp @@ -58,12 +58,14 @@ void MidLFHEToLowLFHEPass::runOnOperation() { target.addIllegalDialect(); // Make sure that no ops `linalg.generic` that have illegal types - target.addDynamicallyLegalOp( - [&](mlir::linalg::GenericOp op) { - return (converter.isLegal(op.getOperandTypes()) && - converter.isLegal(op.getResultTypes()) && + target + .addDynamicallyLegalOp( + [&](mlir::Operation *op) { + return ( + converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); - }); + }); // Make sure that func has legal signature target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { @@ -78,6 +80,9 @@ void MidLFHEToLowLFHEPass::runOnOperation() { patterns.add>( &getContext(), converter); + patterns.add>( + &getContext(), converter); mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter);