diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/TypeInference/inference.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/TypeInference/inference.mlir new file mode 100644 index 000000000..732c4c304 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/check_tests/TypeInference/inference.mlir @@ -0,0 +1,259 @@ +// RUN: concretecompiler --split-input-file --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi %s 2>&1| FileCheck %s + +// CHECK: func.func @funconly_fwd(%arg0: !TFHE.glwe>) -> !TFHE.glwe> { +// CHECK-NEXT: return %arg0 : !TFHE.glwe> +// CHECK-NEXT: } +func.func @funconly_fwd(%arg0: !TFHE.glwe>) -> !TFHE.glwe { + %a0 = "TypeInference.propagate_downward"(%arg0) : (!TFHE.glwe>) -> (!TFHE.glwe) + return %a0: !TFHE.glwe +} + +// ----- + +// CHECK: func.func @funconly_bwd(%arg0: !TFHE.glwe>) -> !TFHE.glwe> { +// CHECK-NEXT: return %arg0 : !TFHE.glwe> +// CHECK-NEXT: } +func.func @funconly_bwd(%arg0: !TFHE.glwe) -> !TFHE.glwe> { + %a0 = "TypeInference.propagate_upward"(%arg0) : (!TFHE.glwe) -> (!TFHE.glwe>) + return %a0: !TFHE.glwe> +} + +// ----- + +// CHECK: func.func @funconly_fwd_multires(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>) -> (!TFHE.glwe>, !TFHE.glwe>) { +// CHECK-NEXT: return %arg0, %arg1 : !TFHE.glwe>, !TFHE.glwe> +// CHECK-NEXT: } +func.func @funconly_fwd_multires(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>) -> (!TFHE.glwe, !TFHE.glwe) { + %a0 = "TypeInference.propagate_downward"(%arg0) : (!TFHE.glwe>) -> (!TFHE.glwe) + %a1 = "TypeInference.propagate_downward"(%arg1) : (!TFHE.glwe>) -> (!TFHE.glwe) + return %a0, %a1: !TFHE.glwe, !TFHE.glwe +} + +// ----- + +// CHECK: func.func @fwd1(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>) -> !TFHE.glwe> { +// CHECK-NEXT: %0 = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: return %0 : !TFHE.glwe> +// CHECK-NEXT: } +func.func @fwd1(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>) -> !TFHE.glwe { + %a0 = "TypeInference.propagate_downward"(%arg0) : (!TFHE.glwe>) -> (!TFHE.glwe) + %a1 = "TypeInference.propagate_downward"(%arg1) : (!TFHE.glwe>) -> (!TFHE.glwe) + + %0 = "TFHE.add_glwe"(%a0, %a1): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + return %0: !TFHE.glwe +} + +// ----- + +// CHECK: func.func @fwd2(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>, %arg2: !TFHE.glwe>) -> !TFHE.glwe> { +// CHECK-NEXT: %0 = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %1 = "TFHE.add_glwe"(%0, %arg2) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: return %1 : !TFHE.glwe> +// CHECK-NEXT: } +func.func @fwd2(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>, %arg2: !TFHE.glwe>) -> !TFHE.glwe { + %a0 = "TypeInference.propagate_downward"(%arg0) : (!TFHE.glwe>) -> (!TFHE.glwe) + %a1 = "TypeInference.propagate_downward"(%arg1) : (!TFHE.glwe>) -> (!TFHE.glwe) + %a2 = "TypeInference.propagate_downward"(%arg2) : (!TFHE.glwe>) -> (!TFHE.glwe) + + %0 = "TFHE.add_glwe"(%a0, %a1): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + %1 = "TFHE.add_glwe"(%0, %a2): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + + return %1: !TFHE.glwe +} + +// ----- + +// CHECK: func.func @for1(%arg0: index) -> !TFHE.glwe> { +// CHECK-NEXT: %c2 = arith.constant 2 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %1 = scf.for %arg1 = %c0 to %c2 step %c1 iter_args(%arg2 = %0) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %2 = "TFHE.zero"() : () -> !TFHE.glwe> +// CHECK-NEXT: %inserted = tensor.insert %2 into %arg2[%arg1] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %inserted : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: %extracted = tensor.extract %1[%c0] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: return %extracted : !TFHE.glwe> +// CHECK-NEXT: } +func.func @for1(%idx: index) -> !TFHE.glwe { + %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %1 = scf.for %i = %c0 to %c2 step %c1 iter_args(%iterarg = %0) -> (tensor<2x!TFHE.glwe>) { + %2 = "TFHE.zero"(): () -> (!TFHE.glwe>) + %a = "TypeInference.propagate_downward"(%2) : (!TFHE.glwe>) -> (!TFHE.glwe) + %3 = tensor.insert %a into %iterarg[%i] : tensor<2x!TFHE.glwe> + scf.yield %3 : tensor<2x!TFHE.glwe> + } + + %4 = tensor.extract %1[%c0] : tensor<2x!TFHE.glwe> + + return %4: !TFHE.glwe +} + +// ----- + +// CHECK: func.func @for2(%arg0: tensor<2x!TFHE.glwe>>, %arg1: tensor<2x!TFHE.glwe>>) -> tensor<2x!TFHE.glwe>> { +// CHECK-NEXT: %c2 = arith.constant 2 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %extracted = tensor.extract %arg0[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %extracted_0 = tensor.extract %arg1[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %2 = "TFHE.add_glwe"(%extracted, %extracted_0) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %inserted = tensor.insert %2 into %arg3[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %inserted : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: return %1 : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +func.func @for2(%arg0: tensor<2x!TFHE.glwe>>, %arg1: tensor<2x!TFHE.glwe>>) -> tensor<2x!TFHE.glwe> { + %a0 = "TypeInference.propagate_downward"(%arg0) : (tensor<2x!TFHE.glwe>>) -> (tensor<2x!TFHE.glwe>) + %a1 = "TypeInference.propagate_downward"(%arg1) : (tensor<2x!TFHE.glwe>>) -> (tensor<2x!TFHE.glwe>) + + %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %1 = scf.for %i = %c0 to %c2 step %c1 iter_args(%iterarg = %0) -> (tensor<2x!TFHE.glwe>) { + %2 = tensor.extract %a0[%i] : tensor<2x!TFHE.glwe> + %3 = tensor.extract %a1[%i] : tensor<2x!TFHE.glwe> + %4 = "TFHE.add_glwe"(%2, %3): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + %5 = tensor.insert %4 into %iterarg[%i] : tensor<2x!TFHE.glwe> + scf.yield %5 : tensor<2x!TFHE.glwe> + } + + return %1: tensor<2x!TFHE.glwe> +} + +// ----- + +// CHECK: func.func @for3(%arg0: tensor<2x!TFHE.glwe>>, %arg1: tensor<2x!TFHE.glwe>>) -> tensor<2x!TFHE.glwe>> { +// CHECK-NEXT: %c2 = arith.constant 2 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %extracted = tensor.extract %arg0[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %extracted_0 = tensor.extract %arg1[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %2 = "TFHE.add_glwe"(%extracted, %extracted_0) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %inserted = tensor.insert %2 into %arg3[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %inserted : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: return %1 : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } + +func.func @for3(%arg0: tensor<2x!TFHE.glwe>, %arg1: tensor<2x!TFHE.glwe>) -> tensor<2x!TFHE.glwe>> { + %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %1 = scf.for %i = %c0 to %c2 step %c1 iter_args(%iterarg = %0) -> (tensor<2x!TFHE.glwe>) { + %2 = tensor.extract %arg0[%i] : tensor<2x!TFHE.glwe> + %3 = tensor.extract %arg1[%i] : tensor<2x!TFHE.glwe> + %4 = "TFHE.add_glwe"(%2, %3): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + %5 = tensor.insert %4 into %iterarg[%i] : tensor<2x!TFHE.glwe> + scf.yield %5 : tensor<2x!TFHE.glwe> + } + + %r = "TypeInference.propagate_upward"(%1) : (tensor<2x!TFHE.glwe>) -> (tensor<2x!TFHE.glwe>>) + + return %r: tensor<2x!TFHE.glwe>> +} + +// ----- + +// CHECK: func.func @for4(%arg0: tensor<2x!TFHE.glwe>>, %arg1: tensor<2x!TFHE.glwe>>) -> tensor<2x!TFHE.glwe>> { +// CHECK-NEXT: %c2 = arith.constant 2 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %2 = scf.for %arg4 = %c0 to %c2 step %c1 iter_args(%arg5 = %arg3) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %extracted = tensor.extract %arg0[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %extracted_0 = tensor.extract %arg1[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %3 = "TFHE.add_glwe"(%extracted, %extracted_0) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %inserted = tensor.insert %3 into %arg5[%arg4] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %inserted : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %2 : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: return %1 : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } + +func.func @for4(%arg0: tensor<2x!TFHE.glwe>, %arg1: tensor<2x!TFHE.glwe>) -> tensor<2x!TFHE.glwe>> { + %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %1 = scf.for %i = %c0 to %c2 step %c1 iter_args(%iterarg0 = %0) -> (tensor<2x!TFHE.glwe>) { + %2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iterarg1 = %iterarg0) -> (tensor<2x!TFHE.glwe>) { + %3 = tensor.extract %arg0[%i] : tensor<2x!TFHE.glwe> + %4 = tensor.extract %arg1[%i] : tensor<2x!TFHE.glwe> + %5 = "TFHE.add_glwe"(%3, %4): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + %6 = tensor.insert %5 into %iterarg1[%j] : tensor<2x!TFHE.glwe> + scf.yield %6 : tensor<2x!TFHE.glwe> + } + + scf.yield %2 : tensor<2x!TFHE.glwe> + } + + %r = "TypeInference.propagate_upward"(%1) : (tensor<2x!TFHE.glwe>) -> (tensor<2x!TFHE.glwe>>) + + return %r: tensor<2x!TFHE.glwe>> +} + +// ----- + +// CHECK: func.func @for5(%arg0: tensor<2x!TFHE.glwe>>, %arg1: tensor<2x!TFHE.glwe>>) -> tensor<2x!TFHE.glwe>> { +// CHECK-NEXT: %c2 = arith.constant 2 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %2 = scf.for %arg4 = %c0 to %c2 step %c1 iter_args(%arg5 = %arg3) -> (tensor<2x!TFHE.glwe>>) { +// CHECK-NEXT: %extracted = tensor.extract %arg0[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %extracted_0 = tensor.extract %arg1[%arg2] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: %3 = "TFHE.add_glwe"(%extracted, %extracted_0) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> +// CHECK-NEXT: %inserted = tensor.insert %3 into %arg5[%arg4] : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: scf.yield %inserted : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %2 : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } +// CHECK-NEXT: return %1 : tensor<2x!TFHE.glwe>> +// CHECK-NEXT: } + +func.func @for5(%arg0: tensor<2x!TFHE.glwe>, %arg1: tensor<2x!TFHE.glwe>) -> tensor<2x!TFHE.glwe> { + %0 = bufferization.alloc_tensor() : tensor<2x!TFHE.glwe> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %1 = scf.for %i = %c0 to %c2 step %c1 iter_args(%iterarg0 = %0) -> (tensor<2x!TFHE.glwe>) { + %2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iterarg1 = %iterarg0) -> (tensor<2x!TFHE.glwe>) { + %3 = tensor.extract %arg0[%i] : tensor<2x!TFHE.glwe> + %4 = tensor.extract %arg1[%i] : tensor<2x!TFHE.glwe> + %5 = "TFHE.add_glwe"(%3, %4): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + %fivea = "TypeInference.propagate_upward"(%5) : (!TFHE.glwe) -> (!TFHE.glwe>) + %iterarg1a = "TypeInference.propagate_upward"(%iterarg1) : (tensor<2x!TFHE.glwe>) -> (tensor<2x!TFHE.glwe>>) + %6 = tensor.insert %fivea into %iterarg1a[%j] : tensor<2x!TFHE.glwe>> + %sixa = "TypeInference.propagate_downward"(%6) : (tensor<2x!TFHE.glwe>>) -> (tensor<2x!TFHE.glwe>) + scf.yield %sixa : tensor<2x!TFHE.glwe> + } + + scf.yield %2 : tensor<2x!TFHE.glwe> + } + + return %1: tensor<2x!TFHE.glwe> +}