mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(CAPI): use const ptrs to create tensor lambda arguments
This commit is contained in:
@@ -258,18 +258,14 @@ MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(
|
||||
const uint8_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(
|
||||
const uint16_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(
|
||||
const uint32_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(
|
||||
const uint64_t *data, const int64_t *dims, size_t rank);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
|
||||
|
||||
@@ -780,10 +780,10 @@ mod test {
|
||||
#[test]
|
||||
fn test_tensor_lambda_argument() {
|
||||
unsafe {
|
||||
let mut tensor_data = [1, 2, 3, 73u64];
|
||||
let mut tensor_dims = [2, 2i64];
|
||||
let tensor_data = [1, 2, 3, 73u64];
|
||||
let tensor_dims = [2, 2i64];
|
||||
let tensor_arg =
|
||||
lambdaArgumentFromTensorU64(tensor_data.as_mut_ptr(), tensor_dims.as_mut_ptr(), 2);
|
||||
lambdaArgumentFromTensorU64(tensor_data.as_ptr(), tensor_dims.as_ptr(), 2);
|
||||
assert!(!lambdaArgumentIsNull(tensor_arg));
|
||||
assert!(!lambdaArgumentIsScalar(tensor_arg));
|
||||
assert!(lambdaArgumentIsTensor(tensor_arg));
|
||||
@@ -836,8 +836,8 @@ mod test {
|
||||
let eval_keys = keySetGetEvaluationKeys(key_set);
|
||||
// build lambda arguments from scalar and encrypt them
|
||||
let args = [
|
||||
lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_mut_ptr(), [2, 3].as_mut_ptr(), 2),
|
||||
lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_mut_ptr(), [2, 3].as_mut_ptr(), 2),
|
||||
lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_ptr(), [2, 3].as_ptr(), 2),
|
||||
lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_ptr(), [2, 3].as_ptr(), 2),
|
||||
];
|
||||
let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap();
|
||||
// execute the compiled function on the encrypted arguments
|
||||
|
||||
@@ -408,7 +408,7 @@ LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
|
||||
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
|
||||
}
|
||||
|
||||
int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) {
|
||||
int64_t getSizeFromRankAndDims(size_t rank, const int64_t *dims) {
|
||||
if (rank == 0) // not a tensor
|
||||
return 1;
|
||||
auto size = dims[0];
|
||||
@@ -417,8 +417,8 @@ int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) {
|
||||
return size;
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
LambdaArgument lambdaArgumentFromTensorU8(const uint8_t *data,
|
||||
const int64_t *dims, size_t rank) {
|
||||
|
||||
std::vector<uint8_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
@@ -428,8 +428,8 @@ LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
LambdaArgument lambdaArgumentFromTensorU16(const uint16_t *data,
|
||||
const int64_t *dims, size_t rank) {
|
||||
|
||||
std::vector<uint16_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
@@ -439,8 +439,8 @@ LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
LambdaArgument lambdaArgumentFromTensorU32(const uint32_t *data,
|
||||
const int64_t *dims, size_t rank) {
|
||||
|
||||
std::vector<uint32_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
@@ -450,8 +450,8 @@ LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
LambdaArgument lambdaArgumentFromTensorU64(const uint64_t *data,
|
||||
const int64_t *dims, size_t rank) {
|
||||
|
||||
std::vector<uint64_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
|
||||
Reference in New Issue
Block a user