feat(compiler/bindings): create bindings for value management

This commit is contained in:
Umut
2023-06-09 10:56:20 +02:00
parent f7f94a1663
commit 27d081e255
18 changed files with 513 additions and 47 deletions

View File

@@ -151,6 +151,12 @@ keySetUnserialize(const std::string &buffer);
MLIR_CAPI_EXPORTED std::string
keySetSerialize(concretelang::clientlib::KeySet &keySet);
MLIR_CAPI_EXPORTED concretelang::clientlib::SharedScalarOrTensorData
valueUnserialize(const std::string &buffer);
MLIR_CAPI_EXPORTED std::string
valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value);
/// Parse then print a textual representation of an MLIR module
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);

View File

@@ -112,17 +112,15 @@ private:
class PublicArguments {
public:
PublicArguments(const ClientParameters &clientParameters,
std::vector<ScalarOrTensorData> &&ciphertextBuffers);
std::vector<clientlib::SharedScalarOrTensorData> &buffers);
~PublicArguments();
PublicArguments(PublicArguments &other) = delete;
PublicArguments(PublicArguments &&other) = delete;
static outcome::checked<std::unique_ptr<PublicArguments>, StringError>
unserialize(ClientParameters &expectedParams, std::istream &istream);
unserialize(const ClientParameters &expectedParams, std::istream &istream);
outcome::checked<void, StringError> serialize(std::ostream &ostream);
std::vector<ScalarOrTensorData> &getArguments() { return arguments; }
std::vector<SharedScalarOrTensorData> &getArguments() { return arguments; }
ClientParameters &getClientParameters() { return clientParameters; }
friend class ::concretelang::serverlib::ServerLambda;
@@ -133,7 +131,7 @@ private:
ClientParameters clientParameters;
/// Store buffers of ciphertexts
std::vector<ScalarOrTensorData> arguments;
std::vector<SharedScalarOrTensorData> arguments;
};
/// PublicResult is a result of a ServerLambda call which contains encrypted
@@ -141,7 +139,7 @@ private:
struct PublicResult {
PublicResult(const ClientParameters &clientParameters,
std::vector<ScalarOrTensorData> &&buffers = {})
std::vector<SharedScalarOrTensorData> &&buffers = {})
: clientParameters(clientParameters), buffers(std::move(buffers)){};
PublicResult(PublicResult &) = delete;
@@ -150,17 +148,18 @@ struct PublicResult {
/// @param argPos The position of the value in the PublicResult
/// @return Either the value or an error if there are no value at this
/// position
outcome::checked<ScalarOrTensorData, StringError> getValue(size_t argPos) {
outcome::checked<SharedScalarOrTensorData, StringError>
getValue(size_t argPos) {
if (argPos >= buffers.size()) {
return StringError("result #") << argPos << " does not exists";
}
return std::move(buffers[argPos]);
return buffers[argPos];
}
/// Create a public result from buffers.
static std::unique_ptr<PublicResult>
fromBuffers(const ClientParameters &clientParameters,
std::vector<ScalarOrTensorData> &&buffers) {
std::vector<SharedScalarOrTensorData> &&buffers) {
return std::make_unique<PublicResult>(clientParameters, std::move(buffers));
}
@@ -182,7 +181,7 @@ struct PublicResult {
outcome::checked<T, StringError> asClearTextScalar(KeySet &keySet,
size_t pos) {
ValueDecrypter decrypter(keySet, clientParameters);
auto &data = buffers[pos];
auto &data = buffers[pos].get();
return decrypter.template decrypt<T>(data, pos);
}
@@ -192,7 +191,7 @@ struct PublicResult {
outcome::checked<std::vector<T>, StringError>
asClearTextVector(KeySet &keySet, size_t pos) {
ValueDecrypter decrypter(keySet, clientParameters);
return decrypter.template decryptTensor<T>(buffers[pos], pos);
return decrypter.template decryptTensor<T>(buffers[pos].get(), pos);
}
/// Return the shape of the clear tensor of a result.
@@ -205,7 +204,7 @@ struct PublicResult {
// private: TODO tmp
friend class ::concretelang::serverlib::ServerLambda;
ClientParameters clientParameters;
std::vector<ScalarOrTensorData> buffers;
std::vector<SharedScalarOrTensorData> buffers;
};
/// Helper function to convert from MemRefDescriptor to

View File

@@ -96,10 +96,9 @@ std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
outcome::checked<ScalarOrTensorData, StringError>
unserializeScalarOrTensorData(std::istream &istream);
std::ostream &
serializeVectorOfScalarOrTensorData(const std::vector<ScalarOrTensorData> &sotd,
std::ostream &ostream);
outcome::checked<std::vector<ScalarOrTensorData>, StringError>
std::ostream &serializeVectorOfScalarOrTensorData(
const std::vector<SharedScalarOrTensorData> &sotd, std::ostream &ostream);
outcome::checked<std::vector<SharedScalarOrTensorData>, StringError>
unserializeVectorOfScalarOrTensorData(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);

View File

@@ -689,7 +689,7 @@ public:
// Returns a void pointer to the first element of a flat
// representation of the tensor
void *getValuesAsOpaquePointer() {
void *getValuesAsOpaquePointer() const {
switch (this->elementType) {
case ElementType::u64:
return static_cast<void *>(values.u64->data());
@@ -879,6 +879,19 @@ public:
return *tensor;
}
};
struct SharedScalarOrTensorData {
std::shared_ptr<ScalarOrTensorData> inner;
SharedScalarOrTensorData(std::shared_ptr<ScalarOrTensorData> inner)
: inner{inner} {}
SharedScalarOrTensorData(ScalarOrTensorData &&inner)
: inner{std::make_shared<ScalarOrTensorData>(std::move(inner))} {}
ScalarOrTensorData &get() const { return *this->inner; }
};
} // namespace clientlib
} // namespace concretelang

View File

@@ -75,17 +75,21 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
}
// Store the result to the PublicResult
std::vector<clientlib::ScalarOrTensorData> buffers;
std::vector<clientlib::SharedScalarOrTensorData> buffers;
{
size_t outputOffset = 0;
for (auto &output : clientParameters.outputs) {
auto shape = clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar scalar
buffers.push_back(concretelang::clientlib::ScalarOrTensorData(
auto value = concretelang::clientlib::ScalarOrTensorData(
concretelang::clientlib::ScalarData(outputs[outputOffset++],
output.shape.sign,
output.shape.width)));
output.shape.width));
auto sharedValue =
clientlib::SharedScalarOrTensorData(std::move(value));
buffers.push_back(sharedValue);
} else {
// buffer gate
auto rank = shape.size();
@@ -102,14 +106,18 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
: output.shape.width;
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
concretelang::clientlib::TensorData td =
auto value = concretelang::clientlib::ScalarOrTensorData(
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
aligned, offset, sizes, strides);
buffers.push_back(
concretelang::clientlib::ScalarOrTensorData(std::move(td)));
aligned, offset, sizes, strides));
auto sharedValue =
clientlib::SharedScalarOrTensorData(std::move(value));
buffers.push_back(sharedValue);
}
}
}
return clientlib::PublicResult::fromBuffers(clientParameters,
std::move(buffers));
}
@@ -120,7 +128,8 @@ invokeRawOnLambda(Lambda *lambda, clientlib::PublicArguments &arguments,
clientlib::EvaluationKeys &evaluationKeys) {
// Prepare arguments with the right calling convention
std::vector<void *> preparedArgs;
for (auto &arg : arguments.getArguments()) {
for (auto &sharedArg : arguments.getArguments()) {
clientlib::ScalarOrTensorData &arg = sharedArg.get();
if (arg.isScalar()) {
auto scalar = arg.getScalar().getValueAsU64();
preparedArgs.push_back((void *)scalar);

View File

@@ -60,6 +60,9 @@ declare_mlir_python_sources(
concrete/compiler/public_result.py
concrete/compiler/evaluation_keys.py
concrete/compiler/utils.py
concrete/compiler/value.py
concrete/compiler/value_decrypter.py
concrete/compiler/value_exporter.py
concrete/compiler/wrapper.py
concrete/__init__.py
concrete/lang/__init__.py

View File

@@ -285,9 +285,98 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def("get_evaluation_keys",
[](clientlib::KeySet &keySet) { return keySet.evaluationKeys(); });
pybind11::class_<clientlib::SharedScalarOrTensorData>(m, "Value")
.def_static("deserialize",
[](const pybind11::bytes &buffer) {
return valueUnserialize(buffer);
})
.def("serialize", [](const clientlib::SharedScalarOrTensorData &value) {
return pybind11::bytes(valueSerialize(value));
});
pybind11::class_<clientlib::ValueExporter>(m, "ValueExporter")
.def_static("create",
[](clientlib::KeySet &keySet,
mlir::concretelang::ClientParameters &clientParameters) {
return clientlib::ValueExporter(keySet, clientParameters);
})
.def("export_scalar",
[](clientlib::ValueExporter &exporter, size_t position,
int64_t value) {
outcome::checked<clientlib::ScalarOrTensorData, StringError>
result = exporter.exportValue(value, position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return clientlib::SharedScalarOrTensorData(
std::move(result.value()));
})
.def("export_tensor", [](clientlib::ValueExporter &exporter,
size_t position, std::vector<int64_t> values,
std::vector<int64_t> shape) {
outcome::checked<clientlib::ScalarOrTensorData, StringError> result =
exporter.exportValue(values.data(), shape, position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return clientlib::SharedScalarOrTensorData(std::move(result.value()));
});
pybind11::class_<clientlib::ValueDecrypter>(m, "ValueDecrypter")
.def_static("create",
[](clientlib::KeySet &keySet,
mlir::concretelang::ClientParameters &clientParameters) {
return clientlib::ValueDecrypter(keySet, clientParameters);
})
.def("get_shape",
[](clientlib::ValueDecrypter &decrypter, size_t position) {
outcome::checked<std::vector<int64_t>, StringError> result =
decrypter.getShape(position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
})
.def("decrypt_scalar",
[](clientlib::ValueDecrypter &decrypter, size_t position,
clientlib::SharedScalarOrTensorData &value) {
outcome::checked<int64_t, StringError> result =
decrypter.decrypt<int64_t>(value.get(), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
})
.def("decrypt_tensor",
[](clientlib::ValueDecrypter &decrypter, size_t position,
clientlib::SharedScalarOrTensorData &value) {
outcome::checked<std::vector<int64_t>, StringError> result =
decrypter.decryptTensor<int64_t>(value.get(), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
});
pybind11::class_<clientlib::PublicArguments,
std::unique_ptr<clientlib::PublicArguments>>(
m, "PublicArguments")
.def_static(
"create",
[](const mlir::concretelang::ClientParameters &clientParameters,
std::vector<clientlib::SharedScalarOrTensorData> &buffers) {
return clientlib::PublicArguments(clientParameters, buffers);
})
.def_static("deserialize",
[](mlir::concretelang::ClientParameters &clientParameters,
const pybind11::bytes &buffer) {
@@ -302,9 +391,25 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
const pybind11::bytes &buffer) {
return publicResultUnserialize(clientParameters, buffer);
})
.def("serialize", [](clientlib::PublicResult &publicResult) {
return pybind11::bytes(publicResultSerialize(publicResult));
});
.def("serialize",
[](clientlib::PublicResult &publicResult) {
return pybind11::bytes(publicResultSerialize(publicResult));
})
.def("n_values",
[](const clientlib::PublicResult &publicResult) {
return publicResult.buffers.size();
})
.def("get_value",
[](clientlib::PublicResult &publicResult, size_t position) {
outcome::checked<clientlib::SharedScalarOrTensorData, StringError>
result = publicResult.getValue(position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
});
pybind11::class_<clientlib::EvaluationKeys>(m, "EvaluationKeys")
.def_static("deserialize",

View File

@@ -257,6 +257,26 @@ keySetSerialize(concretelang::clientlib::KeySet &keySet) {
return buffer.str();
}
MLIR_CAPI_EXPORTED concretelang::clientlib::SharedScalarOrTensorData
valueUnserialize(const std::string &buffer) {
std::stringstream istream(buffer);
auto value = concretelang::clientlib::unserializeScalarOrTensorData(istream);
if (istream.fail() || value.has_error()) {
throw std::runtime_error("Cannot read data");
}
return concretelang::clientlib::SharedScalarOrTensorData(
std::move(value.value()));
}
MLIR_CAPI_EXPORTED std::string
valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) {
std::ostringstream buffer(std::ios::binary);
serializeScalarOrTensorData(value.get(), buffer);
return buffer.str();
}
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
clientParametersUnserialize(const std::string &json) {
GET_OR_THROW_LLVM_EXPECTED(

View File

@@ -29,6 +29,9 @@ from .client_support import ClientSupport
from .jit_support import JITSupport
from .library_support import LibrarySupport
from .evaluation_keys import EvaluationKeys
from .value import Value
from .value_decrypter import ValueDecrypter
from .value_exporter import ValueExporter
def init_dfr():

View File

@@ -3,6 +3,8 @@
"""PublicArguments."""
from typing import List
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
PublicArguments as _PublicArguments,
@@ -10,6 +12,7 @@ from mlir._mlir_libs._concretelang._compiler import (
# pylint: enable=no-name-in-module,import-error
from .client_parameters import ClientParameters
from .value import Value
from .wrapper import WrapperCpp
@@ -35,6 +38,20 @@ class PublicArguments(WrapperCpp):
)
super().__init__(public_arguments)
@staticmethod
def create(
client_parameters: ClientParameters, values: List[Value]
) -> "PublicArguments":
"""
Create public arguments from individual values.
"""
return PublicArguments(
_PublicArguments.create(
client_parameters.cpp(),
[value.cpp() for value in values],
)
)
def serialize(self) -> bytes:
"""Serialize the PublicArguments.

View File

@@ -10,6 +10,7 @@ from mlir._mlir_libs._concretelang._compiler import (
from .client_parameters import ClientParameters
# pylint: enable=no-name-in-module,import-error
from .value import Value
from .wrapper import WrapperCpp
@@ -31,6 +32,18 @@ class PublicResult(WrapperCpp):
)
super().__init__(public_result)
def n_values(self) -> int:
"""
Get number of values in the result.
"""
return self.cpp().n_values()
def get_value(self, position: int) -> Value:
"""
Get a specific value in the result.
"""
return Value(self.cpp().get_value(position))
def serialize(self) -> bytes:
"""Serialize the PublicResult.

View File

@@ -0,0 +1,68 @@
"""Value."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
Value as _Value,
)
from .wrapper import WrapperCpp
# pylint: enable=no-name-in-module,import-error
class Value(WrapperCpp):
"""An encrypted/clear value which can be scalar/tensor."""
def __init__(self, value: _Value):
"""
Wrap the native C++ object.
Args:
value (_Value):
object to wrap
Raises:
TypeError:
if `value` is not of type `_Value`
"""
if not isinstance(value, _Value):
raise TypeError(f"value must be of type _Value, not {type(value)}")
super().__init__(value)
def serialize(self) -> bytes:
"""
Serialize value into bytes.
Returns:
bytes: serialized value
"""
return self.cpp().serialize()
@staticmethod
def deserialize(serialized_value: bytes) -> "Value":
"""
Deserialize value from bytes.
Args:
serialized_value (bytes):
previously serialized value
Returns:
Value:
deserialized value
Raises:
TypeError:
if `serialized_value` is not of type `bytes`
"""
if not isinstance(serialized_value, bytes):
raise TypeError(
f"serialized_value must be of type bytes, not {type(serialized_value)}"
)
return Value.wrap(_Value.deserialize(serialized_value))

View File

@@ -0,0 +1,111 @@
"""ValueDecrypter."""
# pylint: disable=no-name-in-module,import-error
from typing import List, Union
import numpy as np
from mlir._mlir_libs._concretelang._compiler import (
ValueDecrypter as _ValueDecrypter,
)
from .client_parameters import ClientParameters
from .key_set import KeySet
from .value import Value
from .wrapper import WrapperCpp
# pylint: enable=no-name-in-module,import-error
class ValueDecrypter(WrapperCpp):
"""A helper class to decrypt `Value`s."""
def __init__(self, value_decrypter: _ValueDecrypter):
"""
Wrap the native C++ object.
Args:
value_decrypter (_ValueDecrypter):
object to wrap
Raises:
TypeError:
if `value_decrypter` is not of type `_ValueDecrypter`
"""
if not isinstance(value_decrypter, _ValueDecrypter):
raise TypeError(
f"value_decrypter must be of type _ValueDecrypter, not {type(value_decrypter)}"
)
super().__init__(value_decrypter)
@staticmethod
def create(keyset: KeySet, client_parameters: ClientParameters):
"""
Create a value decrypter.
"""
return ValueDecrypter(
_ValueDecrypter.create(keyset.cpp(), client_parameters.cpp())
)
def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]:
"""
Decrypt value.
Args:
position (int):
position of the argument within the circuit
value (Value):
value to decrypt
Returns:
Union[int, np.ndarray]:
decrypted value
"""
shape = tuple(self.cpp().get_shape(position))
if len(shape) == 0:
return self.decrypt_scalar(position, value)
return np.array(self.decrypt_tensor(position, value), dtype=np.int64).reshape(
shape
)
def decrypt_scalar(self, position: int, value: Value) -> int:
"""
Decrypt scalar.
Args:
position (int):
position of the argument within the circuit
value (Value):
scalar value to decrypt
Returns:
int:
decrypted scalar
"""
return self.cpp().decrypt_scalar(position, value.cpp())
def decrypt_tensor(self, position: int, value: Value) -> List[int]:
"""
Decrypt tensor.
Args:
position (int):
position of the argument within the circuit
value (Value):
tensor value to decrypt
Returns:
List[int]:
decrypted tensor
"""
return self.cpp().decrypt_tensor(position, value.cpp())

View File

@@ -0,0 +1,90 @@
"""ValueExporter."""
# pylint: disable=no-name-in-module,import-error
from typing import List
from mlir._mlir_libs._concretelang._compiler import (
ValueExporter as _ValueExporter,
)
from .client_parameters import ClientParameters
from .key_set import KeySet
from .value import Value
from .wrapper import WrapperCpp
# pylint: enable=no-name-in-module,import-error
class ValueExporter(WrapperCpp):
"""A helper class to create `Value`s."""
def __init__(self, value_exporter: _ValueExporter):
"""
Wrap the native C++ object.
Args:
value_exporter (_ValueExporter):
object to wrap
Raises:
TypeError:
if `value_exporter` is not of type `_ValueExporter`
"""
if not isinstance(value_exporter, _ValueExporter):
raise TypeError(
f"value_exporter must be of type _ValueExporter, not {type(value_exporter)}"
)
super().__init__(value_exporter)
@staticmethod
def create(keyset: KeySet, client_parameters: ClientParameters) -> "ValueExporter":
"""
Create a value exporter.
"""
return ValueExporter(
_ValueExporter.create(keyset.cpp(), client_parameters.cpp())
)
def export_scalar(self, position: int, value: int) -> Value:
"""
Export scalar.
Args:
position (int):
position of the argument within the circuit
value (int):
scalar to export
Returns:
Value:
exported scalar
"""
return Value(self.cpp().export_scalar(position, value))
def export_tensor(
self, position: int, values: List[int], shape: List[int]
) -> Value:
"""
Export tensor.
Args:
position (int):
position of the argument within the circuit
values (List[int]):
tensor elements to export
shape (List[int]):
tensor shape to export
Returns:
Value:
exported tensor
"""
return Value(self.cpp().export_tensor(position, values, shape))

View File

@@ -13,7 +13,14 @@ using StringError = concretelang::error::StringError;
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
return std::make_unique<PublicArguments>(clientParameters, std::move(values));
auto sharedValues = std::vector<SharedScalarOrTensorData>();
sharedValues.reserve(this->values.size());
for (auto &&value : this->values) {
sharedValues.push_back(SharedScalarOrTensorData(std::move(value)));
}
return std::make_unique<PublicArguments>(clientParameters, sharedValues);
}
/// Split the input integer into `size` chunks of `chunkWidth` bits each

View File

@@ -15,10 +15,11 @@ namespace clientlib {
using concretelang::error::StringError;
// TODO: optimize the move
PublicArguments::PublicArguments(const ClientParameters &clientParameters,
std::vector<ScalarOrTensorData> &&arguments_)
PublicArguments::PublicArguments(
const ClientParameters &clientParameters,
std::vector<clientlib::SharedScalarOrTensorData> &buffers)
: clientParameters(clientParameters) {
arguments = std::move(arguments_);
arguments = buffers;
}
PublicArguments::~PublicArguments() {}
@@ -44,11 +45,11 @@ PublicArguments::unserializeArgs(std::istream &istream) {
}
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
PublicArguments::unserialize(ClientParameters &clientParameters,
PublicArguments::unserialize(const ClientParameters &expectedParams,
std::istream &istream) {
std::vector<ScalarOrTensorData> emptyBuffers;
auto sArguments = std::make_unique<PublicArguments>(clientParameters,
std::move(emptyBuffers));
std::vector<SharedScalarOrTensorData> emptyBuffers;
auto sArguments =
std::make_unique<PublicArguments>(expectedParams, emptyBuffers);
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
return std::move(sArguments);
}

View File

@@ -398,8 +398,11 @@ unserializeScalarData(std::istream &istream) {
template <typename T>
static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes,
std::istream &istream) {
readWords(istream, values_and_sizes.getElementPointer<T>(0),
values_and_sizes.getNumElements());
// getElementPointer is not valid if the tensor contains no data
if (values_and_sizes.getNumElements() > 0) {
readWords(istream, values_and_sizes.getElementPointer<T>(0),
values_and_sizes.getNumElements());
}
return istream;
}
@@ -555,26 +558,25 @@ unserializeScalarOrTensorData(std::istream &istream) {
}
}
std::ostream &
serializeVectorOfScalarOrTensorData(const std::vector<ScalarOrTensorData> &v,
std::ostream &ostream) {
std::ostream &serializeVectorOfScalarOrTensorData(
const std::vector<SharedScalarOrTensorData> &v, std::ostream &ostream) {
writeSize(ostream, v.size());
for (auto &sotd : v) {
serializeScalarOrTensorData(sotd, ostream);
serializeScalarOrTensorData(sotd.get(), ostream);
if (!ostream.good()) {
return ostream;
}
}
return ostream;
}
outcome::checked<std::vector<ScalarOrTensorData>, StringError>
outcome::checked<std::vector<SharedScalarOrTensorData>, StringError>
unserializeVectorOfScalarOrTensorData(std::istream &istream) {
uint64_t nbElt;
readSize(istream, nbElt);
std::vector<ScalarOrTensorData> v;
std::vector<SharedScalarOrTensorData> v;
for (uint64_t i = 0; i < nbElt; i++) {
OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream));
v.push_back(std::move(elt));
v.push_back(SharedScalarOrTensorData(std::move(elt)));
}
return v;
}

View File

@@ -95,7 +95,7 @@ JITLambda::call(clientlib::PublicArguments &args,
if (auto err = invokeRaw(rawArgs)) {
return std::move(err);
}
std::vector<clientlib::ScalarOrTensorData> buffers;
std::vector<clientlib::SharedScalarOrTensorData> buffers;
return clientlib::PublicResult::fromBuffers(args.clientParameters,
std::move(buffers));
}