#include "EndToEndFixture.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Jit.h" #include "llvm/Support/YAMLParser.h" #include "llvm/Support/YAMLTraits.h" using mlir::concretelang::StreamStringError; llvm::Expected scalarDescToLambdaArgument(ScalarDesc desc) { switch (desc.width) { case 8: return new mlir::concretelang::IntLambdaArgument(desc.value); case 16: return new mlir::concretelang::IntLambdaArgument(desc.value); case 32: return new mlir::concretelang::IntLambdaArgument(desc.value); case 64: return new mlir::concretelang::IntLambdaArgument(desc.value); } return StreamStringError("unsupported width of scalar value: ") << desc.width; } llvm::Expected TensorDescriptionToLambdaArgument(TensorDescription desc) { switch (desc.width) { case 8:; return new mlir::concretelang::TensorLambdaArgument< mlir::concretelang::IntLambdaArgument>( std::vector(desc.values.begin(), desc.values.end()), desc.shape); case 16: return new mlir::concretelang::TensorLambdaArgument< mlir::concretelang::IntLambdaArgument>( std::vector(desc.values.begin(), desc.values.end()), desc.shape); case 32: return new mlir::concretelang::TensorLambdaArgument< mlir::concretelang::IntLambdaArgument>( std::vector(desc.values.begin(), desc.values.end()), desc.shape); case 64: return new mlir::concretelang::TensorLambdaArgument< mlir::concretelang::IntLambdaArgument>(desc.values, desc.shape); } return StreamStringError("unsupported width of tensor value: ") << desc.width; } llvm::Expected valueDescriptionToLambdaArgument(ValueDescription desc) { switch (desc.tag) { case ValueDescription::SCALAR: return scalarDescToLambdaArgument(desc.scalar); case ValueDescription::TENSOR: return TensorDescriptionToLambdaArgument(desc.tensor); } return StreamStringError("unsupported value description"); } llvm::Error checkResult(ScalarDesc &desc, mlir::concretelang::LambdaArgument &res) { auto res64 = res.dyn_cast>(); if (res64 == nullptr) { return StreamStringError("invocation result is not a scalar"); } if (desc.value != res64->getValue()) { return StreamStringError("unexpected result value: got ") << res64->getValue() << "expected " << desc.value; } return llvm::Error::success(); } template llvm::Error checkTensorResult(TensorDescription &desc, mlir::concretelang::TensorLambdaArgument< mlir::concretelang::IntLambdaArgument> *res) { if (!desc.shape.empty()) { auto resShape = res->getDimensions(); if (desc.shape.size() != resShape.size()) { return StreamStringError("size of shape differs, got ") << resShape.size() << " expected " << desc.shape.size(); } for (size_t i = 0; i < desc.shape.size(); i++) { if (resShape[i] != desc.shape[i]) { return StreamStringError("shape differs at pos ") << i << ", got " << resShape[i] << " expected " << desc.shape[i]; } } } auto resValues = res->getValue(); auto numElts = res->getNumElements(); if (!numElts) { return numElts.takeError(); } if (desc.values.size() != *numElts) { return StreamStringError("size of result differs, got ") << *numElts << " expected " << desc.values.size(); } for (size_t i = 0; i < *numElts; i++) { if (resValues[i] != desc.values[i]) { return StreamStringError("result value differ at pos(") << i << "), got " << resValues[i] << " expected " << desc.values[i]; } } return llvm::Error::success(); } llvm::Error checkResult(TensorDescription &desc, mlir::concretelang::LambdaArgument &res) { switch (desc.width) { case 8: return checkTensorResult( desc, res.dyn_cast>>()); case 16: return checkTensorResult( desc, res.dyn_cast>>()); case 32: return checkTensorResult( desc, res.dyn_cast>>()); case 64: return checkTensorResult( desc, res.dyn_cast>>()); default: return StreamStringError("Unsupported width"); } } llvm::Error checkResult(ValueDescription &desc, mlir::concretelang::LambdaArgument &res) { switch (desc.tag) { case ValueDescription::SCALAR: return checkResult(desc.scalar, res); case ValueDescription::TENSOR: return checkResult(desc.tensor, res); } assert(false); } template <> struct llvm::yaml::MappingTraits { static void mapping(IO &io, ValueDescription &desc) { auto keys = io.keys(); if (std::find(keys.begin(), keys.end(), "scalar") != keys.end()) { io.mapRequired("scalar", desc.scalar.value); io.mapOptional("width", desc.scalar.width, 64); desc.tag = ValueDescription::SCALAR; return; } if (std::find(keys.begin(), keys.end(), "tensor") != keys.end()) { io.mapRequired("tensor", desc.tensor.values); io.mapOptional("width", desc.tensor.width, 64); io.mapRequired("shape", desc.tensor.shape); desc.tag = ValueDescription::TENSOR; return; } io.setError("Missing scalar or tensor key"); } }; LLVM_YAML_IS_SEQUENCE_VECTOR(ValueDescription); template <> struct llvm::yaml::MappingTraits { static void mapping(IO &io, TestDescription &desc) { io.mapOptional("inputs", desc.inputs); io.mapOptional("outputs", desc.outputs); } }; LLVM_YAML_IS_SEQUENCE_VECTOR(TestDescription); template <> struct llvm::yaml::MappingTraits { static void mapping(IO &io, EndToEndDesc &desc) { io.mapRequired("description", desc.description); io.mapRequired("program", desc.program); io.mapRequired("tests", desc.tests); std::vector v0parameter; io.mapOptional("v0-parameter", v0parameter); if (!v0parameter.empty()) { if (v0parameter.size() != 7) { io.setError("v0-parameter expect to be a list 7 elemnts " "[glweDimension, logPolynomialSize, nSmall, brLevel, " "brLobBase, ksLevel, ksLogBase]"); } desc.v0Parameter = mlir::concretelang::V0Parameter( v0parameter[0], v0parameter[1], v0parameter[2], v0parameter[3], v0parameter[4], v0parameter[5], v0parameter[6]); } std::vector v0constraint; io.mapOptional("v0-constraint", v0constraint); if (!v0constraint.empty()) { if (v0constraint.size() != 2) { io.setError("v0-constraint expect to be a list 2 elemnts " "[p, norm2]"); } desc.v0Constraint = mlir::concretelang::V0FHEConstraint(); desc.v0Constraint->p = v0constraint[0]; desc.v0Constraint->norm2 = v0constraint[1]; } } }; LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(EndToEndDesc) std::vector loadEndToEndDesc(std::string path) { std::ifstream file(path); std::string content((std::istreambuf_iterator(file)), (std::istreambuf_iterator())); llvm::yaml::Input yin(content); // Parse the YAML file std::vector desc; yin >> desc; // Check for error if (yin.error()) assert(false && "cannot parse doc"); return desc; }