mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
feat(compiler): adds support for dynamic luts in fhelinalg
This commit is contained in:
@@ -283,11 +283,13 @@ mlir::LogicalResult ApplyLookupTableEintOp::verify() {
|
||||
// Check the shape of lut argument
|
||||
auto tEltwidth = tEltTy.getWidth();
|
||||
mlir::SmallVector<int64_t, 1> expectedShape{1 << tEltwidth};
|
||||
if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) {
|
||||
this->emitOpError()
|
||||
<< "should have as operand #2 a tensor<2^pxi64>, where p is the width "
|
||||
"of the encrypted integer of the operand #1,"
|
||||
<< "expect tensor <" << expectedShape[0] << "xi64>";
|
||||
if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isSignlessInteger() ||
|
||||
lutEltTy.getIntOrFloatBitWidth() > 64) {
|
||||
this->emitOpError() << "should have as operand #2 a "
|
||||
"tensor<2^pxi{8,16,32,64}>, where p is the width "
|
||||
"of the encrypted integer of the operand #1,"
|
||||
<< "expect tensor <" << expectedShape[0]
|
||||
<< "xi{8,16,32,64}>";
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!resultTy.hasStaticShape(tTy.getShape())) {
|
||||
@@ -308,12 +310,14 @@ mlir::LogicalResult ApplyMultiLookupTableEintOp::verify() {
|
||||
// Check the shape of luts argument
|
||||
auto lut_size = lutTy.getShape()[lutTy.getShape().size() - 1];
|
||||
auto expected_lut_size = 1 << tEltTy.getWidth();
|
||||
if (lut_size != expected_lut_size || !lutEltTy.isInteger(64)) {
|
||||
this->emitOpError() << "should have as operand #2 a "
|
||||
"tensor<DMx...xD1X2^pxi64>, where p is the width "
|
||||
"of the encrypted integer of the operand #1,"
|
||||
<< "expect tensor <DMx...xD1X" << expected_lut_size
|
||||
<< "xi64>";
|
||||
if (lut_size != expected_lut_size || !lutEltTy.isSignlessInteger() ||
|
||||
lutEltTy.getIntOrFloatBitWidth() > 64) {
|
||||
this->emitOpError()
|
||||
<< "should have as operand #2 a "
|
||||
"tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width "
|
||||
"of the encrypted integer of the operand #1,"
|
||||
<< "expect tensor <DMx...xD1X" << expected_lut_size
|
||||
<< "xi{8,16,32,64}>";
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!resultTy.hasStaticShape(tTy.getShape())) {
|
||||
@@ -380,9 +384,14 @@ mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op,
|
||||
|
||||
mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
|
||||
auto t = this->getT();
|
||||
auto tTy = this->getT().getType().cast<mlir::RankedTensorType>();
|
||||
auto tEltTy =
|
||||
tTy.getElementType().cast<mlir::concretelang::FHE::FheIntegerInterface>();
|
||||
auto luts = this->getLuts();
|
||||
auto map = this->getMap();
|
||||
auto result = this->getResult();
|
||||
auto lutTy = this->getLuts().getType().cast<mlir::RankedTensorType>();
|
||||
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
|
||||
|
||||
auto t_shape = getTensorType(t).getShape();
|
||||
if (!getTensorType(result).hasStaticShape(t_shape)) {
|
||||
@@ -397,6 +406,17 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto expected_lut_size = 1 << tEltTy.getWidth();
|
||||
if (!lutEltTy.isSignlessInteger() || lutEltTy.getIntOrFloatBitWidth() > 64) {
|
||||
this->emitOpError()
|
||||
<< "should have as operand #2 a "
|
||||
"tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width "
|
||||
"of the encrypted integer of the operand #1,"
|
||||
<< "expect tensor <DMx...xD1X" << expected_lut_size
|
||||
<< "xi{8,16,32,64}>";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success(verifyMapHasRightShape(*this, t, map).succeeded() &&
|
||||
verifyLutsSize(*this, t, luts).succeeded());
|
||||
}
|
||||
|
||||
@@ -164,16 +164,16 @@ func.func @main(%a0: tensor<2x3x4x!FHE.eint<2>>, %a1: tensor<2x3x4x!FHE.eint<3>>
|
||||
// FHELinalg.apply_lookup_table
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi32>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}}
|
||||
%1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi32>) -> (tensor<2x3x4x!FHE.eint<2>>)
|
||||
func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi65>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi{8,16,32,64}>}}
|
||||
%1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi65>) -> (tensor<2x3x4x!FHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<12xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}}
|
||||
// expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi{8,16,32,64}>}}
|
||||
%1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<12xi64>) -> (tensor<2x3x4x!FHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!FHE.eint<2>>
|
||||
}
|
||||
@@ -193,13 +193,21 @@ func.func @apply_lookup_table(%arg0: tensor<3x4x!FHE.eint<2>>, %arg1: tensor<4xi
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func.func @apply_multi_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<2x6xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor<DMx...xD1X2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <DMx...xD1X4xi64>}}
|
||||
// expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <DMx...xD1X4xi{8,16,32,64}>}}
|
||||
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<2x6xi64>) -> (tensor<2x3x4x!FHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @apply_multi_lookup_table_bad_prec(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<2x4xi65>) -> tensor<2x3x4x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <DMx...xD1X4xi{8,16,32,64}>}}
|
||||
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<2x4xi65>) -> (tensor<2x3x4x!FHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.apply_mapped_lookup_table
|
||||
@@ -240,6 +248,18 @@ func.func @apply_mapped_lookup_table_bad_map_elmt_type(
|
||||
|
||||
// -----
|
||||
|
||||
func.func @apply_mapped_lookup_table_bad_lut_prec(
|
||||
%input: tensor<2x3x4x!FHE.eint<7>>,
|
||||
%luts: tensor<128xi65>,
|
||||
%map: tensor<2x3x4xindex>
|
||||
) -> tensor<2x3x4x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.apply_mapped_lookup_table' op should have as operand #2 a tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <DMx...xD1X128xi{8,16,32,64}>}}
|
||||
%0 = "FHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!FHE.eint<7>>, tensor<128xi65>, tensor<2x3x4xindex>) -> (tensor<2x3x4x!FHE.eint<7>>)
|
||||
return %0: tensor<2x3x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.conv2d
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
@@ -7,6 +7,18 @@ from end_to_end_linalg_leveled_gen import P_ERROR
|
||||
|
||||
PRECISION_FORCE_CRT = 9
|
||||
|
||||
def get_lut_integer_type(p):
|
||||
if p <= 8:
|
||||
return "i8"
|
||||
if p <= 16:
|
||||
return "i16"
|
||||
if p <= 32:
|
||||
return "i32"
|
||||
if p <= 64:
|
||||
return "i64"
|
||||
else:
|
||||
raise Exception("Unexpected precision")
|
||||
|
||||
def generate(args):
|
||||
print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
|
||||
print("# /!\ THIS FILE HAS BEEN GENERATED")
|
||||
@@ -16,15 +28,15 @@ def generate(args):
|
||||
for n_lut in args.n_lut:
|
||||
max_value = (2 ** p) - 1
|
||||
random_lut = np.random.randint(max_value+1, size=2**p)
|
||||
itype = get_lut_integer_type(p)
|
||||
# identity_apply_lookup_table
|
||||
print(f"description: apply_lookup_table_{p}bits_{n_ct}ct_{n_lut}layer")
|
||||
print("program: |")
|
||||
print(
|
||||
f" func.func @main(%0: tensor<{n_ct}x!FHE.eint<{p}>>) -> tensor<{n_ct}x!FHE.eint<{p}>> {{")
|
||||
print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>")
|
||||
f" func.func @main(%0: tensor<{n_ct}x!FHE.eint<{p}>>, %tlu: tensor<{2**p}x{itype}>) -> tensor<{n_ct}x!FHE.eint<{p}>> {{")
|
||||
for i in range(0, n_lut):
|
||||
print(f" %{i+1} = \"FHELinalg.apply_lookup_table\"(%{i}, %tlu):")
|
||||
print(f" (tensor<{n_ct}x!FHE.eint<{p}>>, tensor<{2**p}xi64>) -> (tensor<{n_ct}x!FHE.eint<{p}>>)")
|
||||
print(f" (tensor<{n_ct}x!FHE.eint<{p}>>, tensor<{2**p}x{itype}>) -> (tensor<{n_ct}x!FHE.eint<{p}>>)")
|
||||
print(f" return %{n_lut}: tensor<{n_ct}x!FHE.eint<{p}>>")
|
||||
print(" }")
|
||||
if p >= PRECISION_FORCE_CRT:
|
||||
@@ -35,6 +47,8 @@ def generate(args):
|
||||
print(" - inputs:")
|
||||
print(f" - tensor: [{','.join(map(str, random_input))}]")
|
||||
print(f" shape: [{n_ct}]")
|
||||
print(f" - tensor: [{','.join(map(str, random_lut))}]")
|
||||
print(f" shape: [{2**p}]")
|
||||
outputs = random_input
|
||||
for i in range(0, n_lut):
|
||||
outputs = [random_lut[v] for v in outputs]
|
||||
|
||||
@@ -1034,8 +1034,8 @@ program: |
|
||||
// [3,0,1] lut [1,3,5,7] = [7,1,3]
|
||||
// [2,3,0] [5,7,1]
|
||||
func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> {
|
||||
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64>
|
||||
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>>
|
||||
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi8>
|
||||
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi8>) -> tensor<3x3x!FHE.eint<3>>
|
||||
return %res : tensor<3x3x!FHE.eint<3>>
|
||||
}
|
||||
tests:
|
||||
@@ -1050,8 +1050,8 @@ tests:
|
||||
description: apply_lookup_table_batched
|
||||
program: |
|
||||
func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> {
|
||||
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64>
|
||||
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>>
|
||||
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi8>
|
||||
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi8>) -> tensor<3x3x!FHE.eint<3>>
|
||||
return %res : tensor<3x3x!FHE.eint<3>>
|
||||
}
|
||||
tests:
|
||||
@@ -1066,8 +1066,8 @@ tests:
|
||||
description: apply_multi_lookup_table
|
||||
program: |
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a 3x3 matrix of tables of size 4=2² of clear integers.
|
||||
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<2>>
|
||||
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>>
|
||||
return %1: tensor<3x3x!FHE.eint<2>>
|
||||
}
|
||||
tests:
|
||||
@@ -1084,8 +1084,8 @@ tests:
|
||||
description: apply_multi_lookup_table_with_boradcast
|
||||
program: |
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a vector of 3 tables of size 4=2² of clear integers.
|
||||
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<3x3x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x3x!FHE.eint<2>>
|
||||
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x4xi8>) -> tensor<3x3x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x4xi8>) -> tensor<3x3x!FHE.eint<2>>
|
||||
return %1: tensor<3x3x!FHE.eint<2>>
|
||||
}
|
||||
tests:
|
||||
@@ -1103,9 +1103,9 @@ tests:
|
||||
description: apply_mapped_lookup_table_sequential
|
||||
program: |
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers.
|
||||
func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> {
|
||||
func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi8>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) :
|
||||
(tensor<3x3x!FHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>>
|
||||
(tensor<3x3x!FHE.eint<2>>, tensor<9x4xi8>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>>
|
||||
return %1: tensor<3x3x!FHE.eint<2>>
|
||||
}
|
||||
tests:
|
||||
@@ -1124,9 +1124,9 @@ tests:
|
||||
description: apply_mapped_lookup_table_same_lut
|
||||
program: |
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers.
|
||||
func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> {
|
||||
func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi8>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) :
|
||||
(tensor<3x3x!FHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>>
|
||||
(tensor<3x3x!FHE.eint<2>>, tensor<9x4xi8>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>>
|
||||
return %1: tensor<3x3x!FHE.eint<2>>
|
||||
}
|
||||
tests:
|
||||
|
||||
Reference in New Issue
Block a user