mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
fix(compiler): Use FHE.zero_tensor instead of bufferization.alloc_tensor as alloc_tensor explictly has a alloc semantic, so it cannot be eliminated by dce
This commit is contained in:
committed by
Quentin Bourgerie
parent
f4099936e2
commit
d71201ff8c
@@ -265,7 +265,7 @@ struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
|
||||
mlir::RankedTensorType rhsTy = ((mlir::Type)linalgOp.getRhs().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(
|
||||
linalgOp.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// Create the affine #maps_0
|
||||
@@ -424,8 +424,8 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, lookup.getResult());
|
||||
};
|
||||
|
||||
auto output = rewriter.create<bufferization::AllocTensorOp>(
|
||||
loc, resultTy, mlir::ValueRange{});
|
||||
auto output =
|
||||
rewriter.create<FHE::ZeroTensorOp>(loc, resultTy, mlir::ValueRange{});
|
||||
|
||||
// Create the `linalg.g eneric` op
|
||||
Types resTys{resultTy};
|
||||
@@ -508,7 +508,7 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
mlir::RankedTensorType lutsTy = getRankedTensorType(luts);
|
||||
auto lutElmtTy = lutsTy.getElementType();
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(
|
||||
fheLinalgLutOp.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
auto lutsShape = lutsTy.getShape();
|
||||
@@ -655,7 +655,7 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
|
||||
((mlir::Type)lutOp.getT().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(
|
||||
lutOp.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// Create the affine #maps_0
|
||||
@@ -756,7 +756,7 @@ struct FHELinalgNegEintToLinalgGeneric
|
||||
.cast<mlir::RankedTensorType>();
|
||||
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(
|
||||
negEintOp.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// Create the affine #maps_0
|
||||
@@ -1985,8 +1985,8 @@ struct FHELinalgToSignedToLinalgGeneric
|
||||
mlir::RankedTensorType resultTy =
|
||||
op->getResult(0).getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(op.getLoc(), resultTy,
|
||||
mlir::ValueRange{});
|
||||
|
||||
llvm::SmallVector<mlir::AffineMap, 2> maps{
|
||||
mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(),
|
||||
@@ -2074,8 +2074,8 @@ struct FHELinalgToUnsignedToLinalgGeneric
|
||||
mlir::RankedTensorType resultTy =
|
||||
op->getResult(0).getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(op.getLoc(), resultTy,
|
||||
mlir::ValueRange{});
|
||||
|
||||
llvm::SmallVector<mlir::AffineMap, 2> maps{
|
||||
mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(),
|
||||
@@ -2161,8 +2161,8 @@ struct FHELinalgRoundToLinalgGeneric
|
||||
auto inputTy = op.getInput().getType().cast<mlir::RankedTensorType>();
|
||||
auto outputTy = op.getOutput().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
auto buffer = rewriter.create<bufferization::AllocTensorOp>(
|
||||
loc, outputTy, mlir::ValueRange{});
|
||||
auto buffer =
|
||||
rewriter.create<FHE::ZeroTensorOp>(loc, outputTy, mlir::ValueRange{});
|
||||
|
||||
auto maps = llvm::SmallVector<mlir::AffineMap, 2>{
|
||||
mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(),
|
||||
@@ -2222,8 +2222,6 @@ void FHETensorOpsToLinalg::runOnOperation() {
|
||||
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
|
||||
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();
|
||||
|
||||
target.addLegalOp<bufferization::AllocTensorOp>();
|
||||
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
patterns.insert<DotToLinalgGeneric<mlir::concretelang::FHELinalg::Dot,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @apply_lookup_table(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>, %[[Varg1:.*]]: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = {{\[}}#map, #map{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[Varg2:.*]]: !FHE.eint<2>, %[[Varg3:.*]]: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.apply_lookup_table"(%[[Varg2]], %[[Varg1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//CHECK: func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
|
||||
//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x4x!FHE.eint<2>>
|
||||
//CHECK: %[[MEM:.*]] = "FHE.zero_tensor"() : () -> tensor<4x4x!FHE.eint<2>>
|
||||
//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x4x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x4x!FHE.eint<2>>) {
|
||||
//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>):
|
||||
//CHECK: %[[INDEXA:.*]] = linalg.index 0 : index
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//CHECK: func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x3x!FHE.eint<2>>
|
||||
//CHECK: %[[MEM:.*]] = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<2>>
|
||||
//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x3x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x3x!FHE.eint<2>>) {
|
||||
//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>):
|
||||
//CHECK: %[[INDEX:.*]] = linalg.index 1 : index
|
||||
|
||||
@@ -22,7 +22,7 @@ func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.ein
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<16> : tensor<1xi7>
|
||||
// CHECK-NEXT: %[[v2:.*]] = bufferization.alloc_tensor() : tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.esint<6>, %[[aa1:.*]]: i7, %[[aa2:.*]]: !FHE.esint<6>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.sub_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.esint<6>, i7) -> !FHE.esint<6>
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @neg_eint(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = {{\[}}#map, #map{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.eint<2>, %[[Varg2:.*]]: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.neg_eint"(%[[Varg1]]) : (!FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
// CHECK: #[[m0:.*]] = affine_map<(d0) -> (d0)>
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<5x!FHE.eint<8>>) -> tensor<5x!FHE.eint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = bufferization.alloc_tensor() : tensor<5x!FHE.eint<6>>
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<6>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m0]]], iterator_types = ["parallel"]} ins(%[[a0]] : tensor<5x!FHE.eint<8>>) outs(%[[v0]] : tensor<5x!FHE.eint<6>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[i0:.*]]: !FHE.eint<8>, %[[o0:.*]]: !FHE.eint<6>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.round"(%[[i0]]) : (!FHE.eint<8>) -> !FHE.eint<6>
|
||||
@@ -23,7 +23,7 @@ func.func @main(%arg0: tensor<5x!FHE.eint<8>>) -> tensor<5x!FHE.eint<6>> {
|
||||
// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<8>>) -> tensor<2x3x4x!FHE.eint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<6>>
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.eint<6>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x4x!FHE.eint<8>>) outs(%[[v0]] : tensor<2x3x4x!FHE.eint<6>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[i0:.*]]: !FHE.eint<8>, %[[o0:.*]]: !FHE.eint<6>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.round"(%[[i0]]) : (!FHE.eint<8>) -> !FHE.eint<6>
|
||||
@@ -41,7 +41,7 @@ func.func @main(%arg0: tensor<2x3x4x!FHE.eint<8>>) -> tensor<2x3x4x!FHE.eint<6>>
|
||||
// CHECK: #[[m0:.*]] = affine_map<(d0) -> (d0)>
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<5x!FHE.esint<8>>) -> tensor<5x!FHE.esint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = bufferization.alloc_tensor() : tensor<5x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<5x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m0]]], iterator_types = ["parallel"]} ins(%[[a0]] : tensor<5x!FHE.esint<8>>) outs(%[[v0]] : tensor<5x!FHE.esint<6>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[i0:.*]]: !FHE.esint<8>, %[[o0:.*]]: !FHE.esint<6>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.round"(%[[i0]]) : (!FHE.esint<8>) -> !FHE.esint<6>
|
||||
@@ -59,7 +59,7 @@ func.func @main(%arg0: tensor<5x!FHE.esint<8>>) -> tensor<5x!FHE.esint<6>> {
|
||||
// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
||||
// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.esint<8>>) -> tensor<2x3x4x!FHE.esint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.esint<6>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x4x!FHE.esint<8>>) outs(%[[v0]] : tensor<2x3x4x!FHE.esint<6>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[i0:.*]]: !FHE.esint<8>, %[[o0:.*]]: !FHE.esint<6>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.round"(%[[i0]]) : (!FHE.esint<8>) -> !FHE.esint<6>
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<2>>
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.esint<2>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.esint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.eint<2>, %[[Varg2:.*]]: !FHE.esint<2>):
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.to_signed"(%[[Varg1]]) : (!FHE.eint<2>) -> !FHE.esint<2>
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[Varg0]] : tensor<2x3x4x!FHE.esint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.esint<2>, %[[Varg2:.*]]: !FHE.eint<2>):
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.to_unsigned"(%[[Varg1]]) : (!FHE.esint<2>) -> !FHE.eint<2>
|
||||
|
||||
Reference in New Issue
Block a user