// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include "boost/outcome.h" #include "concrete-protocol.capnp.h" #include "concretelang/Common/Error.h" #include "concretelang/Common/Keysets.h" #include "concretelang/Common/Protocol.h" #include "concretelang/Common/Transformers.h" #include "concretelang/Common/Values.h" #include "concretelang/Runtime/context.h" #include "concretelang/ServerLib/ServerLib.h" #include "concretelang/Support/CompilerEngine.h" #include "llvm/ADT/ArrayRef.h" using concretelang::keysets::ServerKeyset; using concretelang::transformers::ArgTransformer; using concretelang::transformers::ReturnTransformer; using concretelang::transformers::TransformerFactory; using concretelang::values::Value; using mlir::concretelang::CompilerEngine; using mlir::concretelang::RuntimeContext; namespace concretelang { namespace serverlib { // Depending on the strides of the memref, iteration may not be linear in the // memory space (i.e. it may contain jumps). For this reason we have to compute // a memory index from the linear index of the iteration space. This structure // does just that. struct MultiDimIndexer { std::vector multiDimensionalIndex; size_t offset; const std::vector &sizes; const std::vector &strides; MultiDimIndexer(size_t offset, const std::vector &sizes, const std::vector &strides) : sizes(sizes), strides(strides) { size_t rank = sizes.size(); this->multiDimensionalIndex.resize(rank); for (size_t i = 0; i < rank; i++) { this->multiDimensionalIndex[i] = 0; } // this->sizes = sizes; // this->strides = sizes; this->offset = offset; } /// Increments the index. void increment() { size_t rank = sizes.size(); for (int r = rank - 1; r >= 0; r--) { if (multiDimensionalIndex[r] < sizes[r] - 1) { multiDimensionalIndex[r]++; return; } multiDimensionalIndex[r] = 0; } } /// Returns the current index. size_t currentIndex() { size_t rank = sizes.size(); size_t g_index = offset; size_t default_stride = 1; for (int r = rank - 1; r >= 0; r--) { g_index += multiDimensionalIndex[r] * ((strides[r] == 0) ? default_stride : strides[r]); default_stride *= sizes[r]; } return g_index; } }; // A type representing the memref description of a tensor. struct MemRefDescriptor { size_t precision; bool isSigned; void *allocated; void *aligned; size_t offset; std::vector sizes; std::vector strides; /// Creates a memref descriptor referencing the data contained in a tensor. template static MemRefDescriptor fromTensor(Tensor &input) { std::vector strides; size_t stride = input.values.size(); for (size_t dim : input.dimensions) { stride = (dim == 0 ? 0 : (stride / dim)); strides.push_back(stride); } return MemRefDescriptor{sizeof(T) * 8, std::is_signed(), (void *)nullptr, (void *)input.values.data(), 0, input.dimensions, strides}; } /// Creates a memref descriptor from a vector of uint64_t, which is the way to /// represent outputs in the current calling convention. static MemRefDescriptor fromU64s(llvm::ArrayRef raw, size_t precision, bool isSigned) { auto rank = (raw.size() - 3) / 2; void *allocated = (void *)raw[0]; void *aligned = (void *)raw[1]; size_t offset = (size_t)raw[2]; std::vector sizes(rank); for (size_t i = 0; i < rank; i++) { sizes[i] = (size_t)raw[3 + i]; } std::vector strides(rank); for (size_t i = 0; i < rank; i++) { strides[i] = (size_t)raw[3 + rank + i]; } return MemRefDescriptor{ precision, isSigned, allocated, aligned, offset, sizes, strides, }; } /// Returns the number of elements of the memref. size_t getLength() { size_t output = 1; for (size_t i = 0; i < sizes.size(); i++) { output *= sizes[i]; } return output; } // Allocates a new tensor, and copy the values referenced by a memref // descriptor. template Tensor intoTensor() { assert(sizeof(T) * 8 == precision); assert(std::is_signed() == isSigned); // We create the indexer. auto indexer = MultiDimIndexer(offset, sizes, strides); // We fill a vector of vales to construct the std::vector values(getLength()); for (size_t i = 0; i < values.size(); i++) { T *memrefAligned = reinterpret_cast(aligned); auto index = indexer.currentIndex(); values[i] = memrefAligned[index]; indexer.increment(); } return Tensor{values, sizes}; } void intoOpaquePtrs(llvm::MutableArrayRef &opaquePtrs) { opaquePtrs[0] = allocated; opaquePtrs[1] = aligned; opaquePtrs[2] = (void *)offset; for (size_t i = 0; i < sizes.size(); i++) { opaquePtrs[3 + i] = (void *)sizes[i]; } for (size_t i = 0; i < strides.size(); i++) { opaquePtrs[3 + sizes.size() + i] = (void *)strides[i]; } } }; struct ScalarDescriptor { size_t precision; bool isSigned; uint64_t val; template static ScalarDescriptor fromTensor(Tensor &input) { T value = input.values[0]; size_t width = sizeof(T) * 8; if (width == 64) { return ScalarDescriptor{sizeof(T) * 8, std::is_signed(), (uint64_t)value}; } // Todo : Verify if this is really necessary. uint64_t mask = ((uint64_t)1 << width) - 1; uint64_t val = ((uint64_t)value) & mask; return ScalarDescriptor{sizeof(T) * 8, std::is_signed(), val}; } static ScalarDescriptor fromU64s(llvm::ArrayRef raw, size_t precision, bool isSigned) { return ScalarDescriptor{precision, isSigned, raw[0]}; } template Tensor intoTensor() { assert(sizeof(T) * 8 == precision); assert(std::is_signed() == isSigned); std::vector values{(T)val}; std::vector sizes(0); return Tensor(values, sizes); } void intoOpaquePtrs(llvm::MutableArrayRef &opaquePtrs) { opaquePtrs[0] = (void *)val; } }; /// A type representing an argument used in the invocation of a circuit /// function. struct InvocationDescriptor { /// An argument can be a memref descriptor, if the argument is a tensor, or a /// scalar descriptor, if the argument is a scalar. std::variant inner; static InvocationDescriptor fromValue(Value &value) { if (value.hasElementType()) { return fromTensor(*value.getTensorPtr()); } else if (value.hasElementType()) { return fromTensor(*value.getTensorPtr()); } else if (value.hasElementType()) { return fromTensor(*value.getTensorPtr()); } else if (value.hasElementType()) { return fromTensor(*value.getTensorPtr()); } else if (value.hasElementType()) { return fromTensor(*value.getTensorPtr()); } else if (value.hasElementType()) { return fromTensor(*value.getTensorPtr()); } else if (value.hasElementType()) { return fromTensor(*value.getTensorPtr()); } else if (value.hasElementType()) { return fromTensor(*value.getTensorPtr()); } assert(false); } Value intoValue() { if (getIsSigned()) { if (getPrecision() == 8) { return Value{intoTensor()}; } else if (getPrecision() == 16) { return Value{intoTensor()}; } else if (getPrecision() == 32) { return Value{intoTensor()}; } else if (getPrecision() == 64) { return Value{intoTensor()}; } } else { if (getPrecision() == 8) { return Value{intoTensor()}; } else if (getPrecision() == 16) { return Value{intoTensor()}; } else if (getPrecision() == 32) { return Value{intoTensor()}; } else if (getPrecision() == 64) { return Value{intoTensor()}; } } assert(false); } static InvocationDescriptor fromU64s(llvm::ArrayRef raw, size_t precision, bool isSigned) { if (raw.size() == 1) { return InvocationDescriptor{ ScalarDescriptor::fromU64s(raw, precision, isSigned)}; } else { return InvocationDescriptor{ MemRefDescriptor::fromU64s(raw, precision, isSigned)}; } } void intoOpaquePtrs(llvm::MutableArrayRef &opaquePtrs) { if (std::holds_alternative(inner)) { std::get(inner).intoOpaquePtrs(opaquePtrs); } else { std::get(inner).intoOpaquePtrs(opaquePtrs); } } // Structure used to free memory allocated by the circuit after invocation. struct Liberator { // Insert in the dropper for further freeing. void insert(const InvocationDescriptor &desc) { if (std::holds_alternative(desc.inner)) { ptrs.insert(std::get(desc.inner).allocated); } } // Free the memory. void tryFree() { for (void *ptr : ptrs) { if (ptr != nullptr && !isReferenceToMLIRGlobalMemory(ptr)) { ::free(ptr); } } } private: llvm::SmallSet ptrs; static inline bool isReferenceToMLIRGlobalMemory(void *ptr) { return reinterpret_cast(ptr) == 0xdeadbeef; } }; private: template static InvocationDescriptor fromTensor(Tensor &tensor) { if (tensor.isScalar()) { return InvocationDescriptor{ScalarDescriptor::fromTensor(tensor)}; } else { return InvocationDescriptor{MemRefDescriptor::fromTensor(tensor)}; } } template Tensor intoTensor() { if (std::holds_alternative(inner)) { return std::get(inner).intoTensor(); } else { return std::get(inner).intoTensor(); } } size_t getPrecision() { if (std::holds_alternative(inner)) { return std::get(inner).precision; } else { return std::get(inner).precision; } } bool getIsSigned() { if (std::holds_alternative(inner)) { return std::get(inner).isSigned; } else { return std::get(inner).isSigned; } } }; DynamicModule::~DynamicModule() { if (libraryHandle != nullptr) { dlclose(libraryHandle); } } Result> DynamicModule::open(const std::string &sharedLibPath) { std::shared_ptr module = std::make_shared(); module->libraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY); if (!module->libraryHandle) { return StringError("Cannot open shared library ") << dlerror(); } return module; } size_t getGateDescriptionSize(const Message &gateInfo, bool useSimulation) { auto shapeToSize = [](concreteprotocol::Shape::Reader shape) -> size_t { if (shape.getDimensions().size() == 0) { return 1; } else { return 3 + 2 * shape.getDimensions().size(); } }; auto typeInfo = gateInfo.asReader().getTypeInfo(); if (typeInfo.hasIndex()) { return shapeToSize(typeInfo.getIndex().getShape()); } else if (typeInfo.hasPlaintext()) { return shapeToSize(typeInfo.getPlaintext().getShape()); } else if (typeInfo.hasLweCiphertext()) { if (useSimulation) { if (typeInfo.getLweCiphertext() .getConcreteShape() .getDimensions() .size() == 1) { // Initially it was just one ciphertext in native mode. Only an integer // will be passed... return 1; } else { // This is either a tensor in native encoding mode, or a tensor in crt // mode or whatever. A tensor will be passed, but with the lwe dimension // removed basically (hence the -2). return shapeToSize(typeInfo.getLweCiphertext().getConcreteShape()) - 2; } } else { return shapeToSize(typeInfo.getLweCiphertext().getConcreteShape()); } } else { assert(false); } } size_t getGateIntegerPrecision(const Message &gateInfo) { if (gateInfo.asReader().getTypeInfo().hasIndex()) { return gateInfo.asReader().getTypeInfo().getIndex().getIntegerPrecision(); } else if (gateInfo.asReader().getTypeInfo().hasPlaintext()) { return gateInfo.asReader() .getTypeInfo() .getPlaintext() .getIntegerPrecision(); } else if (gateInfo.asReader().getTypeInfo().hasLweCiphertext()) { return gateInfo.asReader() .getTypeInfo() .getLweCiphertext() .getIntegerPrecision(); } assert(false); } bool getGateIsSigned(const Message &gateInfo) { if (gateInfo.asReader().getTypeInfo().hasIndex()) { return gateInfo.asReader().getTypeInfo().getIndex().getIsSigned(); } else if (gateInfo.asReader().getTypeInfo().hasPlaintext()) { return gateInfo.asReader().getTypeInfo().getPlaintext().getIsSigned(); } else if (gateInfo.asReader().getTypeInfo().hasLweCiphertext()) { return false; } assert(false); } Result> ServerCircuit::call(const ServerKeyset &serverKeyset, std::vector &args) { if (args.size() != argsBuffer.size()) { return StringError("Called circuit with wrong number of arguments"); } // We load the processed arguments in the args buffer. for (size_t i = 0; i < argsBuffer.size(); i++) { OUTCOME_TRY(argsBuffer[i], argTransformers[i](args[i])); } // The arguments has been pushed in the arg buffer, we are now ready to // invoke the circuit function. invoke(serverKeyset); // We process the return values to turn them into transport values. std::vector returns(returnsBuffer.size()); for (size_t i = 0; i < returnsBuffer.size(); i++) { OUTCOME_TRY(returns[i], returnTransformers[i](returnsBuffer[i])); } return returns; } Result> ServerCircuit::simulate(std::vector &args) { ServerKeyset emptyKeyset; return call(emptyKeyset, args); } std::string ServerCircuit::getName() { return circuitInfo.asReader().getName(); } Result ServerCircuit::fromDynamicModule( const Message &circuitInfo, std::shared_ptr dynamicModule, bool useSimulation = false) { ServerCircuit output; output.circuitInfo = circuitInfo; output.useSimulation = useSimulation; output.dynamicModule = dynamicModule; output.func = (void (*)(void *, ...))dlsym( dynamicModule->libraryHandle, (std::string("_mlir_concrete_") + std::string(circuitInfo.asReader().getName().cStr())) .c_str()); if (auto err = dlerror()) { return StringError("Circuit symbol not found in dynamic module: ") << std::string(err); } // We prepare the args transformers used to transform transport values into // arg values. for (auto gateInfo : circuitInfo.asReader().getInputs()) { ArgTransformer transformer; if (gateInfo.getTypeInfo().hasIndex()) { OUTCOME_TRY(transformer, TransformerFactory::getIndexArgTransformer(gateInfo)); } else if (gateInfo.getTypeInfo().hasPlaintext()) { OUTCOME_TRY(transformer, TransformerFactory::getPlaintextArgTransformer(gateInfo)); } else if (gateInfo.getTypeInfo().hasLweCiphertext()) { OUTCOME_TRY(transformer, TransformerFactory::getLweCiphertextArgTransformer( gateInfo, useSimulation)); } else { return StringError("Malformed input gate info."); } output.argTransformers.push_back(transformer); } // We prepare the return transformers used to transform return values into // transport values. for (auto gateInfo : circuitInfo.asReader().getOutputs()) { ReturnTransformer transformer; if (gateInfo.getTypeInfo().hasIndex()) { OUTCOME_TRY(transformer, TransformerFactory::getIndexReturnTransformer(gateInfo)); } else if (gateInfo.getTypeInfo().hasPlaintext()) { OUTCOME_TRY(transformer, TransformerFactory::getPlaintextReturnTransformer(gateInfo)); } else if (gateInfo.getTypeInfo().hasLweCiphertext()) { OUTCOME_TRY(transformer, TransformerFactory::getLweCiphertextReturnTransformer( gateInfo, useSimulation)); } else { return StringError("Malformed input gate info."); } output.returnTransformers.push_back(transformer); } output.argsBuffer = std::vector(circuitInfo.asReader().getInputs().size()); output.returnsBuffer = std::vector(circuitInfo.asReader().getOutputs().size()); output.argRawSize = 0; for (auto gateInfo : circuitInfo.asReader().getInputs()) { auto descriptorSize = getGateDescriptionSize(gateInfo, useSimulation); output.argDescriptorSizes.push_back(descriptorSize); output.argRawSize += descriptorSize; } output.returnRawSize = 0; for (auto gateInfo : circuitInfo.asReader().getOutputs()) { auto descriptorSize = getGateDescriptionSize(gateInfo, useSimulation); output.returnDescriptorSizes.push_back(descriptorSize); output.returnRawSize += descriptorSize; } return output; } void ServerCircuit::invoke(const ServerKeyset &serverKeyset) { // We create a runtime context from the keyset, and place a pointer to it in // the structure. RuntimeContext runtimeContext = RuntimeContext(serverKeyset); RuntimeContext *_runtimeContextPtr = &runtimeContext; auto _argRaws = std::vector(this->argRawSize); auto _argRawMaps = std::vector>(); size_t currentRawIndex = 0; for (auto descriptorSize : this->argDescriptorSizes) { auto map = llvm::MutableArrayRef(&_argRaws[currentRawIndex], descriptorSize); _argRawMaps.push_back(map); currentRawIndex += descriptorSize; } auto _returnRaws = std::vector(this->returnRawSize); auto _returnRawMaps = std::vector>(); currentRawIndex = 0; for (auto descriptorSize : this->returnDescriptorSizes) { auto map = llvm::ArrayRef(&_returnRaws[currentRawIndex], descriptorSize); _returnRawMaps.push_back(map); currentRawIndex += descriptorSize; } auto _invocationRaws = std::vector(); for (auto &arg : _argRaws) { _invocationRaws.push_back(&arg); } _invocationRaws.push_back((void *)(&_runtimeContextPtr)); _invocationRaws.push_back(reinterpret_cast(_returnRaws.data())); // We load the argument descriptors in the _argRaws for (unsigned int i = 0; i < circuitInfo.asReader().getInputs().size(); i++) { // We construct a descriptor from the input value. InvocationDescriptor descriptor = InvocationDescriptor::fromValue(argsBuffer[i]); // We write the descriptor in the _argRaws via the maps. descriptor.intoOpaquePtrs(_argRawMaps[i]); } func(_invocationRaws.data()); // The circuit has been executed, we can load the results from the // _returnRaws. // // Note that, the addition of multi outputs made it possible to have aliased // outputs. We must then deduplicate the output descriptors before freeing // their memory to prevent constructing corrupted outputs and double-freeing. auto liberator = InvocationDescriptor::Liberator(); for (unsigned int i = 0; i < circuitInfo.asReader().getOutputs().size(); i++) { // We read the descriptor from the _returnRaws via the maps. size_t precision = getGateIntegerPrecision(circuitInfo.asReader().getOutputs()[i]); bool isSigned = getGateIsSigned(circuitInfo.asReader().getOutputs()[i]); InvocationDescriptor descriptor = InvocationDescriptor::fromU64s(_returnRawMaps[i], precision, isSigned); // We generate a value from the descriptor which we store in the // returnsBuffer. returnsBuffer[i] = descriptor.intoValue(); // // We push the descriptor into the output set for later freeing. liberator.insert(descriptor); } // We (eventually) free the memory allocated for this result by the circuit. liberator.tryFree(); } Result ServerProgram::load(const Message &programInfo, const std::string &sharedLibPath, bool useSimulation) { ServerProgram output; OUTCOME_TRY(auto dynamicModule, DynamicModule::open(sharedLibPath)); auto sharedDynamicModule = std::shared_ptr(dynamicModule); std::vector serverCircuits; for (auto circuitInfo : programInfo.asReader().getCircuits()) { OUTCOME_TRY(auto serverCircuit, ServerCircuit::fromDynamicModule( circuitInfo, sharedDynamicModule, useSimulation)); serverCircuits.push_back(serverCircuit); } output.serverCircuits = serverCircuits; return output; } Result ServerProgram::getServerCircuit(const std::string &circuitName) { for (auto serverCircuit : serverCircuits) { if (serverCircuit.getName() == circuitName) { return serverCircuit; } } return StringError("Tried to get unknown server circuit: `" + circuitName + "`"); } } // namespace serverlib } // namespace concretelang