// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "concretelang/Common/Values.h" #include "capnp/common.h" #include "capnp/list.h" #include "concrete-protocol.capnp.h" #include "concretelang/Common/Error.h" #include "concretelang/Common/Protocol.h" #include #include #include #include using concretelang::error::Result; using concretelang::error::StringError; using concretelang::protocol::dimensionsToProtoShape; using concretelang::protocol::Message; using concretelang::protocol::protoPayloadToVector; using concretelang::protocol::protoShapeToDimensions; using concretelang::protocol::vectorToProtoPayload; namespace concretelang { namespace values { Value Value::fromRawTransportValue(TransportValue transportVal) { Value output; auto integerPrecision = transportVal.asReader().getRawInfo().getIntegerPrecision(); auto isSigned = transportVal.asReader().getRawInfo().getIsSigned(); auto dimensions = protoShapeToDimensions(transportVal.asReader().getRawInfo().getShape()); auto data = transportVal.asReader().getPayload(); if (integerPrecision == 8 && isSigned) { auto values = protoPayloadToVector(data); output.inner = Tensor{values, dimensions}; } else if (integerPrecision == 16 && isSigned) { auto values = protoPayloadToVector(data); output.inner = Tensor{values, dimensions}; } else if (integerPrecision == 32 && isSigned) { auto values = protoPayloadToVector(data); output.inner = Tensor{values, dimensions}; } else if (integerPrecision == 64 && isSigned) { auto values = protoPayloadToVector(data); output.inner = Tensor{values, dimensions}; } else if (integerPrecision == 8 && !isSigned) { auto values = protoPayloadToVector(data); output.inner = Tensor{values, dimensions}; } else if (integerPrecision == 16 && !isSigned) { auto values = protoPayloadToVector(data); output.inner = Tensor{values, dimensions}; } else if (integerPrecision == 32 && !isSigned) { auto values = protoPayloadToVector(data); output.inner = Tensor{values, dimensions}; } else if (integerPrecision == 64 && !isSigned) { auto values = protoPayloadToVector(data); output.inner = Tensor{values, dimensions}; } else { assert(false); } return output; } TransportValue Value::intoRawTransportValue() const { auto output = Message(); auto rawInfo = output.asBuilder().initRawInfo(); rawInfo.setShape(intoProtoShape().asReader()); rawInfo.setIntegerPrecision(getIntegerPrecision()); rawInfo.setIsSigned(isSigned()); output.asBuilder().setPayload(intoProtoPayload().asReader()); return output; } uint32_t Value::getIntegerPrecision() const { if (hasElementType() || hasElementType()) { return 8; } else if (hasElementType() || hasElementType()) { return 16; } else if (hasElementType() || hasElementType()) { return 32; } else if (hasElementType() || hasElementType()) { return 64; } else { assert(false); } } bool Value::isSigned() const { if (hasElementType() || hasElementType() || hasElementType() || hasElementType()) { return false; } else if (hasElementType() || hasElementType() || hasElementType() || hasElementType()) { return true; } else { assert(false); } } Message Value::intoProtoPayload() const { if (hasElementType()) { return vectorToProtoPayload(std::get>(inner).values); } else if (hasElementType()) { return vectorToProtoPayload(std::get>(inner).values); } else if (hasElementType()) { return vectorToProtoPayload(std::get>(inner).values); } else if (hasElementType()) { return vectorToProtoPayload(std::get>(inner).values); } else if (hasElementType()) { return vectorToProtoPayload(std::get>(inner).values); } else if (hasElementType()) { return vectorToProtoPayload(std::get>(inner).values); } else if (hasElementType()) { return vectorToProtoPayload(std::get>(inner).values); } else if (hasElementType()) { return vectorToProtoPayload(std::get>(inner).values); } else { assert(false); } } Message Value::intoProtoShape() const { return dimensionsToProtoShape(getDimensions()); } std::vector Value::getDimensions() const { if (auto tensor = getTensor(); tensor) { return tensor.value().dimensions; } else if (auto tensor = getTensor(); tensor) { return tensor.value().dimensions; } else if (auto tensor = getTensor(); tensor) { return tensor.value().dimensions; } else if (auto tensor = getTensor(); tensor) { return tensor.value().dimensions; } else if (auto tensor = getTensor(); tensor) { return tensor.value().dimensions; } else if (auto tensor = getTensor(); tensor) { return tensor.value().dimensions; } else if (auto tensor = getTensor(); tensor) { return tensor.value().dimensions; } else if (auto tensor = getTensor(); tensor) { return tensor.value().dimensions; } else { assert(false); } } size_t Value::getLength() const { if (auto tensor = getTensor(); tensor) { return tensor.value().values.size(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().values.size(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().values.size(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().values.size(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().values.size(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().values.size(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().values.size(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().values.size(); } else { assert(false); } } bool Value::isCompatibleWithShape( const Message &shape) const { auto dimensions = getDimensions(); if ((uint32_t)shape.asReader().getDimensions().size() != dimensions.size()) { return false; } for (uint32_t i = 0; i < dimensions.size(); i++) { if (shape.asReader().getDimensions()[i] != dimensions[i]) { return false; } } return true; } bool Value::operator==(const Value &b) const { if (auto tensor = getTensor(); tensor) { return tensor == b.getTensor(); } else if (auto tensor = getTensor(); tensor) { return tensor == b.getTensor(); } else if (auto tensor = getTensor(); tensor) { return tensor == b.getTensor(); } else if (auto tensor = getTensor(); tensor) { return tensor == b.getTensor(); } else if (auto tensor = getTensor(); tensor) { return tensor == b.getTensor(); } else if (auto tensor = getTensor(); tensor) { return tensor == b.getTensor(); } else if (auto tensor = getTensor(); tensor) { return tensor == b.getTensor(); } else if (auto tensor = getTensor(); tensor) { return tensor == b.getTensor(); } else { assert(false); } } bool Value::isScalar() const { if (auto tensor = getTensor(); tensor) { return tensor.value().isScalar(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().isScalar(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().isScalar(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().isScalar(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().isScalar(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().isScalar(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().isScalar(); } else if (auto tensor = getTensor(); tensor) { return tensor.value().isScalar(); } else { assert(false); } } Value Value::toUnsigned() const { if (!this->isSigned()) { return *this; } else if (auto tensor = getTensor(); tensor) { return Value((Tensor)tensor.value()); } else if (auto tensor = getTensor(); tensor) { return Value((Tensor)tensor.value()); } else if (auto tensor = getTensor(); tensor) { return Value((Tensor)tensor.value()); } else if (auto tensor = getTensor(); tensor) { return Value((Tensor)tensor.value()); } else { assert(false); } } Value Value::toSigned() const { if (!this->isSigned()) { return *this; } else if (auto tensor = getTensor(); tensor) { return Value((Tensor)tensor.value()); } else if (auto tensor = getTensor(); tensor) { return Value((Tensor)tensor.value()); } else if (auto tensor = getTensor(); tensor) { return Value((Tensor)tensor.value()); } else if (auto tensor = getTensor(); tensor) { return Value((Tensor)tensor.value()); } else { assert(false); } } size_t getCorrespondingPrecision(size_t originalPrecision) { if (originalPrecision <= 8) { return 8; } if (originalPrecision <= 16) { return 16; } if (originalPrecision <= 32) { return 32; } if (originalPrecision <= 64) { return 64; } assert(false); } } // namespace values } // namespace concretelang