From 60b2cfd9b7b069186a3a22c3c864039aa5e51d07 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 10 Dec 2021 10:36:04 +0100 Subject: [PATCH] feat: support more dtype for scalars/tensors dtype supported now: uint8, uint16, uint32, uint64 --- .../zamalang-c/Support/CompilerEngine.h | 10 ++- .../lib/Bindings/Python/CompilerAPIModule.cpp | 17 +++- .../lib/Bindings/Python/zamalang/compiler.py | 17 ++-- compiler/lib/CAPI/Support/CompilerEngine.cpp | 28 ++++++- compiler/tests/python/test_compiler_engine.py | 83 ++++++++++++++++++- 5 files changed, 143 insertions(+), 12 deletions(-) diff --git a/compiler/include/zamalang-c/Support/CompilerEngine.h b/compiler/include/zamalang-c/Support/CompilerEngine.h index d813406cb..e49fec44a 100644 --- a/compiler/include/zamalang-c/Support/CompilerEngine.h +++ b/compiler/include/zamalang-c/Support/CompilerEngine.h @@ -47,9 +47,15 @@ MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l, executionArguments args); -// Create a lambdaArgument from a tensor -MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensor( +// Create a lambdaArgument from a tensor of different data types +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8( std::vector data, std::vector dimensions); +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16( + std::vector data, std::vector dimensions); +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32( + std::vector data, std::vector dimensions); +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64( + std::vector data, std::vector dimensions); // Create a lambdaArgument from a scalar MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar); // Check if a lambdaArgument holds a tensor diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 8b6b54ef2..3a06ed907 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -45,7 +45,22 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { }); pybind11::class_(m, "LambdaArgument") - .def_static("from_tensor", lambdaArgumentFromTensor) + .def_static("from_tensor", + [](std::vector tensor, std::vector dims) { + return lambdaArgumentFromTensorU8(tensor, dims); + }) + .def_static("from_tensor", + [](std::vector tensor, std::vector dims) { + return lambdaArgumentFromTensorU16(tensor, dims); + }) + .def_static("from_tensor", + [](std::vector tensor, std::vector dims) { + return lambdaArgumentFromTensorU32(tensor, dims); + }) + .def_static("from_tensor", + [](std::vector tensor, std::vector dims) { + return lambdaArgumentFromTensorU64(tensor, dims); + }) .def_static("from_scalar", lambdaArgumentFromScalar) .def("is_tensor", [](lambdaArgument &lambda_arg) { diff --git a/compiler/lib/Bindings/Python/zamalang/compiler.py b/compiler/lib/Bindings/Python/zamalang/compiler.py index 87e8fbaa9..3de549301 100644 --- a/compiler/lib/Bindings/Python/zamalang/compiler.py +++ b/compiler/lib/Bindings/Python/zamalang/compiler.py @@ -10,6 +10,11 @@ from mlir._mlir_libs._zamalang._compiler import library as _library import numpy as np +ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64) +ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS +ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS + + def _lookup_runtime_lib() -> str: """Try to find the absolute path to the runtime library. @@ -95,10 +100,10 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument Returns: _LambdaArgument: lambda argument holding the appropriate value """ - if not isinstance(value, (int, np.ndarray, np.uint8)): - raise TypeError("value of execution argument must be either int, numpy.array or numpy.uint8") - if isinstance(value, (int, np.uint8)): - if not (0 <= value < (2 ** 64 - 1)): + if not isinstance(value, ACCEPTED_TYPES): + raise TypeError("value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}") + if isinstance(value, ACCEPTED_INTS): + if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max): raise TypeError( "single integer must be in the range [0, 2**64 - 1] (uint64)" ) @@ -107,8 +112,8 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument assert isinstance(value, np.ndarray) if value.shape == (): return _LambdaArgument.from_scalar(value) - if value.dtype != np.uint8: - raise TypeError("numpy.array must be of dtype uint8") + if value.dtype not in ACCEPTED_NUMPY_UINTS: + raise TypeError("numpy.array must be of dtype uint{8,16,32,64}") return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape) diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 1f13206a9..515476103 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -143,14 +143,38 @@ uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) { return arg->getValue(); } -lambdaArgument lambdaArgumentFromTensor(std::vector data, - std::vector dimensions) { +lambdaArgument lambdaArgumentFromTensorU8(std::vector data, + std::vector dimensions) { lambdaArgument tensor_arg{ std::make_shared>>(data, dimensions)}; return tensor_arg; } +lambdaArgument lambdaArgumentFromTensorU16(std::vector data, + std::vector dimensions) { + lambdaArgument tensor_arg{ + std::make_shared>>(data, dimensions)}; + return tensor_arg; +} + +lambdaArgument lambdaArgumentFromTensorU32(std::vector data, + std::vector dimensions) { + lambdaArgument tensor_arg{ + std::make_shared>>(data, dimensions)}; + return tensor_arg; +} + +lambdaArgument lambdaArgumentFromTensorU64(std::vector data, + std::vector dimensions) { + lambdaArgument tensor_arg{ + std::make_shared>>(data, dimensions)}; + return tensor_arg; +} + lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) { lambdaArgument scalar_arg{ std::make_shared>(scalar)}; diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 1eca9b4be..b720ca94d 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -43,6 +43,39 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') 8, id="add_eint_int_with_np_uint8_as_scalar", ), + pytest.param( + """ + func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> { + %1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> + } + """, + (np.uint16(3), np.uint16(5)), + 8, + id="add_eint_int_with_np_uint16_as_scalar", + ), + pytest.param( + """ + func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> { + %1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> + } + """, + (np.uint32(3), np.uint32(5)), + 8, + id="add_eint_int_with_np_uint32_as_scalar", + ), + pytest.param( + """ + func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> { + %1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> + } + """, + (np.uint64(3), np.uint64(5)), + 8, + id="add_eint_int_with_np_uint64_as_scalar", + ), pytest.param( """ func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7> @@ -57,7 +90,55 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') np.array([4, 3, 2, 1], dtype=np.uint8), ), 20, - id="dot_eint_int", + id="dot_eint_int_uint8", + ), + pytest.param( + """ + func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7> + { + %ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7> + return %ret : !HLFHE.eint<7> + } + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint16), + np.array([4, 3, 2, 1], dtype=np.uint16), + ), + 20, + id="dot_eint_int_uint16", + ), + pytest.param( + """ + func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7> + { + %ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7> + return %ret : !HLFHE.eint<7> + } + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint32), + np.array([4, 3, 2, 1], dtype=np.uint32), + ), + 20, + id="dot_eint_int_uint32", + ), + pytest.param( + """ + func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7> + { + %ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7> + return %ret : !HLFHE.eint<7> + } + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint64), + np.array([4, 3, 2, 1], dtype=np.uint64), + ), + 20, + id="dot_eint_int_uint64", ), pytest.param( """