fix(compiler): Fix warnings and naming in test for parametric slices

- Rename `extract_slice_parametric_2x2` to
  `extract_slice_parametric_2x2x2x2` to reflect the 4-dimensional
  structure of the tiles.

- Make the array with the specification of the dimensions in
  `extract_slice_parametric_2x2x2x2` a `constexpr` in order to prevent
  the array `A` from being treated as a variable-length array.

- Cast the expression for the expected size of the result to `size_t`
  and change the type of the induction variables of the loop nest
  producing the initial values for the array `A` to `int64_t` to avoid
  warnings about the comparison of integer expressions with different
  signedness.
This commit is contained in:
Andi Drebes
2021-12-16 10:46:07 +01:00
parent 27ca5122bc
commit 77b7aa2f7c

View File

@@ -126,8 +126,8 @@ func @main(%t: tensor<8x4x!HLFHE.eint<6>>, %y: index, %x: index) -> tensor<2x2x!
}
// Extracts 4D tiles from a 4D tensor
TEST(End2EndJit_EncryptedTensor_4D, extract_slice_parametric_2x2) {
const int64_t dimSizes[4] = {8, 4, 5, 3};
TEST(End2EndJit_EncryptedTensor_4D, extract_slice_parametric_2x2x2x2) {
constexpr int64_t dimSizes[4] = {8, 4, 5, 3};
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<8x4x5x3x!HLFHE.eint<6>>, %d0: index, %d1: index, %d2: index, %d3: index) -> tensor<2x2x2x2x!HLFHE.eint<6>> {
@@ -138,10 +138,10 @@ func @main(%t: tensor<8x4x5x3x!HLFHE.eint<6>>, %d0: index, %d1: index, %d2: inde
uint8_t A[dimSizes[0]][dimSizes[1]][dimSizes[2]][dimSizes[3]];
// Fill with some reproducible pattern
for (size_t d0 = 0; d0 < dimSizes[0]; d0++) {
for (size_t d1 = 0; d1 < dimSizes[1]; d1++) {
for (size_t d2 = 0; d2 < dimSizes[2]; d2++) {
for (size_t d3 = 0; d3 < dimSizes[3]; d3++) {
for (int64_t d0 = 0; d0 < dimSizes[0]; d0++) {
for (int64_t d1 = 0; d1 < dimSizes[1]; d1++) {
for (int64_t d2 = 0; d2 < dimSizes[2]; d2++) {
for (int64_t d3 = 0; d3 < dimSizes[3]; d3++) {
A[d0][d1][d2][d3] = d0 + d1 + d2 + d3;
}
}
@@ -175,7 +175,7 @@ func @main(%t: tensor<8x4x5x3x!HLFHE.eint<6>>, %d0: index, %d1: index, %d2: inde
{&argT, &argD0, &argD1, &argD2, &argD3});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), 2 * 2 * 2 * 2);
ASSERT_EQ(res->size(), (size_t)(2 * 2 * 2 * 2));
for (size_t rd0 = 0; rd0 < 2; rd0++) {
for (size_t rd1 = 0; rd1 < 2; rd1++) {