mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-18 00:21:36 -05:00
feat(compiler): Add support for scf.forall and associated ops in TFHE parametrization
This commit is contained in:
@@ -353,6 +353,9 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::EmptyOp>(
|
||||
target, converter);
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::scf::InParallelOp, TFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::linalg::GenericOp, TFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
@@ -362,6 +365,9 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::scf::ForOp, TFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::scf::ForallOp, TFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::func::ReturnOp, TFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
@@ -372,6 +378,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
&getContext(), converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::linalg::YieldOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::tensor::ParallelInsertSliceOp>(target, converter);
|
||||
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(
|
||||
patterns, target, converter);
|
||||
@@ -389,6 +397,9 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::tensor::ParallelInsertSliceOp>(&getContext(), converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
|
||||
Reference in New Issue
Block a user