diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 4ce0dcebb..abbb9a80f 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -430,10 +430,6 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric /// /// #maps_0 = [ /// affine_map<(d0, d1) -> (d0, d1)> -/// affine_map<(d0, d1) -> (d1, 0)> -/// affine_map<(d0, d1) -> (d1, 1)> -/// affine_map<(d0, d1) -> (d1, 2)> -/// affine_map<(d0, d1) -> (d1, 3)> /// ] /// #attributes_0 { /// indexing_maps = #maps_0, @@ -442,18 +438,18 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric /// %init = linalg.init_tensor [4, 3] /// : tensor<4x3x!FHE.eint<2>> /// %res = linalg.generic { -/// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!FHE.eint

>, -/// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>) +/// ins(%t, %luts: tensor<4x3x!FHE.eint

>) /// outs(%init : tensor<4x3x!FHE.eint<2>>) /// { -/// ^bb0(%arg0: !FHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64, -/// %arg4: i64, %arg5: !FHE.eint<2>): -/// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 : -/// tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, %lut) -/// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1 -/// : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 -/// : i32, polynomialSize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, -/// tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}> +/// ^bb0(%arg0: !FHE.eint<2>): +/// %i_lut = linalg.index 0 ; index +/// %lut = tensor.extract_slice %arg21[%i_lut, 0] [1, lut_size] [1, +/// 1] : ... tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, +/// %lut) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension +/// = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS +/// = -1 : i32, polynomialSize = -1 : i32} : +/// (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> +/// !TFHE.glwe<{_,_,_}{2}> /// linalg.yield %0 : !FHE.eint<2> /// } /// } @@ -477,55 +473,83 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric .cast(); mlir::RankedTensorType tensorTy = ((mlir::Type)fheLinalgLutOp.t().getType()) .cast(); - mlir::RankedTensorType lutsTy = - ((mlir::Type)fheLinalgLutOp.luts().getType()) - .cast(); + auto luts = fheLinalgLutOp.luts(); + mlir::RankedTensorType lutsTy = getRankedTensorType(luts); + auto lutElmtTy = lutsTy.getElementType(); // linalg.init_tensor for initial value mlir::Value init = rewriter.create( fheLinalgLutOp.getLoc(), resultTy, mlir::ValueRange{}); auto lutsShape = lutsTy.getShape(); auto lut_size = lutsShape[lutsShape.size() - 1]; + auto indexOfInput = getBroadcastedAffineMap(resultTy, tensorTy, rewriter); // Create the affine maps - llvm::SmallVector maps{ - // Input tensor map - getBroadcastedAffineMap(resultTy, tensorTy, rewriter)}; - maps.reserve(lut_size + 1); - // Create as much affine maps as the size of the lut dimension - for (int64_t i = 0; i < lut_size; i++) - maps.push_back( - getBroadcastedAffineMapMultiLUT(resultTy, lutsTy, i, rewriter)); - // Result map - maps.push_back(getBroadcastedAffineMap(resultTy, resultTy, rewriter)); + llvm::SmallVector maps{indexOfInput, indexOfInput}; // Create the iterator_types - llvm::SmallVector iteratorTypes(resultTy.getShape().size(), - "parallel"); + auto iteratorTypes = parallelIteratorType(resultTy.getShape().size()); + auto integer = [&](auto v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }; + + // We need to know with linalg.generic index to use for lut + // In broadcast case the lut index is inner dimensions of the tensor index + auto tensorShape = tensorTy.getShape(); + auto tensorRank = tensorTy.getShape().size(); + auto lutsRank = lutsShape.size() - 1; // do not count inner dim of luts + auto lutIndexDimAt = tensorRank - lutsRank; + llvm::SmallVector indexLutsToLinalg(lutsRank); + for (auto lutsIndex = 0u; lutsIndex < lutsRank; lutsIndex++) { + auto tensorIndex = lutIndexDimAt + lutsIndex; + if (tensorShape[tensorIndex] != lutsShape[lutsIndex]) { + llvm::errs() << "ERROR: Broadcast only works by having more outer " + "dims.\nConflict: " + << tensorShape[tensorIndex] << " (tensor dim " + << tensorIndex << ") is not compatible with " + << lutsShape[lutsIndex] << " (luts dim " << lutsIndex + << ")\n\n"; + return ::mlir::LogicalResult::failure(); + }; + indexLutsToLinalg[lutsIndex] = tensorIndex; + } + + auto _0_ = integer(0); + auto _1_ = integer(1); + auto lutSizeValue = integer(lut_size); // Create the body of the `linalg.generic` op auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { - mlir::tensor::FromElementsOp lut = - nestedBuilder.create( - fheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size)); - mlir::concretelang::FHE::ApplyLookupTableEintOp lutOp = - nestedBuilder.create( - fheLinalgLutOp.getLoc(), resultTy.getElementType(), blockArgs[0], - lut.result()); + auto loc = fheLinalgLutOp.getLoc(); + auto tElmt = blockArgs[0]; - nestedBuilder.create(fheLinalgLutOp.getLoc(), - lutOp.getResult()); + // %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] : + // tensor to tensor + auto sliceArgDim = lutsShape.size(); + using sliceArg = llvm::SmallVector; + sliceArg offsets(sliceArgDim, _0_); + auto lutsIndex = 0; + for (auto index : indexLutsToLinalg) { + auto offset = nestedBuilder.create(loc, index); + offsets[lutsIndex++] = (mlir::OpFoldResult)offset; + } + sliceArg sizes(sliceArgDim, _1_); + sizes[sliceArgDim - 1] = lutSizeValue; + sliceArg strides(sliceArgDim, _1_); + auto lutTy = mlir::RankedTensorType::get({static_cast(lut_size)}, + lutElmtTy); + mlir::Value lut = nestedBuilder.create( + loc, lutTy, luts, offsets, sizes, strides); + auto lutOp = nestedBuilder.create( + loc, resultTy.getElementType(), tElmt, lut); + + nestedBuilder.create(loc, lutOp.getResult()); }; // Create the `linalg.generic` op llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{fheLinalgLutOp.t()}; - ins.reserve(lut_size + 2); - // We extract one value at a time from one LUT using different maps, so we - // need to pass the LUT `lut_size` time - for (auto i = 0; i < lut_size; i++) - ins.push_back(fheLinalgLutOp.luts()); llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg.mlir index 8f0f7a8fb..df21a3c21 100644 --- a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg.mlir +++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg.mlir @@ -1,21 +1,20 @@ // RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s -//CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> -//CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1, 0)> -//CHECK: #map2 = affine_map<(d0, d1) -> (d0, d1, 1)> -//CHECK: #map3 = affine_map<(d0, d1) -> (d0, d1, 2)> -//CHECK: #map4 = affine_map<(d0, d1) -> (d0, d1, 3)> -//CHECK: func.func @multi_lut(%[[A0:.*]]: tensor<4x4x!FHE.eint<2>>, %[[A1:.*]]: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> { -//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<4x4x!FHE.eint<2>> -//CHECK: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%[[A0]], %[[A1]], %arg1, %arg1, %arg1 : tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>) outs(%[[V0]] : tensor<4x4x!FHE.eint<2>>) { -//CHECK: ^bb0(%[[A2:.*]]: !FHE.eint<2>, %[[A3:.*]]: i64, %[[A4:.*]]: i64, %[[A5:.*]]: i64, %[[A6:.*]]: i64, %[[A7:.*]]: !FHE.eint<2>): -//CHECK: %[[V2:.*]] = tensor.from_elements %[[A3]], %[[A4]], %[[A5]], %[[A6]] : tensor<4xi64> -//CHECK: %[[V3:.*]] = "FHE.apply_lookup_table"(%[[A2]], %[[V2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> -//CHECK: linalg.yield %[[V3]] : !FHE.eint<2> + +//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +//CHECK: func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> { +//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x4x!FHE.eint<2>> +//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x4x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x4x!FHE.eint<2>>) { +//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>): +//CHECK: %[[INDEXA:.*]] = linalg.index 0 : index +//CHECK: %[[INDEXB:.*]] = linalg.index 1 : index +//CHECK: %[[LUT:.*]] = tensor.extract_slice %[[LUTS]][%[[INDEXA]], %[[INDEXB]], 0] [1, 1, 4] [1, 1, 1] : tensor<4x4x4xi64> to tensor<4xi64> +//CHECK: %[[V:.*]] = "FHE.apply_lookup_table"(%[[IN]], %[[LUT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> +//CHECK: linalg.yield %[[V]] : !FHE.eint<2> //CHECK: } -> tensor<4x4x!FHE.eint<2>> -//CHECK: return %[[V1]] : tensor<4x4x!FHE.eint<2>> +//CHECK: return %[[R]] : tensor<4x4x!FHE.eint<2>> //CHECK: } -func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %arg1: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> { - %0 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> +func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %luts: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> { + %0 = "FHELinalg.apply_multi_lookup_table"(%arg0, %luts): (tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> return %0: tensor<4x4x!FHE.eint<2>> } diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg_broadcast.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg_broadcast.mlir index dffeb009f..8361c193f 100644 --- a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg_broadcast.mlir +++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg_broadcast.mlir @@ -1,21 +1,18 @@ // RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s -//CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> -//CHECK: #map1 = affine_map<(d0, d1) -> (d1, 0)> -//CHECK: #map2 = affine_map<(d0, d1) -> (d1, 1)> -//CHECK: #map3 = affine_map<(d0, d1) -> (d1, 2)> -//CHECK: #map4 = affine_map<(d0, d1) -> (d1, 3)> -//CHECK: func.func @multi_lut(%[[A0:.*]]: tensor<4x3x!FHE.eint<2>>, %[[A1:.*]]: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> { -//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<4x3x!FHE.eint<2>> -//CHECK: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%[[A0]], %[[A1]], %arg1, %arg1, %arg1 : tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>) outs(%[[V0]] : tensor<4x3x!FHE.eint<2>>) { -//CHECK: ^bb0(%[[A2:.*]]: !FHE.eint<2>, %[[A3:.*]]: i64, %[[A4:.*]]: i64, %[[A5:.*]]: i64, %[[A6:.*]]: i64, %[[A7:.*]]: !FHE.eint<2>): -//CHECK: %[[V2:.*]] = tensor.from_elements %[[A3]], %[[A4]], %[[A5]], %[[A6]] : tensor<4xi64> -//CHECK: %[[V3:.*]] = "FHE.apply_lookup_table"(%[[A2]], %[[V2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> -//CHECK: linalg.yield %[[V3]] : !FHE.eint<2> +//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +//CHECK: func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> { +//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x3x!FHE.eint<2>> +//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x3x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x3x!FHE.eint<2>>) { +//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>): +//CHECK: %[[INDEX:.*]] = linalg.index 1 : index +//CHECK: %[[LUT:.*]] = tensor.extract_slice %arg1[%[[INDEX]], 0] [1, 4] [1, 1] : tensor<3x4xi64> to tensor<4xi64> +//CHECK: %[[V:.*]] = "FHE.apply_lookup_table"(%[[IN]], %[[LUT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> +//CHECK: linalg.yield %[[V]] : !FHE.eint<2> //CHECK: } -> tensor<4x3x!FHE.eint<2>> -//CHECK: return %[[V1]] : tensor<4x3x!FHE.eint<2>> +//CHECK: return %[[R]] : tensor<4x3x!FHE.eint<2>> //CHECK: } -func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> { - %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> +func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %luts: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> { + %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %luts): (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> return %1: tensor<4x3x!FHE.eint<2>> }