From b3368027d05abe2e21b26aea16aa0570907f4257 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 11 Feb 2022 13:34:40 +0100 Subject: [PATCH] refactor(compiler): Move FHELinalg.zero to FHE.zero_tensor and add zero and sero_tensor in TFHE and Concrete dialects --- .../Utils/GenericOpTypeConversionPattern.h | 26 ++++ .../Dialect/Concrete/IR/ConcreteOps.td | 7 + .../concretelang/Dialect/FHE/IR/FHEOps.td | 18 +++ .../Dialect/FHELinalg/IR/FHELinalgOps.td | 17 --- .../concretelang/Dialect/TFHE/IR/TFHEOps.td | 7 + .../TensorOpsToLinalg.cpp | 111 +++----------- .../lib/Conversion/FHEToTFHE/FHEToTFHE.cpp | 6 +- .../TFHEGlobalParametrization.cpp | 3 + .../TFHEToConcrete/TFHEToConcrete.cpp | 6 +- compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 2 +- .../Dialect/FHELinalg/Transforms/Tiling.cpp | 4 +- .../FHELinalgToLinalg/concat.mlir | 54 ++----- .../FHELinalgToLinalg/matmul_eint_int.mlir | 6 +- .../FHELinalgToLinalg/sum.mlir | 144 +++--------------- .../FHELinalgToLinalg/zero.mlir | 14 -- .../tests/Dialect/FHE/FHE/Analysis/MANP.mlir | 10 ++ .../Dialect/FHE/FHE/Analysis/MANP_linalg.mlir | 20 +-- .../tests/Dialect/FHE/FHE/ops.invalid.mlir | 15 ++ compiler/tests/Dialect/FHE/FHE/ops.mlir | 18 +++ .../FHELinalg/FHELinalg/ops.invalid.mlir | 20 --- .../Dialect/FHELinalg/FHELinalg/ops.mlir | 22 --- .../FHELinalg/tensor-ops-to-linalg.mlir | 34 ++--- .../Dialect/FHELinalg/FHELinalg/tiling.mlir | 4 +- compiler/tests/unittest/end_to_end_jit_fhe.cc | 37 +++++ .../unittest/end_to_end_jit_fhelinalg.cc | 38 ----- 25 files changed, 227 insertions(+), 416 deletions(-) delete mode 100644 compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/zero.mlir create mode 100644 compiler/tests/Dialect/FHE/FHE/ops.invalid.mlir diff --git a/compiler/include/concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h b/compiler/include/concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h index da3b46297..edbe6c8c8 100644 --- a/compiler/include/concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h +++ b/compiler/include/concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h @@ -51,6 +51,32 @@ private: mlir::TypeConverter &converter; }; +template +struct GenericTypeAndOpConverterPattern : public mlir::OpRewritePattern { + GenericTypeAndOpConverterPattern(mlir::MLIRContext *context, + mlir::TypeConverter &converter, + mlir::PatternBenefit benefit = 100) + : mlir::OpRewritePattern(context, benefit), converter(converter) {} + + mlir::LogicalResult + matchAndRewrite(OldOp oldOp, mlir::PatternRewriter &rewriter) const override { + // Rewrite results + mlir::SmallVector resultTypes(oldOp->getNumResults()); + { + for (unsigned i = 0; i < oldOp->getNumResults(); i++) { + auto result = oldOp->getResult(i); + resultTypes[i] = converter.convertType(result.getType()); + } + } + rewriter.replaceOpWithNewOp(oldOp, resultTypes, + oldOp->getOperands()); + return mlir::success(); + } + +private: + mlir::TypeConverter &converter; +}; + template void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target, mlir::TypeConverter &typeConverter) { diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index 1e18e0b7d..6a36396a3 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -17,6 +17,13 @@ def ZeroLWEOp : Concrete_Op<"zero"> { let results = (outs LweCiphertextType:$out); } +def ZeroTensorLWEOp : Concrete_Op<"zero_tensor"> { + let summary = "Returns a trivial encyption of 0"; + + let arguments = (ins); + let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); +} + def AddLweCiphertextsOp : Concrete_Op<"add_lwe_ciphertexts"> { let summary = "Returns the sum of 2 lwe ciphertexts"; diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 643372cd7..0665832de 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -34,6 +34,24 @@ def ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> { let results = (outs EncryptedIntegerType:$out); } + +def ZeroTensorOp : FHE_Op<"zero_tensor", []> { + let summary = "Creates a new tensor with all elements initialized to an encrypted zero."; + + let description = [{ + Creates a new tensor with the shape specified in the result type and initializes its elements with an encrypted zero. + + Example: + ```mlir + %tensor = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<4>> + ``` + }]; + + let arguments = (ins); + + let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); +} + def AddEintIntOp : FHE_Op<"add_eint_int"> { let summary = "Adds an encrypted integer and a clear integer"; diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index f83617429..e82d86cbe 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -499,23 +499,6 @@ def MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> { }]; } -def ZeroOp : FHELinalg_Op<"zero", []> { - let summary = "Creates a new tensor with all elements initialized to an encrypted zero."; - - let description = [{ - Creates a new tensor with the shape specified in the result type and initializes its elements with an encrypted zero. - - Example: - ```mlir - %tensor = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<4>> - ``` - }]; - - let arguments = (ins); - - let results = (outs Type.predicate, HasStaticShapePred]>>:$aggregate); -} - def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> { let summary = "Returns the sum of elements of a tensor of encrypted integers along specified axes."; diff --git a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index 264567f4c..cb60fae20 100644 --- a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -25,6 +25,13 @@ def ZeroGLWEOp : TFHE_Op<"zero"> { let results = (outs GLWECipherTextType:$out); } +def ZeroTensorGLWEOp : TFHE_Op<"zero_tensor"> { + let summary = "Returns a tensor of trivial encyption of 0"; + + let arguments = (ins); + let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); +} + def AddGLWEIntOp : TFHE_Op<"add_glwe_int"> { let summary = "Returns the sum of a clear integer and a lwe ciphertext"; diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 13cefce59..710f6751f 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -52,14 +52,13 @@ struct DotToLinalgGeneric // // becomes: // - // %0 = "FHE.zero"() : () -> !FHE.eint<0> - // %1 = tensor.from_elements %0 : tensor<1x!FHE.eint<0>> - // %2 = linalg.generic { + // %0 = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<0>> + // %1 = linalg.generic { // indexing_maps = [#map0, #map0, #map1], // iterator_types = ["reduction"] // } // ins(%arg0, %arg1 : tensor<2x!FHE.eint<0>>, tensor<2xi32>) - // outs(%1 : tensor<1x!FHE.eint<0>>) { + // outs(%0 : tensor<1x!FHE.eint<0>>) { // ^bb0(%arg2: !FHE.eint<0>, %arg3: i32, %arg4: !FHE.eint<0>): // %4 = "FHE.mul_eint_int"(%arg2, %arg3) : // (!FHE.eint<0>, i32) -> !FHE.eint<0> @@ -71,28 +70,19 @@ struct DotToLinalgGeneric // } -> tensor<1x!FHE.eint<0>> // // %c0 = constant 0 : index - // %o = tensor.extract %2[%c0] : tensor<1x!FHE.eint<0>> + // %o = tensor.extract %1[%c0] : tensor<1x!FHE.eint<0>> // ::mlir::LogicalResult matchAndRewrite(::mlir::concretelang::FHELinalg::Dot dotOp, ::mlir::PatternRewriter &rewriter) const override { - // Zero value to initialize accumulator - mlir::Value zeroCst = rewriter.create( - dotOp.getLoc(), - dotOp.lhs().getType().cast().getElementType()); - // Create one-dimensional accumulator with a single element - // (`tensor.from_elements` does not allow for the creation of 0d - // tensors) - mlir::tensor::FromElementsOp feOp = - rewriter.create(dotOp.getLoc(), zeroCst); - - mlir::Value accu = feOp.getResult(); + auto zeroTensorOp = rewriter.create( + dotOp.getLoc(), mlir::RankedTensorType::get({1}, dotOp.getType())); // Create `linalg.generic` op - llvm::SmallVector resTypes{accu.getType()}; + llvm::SmallVector resTypes{zeroTensorOp.getType()}; llvm::SmallVector ins{dotOp.lhs(), dotOp.rhs()}; - llvm::SmallVector outs{accu}; + llvm::SmallVector outs{zeroTensorOp}; llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), @@ -334,7 +324,8 @@ llvm::SmallVector parallelIteratorType(int n) { // %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64> // %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64> // %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]]) -// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, +// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension +// = -1 : i32, // levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = // -1 : i32, polynomialSize = -1 : i32} // : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> @@ -808,11 +799,7 @@ struct FHELinalgNegEintToLinalgGeneric // indexing_maps = #maps_0, // iterator_types = ["parallel", "parallel", "reduction"] // } -// %init = linalg.generate { -// ^bb0(%i : index, %j : index, %k : index): -// %z = "FHE.zero" : () -> !FHE.eint<2> -// linalg.yield %z -// }: tensor> +// %init = FHE.zero_tensor : tensor> // linalg.generic #attributes_0 // ins(%A, %B : tensor>, // tensor) @@ -847,20 +834,9 @@ struct FHELinalgMatmulToLinalgGeneric ((mlir::Type)matmulOp->getResult(0).getType()) .cast(); mlir::Type resultElementTy = resultTy.getElementType(); - // Create tensor.generate for initial value - auto generateBody = [&](mlir::OpBuilder &nestedBuilder, - mlir::Location nestedLoc, - mlir::ValueRange blockArgs) { - // %z = "FHE.zero" : () -> !FHE.eint<2> - mlir::concretelang::FHE::ZeroEintOp zeroOp = - nestedBuilder.create( - matmulLoc, resultElementTy); - // linalg.yield %z : !FHE.eint

- nestedBuilder.create(matmulLoc, - zeroOp.getResult()); - }; - mlir::tensor::GenerateOp init = rewriter.create( - matmulLoc, (mlir::Type)resultTy, mlir::ValueRange{}, generateBody); + // Create the initial value, `FHE.zero_tensor` + auto init = rewriter.create( + matmulLoc, resultTy); // Create the affine #maps_0 llvm::SmallVector maps{ // (m, n, p) -> (m, p), @@ -922,54 +898,6 @@ private: createMulOp; }; -// This rewrite pattern transforms any instance of operators -// `FHELinalg.zero` to an instance of `linalg.generate` with an -// appropriate region yielding a zero value. -// -// Example: -// -// %out = "FHELinalg.zero"() : () -> tensor> -// -// becomes: -// -// %0 = tensor.generate { -// ^bb0(%arg2: index, %arg3: index): -// %zero = "FHE.zero"() : () -> !FHE.eint

-// tensor.yield %zero : !FHE.eint

-// } : tensor> -// -struct FHELinalgZeroToLinalgGenerate - : public mlir::OpRewritePattern { - FHELinalgZeroToLinalgGenerate(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : ::mlir::OpRewritePattern( - context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::FHELinalg::ZeroOp zeroOp, - ::mlir::PatternRewriter &rewriter) const override { - mlir::RankedTensorType resultTy = - zeroOp->getResult(0).getType().cast(); - - auto generateBody = [&](mlir::OpBuilder &nestedBuilder, - mlir::Location nestedLoc, - mlir::ValueRange blockArgs) { - mlir::Value zeroScalar = - nestedBuilder.create( - zeroOp.getLoc(), resultTy.getElementType()); - nestedBuilder.create(zeroOp.getLoc(), zeroScalar); - }; - mlir::tensor::GenerateOp generateOp = - rewriter.create( - zeroOp.getLoc(), resultTy, mlir::ValueRange{}, generateBody); - - rewriter.replaceOp(zeroOp, {generateOp.getResult()}); - - return ::mlir::success(); - }; -}; - // This rewrite pattern transforms any instance of operators // `FHELinalg.sum` to an instance of `linalg.generic`. // @@ -983,7 +911,7 @@ struct FHELinalgZeroToLinalgGenerate // #map0 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)> // #map1 = affine_map<(i0, i1, ..., iN) -> (0)> // -// %accumulator = "FHELinalg.zero"() : () -> tensor<1x!FHE.eint<7>> +// %accumulator = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // %accumulation = linalg.generic // { // indexing_maps = [#map0, #map1], @@ -1028,7 +956,7 @@ struct SumToLinalgGeneric if (size == 0) { mlir::Value result; if (outputIsTensor) { - result = rewriter.create(location, outputType) + result = rewriter.create(location, outputType) .getResult(); } else { result = rewriter.create(location, outputType) @@ -1058,7 +986,7 @@ struct SumToLinalgGeneric } mlir::Value accumulator = - rewriter.create(location, accumulatorType) + rewriter.create(location, accumulatorType) .getResult(); auto ins = llvm::SmallVector{input}; @@ -1137,7 +1065,7 @@ struct SumToLinalgGeneric // // becomes: // -// %empty = "FHELinalg.zero"() : () -> tensor<2x7x!FHE.eint<4>> +// %empty = "FHE.zero_tensor"() : () -> tensor<2x7x!FHE.eint<4>> // // %x_copied = tensor.insert_slice %x into %empty[0, 0] [2, 3] [1, 1] // : tensor<2x3x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>> @@ -1165,7 +1093,7 @@ struct ConcatRewritePattern size_t outputDimensions = outputShape.size(); mlir::Value result = - rewriter.create(location, outputType).getResult(); + rewriter.create(location, outputType).getResult(); auto offsets = llvm::SmallVector{}; auto sizes = llvm::SmallVector{}; @@ -1299,7 +1227,6 @@ void FHETensorOpsToLinalg::runOnFunction() { patterns.insert(&getContext()); patterns.insert( &getContext()); - patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); diff --git a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp index fb80652b5..fb713925d 100644 --- a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp +++ b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp @@ -88,12 +88,12 @@ void FHEToTFHEPass::runOnOperation() { patterns.add>( &getContext(), converter); - patterns.add>( - &getContext(), converter); patterns.add< RegionOpTypeConverterPattern>( &getContext(), converter); + patterns.add>(&getContext(), converter); mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target, converter); diff --git a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 20be37eb7..e2a1a5a97 100644 --- a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -262,6 +262,9 @@ void populateWithTFHEOpTypeConversionPatterns( mlir::concretelang::V0Parameter &v0Parameter) { populateWithTFHEOpTypeConversionPattern( patterns, target, typeConverter); + populateWithTFHEOpTypeConversionPattern< + mlir::concretelang::TFHE::ZeroTensorGLWEOp>(patterns, target, + typeConverter); populateWithTFHEOpTypeConversionPattern< mlir::concretelang::TFHE::AddGLWEIntOp>(patterns, target, typeConverter); populateWithTFHEOpTypeConversionPattern( diff --git a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index 221d0de27..fa72a5def 100644 --- a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -81,12 +81,12 @@ void TFHEToConcretePass::runOnOperation() { mlir::OwningRewritePatternList patterns(&getContext()); populateWithGeneratedTFHEToConcrete(patterns); + patterns.add>(&getContext(), converter); patterns.add>( &getContext(), converter); - patterns.add>( - &getContext(), converter); patterns.add>( &getContext(), converter); diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 4e9c8d07f..977c4ed83 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -926,7 +926,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); } else if (llvm::isa(op) || - llvm::isa(op) || + llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; } diff --git a/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp b/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp index eadfb1fca..46dc2a282 100644 --- a/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp +++ b/compiler/lib/Dialect/FHELinalg/Transforms/Tiling.cpp @@ -189,8 +189,8 @@ public: mlir::Value B = op.getOperand(1); // Initialization of the output matrix with zeros - mlir::concretelang::FHELinalg::ZeroOp Cinit = - rewriter.create( + mlir::concretelang::FHE::ZeroTensorOp Cinit = + rewriter.create( origLoc, op.getResult().getType()); mlir::TensorType ATTy = A.getType().cast(); diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/concat.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/concat.mlir index fa6f9e5c9..a9349272d 100644 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/concat.mlir +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/concat.mlir @@ -3,11 +3,7 @@ // ----- // CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<7x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<7x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0] [3] [1] : tensor<3x!FHE.eint<7>> into tensor<7x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3] [4] [1] : tensor<4x!FHE.eint<7>> into tensor<7x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<7x!FHE.eint<7>> @@ -20,11 +16,7 @@ func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x! // ----- // CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<7x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<7x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0] [3] [1] : tensor<3x!FHE.eint<7>> into tensor<7x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3] [4] [1] : tensor<4x!FHE.eint<7>> into tensor<7x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<7x!FHE.eint<7>> @@ -37,11 +29,7 @@ func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x! // ----- // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<7x4x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [3, 4] [1, 1] : tensor<3x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3, 0] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<7x4x!FHE.eint<7>> @@ -54,11 +42,7 @@ func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor // ----- // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<7x4x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [3, 4] [1, 1] : tensor<3x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3, 0] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<7x4x!FHE.eint<7>> @@ -72,11 +56,7 @@ func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor // ----- // CHECK: func @main(%[[a0:.*]]: tensor<4x3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<4x7x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x7x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [4, 3] [1, 1] : tensor<4x3x!FHE.eint<7>> into tensor<4x7x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 3] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<4x7x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<4x7x!FHE.eint<7>> @@ -89,11 +69,7 @@ func @main(%x: tensor<4x3x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor // ----- // CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x3x4x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][2, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<4x3x4x!FHE.eint<7>> @@ -106,11 +82,7 @@ func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> te // ----- // CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x3x4x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][2, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<4x3x4x!FHE.eint<7>> @@ -123,11 +95,7 @@ func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> te // ----- // CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<2x6x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x6x4x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x6x4x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 3, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x6x4x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<2x6x4x!FHE.eint<7>> @@ -140,11 +108,7 @@ func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> te // ----- // CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<2x3x8x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x8x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x3x8x!FHE.eint<7>> // CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 0, 4] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x3x8x!FHE.eint<7>> // CHECK-NEXT: return %[[v2]] : tensor<2x3x8x!FHE.eint<7>> diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/matmul_eint_int.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/matmul_eint_int.mlir index cd14ab63f..e8fc7c689 100644 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/matmul_eint_int.mlir +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/matmul_eint_int.mlir @@ -5,11 +5,7 @@ // CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-NEXT: module { // CHECK-NEXT: func @matmul_eint_int(%arg0: tensor<3x4x!FHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>> { -// CHECK-NEXT: %0 = tensor.generate { -// CHECK-NEXT: ^bb0(%arg2: index, %arg3: index): // no predecessors -// CHECK-NEXT: %2 = "FHE.zero"() : () -> !FHE.eint<2> -// CHECK-NEXT: tensor.yield %2 : !FHE.eint<2> -// CHECK-NEXT: } : tensor<3x2x!FHE.eint<2>> +// CHECK-NEXT: %0 = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<2>> // CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x4x!FHE.eint<2>>, tensor<4x2xi3>) outs(%0 : tensor<3x2x!FHE.eint<2>>) { // CHECK-NEXT: ^bb0(%arg2: !FHE.eint<2>, %arg3: i3, %arg4: !FHE.eint<2>): // no predecessors // CHECK-NEXT: %2 = "FHE.mul_eint_int"(%arg2, %arg3) : (!FHE.eint<2>, i3) -> !FHE.eint<2> diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum.mlir index dcc89dfb7..b0e85f263 100644 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum.mlir +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum.mlir @@ -25,11 +25,7 @@ func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> !FHE.eint<7> { // ----- // CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<3x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x4x!FHE.eint<7>> // CHECK-NEXT: return %[[v0]] : tensor<3x4x!FHE.eint<7>> // CHECK-NEXT: } func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> { @@ -40,11 +36,7 @@ func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> { // ----- // CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<3x1x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x1x4x!FHE.eint<7>> // CHECK-NEXT: return %[[v0]] : tensor<3x1x4x!FHE.eint<7>> // CHECK-NEXT: } func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> { @@ -55,11 +47,7 @@ func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> { // ----- // CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<3x0x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x0x!FHE.eint<7>> // CHECK-NEXT: return %[[v0]] : tensor<3x0x!FHE.eint<7>> // CHECK-NEXT: } func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> { @@ -70,11 +58,7 @@ func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> { // ----- // CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<3x0x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x0x1x!FHE.eint<7>> // CHECK-NEXT: return %[[v0]] : tensor<3x0x1x!FHE.eint<7>> // CHECK-NEXT: } func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> { @@ -88,11 +72,7 @@ func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> // CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -113,11 +93,7 @@ func @main(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> // CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -138,11 +114,7 @@ func @main(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> // CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -161,11 +133,7 @@ func @main(%arg0: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> // CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -184,11 +152,7 @@ func @main(%arg0: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -209,11 +173,7 @@ func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, 0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -232,11 +192,7 @@ func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<4x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -255,11 +211,7 @@ func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x4x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x4x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -278,11 +230,7 @@ func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<3x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -301,11 +249,7 @@ func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<3x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -324,11 +268,7 @@ func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -349,11 +289,7 @@ func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, 0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -372,11 +308,7 @@ func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -397,11 +329,7 @@ func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0, 0, 0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x1x1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -420,11 +348,7 @@ func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x2x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -443,11 +367,7 @@ func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<3x1x2x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x1x2x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x1x2x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -466,11 +386,7 @@ func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<4x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<4x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -489,11 +405,7 @@ func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0, d1, 0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x4x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x4x1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x4x1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -512,11 +424,7 @@ func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> @@ -537,11 +445,7 @@ func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0, 0, 0)> // CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> -// CHECK-NEXT: } : tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x1x!FHE.eint<7>> // CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x1x1x!FHE.eint<7>>) { // CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): // CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/zero.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/zero.mlir deleted file mode 100644 index 9eb002577..000000000 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/zero.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s - -// CHECK: func @zero(%arg0: !FHE.eint<2>) -> tensor<3x2x!FHE.eint<2>> { -// CHECK-NEXT: %[[v0:.*]] = tensor.generate { -// CHECK-NEXT: ^bb0(%arg1: index, %arg2: index): // no predecessors -// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<2> -// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<2> -// CHECK-NEXT: } : tensor<3x2x!FHE.eint<2>> -// CHECK-NEXT: return %[[v0]] : tensor<3x2x!FHE.eint<2>> -// CHECK-NEXT: } -func @zero(%arg0: !FHE.eint<2>) -> tensor<3x2x!FHE.eint<2>> { - %1 = "FHELinalg.zero"(): () -> tensor<3x2x!FHE.eint<2>> - return %1 : tensor<3x2x!FHE.eint<2>> -} diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir index 4da67cdf5..ec5243deb 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir @@ -10,6 +10,16 @@ func @single_zero() -> !FHE.eint<2> // ----- +func @zero() -> tensor<8x!FHE.eint<2>> +{ + // CHECK: %[[ret:.*]] = "FHE.zero_tensor"() {MANP = 1 : ui{{[0-9]+}}} : () -> tensor<8x!FHE.eint<2>> + %0 = "FHE.zero_tensor"() : () -> tensor<8x!FHE.eint<2>> + + return %0 : tensor<8x!FHE.eint<2>> +} + +// ----- + func @single_cst_add_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir index 814a062e2..d15eb2fb2 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir @@ -383,18 +383,8 @@ func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!FHE.eint<2>>) -> tensor<2x3 // ----- -func @zero() -> tensor<8x!FHE.eint<2>> -{ - // CHECK: %[[ret:.*]] = "FHELinalg.zero"() {MANP = 1 : ui{{[0-9]+}}} : () -> tensor<8x!FHE.eint<2>> - %0 = "FHELinalg.zero"() : () -> tensor<8x!FHE.eint<2>> - - return %0 : tensor<8x!FHE.eint<2>> -} - -// ----- - func @sum() -> !FHE.eint<7> { - %0 = "FHELinalg.zero"() : () -> tensor<5x3x4x2x!FHE.eint<7>> + %0 = "FHE.zero_tensor"() : () -> tensor<5x3x4x2x!FHE.eint<7>> // CHECK: MANP = 11 : ui{{[0-9]+}} %1 = "FHELinalg.sum"(%0) : (tensor<5x3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> @@ -494,7 +484,7 @@ func @sum() -> !FHE.eint<7> { // =============================== - %35 = "FHELinalg.zero"() : () -> tensor<2x0x3x!FHE.eint<7>> + %35 = "FHE.zero_tensor"() : () -> tensor<2x0x3x!FHE.eint<7>> // CHECK: MANP = 1 : ui{{[0-9]+}} %36 = "FHELinalg.sum"(%35) : (tensor<2x0x3x!FHE.eint<7>>) -> !FHE.eint<7> @@ -550,15 +540,15 @@ func @sum() -> !FHE.eint<7> { // ----- func @concat() -> tensor<3x!FHE.eint<7>> { - %0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<7>> + %0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<7>> // CHECK: MANP = 2 : ui{{[0-9]+}} %1 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> - %2 = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<7>> + %2 = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<7>> // CHECK: MANP = 3 : ui{{[0-9]+}} %3 = "FHELinalg.sum"(%2) { keep_dims = true } : (tensor<5x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> - %4 = "FHELinalg.zero"() : () -> tensor<10x!FHE.eint<7>> + %4 = "FHE.zero_tensor"() : () -> tensor<10x!FHE.eint<7>> // CHECK: MANP = 4 : ui{{[0-9]+}} %5 = "FHELinalg.sum"(%4) { keep_dims = true } : (tensor<10x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> diff --git a/compiler/tests/Dialect/FHE/FHE/ops.invalid.mlir b/compiler/tests/Dialect/FHE/FHE/ops.invalid.mlir new file mode 100644 index 000000000..71f94afae --- /dev/null +++ b/compiler/tests/Dialect/FHE/FHE/ops.invalid.mlir @@ -0,0 +1,15 @@ +// RUN: concretecompiler --split-input-file --verify-diagnostics --action=roundtrip %s + +func @zero_1D_scalar() -> tensor<4x!FHE.eint<2>> { + // expected-error @+1 {{'FHE.zero_tensor' op}} + %0 = "FHE.zero_tensor"() : () -> !FHE.eint<2> + return %0 : !FHE.eint<2> +} + +// ----- + +func @zero_plaintext() -> tensor<4x9xi32> { + // expected-error @+1 {{'FHE.zero_tensor' op}} + %0 = "FHE.zero_tensor"() : () -> tensor<4x9xi32> + return %0 : tensor<4x9xi32> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/FHE/FHE/ops.mlir b/compiler/tests/Dialect/FHE/FHE/ops.mlir index 8f28a851b..6508bd7ad 100644 --- a/compiler/tests/Dialect/FHE/FHE/ops.mlir +++ b/compiler/tests/Dialect/FHE/FHE/ops.mlir @@ -9,6 +9,24 @@ func @zero() -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK: func @zero_1D() -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @zero_1D() -> tensor<4x!FHE.eint<2>> { + %0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<2>> + return %0 : tensor<4x!FHE.eint<2>> +} + +// CHECK: func @zero_2D() -> tensor<4x9x!FHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x9x!FHE.eint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<4x9x!FHE.eint<2>> +// CHECK-NEXT: } +func @zero_2D() -> tensor<4x9x!FHE.eint<2>> { + %0 = "FHE.zero_tensor"() : () -> tensor<4x9x!FHE.eint<2>> + return %0 : tensor<4x9x!FHE.eint<2>> +} + // CHECK-LABEL: func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.invalid.mlir index 3f537748b..6fd4d88df 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.invalid.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.invalid.mlir @@ -270,23 +270,3 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!FHE.eint<2>>) -> %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> return %1 : tensor<4x2x!FHE.eint<2>> } - -// ----- - -///////////////////////////////////////////////// -// FHELinalg.zero -///////////////////////////////////////////////// - -func @zero_1D_scalar() -> tensor<4x!FHE.eint<2>> { - // expected-error @+1 {{'FHELinalg.zero' op}} - %0 = "FHELinalg.zero"() : () -> !FHE.eint<2> - return %0 : !FHE.eint<2> -} - -// ----- - -func @zero_plaintext() -> tensor<4x9xi32> { - // expected-error @+1 {{'FHELinalg.zero' op}} - %0 = "FHELinalg.zero"() : () -> tensor<4x9xi32> - return %0 : tensor<4x9xi32> -} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir index f718d8f1f..1d4577f50 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir @@ -345,25 +345,3 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!FHE.eint<2>>) -> %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } - -///////////////////////////////////////////////// -// FHELinalg.zero -///////////////////////////////////////////////// - -// CHECK: func @zero_1D() -> tensor<4x!FHE.eint<2>> { -// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<2>> -// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<2>> -// CHECK-NEXT: } -func @zero_1D() -> tensor<4x!FHE.eint<2>> { - %0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<2>> - return %0 : tensor<4x!FHE.eint<2>> -} - -// CHECK: func @zero_2D() -> tensor<4x9x!FHE.eint<2>> { -// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.zero"() : () -> tensor<4x9x!FHE.eint<2>> -// CHECK-NEXT: return %[[v0]] : tensor<4x9x!FHE.eint<2>> -// CHECK-NEXT: } -func @zero_2D() -> tensor<4x9x!FHE.eint<2>> { - %0 = "FHELinalg.zero"() : () -> tensor<4x9x!FHE.eint<2>> - return %0 : tensor<4x9x!FHE.eint<2>> -} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/tensor-ops-to-linalg.mlir index 4d0290dac..7f2be19f2 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/tensor-ops-to-linalg.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/tensor-ops-to-linalg.mlir @@ -1,22 +1,22 @@ // RUN: concretecompiler %s --action=dump-tfhe 2>&1 | FileCheck %s -//CHECK: #map0 = affine_map<(d0) -> (d0)> -//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> -//CHECK-NEXT: module { -//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>) -> !TFHE.glwe<{_,_,_}{2}> { -//CHECK-NEXT: %0 = "TFHE.zero"() : () -> !TFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!TFHE.glwe<{_,_,_}{2}>> -//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!TFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%1 : tensor<1x!TFHE.glwe<{_,_,_}{2}>>) { -//CHECK-NEXT: ^bb0(%arg2: !TFHE.glwe<{_,_,_}{2}>, %arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>): // no predecessors -//CHECK-NEXT: %4 = "TFHE.mul_glwe_int"(%arg2, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: %5 = "TFHE.add_glwe"(%4, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: linalg.yield %5 : !TFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: } -> tensor<1x!TFHE.glwe<{_,_,_}{2}>> -//CHECK-NEXT: %c0 = arith.constant 0 : index -//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!TFHE.glwe<{_,_,_}{2}>> -//CHECK-NEXT: return %3 : !TFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: } -//CHECK-NEXT: } +// CHECK: #map0 = affine_map<(d0) -> (d0)> +// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> +// CHECK-NEXT: module { +// CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>) -> !TFHE.glwe<{_,_,_}{2}> { +// CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<1x!TFHE.glwe<{_,_,_}{2}>> +// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!TFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%0 : tensor<1x!TFHE.glwe<{_,_,_}{2}>>) { +// CHECK-NEXT: ^bb0(%arg2: !TFHE.glwe<{_,_,_}{2}>, %arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>): // no predecessors +// CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg2, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: %4 = "TFHE.add_glwe"(%3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: } -> tensor<1x!TFHE.glwe<{_,_,_}{2}>> +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %2 = tensor.extract %1[%c0] : tensor<1x!TFHE.glwe<{_,_,_}{2}>> +// CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: } +// CHECK-NEXT: } + func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>, %arg1: tensor<2xi3>) -> !FHE.eint<2> { diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/tiling.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/tiling.mlir index 4d74250e2..c936a3c43 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/tiling.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/tiling.mlir @@ -5,7 +5,7 @@ // CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[Vc8:.*]] = arith.constant 8 : index // CHECK-NEXT: %[[Vc4:.*]] = arith.constant 4 : index -// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.zero"() : () -> tensor<8x2x!FHE.eint<6>> +// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<8x2x!FHE.eint<6>> // CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc8]] step %[[Vc2]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<8x2x!FHE.eint<6>>) { // CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg4:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc2]] iter_args(%[[Varg5:.*]] = %[[Varg3]]) -> (tensor<8x2x!FHE.eint<6>>) { // CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg6:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc2]] iter_args(%[[Varg7:.*]] = %[[Varg5]]) -> (tensor<8x2x!FHE.eint<6>>) { @@ -35,7 +35,7 @@ func @tiled_2x2(%a: tensor<8x4x!FHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8x2x // CHECK-NEXT: %[[Vc4:.*]] = arith.constant 4 : index // CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index // CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.zero"() : () -> tensor<8x2x!FHE.eint<6>> +// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<8x2x!FHE.eint<6>> // CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc8]] step %[[Vc8]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<8x2x!FHE.eint<6>>) { // CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg4:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc4]] iter_args(%[[Varg5:.*]] = %[[Varg3]]) -> (tensor<8x2x!FHE.eint<6>>) { // CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg6:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc2]] iter_args(%[[Varg7:.*]] = %[[Varg5]]) -> (tensor<8x2x!FHE.eint<6>>) { diff --git a/compiler/tests/unittest/end_to_end_jit_fhe.cc b/compiler/tests/unittest/end_to_end_jit_fhe.cc index 82e9f03a4..8484553bf 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhe.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhe.cc @@ -23,6 +23,43 @@ func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { ASSERT_EXPECTED_VALUE(lambda(8_u64), 8); } +// FHE.zero_tensor //////////////////////////////////////////////////////////// + +TEST(End2EndJit_FHE, zero_tensor) { + mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main() -> tensor<2x2x4x!FHE.eint<6>> { + %0 = "FHE.zero_tensor"() : () -> tensor<2x2x4x!FHE.eint<6>> + return %0 : tensor<2x2x4x!FHE.eint<6>> +} +)XXX"); + + llvm::Expected> res = + lambda.operator()>(); + + ASSERT_EXPECTED_SUCCESS(res); + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument<>> &resp = + (*res) + ->cast>>(); + + ASSERT_EQ(resp.getDimensions().size(), (size_t)3); + ASSERT_EQ(resp.getDimensions().at(0), 2); + ASSERT_EQ(resp.getDimensions().at(1), 2); + ASSERT_EQ(resp.getDimensions().at(2), 4); + ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 2 * 4); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + for (size_t k = 0; k < 4; k++) { + EXPECT_EQ(resp.getValue()[i * 8 + j * 4 + k], 0) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + // FHE.add_eint_int ///////////////////////////////////////////////////////// TEST(End2EndJit_FHE, add_eint_int_cst) { diff --git a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc index 2df3c8c4c..f82647ac9 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc @@ -1541,44 +1541,6 @@ func @main(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x!FHE.eint<6>> { } } -/////////////////////////////////////////////////////////////////////////////// -// FHELinalg.zero /////////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////// - -TEST(End2EndJit_Linalg, zero) { - mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( -func @main() -> tensor<2x2x4x!FHE.eint<6>> { - %0 = "FHELinalg.zero"() : () -> tensor<2x2x4x!FHE.eint<6>> - return %0 : tensor<2x2x4x!FHE.eint<6>> -} -)XXX"); - - llvm::Expected> res = - lambda.operator()>(); - - ASSERT_EXPECTED_SUCCESS(res); - - mlir::concretelang::TensorLambdaArgument> - &resp = (*res) - ->cast>>(); - - ASSERT_EQ(resp.getDimensions().size(), (size_t)3); - ASSERT_EQ(resp.getDimensions().at(0), 2); - ASSERT_EQ(resp.getDimensions().at(1), 2); - ASSERT_EQ(resp.getDimensions().at(2), 4); - ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 2 * 4); - - for (size_t i = 0; i < 2; i++) { - for (size_t j = 0; j < 2; j++) { - for (size_t k = 0; k < 4; k++) { - EXPECT_EQ(resp.getValue()[i * 8 + j * 4 + k], 0) - << ", at pos(" << i << "," << j << "," << k << ")"; - } - } - } -} - /////////////////////////////////////////////////////////////////////////////// // FHELinalg sum ///////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////