From d28bf3767b11519dcea62fff20aa29acde0e959a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Thu, 14 Sep 2023 09:44:38 +0200 Subject: [PATCH] feat(compiler): adds support for dynamic luts in fhelinalg --- .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 42 ++++++++++++++----- .../Dialect/FHELinalg/ops.invalid.mlir | 30 ++++++++++--- ...nd_to_end_linalg_apply_lookup_table_gen.py | 20 +++++++-- .../tests_cpu/end_to_end_fhelinalg.yaml | 24 +++++------ 4 files changed, 85 insertions(+), 31 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index f5e08169b..8c241f06f 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -283,11 +283,13 @@ mlir::LogicalResult ApplyLookupTableEintOp::verify() { // Check the shape of lut argument auto tEltwidth = tEltTy.getWidth(); mlir::SmallVector expectedShape{1 << tEltwidth}; - if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) { - this->emitOpError() - << "should have as operand #2 a tensor<2^pxi64>, where p is the width " - "of the encrypted integer of the operand #1," - << "expect tensor <" << expectedShape[0] << "xi64>"; + if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isSignlessInteger() || + lutEltTy.getIntOrFloatBitWidth() > 64) { + this->emitOpError() << "should have as operand #2 a " + "tensor<2^pxi{8,16,32,64}>, where p is the width " + "of the encrypted integer of the operand #1," + << "expect tensor <" << expectedShape[0] + << "xi{8,16,32,64}>"; return mlir::failure(); } if (!resultTy.hasStaticShape(tTy.getShape())) { @@ -308,12 +310,14 @@ mlir::LogicalResult ApplyMultiLookupTableEintOp::verify() { // Check the shape of luts argument auto lut_size = lutTy.getShape()[lutTy.getShape().size() - 1]; auto expected_lut_size = 1 << tEltTy.getWidth(); - if (lut_size != expected_lut_size || !lutEltTy.isInteger(64)) { - this->emitOpError() << "should have as operand #2 a " - "tensor, where p is the width " - "of the encrypted integer of the operand #1," - << "expect tensor "; + if (lut_size != expected_lut_size || !lutEltTy.isSignlessInteger() || + lutEltTy.getIntOrFloatBitWidth() > 64) { + this->emitOpError() + << "should have as operand #2 a " + "tensor, where p is the width " + "of the encrypted integer of the operand #1," + << "expect tensor "; return mlir::failure(); } if (!resultTy.hasStaticShape(tTy.getShape())) { @@ -380,9 +384,14 @@ mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op, mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() { auto t = this->getT(); + auto tTy = this->getT().getType().cast(); + auto tEltTy = + tTy.getElementType().cast(); auto luts = this->getLuts(); auto map = this->getMap(); auto result = this->getResult(); + auto lutTy = this->getLuts().getType().cast(); + auto lutEltTy = lutTy.getElementType().cast(); auto t_shape = getTensorType(t).getShape(); if (!getTensorType(result).hasStaticShape(t_shape)) { @@ -397,6 +406,17 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() { return mlir::failure(); } + auto expected_lut_size = 1 << tEltTy.getWidth(); + if (!lutEltTy.isSignlessInteger() || lutEltTy.getIntOrFloatBitWidth() > 64) { + this->emitOpError() + << "should have as operand #2 a " + "tensor, where p is the width " + "of the encrypted integer of the operand #1," + << "expect tensor "; + return mlir::failure(); + } + return mlir::success(verifyMapHasRightShape(*this, t, map).succeeded() && verifyLutsSize(*this, t, luts).succeeded()); } diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir index cf85112b4..4ac547ec6 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir @@ -164,16 +164,16 @@ func.func @main(%a0: tensor<2x3x4x!FHE.eint<2>>, %a1: tensor<2x3x4x!FHE.eint<3>> // FHELinalg.apply_lookup_table ///////////////////////////////////////////////// -func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi32>) -> tensor<2x3x4x!FHE.eint<2>> { - // expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}} - %1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi32>) -> (tensor<2x3x4x!FHE.eint<2>>) +func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi65>) -> tensor<2x3x4x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi{8,16,32,64}>}} + %1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi65>) -> (tensor<2x3x4x!FHE.eint<2>>) return %1: tensor<2x3x4x!FHE.eint<2>> } // ----- func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<12xi64>) -> tensor<2x3x4x!FHE.eint<2>> { - // expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}} + // expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi{8,16,32,64}>}} %1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<12xi64>) -> (tensor<2x3x4x!FHE.eint<2>>) return %1: tensor<2x3x4x!FHE.eint<2>> } @@ -193,13 +193,21 @@ func.func @apply_lookup_table(%arg0: tensor<3x4x!FHE.eint<2>>, %arg1: tensor<4xi ///////////////////////////////////////////////// func.func @apply_multi_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<2x6xi64>) -> tensor<2x3x4x!FHE.eint<2>> { - // expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor, where p is the width of the encrypted integer of the operand #1,expect tensor }} + // expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor, where p is the width of the encrypted integer of the operand #1,expect tensor }} %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<2x6xi64>) -> (tensor<2x3x4x!FHE.eint<2>>) return %1: tensor<2x3x4x!FHE.eint<2>> } // ----- +func.func @apply_multi_lookup_table_bad_prec(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<2x4xi65>) -> tensor<2x3x4x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor, where p is the width of the encrypted integer of the operand #1,expect tensor }} + %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<2x4xi65>) -> (tensor<2x3x4x!FHE.eint<2>>) + return %1: tensor<2x3x4x!FHE.eint<2>> +} + +// ----- + ///////////////////////////////////////////////// // FHELinalg.apply_mapped_lookup_table @@ -240,6 +248,18 @@ func.func @apply_mapped_lookup_table_bad_map_elmt_type( // ----- +func.func @apply_mapped_lookup_table_bad_lut_prec( + %input: tensor<2x3x4x!FHE.eint<7>>, + %luts: tensor<128xi65>, + %map: tensor<2x3x4xindex> +) -> tensor<2x3x4x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.apply_mapped_lookup_table' op should have as operand #2 a tensor, where p is the width of the encrypted integer of the operand #1,expect tensor }} + %0 = "FHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!FHE.eint<7>>, tensor<128xi65>, tensor<2x3x4xindex>) -> (tensor<2x3x4x!FHE.eint<7>>) + return %0: tensor<2x3x4x!FHE.eint<7>> +} + +// ----- + ///////////////////////////////////////////////// // FHELinalg.conv2d ///////////////////////////////////////////////// diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py index d80ba40e3..f969d5f7f 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/end_to_end_linalg_apply_lookup_table_gen.py @@ -7,6 +7,18 @@ from end_to_end_linalg_leveled_gen import P_ERROR PRECISION_FORCE_CRT = 9 +def get_lut_integer_type(p): + if p <= 8: + return "i8" + if p <= 16: + return "i16" + if p <= 32: + return "i32" + if p <= 64: + return "i64" + else: + raise Exception("Unexpected precision") + def generate(args): print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY") print("# /!\ THIS FILE HAS BEEN GENERATED") @@ -16,15 +28,15 @@ def generate(args): for n_lut in args.n_lut: max_value = (2 ** p) - 1 random_lut = np.random.randint(max_value+1, size=2**p) + itype = get_lut_integer_type(p) # identity_apply_lookup_table print(f"description: apply_lookup_table_{p}bits_{n_ct}ct_{n_lut}layer") print("program: |") print( - f" func.func @main(%0: tensor<{n_ct}x!FHE.eint<{p}>>) -> tensor<{n_ct}x!FHE.eint<{p}>> {{") - print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>") + f" func.func @main(%0: tensor<{n_ct}x!FHE.eint<{p}>>, %tlu: tensor<{2**p}x{itype}>) -> tensor<{n_ct}x!FHE.eint<{p}>> {{") for i in range(0, n_lut): print(f" %{i+1} = \"FHELinalg.apply_lookup_table\"(%{i}, %tlu):") - print(f" (tensor<{n_ct}x!FHE.eint<{p}>>, tensor<{2**p}xi64>) -> (tensor<{n_ct}x!FHE.eint<{p}>>)") + print(f" (tensor<{n_ct}x!FHE.eint<{p}>>, tensor<{2**p}x{itype}>) -> (tensor<{n_ct}x!FHE.eint<{p}>>)") print(f" return %{n_lut}: tensor<{n_ct}x!FHE.eint<{p}>>") print(" }") if p >= PRECISION_FORCE_CRT: @@ -35,6 +47,8 @@ def generate(args): print(" - inputs:") print(f" - tensor: [{','.join(map(str, random_input))}]") print(f" shape: [{n_ct}]") + print(f" - tensor: [{','.join(map(str, random_lut))}]") + print(f" shape: [{2**p}]") outputs = random_input for i in range(0, n_lut): outputs = [random_lut[v] for v in outputs] diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml index 8bf4edc60..3d9d358e0 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml @@ -1034,8 +1034,8 @@ program: | // [3,0,1] lut [1,3,5,7] = [7,1,3] // [2,3,0] [5,7,1] func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> { - %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64> - %res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>> + %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi8> + %res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi8>) -> tensor<3x3x!FHE.eint<3>> return %res : tensor<3x3x!FHE.eint<3>> } tests: @@ -1050,8 +1050,8 @@ tests: description: apply_lookup_table_batched program: | func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> { - %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64> - %res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>> + %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi8> + %res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi8>) -> tensor<3x3x!FHE.eint<3>> return %res : tensor<3x3x!FHE.eint<3>> } tests: @@ -1066,8 +1066,8 @@ tests: description: apply_multi_lookup_table program: | // Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a 3x3 matrix of tables of size 4=2² of clear integers. - func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<2>> { - %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<2>> + func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>> { + %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>> return %1: tensor<3x3x!FHE.eint<2>> } tests: @@ -1084,8 +1084,8 @@ tests: description: apply_multi_lookup_table_with_boradcast program: | // Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a vector of 3 tables of size 4=2² of clear integers. - func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<3x3x!FHE.eint<2>> { - %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x3x!FHE.eint<2>> + func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x4xi8>) -> tensor<3x3x!FHE.eint<2>> { + %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x4xi8>) -> tensor<3x3x!FHE.eint<2>> return %1: tensor<3x3x!FHE.eint<2>> } tests: @@ -1103,9 +1103,9 @@ tests: description: apply_mapped_lookup_table_sequential program: | // Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers. - func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> { + func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi8>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> { %1 = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : - (tensor<3x3x!FHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> + (tensor<3x3x!FHE.eint<2>>, tensor<9x4xi8>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> return %1: tensor<3x3x!FHE.eint<2>> } tests: @@ -1124,9 +1124,9 @@ tests: description: apply_mapped_lookup_table_same_lut program: | // Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers. - func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> { + func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi8>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> { %1 = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : - (tensor<3x3x!FHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> + (tensor<3x3x!FHE.eint<2>>, tensor<9x4xi8>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> return %1: tensor<3x3x!FHE.eint<2>> } tests: