mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
@@ -430,10 +430,6 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
///
|
||||
/// #maps_0 = [
|
||||
/// affine_map<(d0, d1) -> (d0, d1)>
|
||||
/// affine_map<(d0, d1) -> (d1, 0)>
|
||||
/// affine_map<(d0, d1) -> (d1, 1)>
|
||||
/// affine_map<(d0, d1) -> (d1, 2)>
|
||||
/// affine_map<(d0, d1) -> (d1, 3)>
|
||||
/// ]
|
||||
/// #attributes_0 {
|
||||
/// indexing_maps = #maps_0,
|
||||
@@ -442,18 +438,18 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
/// %init = linalg.init_tensor [4, 3]
|
||||
/// : tensor<4x3x!FHE.eint<2>>
|
||||
/// %res = linalg.generic {
|
||||
/// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!FHE.eint<p>>,
|
||||
/// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>)
|
||||
/// ins(%t, %luts: tensor<4x3x!FHE.eint<p>>)
|
||||
/// outs(%init : tensor<4x3x!FHE.eint<2>>)
|
||||
/// {
|
||||
/// ^bb0(%arg0: !FHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
|
||||
/// %arg4: i64, %arg5: !FHE.eint<2>):
|
||||
/// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 :
|
||||
/// tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, %lut)
|
||||
/// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1
|
||||
/// : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1
|
||||
/// : i32, polynomialSize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>,
|
||||
/// tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
/// ^bb0(%arg0: !FHE.eint<2>):
|
||||
/// %i_lut = linalg.index 0 ; index
|
||||
/// %lut = tensor.extract_slice %arg21[%i_lut, 0] [1, lut_size] [1,
|
||||
/// 1] : ... tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0,
|
||||
/// %lut) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension
|
||||
/// = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS
|
||||
/// = -1 : i32, polynomialSize = -1 : i32} :
|
||||
/// (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{2}>
|
||||
/// linalg.yield %0 : !FHE.eint<2>
|
||||
/// }
|
||||
/// }
|
||||
@@ -477,55 +473,83 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType tensorTy = ((mlir::Type)fheLinalgLutOp.t().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType lutsTy =
|
||||
((mlir::Type)fheLinalgLutOp.luts().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
auto luts = fheLinalgLutOp.luts();
|
||||
mlir::RankedTensorType lutsTy = getRankedTensorType(luts);
|
||||
auto lutElmtTy = lutsTy.getElementType();
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
fheLinalgLutOp.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
auto lutsShape = lutsTy.getShape();
|
||||
auto lut_size = lutsShape[lutsShape.size() - 1];
|
||||
auto indexOfInput = getBroadcastedAffineMap(resultTy, tensorTy, rewriter);
|
||||
// Create the affine maps
|
||||
llvm::SmallVector<mlir::AffineMap> maps{
|
||||
// Input tensor map
|
||||
getBroadcastedAffineMap(resultTy, tensorTy, rewriter)};
|
||||
maps.reserve(lut_size + 1);
|
||||
// Create as much affine maps as the size of the lut dimension
|
||||
for (int64_t i = 0; i < lut_size; i++)
|
||||
maps.push_back(
|
||||
getBroadcastedAffineMapMultiLUT(resultTy, lutsTy, i, rewriter));
|
||||
// Result map
|
||||
maps.push_back(getBroadcastedAffineMap(resultTy, resultTy, rewriter));
|
||||
llvm::SmallVector<mlir::AffineMap> maps{indexOfInput, indexOfInput};
|
||||
|
||||
// Create the iterator_types
|
||||
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
|
||||
"parallel");
|
||||
auto iteratorTypes = parallelIteratorType(resultTy.getShape().size());
|
||||
|
||||
auto integer = [&](auto v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
};
|
||||
|
||||
// We need to know with linalg.generic index to use for lut
|
||||
// In broadcast case the lut index is inner dimensions of the tensor index
|
||||
auto tensorShape = tensorTy.getShape();
|
||||
auto tensorRank = tensorTy.getShape().size();
|
||||
auto lutsRank = lutsShape.size() - 1; // do not count inner dim of luts
|
||||
auto lutIndexDimAt = tensorRank - lutsRank;
|
||||
llvm::SmallVector<uint> indexLutsToLinalg(lutsRank);
|
||||
for (auto lutsIndex = 0u; lutsIndex < lutsRank; lutsIndex++) {
|
||||
auto tensorIndex = lutIndexDimAt + lutsIndex;
|
||||
if (tensorShape[tensorIndex] != lutsShape[lutsIndex]) {
|
||||
llvm::errs() << "ERROR: Broadcast only works by having more outer "
|
||||
"dims.\nConflict: "
|
||||
<< tensorShape[tensorIndex] << " (tensor dim "
|
||||
<< tensorIndex << ") is not compatible with "
|
||||
<< lutsShape[lutsIndex] << " (luts dim " << lutsIndex
|
||||
<< ")\n\n";
|
||||
return ::mlir::LogicalResult::failure();
|
||||
};
|
||||
indexLutsToLinalg[lutsIndex] = tensorIndex;
|
||||
}
|
||||
|
||||
auto _0_ = integer(0);
|
||||
auto _1_ = integer(1);
|
||||
auto lutSizeValue = integer(lut_size);
|
||||
// Create the body of the `linalg.generic` op
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::tensor::FromElementsOp lut =
|
||||
nestedBuilder.create<mlir::tensor::FromElementsOp>(
|
||||
fheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size));
|
||||
mlir::concretelang::FHE::ApplyLookupTableEintOp lutOp =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
|
||||
fheLinalgLutOp.getLoc(), resultTy.getElementType(), blockArgs[0],
|
||||
lut.result());
|
||||
auto loc = fheLinalgLutOp.getLoc();
|
||||
auto tElmt = blockArgs[0];
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(fheLinalgLutOp.getLoc(),
|
||||
lutOp.getResult());
|
||||
// %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] :
|
||||
// tensor<NxKxi64> to tensor<Kxi64>
|
||||
auto sliceArgDim = lutsShape.size();
|
||||
using sliceArg = llvm::SmallVector<mlir::OpFoldResult>;
|
||||
sliceArg offsets(sliceArgDim, _0_);
|
||||
auto lutsIndex = 0;
|
||||
for (auto index : indexLutsToLinalg) {
|
||||
auto offset = nestedBuilder.create<linalg::IndexOp>(loc, index);
|
||||
offsets[lutsIndex++] = (mlir::OpFoldResult)offset;
|
||||
}
|
||||
sliceArg sizes(sliceArgDim, _1_);
|
||||
sizes[sliceArgDim - 1] = lutSizeValue;
|
||||
sliceArg strides(sliceArgDim, _1_);
|
||||
auto lutTy = mlir::RankedTensorType::get({static_cast<int64_t>(lut_size)},
|
||||
lutElmtTy);
|
||||
mlir::Value lut = nestedBuilder.create<tensor::ExtractSliceOp>(
|
||||
loc, lutTy, luts, offsets, sizes, strides);
|
||||
auto lutOp = nestedBuilder.create<FHE::ApplyLookupTableEintOp>(
|
||||
loc, resultTy.getElementType(), tElmt, lut);
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(loc, lutOp.getResult());
|
||||
};
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value> ins{fheLinalgLutOp.t()};
|
||||
ins.reserve(lut_size + 2);
|
||||
// We extract one value at a time from one LUT using different maps, so we
|
||||
// need to pass the LUT `lut_size` time
|
||||
for (auto i = 0; i < lut_size; i++)
|
||||
ins.push_back(fheLinalgLutOp.luts());
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
//CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1, 0)>
|
||||
//CHECK: #map2 = affine_map<(d0, d1) -> (d0, d1, 1)>
|
||||
//CHECK: #map3 = affine_map<(d0, d1) -> (d0, d1, 2)>
|
||||
//CHECK: #map4 = affine_map<(d0, d1) -> (d0, d1, 3)>
|
||||
//CHECK: func.func @multi_lut(%[[A0:.*]]: tensor<4x4x!FHE.eint<2>>, %[[A1:.*]]: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
|
||||
//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<4x4x!FHE.eint<2>>
|
||||
//CHECK: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%[[A0]], %[[A1]], %arg1, %arg1, %arg1 : tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>) outs(%[[V0]] : tensor<4x4x!FHE.eint<2>>) {
|
||||
//CHECK: ^bb0(%[[A2:.*]]: !FHE.eint<2>, %[[A3:.*]]: i64, %[[A4:.*]]: i64, %[[A5:.*]]: i64, %[[A6:.*]]: i64, %[[A7:.*]]: !FHE.eint<2>):
|
||||
//CHECK: %[[V2:.*]] = tensor.from_elements %[[A3]], %[[A4]], %[[A5]], %[[A6]] : tensor<4xi64>
|
||||
//CHECK: %[[V3:.*]] = "FHE.apply_lookup_table"(%[[A2]], %[[V2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
//CHECK: linalg.yield %[[V3]] : !FHE.eint<2>
|
||||
|
||||
//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//CHECK: func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
|
||||
//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x4x!FHE.eint<2>>
|
||||
//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x4x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x4x!FHE.eint<2>>) {
|
||||
//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>):
|
||||
//CHECK: %[[INDEXA:.*]] = linalg.index 0 : index
|
||||
//CHECK: %[[INDEXB:.*]] = linalg.index 1 : index
|
||||
//CHECK: %[[LUT:.*]] = tensor.extract_slice %[[LUTS]][%[[INDEXA]], %[[INDEXB]], 0] [1, 1, 4] [1, 1, 1] : tensor<4x4x4xi64> to tensor<4xi64>
|
||||
//CHECK: %[[V:.*]] = "FHE.apply_lookup_table"(%[[IN]], %[[LUT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
//CHECK: linalg.yield %[[V]] : !FHE.eint<2>
|
||||
//CHECK: } -> tensor<4x4x!FHE.eint<2>>
|
||||
//CHECK: return %[[V1]] : tensor<4x4x!FHE.eint<2>>
|
||||
//CHECK: return %[[R]] : tensor<4x4x!FHE.eint<2>>
|
||||
//CHECK: }
|
||||
func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %arg1: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
|
||||
%0 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>>
|
||||
func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %luts: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
|
||||
%0 = "FHELinalg.apply_multi_lookup_table"(%arg0, %luts): (tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>>
|
||||
return %0: tensor<4x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
//CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//CHECK: #map1 = affine_map<(d0, d1) -> (d1, 0)>
|
||||
//CHECK: #map2 = affine_map<(d0, d1) -> (d1, 1)>
|
||||
//CHECK: #map3 = affine_map<(d0, d1) -> (d1, 2)>
|
||||
//CHECK: #map4 = affine_map<(d0, d1) -> (d1, 3)>
|
||||
//CHECK: func.func @multi_lut(%[[A0:.*]]: tensor<4x3x!FHE.eint<2>>, %[[A1:.*]]: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<4x3x!FHE.eint<2>>
|
||||
//CHECK: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%[[A0]], %[[A1]], %arg1, %arg1, %arg1 : tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>) outs(%[[V0]] : tensor<4x3x!FHE.eint<2>>) {
|
||||
//CHECK: ^bb0(%[[A2:.*]]: !FHE.eint<2>, %[[A3:.*]]: i64, %[[A4:.*]]: i64, %[[A5:.*]]: i64, %[[A6:.*]]: i64, %[[A7:.*]]: !FHE.eint<2>):
|
||||
//CHECK: %[[V2:.*]] = tensor.from_elements %[[A3]], %[[A4]], %[[A5]], %[[A6]] : tensor<4xi64>
|
||||
//CHECK: %[[V3:.*]] = "FHE.apply_lookup_table"(%[[A2]], %[[V2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
//CHECK: linalg.yield %[[V3]] : !FHE.eint<2>
|
||||
//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//CHECK: func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x3x!FHE.eint<2>>
|
||||
//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x3x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x3x!FHE.eint<2>>) {
|
||||
//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>):
|
||||
//CHECK: %[[INDEX:.*]] = linalg.index 1 : index
|
||||
//CHECK: %[[LUT:.*]] = tensor.extract_slice %arg1[%[[INDEX]], 0] [1, 4] [1, 1] : tensor<3x4xi64> to tensor<4xi64>
|
||||
//CHECK: %[[V:.*]] = "FHE.apply_lookup_table"(%[[IN]], %[[LUT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
//CHECK: linalg.yield %[[V]] : !FHE.eint<2>
|
||||
//CHECK: } -> tensor<4x3x!FHE.eint<2>>
|
||||
//CHECK: return %[[V1]] : tensor<4x3x!FHE.eint<2>>
|
||||
//CHECK: return %[[R]] : tensor<4x3x!FHE.eint<2>>
|
||||
//CHECK: }
|
||||
func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>>
|
||||
func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %luts: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %luts): (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>>
|
||||
return %1: tensor<4x3x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user