feat(CAPI): use const ptrs to create tensor lambda arguments

This commit is contained in:
tmontaigu
2022-12-07 13:16:21 +01:00
parent 188642b153
commit f36e1fe882
3 changed files with 22 additions and 26 deletions

View File

@@ -258,18 +258,14 @@ MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(
const uint8_t *data, const int64_t *dims, size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(
const uint16_t *data, const int64_t *dims, size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(
const uint32_t *data, const int64_t *dims, size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(
const uint64_t *data, const int64_t *dims, size_t rank);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);

View File

@@ -780,10 +780,10 @@ mod test {
#[test]
fn test_tensor_lambda_argument() {
unsafe {
let mut tensor_data = [1, 2, 3, 73u64];
let mut tensor_dims = [2, 2i64];
let tensor_data = [1, 2, 3, 73u64];
let tensor_dims = [2, 2i64];
let tensor_arg =
lambdaArgumentFromTensorU64(tensor_data.as_mut_ptr(), tensor_dims.as_mut_ptr(), 2);
lambdaArgumentFromTensorU64(tensor_data.as_ptr(), tensor_dims.as_ptr(), 2);
assert!(!lambdaArgumentIsNull(tensor_arg));
assert!(!lambdaArgumentIsScalar(tensor_arg));
assert!(lambdaArgumentIsTensor(tensor_arg));
@@ -836,8 +836,8 @@ mod test {
let eval_keys = keySetGetEvaluationKeys(key_set);
// build lambda arguments from scalar and encrypt them
let args = [
lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_mut_ptr(), [2, 3].as_mut_ptr(), 2),
lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_mut_ptr(), [2, 3].as_mut_ptr(), 2),
lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_ptr(), [2, 3].as_ptr(), 2),
lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_ptr(), [2, 3].as_ptr(), 2),
];
let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap();
// execute the compiled function on the encrypted arguments

View File

@@ -408,7 +408,7 @@ LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
}
int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) {
int64_t getSizeFromRankAndDims(size_t rank, const int64_t *dims) {
if (rank == 0) // not a tensor
return 1;
auto size = dims[0];
@@ -417,8 +417,8 @@ int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) {
return size;
}
LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims,
size_t rank) {
LambdaArgument lambdaArgumentFromTensorU8(const uint8_t *data,
const int64_t *dims, size_t rank) {
std::vector<uint8_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
@@ -428,8 +428,8 @@ LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims,
size_t rank) {
LambdaArgument lambdaArgumentFromTensorU16(const uint16_t *data,
const int64_t *dims, size_t rank) {
std::vector<uint16_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
@@ -439,8 +439,8 @@ LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims,
size_t rank) {
LambdaArgument lambdaArgumentFromTensorU32(const uint32_t *data,
const int64_t *dims, size_t rank) {
std::vector<uint32_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
@@ -450,8 +450,8 @@ LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
size_t rank) {
LambdaArgument lambdaArgumentFromTensorU64(const uint64_t *data,
const int64_t *dims, size_t rank) {
std::vector<uint64_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));