feat: support more dtype for scalars/tensors

dtype supported now: uint8, uint16, uint32, uint64
This commit is contained in:
youben11
2021-12-10 10:36:04 +01:00
committed by Ayoub Benaissa
parent 550318f67e
commit 60b2cfd9b7
5 changed files with 143 additions and 12 deletions

View File

@@ -143,14 +143,38 @@ uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) {
return arg->getValue();
}
lambdaArgument lambdaArgumentFromTensor(std::vector<uint8_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument lambdaArgumentFromTensorU8(std::vector<uint8_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument tensor_arg{
std::make_shared<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>>(data, dimensions)};
return tensor_arg;
}
lambdaArgument lambdaArgumentFromTensorU16(std::vector<uint16_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument tensor_arg{
std::make_shared<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint16_t>>>(data, dimensions)};
return tensor_arg;
}
lambdaArgument lambdaArgumentFromTensorU32(std::vector<uint32_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument tensor_arg{
std::make_shared<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint32_t>>>(data, dimensions)};
return tensor_arg;
}
lambdaArgument lambdaArgumentFromTensorU64(std::vector<uint64_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument tensor_arg{
std::make_shared<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>>>(data, dimensions)};
return tensor_arg;
}
lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) {
lambdaArgument scalar_arg{
std::make_shared<mlir::zamalang::IntLambdaArgument<uint64_t>>(scalar)};