mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Add support for scf.forall and associated ops in TFHE to Concrete pass
This commit is contained in:
@@ -522,7 +522,7 @@ struct InsertSliceOpPattern : public mlir::OpConversionPattern<OpTy> {
|
||||
adaptor.getStaticStrides(), adaptor.getStrides(), rewriter);
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.insert_slice with the new one
|
||||
// replace insert slice-like operation with the new one
|
||||
rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, adaptor.getSource(),
|
||||
adaptor.getDest(), offsets, sizes,
|
||||
strides);
|
||||
@@ -779,6 +779,20 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalOp<mlir::scf::ForallOp>(
|
||||
[&](mlir::scf::ForallOp op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()) &&
|
||||
converter.isLegal(op.getOutputs().getTypes()));
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalOp<mlir::scf::InParallelOp>(
|
||||
[&](mlir::scf::InParallelOp op) {
|
||||
return converter.isLegal(&op.getBodyRegion());
|
||||
});
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
@@ -858,6 +872,7 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
// types
|
||||
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern,
|
||||
InsertSliceOpPattern<mlir::tensor::InsertSliceOp>,
|
||||
InsertSliceOpPattern<mlir::tensor::ParallelInsertSliceOp>,
|
||||
InsertOpPattern, FromElementsOpPattern>(&getContext(),
|
||||
converter);
|
||||
// Add patterns to rewrite some of tensor ops that were introduced by the
|
||||
@@ -874,19 +889,21 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
target.addDynamicallyLegalOp<
|
||||
mlir::tensor::YieldOp, mlir::scf::YieldOp, mlir::tensor::GenerateOp,
|
||||
mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp,
|
||||
mlir::tensor::InsertSliceOp, mlir::tensor::ExpandShapeOp,
|
||||
mlir::tensor::CollapseShapeOp, mlir::bufferization::AllocTensorOp,
|
||||
mlir::tensor::EmptyOp>([&](mlir::Operation *op) {
|
||||
return converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getOperandTypes());
|
||||
});
|
||||
mlir::tensor::InsertSliceOp, mlir::tensor::ParallelInsertSliceOp,
|
||||
mlir::tensor::ExpandShapeOp, mlir::tensor::CollapseShapeOp,
|
||||
mlir::tensor::EmptyOp, mlir::bufferization::AllocTensorOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getOperandTypes());
|
||||
});
|
||||
|
||||
// rewrite scf for loops if working on illegal types
|
||||
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::ForOp>>(&getContext(), converter);
|
||||
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
|
||||
converter);
|
||||
mlir::scf::ForOp>,
|
||||
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::ForallOp>,
|
||||
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::InParallelOp>>(&getContext(), converter);
|
||||
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
|
||||
target, converter);
|
||||
|
||||
Reference in New Issue
Block a user