diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 8904bfe26..d375f24ec 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -227,19 +227,19 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( }); pybind11::class_(m, "LambdaArgument") - .def_static("from_tensor", + .def_static("from_tensor_8", [](std::vector tensor, std::vector dims) { return lambdaArgumentFromTensorU8(tensor, dims); }) - .def_static("from_tensor", + .def_static("from_tensor_16", [](std::vector tensor, std::vector dims) { return lambdaArgumentFromTensorU16(tensor, dims); }) - .def_static("from_tensor", + .def_static("from_tensor_32", [](std::vector tensor, std::vector dims) { return lambdaArgumentFromTensorU32(tensor, dims); }) - .def_static("from_tensor", + .def_static("from_tensor_64", [](std::vector tensor, std::vector dims) { return lambdaArgumentFromTensorU64(tensor, dims); }) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/client_support.py b/compiler/lib/Bindings/Python/concrete/compiler/client_support.py index 5ad785572..b98f0cae8 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/client_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/client_support.py @@ -188,4 +188,12 @@ class ClientSupport(WrapperCpp): value = value.max() # should be a single uint here return LambdaArgument.from_scalar(value) - return LambdaArgument.from_tensor(value.flatten().tolist(), value.shape) + if value.dtype == np.uint8: + return LambdaArgument.from_tensor_8(value.flatten().tolist(), value.shape) + if value.dtype == np.uint16: + return LambdaArgument.from_tensor_16(value.flatten().tolist(), value.shape) + if value.dtype == np.uint32: + return LambdaArgument.from_tensor_32(value.flatten().tolist(), value.shape) + if value.dtype == np.uint64: + return LambdaArgument.from_tensor_64(value.flatten().tolist(), value.shape) + raise TypeError("numpy.array must be of dtype uint{8,16,32,64}") diff --git a/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py b/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py index 6359786e1..870a269ae 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py @@ -63,7 +63,7 @@ class LambdaArgument(WrapperCpp): return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar)) @staticmethod - def from_tensor(data: List[int], shape: List[int]) -> "LambdaArgument": + def from_tensor_8(data: List[int], shape: List[int]) -> "LambdaArgument": """Build a LambdaArgument containing the given tensor. Args: @@ -73,7 +73,46 @@ class LambdaArgument(WrapperCpp): Returns: LambdaArgument """ - return LambdaArgument.wrap(_LambdaArgument.from_tensor(data, shape)) + return LambdaArgument.wrap(_LambdaArgument.from_tensor_8(data, shape)) + + @staticmethod + def from_tensor_16(data: List[int], shape: List[int]) -> "LambdaArgument": + """Build a LambdaArgument containing the given tensor. + + Args: + data (List[int]): flattened tensor data + shape (List[int]): shape of original tensor before flattening + + Returns: + LambdaArgument + """ + return LambdaArgument.wrap(_LambdaArgument.from_tensor_16(data, shape)) + + @staticmethod + def from_tensor_32(data: List[int], shape: List[int]) -> "LambdaArgument": + """Build a LambdaArgument containing the given tensor. + + Args: + data (List[int]): flattened tensor data + shape (List[int]): shape of original tensor before flattening + + Returns: + LambdaArgument + """ + return LambdaArgument.wrap(_LambdaArgument.from_tensor_32(data, shape)) + + @staticmethod + def from_tensor_64(data: List[int], shape: List[int]) -> "LambdaArgument": + """Build a LambdaArgument containing the given tensor. + + Args: + data (List[int]): flattened tensor data + shape (List[int]): shape of original tensor before flattening + + Returns: + LambdaArgument + """ + return LambdaArgument.wrap(_LambdaArgument.from_tensor_64(data, shape)) def is_scalar(self) -> bool: """Check if the contained argument is a scalar. diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index bd21ee53b..b40f66bdb 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -262,44 +262,95 @@ std::string roundTrip(const char *module) { bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) { return lambda_arg.ptr->isa>>(); + mlir::concretelang::IntLambdaArgument>>() || + lambda_arg.ptr->isa>>() || + lambda_arg.ptr->isa>>() || + lambda_arg.ptr->isa>>(); } -std::vector lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) { - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> *arg = - lambda_arg.ptr->dyn_cast>>(); - if (arg == nullptr) { - throw std::invalid_argument( - "LambdaArgument isn't a tensor, should " - "be a TensorLambdaArgument>"); - } - - llvm::Expected sizeOrErr = arg->getNumElements(); - if (!sizeOrErr) { +template +std::vector copyTensorLambdaArgumentTo64bitsvector( + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> *tensor) { + auto numElements = tensor->getNumElements(); + if (!numElements) { std::string backingString; llvm::raw_string_ostream os(backingString); os << "Couldn't get size of tensor: " - << llvm::toString(std::move(sizeOrErr.takeError())); + << llvm::toString(std::move(numElements.takeError())); throw std::runtime_error(os.str()); } - std::vector data(arg->getValue(), arg->getValue() + *sizeOrErr); - return data; + std::vector res; + res.reserve(*numElements); + T *data = tensor->getValue(); + for (size_t i = 0; i < *numElements; i++) { + res.push_back(data[i]); + } + return res; +} + +std::vector lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) { + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + llvm::Expected sizeOrErr = arg->getNumElements(); + if (!sizeOrErr) { + std::string backingString; + llvm::raw_string_ostream os(backingString); + os << "Couldn't get size of tensor: " + << llvm::toString(std::move(sizeOrErr.takeError())); + throw std::runtime_error(os.str()); + } + std::vector data(arg->getValue(), arg->getValue() + *sizeOrErr); + return data; + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return copyTensorLambdaArgumentTo64bitsvector(arg); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return copyTensorLambdaArgumentTo64bitsvector(arg); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return copyTensorLambdaArgumentTo64bitsvector(arg); + } + throw std::invalid_argument( + "LambdaArgument isn't a tensor or has an unsupported bitwidth"); } std::vector lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) { - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> *arg = - lambda_arg.ptr->dyn_cast>>(); - if (arg == nullptr) { - throw std::invalid_argument( - "LambdaArgument isn't a tensor, should " - "be a TensorLambdaArgument>"); + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return arg->getDimensions(); } - return arg->getDimensions(); + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return arg->getDimensions(); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return arg->getDimensions(); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return arg->getDimensions(); + } + throw std::invalid_argument( + "LambdaArgument isn't a tensor, should " + "be a TensorLambdaArgument>"); } bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) { diff --git a/compiler/tests/python/test_argument_support.py b/compiler/tests/python/test_argument_support.py index 75e0d8811..08eb77f80 100644 --- a/compiler/tests/python/test_argument_support.py +++ b/compiler/tests/python/test_argument_support.py @@ -43,30 +43,30 @@ def test_accepted_ints(value): # TODO: #495 -# @pytest.mark.parametrize( -# "dtype", -# [ -# pytest.param(np.uint8, id="uint8"), -# pytest.param(np.uint16, id="uint16"), -# pytest.param(np.uint32, id="uint32"), -# pytest.param(np.uint64, id="uint64"), -# ], -# ) -# def test_accepted_ndarray(dtype): -# value = np.array([0, 1, 2], dtype=dtype) -# try: -# arg = ClientSupport._create_lambda_argument(value) -# except Exception: -# pytest.fail(f"value of type {type(value)} should be supported") +@pytest.mark.parametrize( + "dtype, maxvalue", + [ + pytest.param(np.uint8, 2**8 - 1, id="uint8"), + pytest.param(np.uint16, 2**16 - 1, id="uint16"), + pytest.param(np.uint32, 2**32 - 1, id="uint32"), + pytest.param(np.uint64, 2**64 - 1, id="uint64"), + ], +) +def test_accepted_ndarray(dtype, maxvalue): + value = np.array([0, 1, 2, maxvalue], dtype=dtype) + try: + arg = ClientSupport._create_lambda_argument(value) + except Exception: + pytest.fail(f"value of type {type(value)} should be supported") -# assert arg.is_tensor(), "should have been a tensor" -# assert np.all(np.equal(arg.get_tensor_shape(), value.shape)) -# assert np.all( -# np.equal( -# value, -# np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()), -# ) -# ) + assert arg.is_tensor(), "should have been a tensor" + assert np.all(np.equal(arg.get_tensor_shape(), value.shape)) + assert np.all( + np.equal( + value, + np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()), + ) + ) def test_accepted_array_as_scalar():