diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 2a0e30e90..564daf445 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -853,6 +853,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { mlir::scf::ForallOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::EmptyOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::FromElementsOp>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::DimOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::ParallelInsertSliceOp, true>>(&getContext(), converter); @@ -875,6 +879,8 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::tensor::ParallelInsertSliceOp>(target, converter); @@ -910,6 +916,9 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 00f27ce9a..4f9942f3f 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -348,6 +348,17 @@ void TFHEGlobalParametrizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + &getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + + patterns.add>(&getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + patterns.add>( &getContext(), converter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp index 8b739718d..f8288d7c5 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEKeyNormalization/TFHEKeyNormalization.cpp @@ -371,6 +371,12 @@ void TFHEKeyNormalizationPass::runOnOperation() { mlir::concretelang::addDynamicallyLegalTypeOp( target, typeConverter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + &getContext(), typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, typeConverter); + patterns.add>( &getContext(), typeConverter); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index 7e57cc6c8..4b35e8ef4 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -926,11 +926,11 @@ void TFHEToConcretePass::runOnOperation() { mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp, mlir::tensor::InsertSliceOp, mlir::tensor::ParallelInsertSliceOp, mlir::tensor::ExpandShapeOp, mlir::tensor::CollapseShapeOp, - mlir::tensor::EmptyOp, mlir::bufferization::AllocTensorOp>( - [&](mlir::Operation *op) { - return converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getOperandTypes()); - }); + mlir::tensor::EmptyOp, mlir::tensor::FromElementsOp, mlir::tensor::DimOp, + mlir::bufferization::AllocTensorOp>([&](mlir::Operation *op) { + return converter.isLegal(op->getResultTypes()) && + converter.isLegal(op->getOperandTypes()); + }); // rewrite scf for loops if working on illegal types patterns.add, mlir::concretelang::TypeConvertingReinstantiationPattern< - mlir::tensor::EmptyOp, true>>(&getContext(), converter); + mlir::tensor::EmptyOp, true>, + mlir::concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::DimOp>>(&getContext(), converter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter);