From 16d0502f568515ca9efbd9866362f6b0bb6c5d96 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 10 Dec 2021 15:51:12 +0100 Subject: [PATCH] fix(compiler): Initialize strides of memref parameters when JIT-invoking a function Upon invocation of a function with memref arguments, the strides for all dimensions are currently set to 0. This causes dynamic offsets to be calculated incorrectly in the function body. This patch replaces the placeholder values with the actual strides for each dimension and adds a test with parametric slice extraction from a tensor that triggers dynamic indexing. --- compiler/lib/Support/Jit.cpp | 17 ++- .../end_to_end_jit_encrypted_tensor.cc | 102 ++++++++++++++++++ 2 files changed, 114 insertions(+), 5 deletions(-) diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 4fe1140dc..14f205adb 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -311,12 +311,19 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, rawArg[offset] = &inputs[offset]; offset++; } - // strides is an array of size equals to numDim - for (size_t i = 0; i < shape.size(); i++) { - inputs[offset] = (void *)0; - rawArg[offset] = &inputs[offset]; - offset++; + + // Set the stride for each dimension, equal to the product of the + // following dimensions. + int64_t stride = 1; + + for (ssize_t i = shape.size() - 1; i >= 0; i--) { + inputs[offset + i] = (void *)stride; + rawArg[offset + i] = &inputs[offset + i]; + stride *= shape[i]; } + + offset += shape.size(); + return llvm::Error::success(); } diff --git a/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc b/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc index 72004ed6b..7551f9eb6 100644 --- a/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc +++ b/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc @@ -88,6 +88,108 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> { } } +TEST(End2EndJit_EncryptedTensor_2D, extract_slice_parametric_2x2) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%t: tensor<8x4x!HLFHE.eint<6>>, %y: index, %x: index) -> tensor<2x2x!HLFHE.eint<6>> { + %r = tensor.extract_slice %t[%y, %x][2, 2][1, 1] : tensor<8x4x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>> + return %r : tensor<2x2x!HLFHE.eint<6>> +} +)XXX"); + const size_t rows = 8; + const size_t cols = 4; + const size_t tileSize = 2; + const uint8_t A[rows][cols] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 0, 1, 2}, + {3, 4, 5, 6}, {7, 8, 9, 0}, {1, 2, 3, 4}, + {5, 6, 7, 8}, {9, 0, 1, 2}}; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + argT(llvm::ArrayRef((const uint8_t *)A, rows * cols), + {rows, cols}); + + for (uint64_t y = 0; y <= rows - tileSize; y += tileSize) { + for (uint64_t x = 0; x <= cols - tileSize; x += tileSize) { + mlir::zamalang::IntLambdaArgument argY(y); + mlir::zamalang::IntLambdaArgument argX(x); + + llvm::Expected> res = + lambda.operator()>({&argT, &argY, &argX}); + + ASSERT_EXPECTED_SUCCESS(res); + ASSERT_EQ(res->size(), tileSize * tileSize); + ASSERT_EQ((*res)[0], A[y][x]); + ASSERT_EQ((*res)[1], A[y][x + 1]); + ASSERT_EQ((*res)[2], A[y + 1][x]); + ASSERT_EQ((*res)[3], A[y + 1][x + 1]); + } + } +} + +// Extracts 4D tiles from a 4D tensor +TEST(End2EndJit_EncryptedTensor_4D, extract_slice_parametric_2x2) { + const int64_t dimSizes[4] = {8, 4, 5, 3}; + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%t: tensor<8x4x5x3x!HLFHE.eint<6>>, %d0: index, %d1: index, %d2: index, %d3: index) -> tensor<2x2x2x2x!HLFHE.eint<6>> { + %r = tensor.extract_slice %t[%d0, %d1, %d2, %d3][2, 2, 2, 2][1, 1, 1, 1] : tensor<8x4x5x3x!HLFHE.eint<6>> to tensor<2x2x2x2x!HLFHE.eint<6>> + return %r : tensor<2x2x2x2x!HLFHE.eint<6>> +} +)XXX"); + uint8_t A[dimSizes[0]][dimSizes[1]][dimSizes[2]][dimSizes[3]]; + + // Fill with some reproducible pattern + for (size_t d0 = 0; d0 < dimSizes[0]; d0++) { + for (size_t d1 = 0; d1 < dimSizes[1]; d1++) { + for (size_t d2 = 0; d2 < dimSizes[2]; d2++) { + for (size_t d3 = 0; d3 < dimSizes[3]; d3++) { + A[d0][d1][d2][d3] = d0 + d1 + d2 + d3; + } + } + } + } + + const size_t ncoords = 5; + const size_t coords[ncoords][4] = { + {0, 0, 0, 0}, {1, 1, 1, 1}, {6, 2, 0, 1}, {3, 1, 2, 0}, {3, 1, 2, 1}}; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + argT(llvm::ArrayRef((const uint8_t *)A, + dimSizes[0] * dimSizes[1] * dimSizes[2] * + dimSizes[3]), + dimSizes); + + for (uint64_t i = 0; i < ncoords; i++) { + size_t d0 = coords[i][0]; + size_t d1 = coords[i][1]; + size_t d2 = coords[i][2]; + size_t d3 = coords[i][3]; + + mlir::zamalang::IntLambdaArgument argD0(d0); + mlir::zamalang::IntLambdaArgument argD1(d1); + mlir::zamalang::IntLambdaArgument argD2(d2); + mlir::zamalang::IntLambdaArgument argD3(d3); + + llvm::Expected> res = + lambda.operator()>( + {&argT, &argD0, &argD1, &argD2, &argD3}); + + ASSERT_EXPECTED_SUCCESS(res); + ASSERT_EQ(res->size(), 2 * 2 * 2 * 2); + + for (size_t rd0 = 0; rd0 < 2; rd0++) { + for (size_t rd1 = 0; rd1 < 2; rd1++) { + for (size_t rd2 = 0; rd2 < 2; rd2++) { + for (size_t rd3 = 0; rd3 < 2; rd3++) { + ASSERT_EQ((*res)[rd0 * 8 + rd1 * 4 + rd2 * 2 + rd3], + A[d0 + rd0][d1 + rd1][d2 + rd2][d3 + rd3]); + } + } + } + } + } +} + TEST(End2EndJit_EncryptedTensor_2D, extract_slice_stride) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(