feat(rust): support execution with tensor args

This commit is contained in:
youben11
2022-11-29 08:58:03 +01:00
committed by Ayoub Benaissa
parent 16f3b0bbf6
commit f05b1bd1ea
3 changed files with 265 additions and 9 deletions

View File

@@ -12,6 +12,10 @@
extern "C" {
#endif
/// The CAPI should be really careful about memory allocation. Every pointer
/// returned should points to a new buffer allocated for the purpose of the
/// CAPI, and should have a respective destructor function.
/// Opaque type declarations. Inspired from
/// llvm-project/mlir/include/mlir-c/IR.h
///
@@ -204,6 +208,15 @@ MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data,
int64_t *dims,
size_t rank);
@@ -212,11 +225,13 @@ MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED uint64_t *
lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg,
uint64_t *buffer);
MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED int64_t *
lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED int64_t
lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg,
int64_t *buffer);
MLIR_CAPI_EXPORTED PublicArguments
lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber,