diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 4fe1140dc..14f205adb 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -311,12 +311,19 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, rawArg[offset] = &inputs[offset]; offset++; } - // strides is an array of size equals to numDim - for (size_t i = 0; i < shape.size(); i++) { - inputs[offset] = (void *)0; - rawArg[offset] = &inputs[offset]; - offset++; + + // Set the stride for each dimension, equal to the product of the + // following dimensions. + int64_t stride = 1; + + for (ssize_t i = shape.size() - 1; i >= 0; i--) { + inputs[offset + i] = (void *)stride; + rawArg[offset + i] = &inputs[offset + i]; + stride *= shape[i]; } + + offset += shape.size(); + return llvm::Error::success(); } diff --git a/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc b/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc index 72004ed6b..7551f9eb6 100644 --- a/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc +++ b/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc @@ -88,6 +88,108 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> { } } +TEST(End2EndJit_EncryptedTensor_2D, extract_slice_parametric_2x2) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%t: tensor<8x4x!HLFHE.eint<6>>, %y: index, %x: index) -> tensor<2x2x!HLFHE.eint<6>> { + %r = tensor.extract_slice %t[%y, %x][2, 2][1, 1] : tensor<8x4x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>> + return %r : tensor<2x2x!HLFHE.eint<6>> +} +)XXX"); + const size_t rows = 8; + const size_t cols = 4; + const size_t tileSize = 2; + const uint8_t A[rows][cols] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 0, 1, 2}, + {3, 4, 5, 6}, {7, 8, 9, 0}, {1, 2, 3, 4}, + {5, 6, 7, 8}, {9, 0, 1, 2}}; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + argT(llvm::ArrayRef((const uint8_t *)A, rows * cols), + {rows, cols}); + + for (uint64_t y = 0; y <= rows - tileSize; y += tileSize) { + for (uint64_t x = 0; x <= cols - tileSize; x += tileSize) { + mlir::zamalang::IntLambdaArgument argY(y); + mlir::zamalang::IntLambdaArgument argX(x); + + llvm::Expected> res = + lambda.operator()>({&argT, &argY, &argX}); + + ASSERT_EXPECTED_SUCCESS(res); + ASSERT_EQ(res->size(), tileSize * tileSize); + ASSERT_EQ((*res)[0], A[y][x]); + ASSERT_EQ((*res)[1], A[y][x + 1]); + ASSERT_EQ((*res)[2], A[y + 1][x]); + ASSERT_EQ((*res)[3], A[y + 1][x + 1]); + } + } +} + +// Extracts 4D tiles from a 4D tensor +TEST(End2EndJit_EncryptedTensor_4D, extract_slice_parametric_2x2) { + const int64_t dimSizes[4] = {8, 4, 5, 3}; + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%t: tensor<8x4x5x3x!HLFHE.eint<6>>, %d0: index, %d1: index, %d2: index, %d3: index) -> tensor<2x2x2x2x!HLFHE.eint<6>> { + %r = tensor.extract_slice %t[%d0, %d1, %d2, %d3][2, 2, 2, 2][1, 1, 1, 1] : tensor<8x4x5x3x!HLFHE.eint<6>> to tensor<2x2x2x2x!HLFHE.eint<6>> + return %r : tensor<2x2x2x2x!HLFHE.eint<6>> +} +)XXX"); + uint8_t A[dimSizes[0]][dimSizes[1]][dimSizes[2]][dimSizes[3]]; + + // Fill with some reproducible pattern + for (size_t d0 = 0; d0 < dimSizes[0]; d0++) { + for (size_t d1 = 0; d1 < dimSizes[1]; d1++) { + for (size_t d2 = 0; d2 < dimSizes[2]; d2++) { + for (size_t d3 = 0; d3 < dimSizes[3]; d3++) { + A[d0][d1][d2][d3] = d0 + d1 + d2 + d3; + } + } + } + } + + const size_t ncoords = 5; + const size_t coords[ncoords][4] = { + {0, 0, 0, 0}, {1, 1, 1, 1}, {6, 2, 0, 1}, {3, 1, 2, 0}, {3, 1, 2, 1}}; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + argT(llvm::ArrayRef((const uint8_t *)A, + dimSizes[0] * dimSizes[1] * dimSizes[2] * + dimSizes[3]), + dimSizes); + + for (uint64_t i = 0; i < ncoords; i++) { + size_t d0 = coords[i][0]; + size_t d1 = coords[i][1]; + size_t d2 = coords[i][2]; + size_t d3 = coords[i][3]; + + mlir::zamalang::IntLambdaArgument argD0(d0); + mlir::zamalang::IntLambdaArgument argD1(d1); + mlir::zamalang::IntLambdaArgument argD2(d2); + mlir::zamalang::IntLambdaArgument argD3(d3); + + llvm::Expected> res = + lambda.operator()>( + {&argT, &argD0, &argD1, &argD2, &argD3}); + + ASSERT_EXPECTED_SUCCESS(res); + ASSERT_EQ(res->size(), 2 * 2 * 2 * 2); + + for (size_t rd0 = 0; rd0 < 2; rd0++) { + for (size_t rd1 = 0; rd1 < 2; rd1++) { + for (size_t rd2 = 0; rd2 < 2; rd2++) { + for (size_t rd3 = 0; rd3 < 2; rd3++) { + ASSERT_EQ((*res)[rd0 * 8 + rd1 * 4 + rd2 * 2 + rd3], + A[d0 + rd0][d1 + rd1][d2 + rd2][d3 + rd3]); + } + } + } + } + } +} + TEST(End2EndJit_EncryptedTensor_2D, extract_slice_stride) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(