From f36e1fe882376032704443d2799cbde70c91db6f Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Wed, 7 Dec 2022 13:16:21 +0100 Subject: [PATCH] feat(CAPI): use const ptrs to create tensor lambda arguments --- .../concretelang-c/Support/CompilerEngine.h | 20 ++++++++----------- compiler/lib/Bindings/Rust/src/compiler.rs | 10 +++++----- compiler/lib/CAPI/Support/CompilerEngine.cpp | 18 ++++++++--------- 3 files changed, 22 insertions(+), 26 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index cf42ea795..33badbce9 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -258,18 +258,14 @@ MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys); MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value); -MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, - int64_t *dims, - size_t rank); -MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, - int64_t *dims, - size_t rank); -MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, - int64_t *dims, - size_t rank); -MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, - int64_t *dims, - size_t rank); +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8( + const uint8_t *data, const int64_t *dims, size_t rank); +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16( + const uint16_t *data, const int64_t *dims, size_t rank); +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32( + const uint32_t *data, const int64_t *dims, size_t rank); +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64( + const uint64_t *data, const int64_t *dims, size_t rank); MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg); MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg); diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index e49bb87c3..477277775 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -780,10 +780,10 @@ mod test { #[test] fn test_tensor_lambda_argument() { unsafe { - let mut tensor_data = [1, 2, 3, 73u64]; - let mut tensor_dims = [2, 2i64]; + let tensor_data = [1, 2, 3, 73u64]; + let tensor_dims = [2, 2i64]; let tensor_arg = - lambdaArgumentFromTensorU64(tensor_data.as_mut_ptr(), tensor_dims.as_mut_ptr(), 2); + lambdaArgumentFromTensorU64(tensor_data.as_ptr(), tensor_dims.as_ptr(), 2); assert!(!lambdaArgumentIsNull(tensor_arg)); assert!(!lambdaArgumentIsScalar(tensor_arg)); assert!(lambdaArgumentIsTensor(tensor_arg)); @@ -836,8 +836,8 @@ mod test { let eval_keys = keySetGetEvaluationKeys(key_set); // build lambda arguments from scalar and encrypt them let args = [ - lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_mut_ptr(), [2, 3].as_mut_ptr(), 2), - lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_mut_ptr(), [2, 3].as_mut_ptr(), 2), + lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_ptr(), [2, 3].as_ptr(), 2), + lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_ptr(), [2, 3].as_ptr(), 2), ]; let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); // execute the compiled function on the encrypted arguments diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 5e9fea2ec..0502aaec1 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -408,7 +408,7 @@ LambdaArgument lambdaArgumentFromScalar(uint64_t value) { return wrap(new mlir::concretelang::IntLambdaArgument(value)); } -int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) { +int64_t getSizeFromRankAndDims(size_t rank, const int64_t *dims) { if (rank == 0) // not a tensor return 1; auto size = dims[0]; @@ -417,8 +417,8 @@ int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) { return size; } -LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims, - size_t rank) { +LambdaArgument lambdaArgumentFromTensorU8(const uint8_t *data, + const int64_t *dims, size_t rank) { std::vector data_vector(data, data + getSizeFromRankAndDims(rank, dims)); @@ -428,8 +428,8 @@ LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims, dims_vector)); } -LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims, - size_t rank) { +LambdaArgument lambdaArgumentFromTensorU16(const uint16_t *data, + const int64_t *dims, size_t rank) { std::vector data_vector(data, data + getSizeFromRankAndDims(rank, dims)); @@ -439,8 +439,8 @@ LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims, dims_vector)); } -LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims, - size_t rank) { +LambdaArgument lambdaArgumentFromTensorU32(const uint32_t *data, + const int64_t *dims, size_t rank) { std::vector data_vector(data, data + getSizeFromRankAndDims(rank, dims)); @@ -450,8 +450,8 @@ LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims, dims_vector)); } -LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims, - size_t rank) { +LambdaArgument lambdaArgumentFromTensorU64(const uint64_t *data, + const int64_t *dims, size_t rank) { std::vector data_vector(data, data + getSizeFromRankAndDims(rank, dims));