test(compiler): Add test for lambdas returning tensors of different integer types

Add tests with lambdas returning constant n-dimensional tensors
composed of different integer types (uint8_t, uint16_t, uint32_t, and
uint64_t).
This commit is contained in:
Andi Drebes
2021-11-09 14:23:02 +01:00
committed by Quentin Bourgerie
parent f54c0dd8d8
commit d1960a2a7f

View File

@@ -329,4 +329,72 @@ func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> {
}
}
}
}
}
template <typename T>
void checkResultTensor(
bool &status,
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> &res) {
status = false;
ASSERT_TRUE((*res)
->isa<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<T>>>());
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<T>>
&resp = (*res)
->cast<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<T>>>();
ASSERT_EQ(resp.getDimensions().size(), (size_t)3);
ASSERT_EQ(resp.getDimensions().at(0), 5);
ASSERT_EQ(resp.getDimensions().at(1), 3);
ASSERT_EQ(resp.getDimensions().at(2), 2);
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 5 * 3 * 2);
for (size_t i = 0; i < 5 * 3 * 2; i++) {
ASSERT_EQ(resp.getValue()[i], 1_u64);
}
status = true;
}
class ReturnTensorWithPrecision : public ::testing::TestWithParam<int> {};
TEST_P(ReturnTensorWithPrecision, return_tensor) {
uint64_t precision = GetParam();
std::ostringstream mlirProgram;
mlirProgram << "func @main() -> tensor<5x3x2xi" << precision << "> {\n"
<< " %res = arith.constant dense<1> : tensor<5x3x2xi"
<< precision << ">\n"
<< " return %res : tensor<5x3x2xi" << precision << ">\n"
<< "}";
mlir::zamalang::JitCompilerEngine::Lambda lambda =
checkedJit(mlirProgram.str(), "main", true);
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>({});
ASSERT_EXPECTED_SUCCESS(res);
bool status;
if (precision > 64)
GTEST_FATAL_FAILURE_("Cannot handle precision > 64 bits");
else if (precision > 32)
checkResultTensor<uint64_t>(status, res);
else if (precision > 16)
checkResultTensor<uint32_t>(status, res);
else if (precision > 8)
checkResultTensor<uint16_t>(status, res);
else
checkResultTensor<uint8_t>(status, res);
ASSERT_TRUE(status);
}
INSTANTIATE_TEST_SUITE_P(ReturnTensor, ReturnTensorWithPrecision,
::testing::Values(1, 7, 8, 9, 15, 16, 17, 31, 32, 33,
63, 64));