From d1960a2a7fb1875491a8183d33f4296df7a200b2 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 9 Nov 2021 14:23:02 +0100 Subject: [PATCH] test(compiler): Add test for lambdas returning tensors of different integer types Add tests with lambdas returning constant n-dimensional tensors composed of different integer types (uint8_t, uint16_t, uint32_t, and uint64_t). --- .../unittest/end_to_end_jit_clear_tensor.cc | 70 ++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/compiler/tests/unittest/end_to_end_jit_clear_tensor.cc b/compiler/tests/unittest/end_to_end_jit_clear_tensor.cc index e046c0aa4..2cc78c1f0 100644 --- a/compiler/tests/unittest/end_to_end_jit_clear_tensor.cc +++ b/compiler/tests/unittest/end_to_end_jit_clear_tensor.cc @@ -329,4 +329,72 @@ func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> { } } } -} \ No newline at end of file +} + +template +void checkResultTensor( + bool &status, + llvm::Expected> &res) { + status = false; + + ASSERT_TRUE((*res) + ->isa>>()); + + mlir::zamalang::TensorLambdaArgument> + &resp = (*res) + ->cast>>(); + + ASSERT_EQ(resp.getDimensions().size(), (size_t)3); + ASSERT_EQ(resp.getDimensions().at(0), 5); + ASSERT_EQ(resp.getDimensions().at(1), 3); + ASSERT_EQ(resp.getDimensions().at(2), 2); + + ASSERT_EXPECTED_VALUE(resp.getNumElements(), 5 * 3 * 2); + + for (size_t i = 0; i < 5 * 3 * 2; i++) { + ASSERT_EQ(resp.getValue()[i], 1_u64); + } + + status = true; +} + +class ReturnTensorWithPrecision : public ::testing::TestWithParam {}; + +TEST_P(ReturnTensorWithPrecision, return_tensor) { + uint64_t precision = GetParam(); + std::ostringstream mlirProgram; + + mlirProgram << "func @main() -> tensor<5x3x2xi" << precision << "> {\n" + << " %res = arith.constant dense<1> : tensor<5x3x2xi" + << precision << ">\n" + << " return %res : tensor<5x3x2xi" << precision << ">\n" + << "}"; + + mlir::zamalang::JitCompilerEngine::Lambda lambda = + checkedJit(mlirProgram.str(), "main", true); + + llvm::Expected> res = + lambda.operator()>({}); + + ASSERT_EXPECTED_SUCCESS(res); + bool status; + + if (precision > 64) + GTEST_FATAL_FAILURE_("Cannot handle precision > 64 bits"); + else if (precision > 32) + checkResultTensor(status, res); + else if (precision > 16) + checkResultTensor(status, res); + else if (precision > 8) + checkResultTensor(status, res); + else + checkResultTensor(status, res); + + ASSERT_TRUE(status); +} + +INSTANTIATE_TEST_SUITE_P(ReturnTensor, ReturnTensorWithPrecision, + ::testing::Values(1, 7, 8, 9, 15, 16, 17, 31, 32, 33, + 63, 64));