mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test(compiler): Add test for lambdas returning tensors of different integer types
Add tests with lambdas returning constant n-dimensional tensors composed of different integer types (uint8_t, uint16_t, uint32_t, and uint64_t).
This commit is contained in:
committed by
Quentin Bourgerie
parent
f54c0dd8d8
commit
d1960a2a7f
@@ -329,4 +329,72 @@ func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void checkResultTensor(
|
||||
bool &status,
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> &res) {
|
||||
status = false;
|
||||
|
||||
ASSERT_TRUE((*res)
|
||||
->isa<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<T>>>());
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<T>>
|
||||
&resp = (*res)
|
||||
->cast<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<T>>>();
|
||||
|
||||
ASSERT_EQ(resp.getDimensions().size(), (size_t)3);
|
||||
ASSERT_EQ(resp.getDimensions().at(0), 5);
|
||||
ASSERT_EQ(resp.getDimensions().at(1), 3);
|
||||
ASSERT_EQ(resp.getDimensions().at(2), 2);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 5 * 3 * 2);
|
||||
|
||||
for (size_t i = 0; i < 5 * 3 * 2; i++) {
|
||||
ASSERT_EQ(resp.getValue()[i], 1_u64);
|
||||
}
|
||||
|
||||
status = true;
|
||||
}
|
||||
|
||||
class ReturnTensorWithPrecision : public ::testing::TestWithParam<int> {};
|
||||
|
||||
TEST_P(ReturnTensorWithPrecision, return_tensor) {
|
||||
uint64_t precision = GetParam();
|
||||
std::ostringstream mlirProgram;
|
||||
|
||||
mlirProgram << "func @main() -> tensor<5x3x2xi" << precision << "> {\n"
|
||||
<< " %res = arith.constant dense<1> : tensor<5x3x2xi"
|
||||
<< precision << ">\n"
|
||||
<< " return %res : tensor<5x3x2xi" << precision << ">\n"
|
||||
<< "}";
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda =
|
||||
checkedJit(mlirProgram.str(), "main", true);
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
|
||||
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>({});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
bool status;
|
||||
|
||||
if (precision > 64)
|
||||
GTEST_FATAL_FAILURE_("Cannot handle precision > 64 bits");
|
||||
else if (precision > 32)
|
||||
checkResultTensor<uint64_t>(status, res);
|
||||
else if (precision > 16)
|
||||
checkResultTensor<uint32_t>(status, res);
|
||||
else if (precision > 8)
|
||||
checkResultTensor<uint16_t>(status, res);
|
||||
else
|
||||
checkResultTensor<uint8_t>(status, res);
|
||||
|
||||
ASSERT_TRUE(status);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ReturnTensor, ReturnTensorWithPrecision,
|
||||
::testing::Values(1, 7, 8, 9, 15, 16, 17, 31, 32, 33,
|
||||
63, 64));
|
||||
|
||||
Reference in New Issue
Block a user