fix: apply_mapped_lookup_table, remove costly workaround

Closes #880
This commit is contained in:
rudy
2023-01-13 17:27:28 +01:00
committed by rudy-6-4
parent d1ddd60a23
commit b7668a7256

View File

@@ -299,9 +299,6 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
/// specification for the iteration dimensions and appropriate operations
/// managing the accumulator of `linalg.generic`.
///
/// The current implementation does not rely on 'tensor.extract_slice'
/// because of a bug in lowering this operation.
///
/// Example:
/// %res = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
/// : (tensor<2x3x!FHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
@@ -317,17 +314,8 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
/// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>) {
/// ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
/// !TFHE.glwe<{_,_,_}{2}>): // no predecessors
/// // SHOULD BE
/// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1]
/// : tensor<5x4xi64> to tensor<4xi64>
/// // BUT IS
/// %i0 = arith.constant 0 : index
/// ...
/// %i3 = arith.constant 3 : index
/// %e0 = tensor.extract %arg5[%lut_idx, %i0] : tensor<5x4xi64>
/// ...
/// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64>
/// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64>
/// %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]])
/// {baseLogBS = -1 : i32, baseLogKS = -1 : i32,
/// glweDimension = -1 : i32,
@@ -383,49 +371,23 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
// Create the body of the `linalg.generic` op
// %arg0 is an element of t (encrypted int)
// %arg1 is an element of map (i64)
// %arg1 is the lut index (i64)
// %arg2 is the output element
auto lambdaBlock = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
auto tElmt = blockArgs[0];
auto lutIdx = blockArgs[1];
auto indexTy = rewriter.getIndexType();
// %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] :
// tensor<NxKxi64> to tensor<Kxi64>
mlir::Value lut;
const bool WORKAROUND_EXTRACT_SLICE = true;
if (!WORKAROUND_EXTRACT_SLICE) {
sliceArg offsets{lutIdx, _0_};
sliceArg sizes{_1_, lutSizeValue};
sliceArg strides{_1_, _1_};
auto lutTy = mlir::RankedTensorType::get(
{static_cast<int64_t>(lutSize)}, lutElmtTy);
lut = nestedBuilder.create<tensor::ExtractSliceOp>(
loc, lutTy, luts, offsets, sizes, strides);
} else {
// WORKAROUND BEGIN
// A bug in linalg-bufferize prevents rank reduction in extract_slice
// Reshaping does not work either or is too complicated so let's rebuild
// the tensor from scratch
llvm::SmallVector<mlir::Value> consts;
llvm::SmallVector<mlir::Value> extracts;
for (int i = 0; i < lutSize; i++) {
consts.push_back(
// %5 = arith.constant(<i> : index) : index
nestedBuilder.create<mlir::arith::ConstantOp>(
loc, indexTy, rewriter.getIndexAttr(i)));
}
for (int i = 0; i < lutSize; i++) {
extracts.push_back(
// %8 = tensor.extract %luts[<lutIdx>, <i>] : ...
nestedBuilder.create<tensor::ExtractOp>(
loc, luts, mlir::ValueRange({lutIdx, consts[i]})));
}
// %12 = tensor.from_elements %8, ... : ...
lut = nestedBuilder.create<tensor::FromElementsOp>(loc, extracts);
} // WORKAROUND END
sliceArg offsets{lutIdx, _0_};
sliceArg sizes{_1_, lutSizeValue};
sliceArg strides{_1_, _1_};
auto lutTy = mlir::RankedTensorType::get({static_cast<int64_t>(lutSize)},
lutElmtTy);
mlir::Value lut = nestedBuilder.create<tensor::ExtractSliceOp>(
loc, lutTy, luts, offsets, sizes, strides);
// %res1 = apply_lookup_table %arg0 %lut
auto lookup = nestedBuilder.create<FHE::ApplyLookupTableEintOp>(
loc, elementTy, tElmt, lut);