From 95d49f465753a99891c9f6bc7acbeaccdcc7fd52 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 17 Jan 2023 11:16:09 +0100 Subject: [PATCH] feat: add boolean types/ops in FHE dialect --- .../concretelang/Dialect/FHE/IR/FHEOps.td | 176 ++++++++++++++++++ .../concretelang/Dialect/FHE/IR/FHETypes.td | 11 ++ compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 24 +++ .../check_tests/Dialect/FHE/ops.invalid.mlir | 24 +++ .../tests/check_tests/Dialect/FHE/ops.mlir | 81 ++++++++ 5 files changed, 316 insertions(+) diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index ca6930f31..d626ac809 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -361,4 +361,180 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> { let hasVerifier = 1; } +// FHE Boolean Operations + +def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> { + + let summary = "Applies a truth table based on two boolean inputs"; + + let description = [{ + Applies a truth table based on two boolean inputs. + + truth table must be a tensor of 4 boolean values. + + Example: + ```mlir + // ok + "FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> (!FHE.ebool) + + // error + "FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi4>) -> (!FHE.ebool) + "FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<7xi1>) -> (!FHE.ebool) + ``` + }]; + + let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right, TensorOf<[I1]>:$truth_table); + let results = (outs FHE_EncryptedBooleanType); + let hasVerifier = 1; +} + +def FHE_MuxOp : FHE_Op<"mux", [NoSideEffect]> { + + let summary = "Multiplexer for two encrypted boolean inputs, based on an encrypted condition"; + + let description = [{ + Mutex between two encrypted boolean inputs, based on an encrypted condition. + + Example: + ```mlir + "FHE.mux"(%cond, %c1, %c2): (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> (!FHE.ebool) + ``` + }]; + + let arguments = (ins FHE_EncryptedBooleanType:$cond, FHE_EncryptedBooleanType:$c1, FHE_EncryptedBooleanType:$c2); + let results = (outs FHE_EncryptedBooleanType); +} + + + +def FHE_BoolAndOp : FHE_Op<"and", [NoSideEffect]> { + + let summary = "Applies an AND gate to two encrypted boolean values"; + + let description = [{ + Applies an AND gate to two encrypted boolean values. + + Example: + ```mlir + "FHE.and"(%a, %b): (!FHE.ebool, !FHE.ebool) -> (!FHE.ebool) + ``` + }]; + + let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right); + let results = (outs FHE_EncryptedBooleanType); +} + +def FHE_BoolOrOp : FHE_Op<"or", [NoSideEffect]> { + + let summary = "Applies an OR gate to two encrypted boolean values"; + + let description = [{ + Applies an OR gate to two encrypted boolean values. + + Example: + ```mlir + "FHE.or"(%a, %b): (!FHE.ebool, !FHE.ebool) -> (!FHE.ebool) + ``` + }]; + + let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right); + let results = (outs FHE_EncryptedBooleanType); +} + +def FHE_BoolNandOp : FHE_Op<"nand", [NoSideEffect]> { + + let summary = "Applies a NAND gate to two encrypted boolean values"; + + let description = [{ + Applies a NAND gate to two encrypted boolean values. + + Example: + ```mlir + "FHE.nand"(%a, %b): (!FHE.ebool, !FHE.ebool) -> (!FHE.ebool) + ``` + }]; + + let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right); + let results = (outs FHE_EncryptedBooleanType); +} + +def FHE_BoolXorOp : FHE_Op<"xor", [NoSideEffect]> { + + let summary = "Applies a XOR gate to two encrypted boolean values"; + + let description = [{ + Applies a XOR gate to two encrypted boolean values. + + Example: + ```mlir + "FHE.xor"(%a, %b): (!FHE.ebool, !FHE.ebool) -> (!FHE.ebool) + ``` + }]; + + let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right); + let results = (outs FHE_EncryptedBooleanType); +} + +def FHE_BoolNotOp : FHE_Op<"not", [NoSideEffect]> { + + let summary = "Applies a NOT gate to an encrypted boolean value"; + + let description = [{ + Applies a NOT gate to an encrypted boolean value. + + Example: + ```mlir + "FHE.not"(%a): (!FHE.ebool) -> (!FHE.ebool) + ``` + }]; + + let arguments = (ins FHE_EncryptedBooleanType:$value); + let results = (outs FHE_EncryptedBooleanType); +} + +def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> { + let summary = "Cast an unsigned integer to a boolean"; + + let description = [{ + Cast an unsigned integer to a boolean. + + The input must necessarily be of width 1, in order to put that single + bit into a boolean. + + Examples: + ```mlir + // ok + "FHE.to_bool"(%x) : (!FHE.eint<1>) -> !FHE.ebool + + // error + "FHE.to_bool"(%x) : (!FHE.eint<2>) -> !FHE.ebool + ``` + }]; + + let arguments = (ins FHE_EncryptedIntegerType:$input); + let results = (outs FHE_EncryptedBooleanType); + + let hasVerifier = 1; +} + +def FHE_FromBoolOp : FHE_Op<"from_bool", [NoSideEffect]> { + let summary = "Cast a boolean to an unsigned integer"; + + let description = [{ + Cast a boolean to an unsigned integer. + + Examples: + ```mlir + "FHE.from_bool"(%x) : (!FHE.ebool) -> !FHE.eint<1> + "FHE.from_bool"(%x) : (!FHE.ebool) -> !FHE.eint<2> + "FHE.from_bool"(%x) : (!FHE.ebool) -> !FHE.eint<4> + ``` + }]; + + let arguments = (ins FHE_EncryptedBooleanType:$input); + let results = (outs FHE_EncryptedIntegerType); +} + + + #endif diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td index 815d72ed4..50f5138aa 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td @@ -69,4 +69,15 @@ def FHE_AnyEncryptedInteger : Type>; +def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean", + [MemRefElementTypeInterface]> { + let mnemonic = "ebool"; + + let summary = "An encrypted boolean"; + + let description = [{ + An encrypted boolean. + }]; +} + #endif diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index dd26a83d0..cfe13d8d6 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -203,6 +203,30 @@ mlir::LogicalResult ToUnsignedOp::verify() { return mlir::success(); } +mlir::LogicalResult ToBoolOp::verify() { + auto input = this->input().getType().cast(); + + if (input.getWidth() != 1) { + this->emitOpError( + "should have 1 as the width of encrypted input to cast to a boolean"); + return mlir::failure(); + } + + return mlir::success(); +} + +mlir::LogicalResult GenGateOp::verify() { + auto truth_table = this->truth_table().getType().cast(); + + mlir::SmallVector expectedShape{4}; + if (!truth_table.hasStaticShape(expectedShape)) { + this->emitOpError("truth table should be a tensor of 4 boolean values"); + return mlir::failure(); + } + + return mlir::success(); +} + ::mlir::LogicalResult ApplyLookupTableEintOp::verify() { auto ct = this->a().getType().cast(); auto lut = this->lut().getType().cast(); diff --git a/compiler/tests/check_tests/Dialect/FHE/ops.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/ops.invalid.mlir index 4ce6bd5cd..af4aae43e 100644 --- a/compiler/tests/check_tests/Dialect/FHE/ops.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/ops.invalid.mlir @@ -13,3 +13,27 @@ func.func @zero_plaintext() -> tensor<4x9xi32> { %0 = "FHE.zero_tensor"() : () -> tensor<4x9xi32> return %0 : tensor<4x9xi32> } + +// ----- + +func.func @to_bool(%arg0: !FHE.eint<2>) -> !FHE.ebool { + // expected-error @+1 {{'FHE.to_bool' op}} + %1 = "FHE.to_bool"(%arg0): (!FHE.eint<2>) -> (!FHE.ebool) + return %1: !FHE.ebool +} + +// ----- + +func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<5xi1>) -> !FHE.ebool { + // expected-error @+1 {{'FHE.gen_gate' op}} + %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<5xi1>) -> !FHE.ebool + return %1: !FHE.ebool +} + +// ----- + +func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi2>) -> !FHE.ebool { + // expected-error @+1 {{'FHE.gen_gate' op}} + %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi2>) -> !FHE.ebool + return %1: !FHE.ebool +} diff --git a/compiler/tests/check_tests/Dialect/FHE/ops.mlir b/compiler/tests/check_tests/Dialect/FHE/ops.mlir index 517b6d9f8..567902bc2 100644 --- a/compiler/tests/check_tests/Dialect/FHE/ops.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/ops.mlir @@ -213,3 +213,84 @@ func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE %1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<2>) return %1: !FHE.eint<2> } + +// CHECK-LABEL: func.func @to_bool(%arg0: !FHE.eint<1>) -> !FHE.ebool +func.func @to_bool(%arg0: !FHE.eint<1>) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.to_bool"(%arg0) : (!FHE.eint<1>) -> !FHE.ebool + // CHECK-NEXT: return %[[V1]] : !FHE.ebool + + %1 = "FHE.to_bool"(%arg0): (!FHE.eint<1>) -> (!FHE.ebool) + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @from_bool(%arg0: !FHE.ebool) -> !FHE.eint<1> +func.func @from_bool(%arg0: !FHE.ebool) -> !FHE.eint<1> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<1> + // CHECK-NEXT: return %[[V1]] : !FHE.eint<1> + + %1 = "FHE.from_bool"(%arg0): (!FHE.ebool) -> (!FHE.eint<1>) + return %1: !FHE.eint<1> +} + +// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi1>) -> !FHE.ebool +func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi1>) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> !FHE.ebool + // CHECK-NEXT: return %[[V1]] : !FHE.ebool + + %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool +func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.mux"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> !FHE.ebool + // CHECK-NEXT: return %[[V1]] : !FHE.ebool + + %1 = "FHE.mux"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool +func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.and"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + // CHECK-NEXT: return %[[V1]] : !FHE.ebool + + %1 = "FHE.and"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool +func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.or"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + // CHECK-NEXT: return %[[V1]] : !FHE.ebool + + %1 = "FHE.or"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool +func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.nand"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + // CHECK-NEXT: return %[[V1]] : !FHE.ebool + + %1 = "FHE.nand"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool +func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + // CHECK-NEXT: return %[[V1]] : !FHE.ebool + + %1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool +func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool + // CHECK-NEXT: return %[[V1]] : !FHE.ebool + + %1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +}