From bb4412499936431943a05d10eb8255440c97c8e7 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Tue, 5 Oct 2021 14:45:56 +0100 Subject: [PATCH] feat(dfr): add the DFR (DataFlow Runtime). --- .github/workflows/conformance.yml | 28 ++ .github/workflows/docker-hpx.yml | 29 ++ .github/workflows/docker-zamalang-df.yml | 35 ++ builders/Dockerfile.hpx-env | 22 ++ builders/Dockerfile.zamalang-df-env | 43 +++ compiler/CMakeLists.txt | 17 + compiler/Makefile | 12 +- .../include/zamalang/Runtime/DFRuntime.hpp | 116 ++++++ .../zamalang/Runtime/dfr_debug_interface.h | 13 + .../distributed_generic_task_server.hpp | 273 ++++++++++++++ .../include/zamalang/Runtime/key_manager.hpp | 153 ++++++++ .../include/zamalang/Runtime/runtime_api.h | 34 ++ compiler/lib/Runtime/CMakeLists.txt | 10 +- compiler/lib/Runtime/DFRuntime.cpp | 253 +++++++++++++ compiler/src/CMakeLists.txt | 52 ++- compiler/tests/CMakeLists.txt | 2 +- compiler/tests/unittest/CMakeLists.txt | 20 + compiler/tests/unittest/end_to_end_jit_dfr.cc | 348 ++++++++++++++++++ 18 files changed, 1444 insertions(+), 16 deletions(-) create mode 100644 .github/workflows/docker-hpx.yml create mode 100644 .github/workflows/docker-zamalang-df.yml create mode 100644 builders/Dockerfile.hpx-env create mode 100644 builders/Dockerfile.zamalang-df-env create mode 100644 compiler/include/zamalang/Runtime/DFRuntime.hpp create mode 100644 compiler/include/zamalang/Runtime/dfr_debug_interface.h create mode 100644 compiler/include/zamalang/Runtime/distributed_generic_task_server.hpp create mode 100644 compiler/include/zamalang/Runtime/key_manager.hpp create mode 100644 compiler/include/zamalang/Runtime/runtime_api.h create mode 100644 compiler/lib/Runtime/DFRuntime.cpp create mode 100644 compiler/tests/unittest/end_to_end_jit_dfr.cc diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 6496f159b..c09272326 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -52,3 +52,31 @@ jobs: echo "Debug: ccache statistics (after the build):" ccache -s chmod -R ugo+rwx /tmp/KeySetCache + BuildAndTestDF: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + + - name: Build and test compiler (dataflow) + uses: addnab/docker-run-action@v3 + with: + registry: ghcr.io + image: ghcr.io/zama-ai/zamalang-df-compiler:latest + username: ${{ secrets.GHCR_LOGIN }} + password: ${{ secrets.GHCR_PASSWORD }} + options: -v ${{ github.workspace }}/compiler:/compiler + shell: bash + run: | + set -e + echo "Debug: ccache statistics (prior to the build):" + ccache -s + cd /compiler + pip install pytest + rm -rf /build + export PYTHONPATH="" + make PARALLEL_EXECUTION_ENABLED=ON CCACHE=ON BUILD_DIR=/build test test-dataflow + echo "Debug: ccache statistics (after the build):" + ccache -s + chmod -R ugo+rwx /tmp/KeySetCache diff --git a/.github/workflows/docker-hpx.yml b/.github/workflows/docker-hpx.yml new file mode 100644 index 000000000..15d5bd2e9 --- /dev/null +++ b/.github/workflows/docker-hpx.yml @@ -0,0 +1,29 @@ +name: Docker image (HPX build) + +on: + push: + paths: + - builders/Dockerfile.hpx-env + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + build_publish: + name: Build & Publish the Docker image + runs-on: ubuntu-latest + env: + IMAGE: ghcr.io/zama-ai/hpx + + steps: + - uses: actions/checkout@v2 + + - name: build + run: docker build -t $IMAGE -f builders/Dockerfile.hpx-env . + + - name: login + run: echo "${{ secrets.GHCR_PASSWORD }}" | docker login -u ${{ secrets.GHCR_LOGIN }} --password-stdin ghcr.io + + - name: tag and publish + run: | + docker push $IMAGE:latest diff --git a/.github/workflows/docker-zamalang-df.yml b/.github/workflows/docker-zamalang-df.yml new file mode 100644 index 000000000..d9f1b1345 --- /dev/null +++ b/.github/workflows/docker-zamalang-df.yml @@ -0,0 +1,35 @@ +name: Docker image (dataflow) + +on: + push: + branches: + - master + release: + types: [created] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + build_publish: + name: Build & Publish the Docker image (dataflow) + runs-on: ubuntu-latest + env: + IMAGE: ghcr.io/zama-ai/zamalang-df-compiler + + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + + - name: login + run: echo "${{ secrets.GHCR_PASSWORD }}" | docker login -u ${{ secrets.GHCR_LOGIN }} --password-stdin ghcr.io + + - name: build + run: docker image build --no-cache --label "commit-sha=${{ github.sha }}" -t $IMAGE -f builders/Dockerfile.zamalang-df-env . + + - name: tag and publish + run: | + docker image tag $IMAGE $IMAGE:${{ github.sha }} + docker image push $IMAGE:latest + docker image push $IMAGE:${{ github.sha }} diff --git a/builders/Dockerfile.hpx-env b/builders/Dockerfile.hpx-env new file mode 100644 index 000000000..32959f4d4 --- /dev/null +++ b/builders/Dockerfile.hpx-env @@ -0,0 +1,22 @@ +FROM ubuntu:latest + +RUN apt-get update --fix-missing +RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y curl cmake g++ build-essential python3 python3-pip python3-setuptools ninja-build git libboost-filesystem-dev libhwloc-dev +RUN pip install numpy pybind11==2.6.2 PyYAML +RUN mkdir /cmake-build +ADD https://github.com/Kitware/CMake/releases/download/v3.22.0/cmake-3.22.0-linux-x86_64.tar.gz /cmake-build/cmake.tar.gz +RUN cd /cmake-build && tar xzf cmake.tar.gz +ENV PATH=/cmake-build/cmake-3.22.0-linux-x86_64/bin:${PATH} +RUN git clone https://github.com/STEllAR-GROUP/hpx.git +ENV HPX=$PWD/hpx +RUN cd ${HPX} && git checkout 1.7.1 +RUN mkdir ${HPX}/build +RUN cd ${HPX}/build && cmake \ + -DHPX_WITH_FETCH_ASIO=on \ + -DHPX_FILESYSTEM_WITH_BOOST_FILESYSTEM_COMPATIBILITY=ON \ + -DHPX_WITH_MALLOC=system .. +RUN cd ${HPX}/build && make -j2 + +FROM ubuntu:latest +COPY --from=0 /hpx/ /hpx/ +COPY --from=0 /cmake-build/ /cmake/ diff --git a/builders/Dockerfile.zamalang-df-env b/builders/Dockerfile.zamalang-df-env new file mode 100644 index 000000000..f35440109 --- /dev/null +++ b/builders/Dockerfile.zamalang-df-env @@ -0,0 +1,43 @@ +FROM ubuntu:latest + +RUN apt-get update +RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y curl cmake g++ \ + build-essential python3 python3-pip \ + python3-setuptools ninja-build git \ + zlib1g-dev ccache libboost-filesystem-dev libhwloc-dev +# setup ccache with an unlimited amount of files and storage +RUN ccache -M 0 +RUN ccache -F 0 +RUN pip install numpy pybind11==2.6.2 PyYAML +# Setup Concrete +COPY --from=ghcr.io/zama-ai/concrete-api-env:latest /target/release /concrete/target/release +ENV CONCRETE_PROJECT=/concrete +# Setup HPX +COPY --from=ghcr.io/zama-ai/hpx:latest /hpx /hpx +ENV HPX_INSTALL_DIR=/hpx/build +# Setup Cmake +COPY --from=ghcr.io/zama-ai/hpx:latest /cmake /cmake +ENV PATH=/cmake/cmake-3.22.0-linux-x86_64/bin:${PATH} +# Setup LLVM +COPY /llvm-project /llvm-project +# Setup and build compiler +COPY /compiler /compiler +WORKDIR /compiler +RUN mkdir -p /build +RUN make PARALLEL_EXECUTION_ENABLED=ON BUILD_DIR=/build CCACHE=ON \ + zamacompiler python-bindings && \ + mv /build/tools/zamalang/python_packages/zamalang_core /zamalang_core && \ + mv /build/bin/zamacompiler /bin && \ + mv /build/lib/libZamalangRuntime.so /lib && \ + mv /build/lib/libDFRuntime.so /lib && \ + rm -rf /build && \ + mkdir -p /build/tools/zamalang/python_packages/ && \ + mkdir -p /build/bin && \ + mkdir -p /build/lib && \ + mv /zamalang_core /build/tools/zamalang/python_packages/ && \ + mv /bin/zamacompiler /build/bin && \ + mv /lib/libZamalangRuntime.so /build/lib && \ + mv /lib/libDFRuntime.so /build/lib +ENV PYTHONPATH "$PYTHONPATH:/build/tools/zamalang/python_packages/zamalang_core" +ENV PATH "$PATH:/build/bin" +ENV RT_LIB "/build/lib/libZamalangRuntime.so" diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index 1de40aa21..788354d63 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -84,6 +84,23 @@ else() message(STATUS "ZamaLang Python bindings are disabled.") endif() +#------------------------------------------------------------------------------- +# DFR - parallel execution configuration +#------------------------------------------------------------------------------- + +option(ZAMALANG_PARALLEL_EXECUTION_ENABLED "Enables parallel execution for ZamaLang." ON) + +if(ZAMALANG_PARALLEL_EXECUTION_ENABLED) + message(STATUS "ZamaLang parallel execution enabled.") + + find_package(HPX REQUIRED CONFIG) + include_directories(SYSTEM ${HPX_INCLUDE_DIRS}) + list(APPEND CMAKE_MODULE_PATH "${HPX_CMAKE_DIR}") + +else() + message(STATUS "ZamaLang parallel execution disabled.") +endif() + #------------------------------------------------------------------------------- # Unit tests #------------------------------------------------------------------------------- diff --git a/compiler/Makefile b/compiler/Makefile index 75ba847f2..11ac83d0a 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -1,6 +1,7 @@ BUILD_DIR=./build Python3_EXECUTABLE= BINDINGS_PYTHON_ENABLED=ON +PARALLEL_EXECUTION_ENABLED=OFF ifeq ($(shell which ccache),) CCACHE=OFF @@ -24,7 +25,9 @@ $(BUILD_DIR)/configured.stamp: -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=$(BINDINGS_PYTHON_ENABLED) \ -DZAMALANG_BINDINGS_PYTHON_ENABLED=$(BINDINGS_PYTHON_ENABLED) \ + -DZAMALANG_PARALLEL_EXECUTION_ENABLED=$(PARALLEL_EXECUTION_ENABLED) \ -DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release \ + -DHPX_DIR=${HPX_INSTALL_DIR}/lib/cmake/HPX \ -DLLVM_EXTERNAL_PROJECTS=zamalang \ -DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=. \ -DPython3_EXECUTABLE=${Python3_EXECUTABLE} @@ -47,6 +50,8 @@ test-python: python-bindings zamacompiler test: test-check test-end-to-end-jit test-python +test-dataflow: test-end-to-end-jit-dfr + # Unittests build-end-to-end-jit-test: build-initialized @@ -64,8 +69,12 @@ build-end-to-end-jit-hlfhelinalg: build-initialized build-end-to-end-jit-lambda: build-initialized cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda +build-end-to-end-jit-dfr: build-initialized + cmake --build $(BUILD_DIR) --target end_to_end_jit_dfr + build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg + test-end-to-end-jit-test: build-end-to-end-jit-test $(BUILD_DIR)/bin/end_to_end_jit_test @@ -81,7 +90,8 @@ test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda $(BUILD_DIR)/bin/end_to_end_jit_lambda - +test-end-to-end-jit-dfr: build-end-to-end-jit-dfr + $(BUILD_DIR)/bin/end_to_end_jit_dfr test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg diff --git a/compiler/include/zamalang/Runtime/DFRuntime.hpp b/compiler/include/zamalang/Runtime/DFRuntime.hpp new file mode 100644 index 000000000..55fe86cc3 --- /dev/null +++ b/compiler/include/zamalang/Runtime/DFRuntime.hpp @@ -0,0 +1,116 @@ +#ifndef ZAMALANG_DFR_DFRUNTIME_HPP +#define ZAMALANG_DFR_DFRUNTIME_HPP + +#include +#include +#include + +#include "zamalang/Runtime/runtime_api.h" + +/* Debug interface. */ +#include "zamalang/Runtime/dfr_debug_interface.h" + +extern void *dl_handle; +struct WorkFunctionRegistry; +extern WorkFunctionRegistry *node_level_work_function_registry; + +// Recover the name of the work function +static inline const char * +_dfr_get_function_name_from_address(void *fn) +{ + Dl_info info; + + if (!dladdr(fn, &info) || info.dli_sname == nullptr) + HPX_THROW_EXCEPTION(hpx::no_success, + "_dfr_get_function_name_from_address", + "Error recovering work function name from address."); + return info.dli_sname; +} + +static inline wfnptr +_dfr_get_function_pointer_from_name(const char *fn_name) +{ + auto ptr = dlsym(dl_handle, fn_name); + + if (ptr == nullptr) + HPX_THROW_EXCEPTION(hpx::no_success, + "_dfr_get_function_pointer_from_name", + "Error recovering work function pointer from name."); + return (wfnptr) ptr; +} + +// Determine where new task should run. For now just round-robin +// distribution - TODO: optimise. +static inline size_t +_dfr_find_next_execution_locality() +{ + static size_t num_nodes = hpx::get_num_localities().get(); + static std::atomic next_locality{0}; + + size_t next_loc = ++next_locality; + + return next_loc % num_nodes; +} + +static inline bool +_dfr_is_root_node() +{ + return hpx::find_here() == hpx::find_root_locality(); +} + +struct WorkFunctionRegistry +{ + WorkFunctionRegistry() + { + node_level_work_function_registry = this; + } + + wfnptr getWorkFunctionPointer(const std::string &name) + { + std::lock_guard guard(registry_guard); + + auto fnptrit = name_to_ptr_registry.find(name); + if (fnptrit != name_to_ptr_registry.end()) + return (wfnptr) fnptrit->second; + + auto ptr = dlsym(dl_handle, name.c_str()); + if (ptr == nullptr) + HPX_THROW_EXCEPTION(hpx::no_success, + "WorkFunctionRegistry::getWorkFunctionPointer", + "Error recovering work function pointer from name."); + ptr_to_name_registry.insert(std::pair(ptr, name)); + name_to_ptr_registry.insert(std::pair(name, ptr)); + return (wfnptr) ptr; + } + + std::string getWorkFunctionName(const void *fn) + { + std::lock_guard guard(registry_guard); + + auto fnnameit = ptr_to_name_registry.find(fn); + if (fnnameit != ptr_to_name_registry.end()) + return fnnameit->second; + + Dl_info info; + std::string ret; + // Assume that if we can't find the name, there is no dynamic + // library to find it in. TODO: fix this to distinguish JIT/binary + // and in case of distributed exec. + if (!dladdr(fn, &info) || info.dli_sname == nullptr) + { + static std::atomic fnid{0}; + ret = "_dfr_jit_wfnname_" + std::to_string(fnid++); + } else { + ret = info.dli_sname; + } + ptr_to_name_registry.insert(std::pair(fn, ret)); + name_to_ptr_registry.insert(std::pair(ret, fn)); + return ret; + } +private: + std::mutex registry_guard; + std::map ptr_to_name_registry; + std::map name_to_ptr_registry; +}; + +#endif diff --git a/compiler/include/zamalang/Runtime/dfr_debug_interface.h b/compiler/include/zamalang/Runtime/dfr_debug_interface.h new file mode 100644 index 000000000..48ca0dcc4 --- /dev/null +++ b/compiler/include/zamalang/Runtime/dfr_debug_interface.h @@ -0,0 +1,13 @@ +#ifndef ZAMALANG_DRF_DEBUG_INTERFACE_H +#define ZAMALANG_DRF_DEBUG_INTERFACE_H + +#include +#include + +extern "C" { +size_t _dfr_debug_get_node_id(); +size_t _dfr_debug_get_worker_id(); +void _dfr_debug_print_task(const char *name, int inputs, int outputs); +void _dfr_print_debug(size_t val); +} +#endif diff --git a/compiler/include/zamalang/Runtime/distributed_generic_task_server.hpp b/compiler/include/zamalang/Runtime/distributed_generic_task_server.hpp new file mode 100644 index 000000000..6d47970b6 --- /dev/null +++ b/compiler/include/zamalang/Runtime/distributed_generic_task_server.hpp @@ -0,0 +1,273 @@ +#ifndef ZAMALANG_DFR_DISTRIBUTED_GENERIC_TASK_SERVER_HPP +#define ZAMALANG_DFR_DISTRIBUTED_GENERIC_TASK_SERVER_HPP + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "zamalang/Runtime/key_manager.hpp" +#include "zamalang/Runtime/DFRuntime.hpp" + +extern WorkFunctionRegistry *node_level_work_function_registry; + +using namespace hpx::naming; +using namespace hpx::components; +using namespace hpx::collectives; + + +struct OpaqueInputData +{ + OpaqueInputData() = default; + + OpaqueInputData(std::string wfn_name, + std::vector params, + std::vector param_sizes, + std::vector output_sizes, + bool alloc_p = false) : + wfn_name(wfn_name), + params(std::move(params)), + param_sizes(std::move(param_sizes)), + output_sizes(std::move(output_sizes)), + alloc_p(alloc_p) + {} + + OpaqueInputData(const OpaqueInputData &oid) : + wfn_name(std::move(oid.wfn_name)), + params(std::move(oid.params)), + param_sizes(std::move(oid.param_sizes)), + output_sizes(std::move(oid.output_sizes)), + alloc_p(oid.alloc_p) + {} + + friend class hpx::serialization::access; + template + void load(Archive &ar, const unsigned int version) + { + ar & wfn_name; + ar & param_sizes; + ar & output_sizes; + for (auto p : param_sizes) + { + char *param = new char[p]; + // TODO: Optimise these serialisation operations + for (size_t i = 0; i < p; ++i) + ar & param[i]; + params.push_back((void *)param); + } + alloc_p = true; + } + template + void save(Archive &ar, const unsigned int version) const + { + ar & wfn_name; + ar & param_sizes; + ar & output_sizes; + for (size_t p = 0; p < params.size(); ++p) + for (size_t i = 0; i < param_sizes[p]; ++i) + ar & static_cast(params[p])[i]; + } + HPX_SERIALIZATION_SPLIT_MEMBER() + + std::string wfn_name; + std::vector params; + std::vector param_sizes; + std::vector output_sizes; + bool alloc_p = false; +}; + +struct OpaqueOutputData +{ + OpaqueOutputData() = default; + OpaqueOutputData(std::vector outputs, + std::vector output_sizes, + bool alloc_p = false) : + outputs(std::move(outputs)), + output_sizes(std::move(output_sizes)), + alloc_p(alloc_p) + {} + OpaqueOutputData(const OpaqueOutputData &ood) : + outputs(std::move(ood.outputs)), + output_sizes(std::move(ood.output_sizes)), + alloc_p(ood.alloc_p) + {} + + friend class hpx::serialization::access; + template + void load(Archive &ar, const unsigned int version) + { + ar & output_sizes; + for (auto p : output_sizes) + { + char *output = new char[p]; + for (size_t i = 0; i < p; ++i) + ar & output[i]; + outputs.push_back((void *)output); + } + alloc_p = true; + } + template + void save(Archive &ar, const unsigned int version) const + { + ar & output_sizes; + for (size_t p = 0; p < outputs.size(); ++p) + { + for (size_t i = 0; i < output_sizes[p]; ++i) + ar & static_cast(outputs[p])[i]; + // TODO: investigate if HPX is automatically deallocating + //these. Here it could be safely assumed that these would no + //longer be live. + //delete ((char*)outputs[p]); + } + } + HPX_SERIALIZATION_SPLIT_MEMBER() + + std::vector outputs; + std::vector output_sizes; + bool alloc_p = false; +}; + +struct GenericComputeServer : component_base +{ + GenericComputeServer () = default; + + // Component actions exposed + OpaqueOutputData execute_task (const OpaqueInputData &inputs) + { + auto wfn = node_level_work_function_registry->getWorkFunctionPointer(inputs.wfn_name); + std::vector outputs; + + switch (inputs.output_sizes.size()) { + case 1: + { + void *output = (void *)(new char[inputs.output_sizes[0]]); + switch (inputs.params.size()) { + case 0: + wfn(output); + break; + case 1: + wfn(inputs.params[0], output); + break; + case 2: + wfn(inputs.params[0], inputs.params[1], output); + break; + case 3: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], output); + break; + default: + HPX_THROW_EXCEPTION(hpx::no_success, + "GenericComputeServer::execute_task", + "Error: number of task parameters not supported."); + } + outputs = {output}; + break; + } + case 2: + { + void *output1 = (void *)(new char[inputs.output_sizes[0]]); + void *output2 = (void *)(new char[inputs.output_sizes[1]]); + switch (inputs.params.size()) { + case 0: + wfn(output1, output2); + break; + case 1: + wfn(inputs.params[0], output1, output2); + break; + case 2: + wfn(inputs.params[0], inputs.params[1], output1, output2); + break; + case 3: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, output2); + break; + default: + HPX_THROW_EXCEPTION(hpx::no_success, + "GenericComputeServer::execute_task", + "Error: number of task parameters not supported."); + } + outputs = {output1, output2}; + break; + } + case 3: + { + void *output1 = (void *)(new char[inputs.output_sizes[0]]); + void *output2 = (void *)(new char[inputs.output_sizes[1]]); + void *output3 = (void *)(new char[inputs.output_sizes[2]]); + switch (inputs.params.size()) { + case 0: + wfn(output1, output2, output3); + break; + case 1: + wfn(inputs.params[0], output1, output2, output3); + break; + case 2: + wfn(inputs.params[0], inputs.params[1], output1, output2, output3); + break; + case 3: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, output2, output3); + break; + default: + HPX_THROW_EXCEPTION(hpx::no_success, + "GenericComputeServer::execute_task", + "Error: number of task parameters not supported."); + } + outputs = {output1, output2, output3}; + break; + } + default: + HPX_THROW_EXCEPTION(hpx::no_success, + "GenericComputeServer::execute_task", + "Error: number of task outputs not supported."); + } + + if (inputs.alloc_p) + for (auto p : inputs.params) + delete((char*)p); + + return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes), inputs.alloc_p); + } + + HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task); +}; + +HPX_REGISTER_ACTION_DECLARATION(GenericComputeServer::execute_task_action, + GenericComputeServer_execute_task_action) + +HPX_REGISTER_COMPONENT_MODULE() +HPX_REGISTER_COMPONENT(hpx::components::component, + GenericComputeServer) + +HPX_REGISTER_ACTION(GenericComputeServer::execute_task_action, + GenericComputeServer_execute_task_action) + + +struct GenericComputeClient : client_base +{ + typedef client_base base_type; + + GenericComputeClient() = default; + GenericComputeClient(id_type id) : base_type(std::move(id)) {} + + hpx::future + execute_task(const OpaqueInputData &inputs) + { + typedef GenericComputeServer::execute_task_action action_type; + return hpx::async(this->get_id(), inputs); + } +}; + +#endif diff --git a/compiler/include/zamalang/Runtime/key_manager.hpp b/compiler/include/zamalang/Runtime/key_manager.hpp new file mode 100644 index 000000000..4ce73cc88 --- /dev/null +++ b/compiler/include/zamalang/Runtime/key_manager.hpp @@ -0,0 +1,153 @@ +#ifndef ZAMALANG_DFR_KEY_MANAGER_HPP +#define ZAMALANG_DFR_KEY_MANAGER_HPP + +#include +#include + +#include +#include + +#include "zamalang/Runtime/DFRuntime.hpp" + +struct PbsKeyManager; +extern PbsKeyManager *node_level_key_manager; + + +struct PbsKeyWrapper +{ + std::shared_ptr key; + size_t key_id; + size_t size; + + PbsKeyWrapper() {} + + PbsKeyWrapper(void *key, size_t key_id, size_t size) : + key(std::make_shared(key)), key_id(key_id), size(size) {} + + PbsKeyWrapper(std::shared_ptr key, size_t key_id, size_t size) : + key(key), key_id(key_id), size(size) {} + + PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept : + key(moved.key), key_id(moved.key_id), size(moved.size) {} + + PbsKeyWrapper(const PbsKeyWrapper &pbsk) : + key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {} + + friend class hpx::serialization::access; + template + void save(Archive &ar, const unsigned int version) const + { + char *_key_ = static_cast(*key); + ar & key_id & size; + for (size_t i = 0; i < size; ++i) + ar & _key_[i]; + } + + template + void load(Archive &ar, const unsigned int version) + { + ar & key_id & size; + char *_key_ = (char *) malloc(size); + for (size_t i = 0; i < size; ++i) + ar & _key_[i]; + key = std::make_shared(_key_); + } + HPX_SERIALIZATION_SPLIT_MEMBER() +}; + +inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs) +{ + return lhs.key_id == rhs.key_id; +} + + +PbsKeyWrapper _dfr_fetch_key(size_t); +HPX_PLAIN_ACTION(_dfr_fetch_key, _dfr_fetch_key_action) + +struct PbsKeyManager +{ + // The initial keys registered on the root node and whether to push + // them is TBD. + + PbsKeyManager() + { + node_level_key_manager = this; + } + + PbsKeyWrapper get_key(const size_t key_id) + { + keystore_guard.lock(); + auto keyit = keystore.find(key_id); + keystore_guard.unlock(); + + if (keyit == keystore.end()) + { + _dfr_fetch_key_action fet; + PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id); + if (pkw.size == 0) + { + // Maybe retry or try other nodes... but for now it's an error. + HPX_THROW_EXCEPTION(hpx::no_success, + "_dfr_get_key", + "Error: key not found on remote node."); + } + else + { + std::lock_guard guard(keystore_guard); + keyit = keystore.insert(std::pair(key_id, pkw)).first; + } + } + return keyit->second; + } + + // To be used only for remote requests + PbsKeyWrapper fetch_key(const size_t key_id) + { + std::lock_guard guard(keystore_guard); + + auto keyit = keystore.find(key_id); + if (keyit != keystore.end()) + return keyit->second; + // If this node does not contain this key, return an empty wrapper + return PbsKeyWrapper(nullptr, 0, 0); + } + + void register_key(void *key, size_t key_id, size_t size) + { + std::lock_guard guard(keystore_guard); + auto keyit = + keystore.insert( + std::pair(key_id, + PbsKeyWrapper(key, key_id, size))).first; + if (keyit == keystore.end()) + { + HPX_THROW_EXCEPTION(hpx::no_success, + "_dfr_register_key", + "Error: could not register new key."); + } + } + + void broadcast_keys() + { + std::lock_guard guard(keystore_guard); + if (_dfr_is_root_node()) + hpx::collectives::broadcast_to("keystore", this->keystore).get(); + else + keystore = std::move( + hpx::collectives::broadcast_from> + ("keystore").get()); + } + +private: + std::mutex keystore_guard; + std::map keystore; +}; + + +PbsKeyWrapper +_dfr_fetch_key(size_t key_id) +{ + return node_level_key_manager->fetch_key(key_id); +} + +#endif diff --git a/compiler/include/zamalang/Runtime/runtime_api.h b/compiler/include/zamalang/Runtime/runtime_api.h new file mode 100644 index 000000000..6c43d413d --- /dev/null +++ b/compiler/include/zamalang/Runtime/runtime_api.h @@ -0,0 +1,34 @@ +/** + Define the API exposed to the compiler for code generation. + */ + +#ifndef ZAMALANG_DFR_RUNTIME_API_H +#define ZAMALANG_DFR_RUNTIME_API_H +#include +#include + +extern "C" { + +typedef void (*wfnptr)(...); + +void *_dfr_make_ready_future(void *); +void _dfr_create_async_task(wfnptr, size_t, size_t, ...); +void *_dfr_await_future(void *); + +/* Keys can have node-local copies which can be retrieved. This + should only be called on the node where the key is required. */ +void _dfr_register_key(void *, size_t, size_t); +void _dfr_broadcast_keys(); +void *_dfr_get_key(size_t); + +/* Memory management: + _dfr_make_ready_future allocates the future, not the underlying storage. + _dfr_create_async_task allocates both future and storage for outputs. */ +void _dfr_deallocate_future(void *); +void _dfr_deallocate_future_data(void *); + +/* Initialisation & termination. */ +void _dfr_start(); +void _dfr_stop(); +} +#endif diff --git a/compiler/lib/Runtime/CMakeLists.txt b/compiler/lib/Runtime/CMakeLists.txt index ca71ffbee..9022735f1 100644 --- a/compiler/lib/Runtime/CMakeLists.txt +++ b/compiler/lib/Runtime/CMakeLists.txt @@ -6,4 +6,12 @@ add_library(ZamalangRuntime SHARED target_link_libraries(ZamalangRuntime Concrete pthread m dl) install(TARGETS ZamalangRuntime EXPORT ZamalangRuntime) -install(EXPORT ZamalangRuntime DESTINATION "./") \ No newline at end of file +install(EXPORT ZamalangRuntime DESTINATION "./") + +if(ZAMALANG_PARALLEL_EXECUTION_ENABLED) + add_library(DFRuntime SHARED DFRuntime.cpp) + target_link_libraries(DFRuntime PUBLIC pthread m dl HPX::hpx HPX::iostreams_component -rdynamic) + + install(TARGETS DFRuntime EXPORT DFRuntime) + install(EXPORT DFRuntime DESTINATION "./") +endif() diff --git a/compiler/lib/Runtime/DFRuntime.cpp b/compiler/lib/Runtime/DFRuntime.cpp new file mode 100644 index 000000000..658da0440 --- /dev/null +++ b/compiler/lib/Runtime/DFRuntime.cpp @@ -0,0 +1,253 @@ +/** + This file implements the dataflow runtime. It encapsulates all of + the underlying communication, parallelism, etc. and only exposes a + simplified interface for code generation in runtime_api.h + + This hides the details of implementation, including of the HPX + framework currently used, from the code generation side. + */ + +#include +#include +#include + +#include "zamalang/Runtime/DFRuntime.hpp" +#include "zamalang/Runtime/distributed_generic_task_server.hpp" +#include "zamalang/Runtime/runtime_api.h" + +std::vector gcc; +void *dl_handle; +PbsKeyManager *node_level_key_manager; +WorkFunctionRegistry *node_level_work_function_registry; + +using namespace hpx; + +void *_dfr_make_ready_future(void *in) { + return static_cast( + new hpx::shared_future(hpx::make_ready_future(in))); +} + +void *_dfr_await_future(void *in) { + return static_cast *>(in)->get(); +} + +void _dfr_deallocate_future_data(void *in) { + delete (static_cast *>(in)->get()); +} + +void _dfr_deallocate_future(void *in) { + delete (static_cast *>(in)); +} + +// Runtime generic async_task. Each first NUM_PARAMS pairs of +// arguments in the variadic list corresponds to a void* pointer on a +// hpx::future and the size of data within the future. After +// that come NUM_OUTPUTS pairs of hpx::future* and size_t for +// the returns. +void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, + ...) { + std::vector params; + std::vector outputs; + std::vector param_sizes; + std::vector output_sizes; + + va_list args; + va_start(args, num_outputs); + for (size_t i = 0; i < num_params; ++i) { + params.push_back(va_arg(args, void *)); + param_sizes.push_back(va_arg(args, size_t)); + } + for (size_t i = 0; i < num_outputs; ++i) { + outputs.push_back(va_arg(args, void *)); + output_sizes.push_back(va_arg(args, size_t)); + } + va_end(args); + + // We pass functions by name - which is not strictly necessary in + // shared memory as pointers suffice, but is needed in the + // distributed case where the functions need to be located/loaded on + // the node. + auto wfnname = + node_level_work_function_registry->getWorkFunctionName((void *)wfn); + hpx::future> oodf; + + // In order to allow complete dataflow semantics for + // communication/synchronization, we split tasks in two parts: an + // execution body that is scheduled once all input dependences are + // satisfied, which generates a future on a tuple of outputs, which + // is then further split into a tuple of futures and provide + // individual synchronization for each return independently. + switch (num_params) { + case 0: + oodf = std::move( + hpx::dataflow([wfnname, param_sizes, + output_sizes]() -> hpx::future { + std::vector params = {}; + OpaqueInputData oid(wfnname, params, param_sizes, output_sizes); + return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + })); + break; + + case 1: + oodf = std::move(hpx::dataflow( + [wfnname, param_sizes, output_sizes](hpx::shared_future param0) + -> hpx::future { + std::vector params = {param0.get()}; + OpaqueInputData oid(wfnname, params, param_sizes, output_sizes); + return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + }, + *(hpx::shared_future *)params[0])); + break; + + case 2: + oodf = std::move(hpx::dataflow( + [wfnname, param_sizes, output_sizes](hpx::shared_future param0, + hpx::shared_future param1) + -> hpx::future { + std::vector params = {param0.get(), param1.get()}; + OpaqueInputData oid(wfnname, params, param_sizes, output_sizes); + return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + }, + *(hpx::shared_future *)params[0], + *(hpx::shared_future *)params[1])); + break; + + case 3: + oodf = std::move(hpx::dataflow( + [wfnname, param_sizes, output_sizes](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2) + -> hpx::future { + std::vector params = {param0.get(), param1.get(), + param2.get()}; + OpaqueInputData oid(wfnname, params, param_sizes, output_sizes); + return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + }, + *(hpx::shared_future *)params[0], + *(hpx::shared_future *)params[1], + *(hpx::shared_future *)params[2])); + break; + + default: + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_create_async_task", + "Error: number of task parameters not supported."); + } + + switch (num_outputs) { + case 1: + *((void **)outputs[0]) = new hpx::shared_future(hpx::dataflow( + [](hpx::future oodf_in) -> void * { + return oodf_in.get().outputs[0]; + }, + oodf)); + break; + + case 2: { + hpx::future> &&ft = hpx::dataflow( + [](hpx::future oodf_in) + -> hpx::tuple { + std::vector outputs = std::move(oodf_in.get().outputs); + return hpx::make_tuple<>(outputs[0], outputs[1]); + }, + oodf); + hpx::tuple, hpx::future> &&tf = + hpx::split_future(std::move(ft)); + *((void **)outputs[0]) = + (void *)new hpx::shared_future(std::move(hpx::get<0>(tf))); + *((void **)outputs[1]) = + (void *)new hpx::shared_future(std::move(hpx::get<1>(tf))); + break; + } + + case 3: { + hpx::future> &&ft = hpx::dataflow( + [](hpx::future oodf_in) + -> hpx::tuple { + std::vector outputs = std::move(oodf_in.get().outputs); + return hpx::make_tuple<>(outputs[0], outputs[1], outputs[2]); + }, + oodf); + hpx::tuple, hpx::future, hpx::future> + &&tf = hpx::split_future(std::move(ft)); + *((void **)outputs[0]) = + (void *)new hpx::shared_future(std::move(hpx::get<0>(tf))); + *((void **)outputs[1]) = + (void *)new hpx::shared_future(std::move(hpx::get<1>(tf))); + *((void **)outputs[2]) = + (void *)new hpx::shared_future(std::move(hpx::get<2>(tf))); + break; + } + default: + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_create_async_task", + "Error: number of task outputs not supported."); + } +} + +/* Distributed key management. */ +void _dfr_register_key(void *key, size_t key_id, size_t size) { + node_level_key_manager->register_key(key, key_id, size); +} + +void _dfr_broadcast_keys() { node_level_key_manager->broadcast_keys(); } + +void *_dfr_get_key(size_t key_id) { + return *node_level_key_manager->get_key(key_id).key.get(); +} + +/* Runtime initialization and finalization. */ +static inline void _dfr_stop_impl() { + hpx::apply([]() { hpx::finalize(); }); + hpx::stop(); +} + +static inline void _dfr_start_impl(int argc, char *argv[]) { + dl_handle = dlopen(nullptr, RTLD_NOW); + if (argc == 0) { + char *_argv[1] = {const_cast("__dummy_dfr_HPX_program_name__")}; + int _argc = 1; + hpx::start(nullptr, _argc, _argv); + } else { + hpx::start(nullptr, argc, argv); + } + + new PbsKeyManager(); + new WorkFunctionRegistry(); + + if (!_dfr_is_root_node()) { + _dfr_stop_impl(); + exit(EXIT_SUCCESS); + } + + // Create compute server components on each node and the + // corresponding compute client. + auto num_nodes = hpx::get_num_localities().get(); + gcc = hpx::new_( + hpx::default_layout(hpx::find_all_localities()), num_nodes) + .get(); +} + +// TODO: we need a better way to wrap main. For now loader --wrap and +// main's constructor/destructor are not functional, but that should +// replace the current, inefficient calls to _dfr_start/stop generated +// in each compiled function. +void _dfr_start() { _dfr_start_impl(0, nullptr); } +void _dfr_stop() { _dfr_stop_impl(); } + +/* Debug interface. */ +size_t _dfr_debug_get_node_id() { return hpx::get_locality_id(); } + +size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); } + +void _dfr_debug_print_task(const char *name, int inputs, int outputs) { + // clang-format off + hpx::cout << "Task \"" << name << "\"" + << " [" << inputs << " inputs, " << outputs << " outputs]" + << " Executing on Node/Worker: " << _dfr_debug_get_node_id() + << " / " << _dfr_debug_get_worker_id() << "\n" << std::flush; + // clang-format on +} + +// Generic utility function for printing debug info +void _dfr_print_debug(size_t val) { + hpx::cout << "_dfr_print_debug : " << val << "\n" << std::flush; +} diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index 1a29b8b3c..617919b1d 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -5,21 +5,47 @@ llvm_update_compile_flags(zamacompiler) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -target_link_libraries(zamacompiler - PRIVATE - ${dialect_libs} - ${conversion_libs} - MLIRTransforms - LowLFHEDialect - MidLFHEDialect - HLFHEDialect +if(ZAMALANG_PARALLEL_EXECUTION_ENABLED) + target_link_libraries(zamacompiler + PRIVATE + ${dialect_libs} + ${conversion_libs} - MLIRIR - MLIRLLVMIR - MLIRLLVMToLLVMIRTranslation + MLIRTransforms + LowLFHEDialect + MidLFHEDialect + HLFHEDialect + + MLIRIR + MLIRLLVMIR + MLIRLLVMToLLVMIRTranslation + + ZamalangSupport + + -Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime + -Wl,-rpath,${HPX_DIR}/../../ + -Wl,--no-as-needed + DFRuntime + ) +else() + target_link_libraries(zamacompiler + PRIVATE + ${dialect_libs} + ${conversion_libs} + + MLIRTransforms + LowLFHEDialect + MidLFHEDialect + HLFHEDialect + + MLIRIR + MLIRLLVMIR + MLIRLLVMToLLVMIRTranslation + + ZamalangSupport + ) +endif() - ZamalangSupport -) mlir_check_all_link_libraries(zamacompiler) diff --git a/compiler/tests/CMakeLists.txt b/compiler/tests/CMakeLists.txt index 7266ea592..be6027493 100644 --- a/compiler/tests/CMakeLists.txt +++ b/compiler/tests/CMakeLists.txt @@ -1,3 +1,3 @@ if (ZAMALANG_UNIT_TESTS) add_subdirectory(unittest) -endif() \ No newline at end of file +endif() diff --git a/compiler/tests/unittest/CMakeLists.txt b/compiler/tests/unittest/CMakeLists.txt index b6ca3f966..438d9954c 100644 --- a/compiler/tests/unittest/CMakeLists.txt +++ b/compiler/tests/unittest/CMakeLists.txt @@ -79,3 +79,23 @@ gtest_discover_tests(end_to_end_jit_encrypted_tensor) gtest_discover_tests(end_to_end_jit_hlfhelinalg) gtest_discover_tests(end_to_end_jit_lambda) +if(ZAMALANG_PARALLEL_EXECUTION_ENABLED) + add_executable( + end_to_end_jit_dfr + end_to_end_jit_dfr.cc + ) + set_source_files_properties( + end_to_end_jit_dfr.cc + PROPERTIES COMPILE_FLAGS "-fno-rtti" + ) + target_link_libraries( + end_to_end_jit_dfr + gtest_main + ZamalangSupport + -Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime + -Wl,-rpath,${HPX_DIR}/../../ + -Wl,--no-as-needed + DFRuntime + ) + gtest_discover_tests(end_to_end_jit_dfr) +endif() diff --git a/compiler/tests/unittest/end_to_end_jit_dfr.cc b/compiler/tests/unittest/end_to_end_jit_dfr.cc new file mode 100644 index 000000000..6cc3af452 --- /dev/null +++ b/compiler/tests/unittest/end_to_end_jit_dfr.cc @@ -0,0 +1,348 @@ + +#include +#include +#include + +#include "end_to_end_jit_test.h" + +const mlir::zamalang::V0FHEConstraint defaultV0Constraints{10, 7}; + +TEST(CompileAndRunDFR, start_stop) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func private @_dfr_stop() +func private @_dfr_start() +func @main() -> i64{ + call @_dfr_start() : () -> () + %1 = arith.constant 7 : i64 + call @_dfr_stop() : () -> () + return %1 : i64 +} +)XXX", "main", true); + ASSERT_EXPECTED_VALUE(lambda(), 7); +} + +TEST(CompileAndRunDFR, 0in1out_task) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + llvm.func @_dfr_await_future(!llvm.ptr) -> !llvm.ptr> attributes {sym_visibility = "private"} + llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"} + llvm.func @_dfr_stop() + llvm.func @_dfr_start() + func @main() -> i64 { + %0 = llvm.mlir.addressof @_dfr_DFT_work_function__main0 : !llvm.ptr)>> + %1 = llvm.mlir.constant(0 : i64) : i64 + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.mlir.constant(8 : i64) : i64 + llvm.call @_dfr_start() : () -> () + %4 = llvm.mlir.constant(1 : i64) : i64 + %5 = llvm.alloca %4 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%0, %1, %2, %5, %3) : (!llvm.ptr)>>, i64, i64, !llvm.ptr>, i64) -> () + %6 = llvm.load %5 : !llvm.ptr> + %7 = llvm.call @_dfr_await_future(%6) : (!llvm.ptr) -> !llvm.ptr> + %8 = llvm.bitcast %7 : !llvm.ptr> to !llvm.ptr + %9 = llvm.load %8 : !llvm.ptr + llvm.call @_dfr_stop() : () -> () + return %9 : i64 + } + llvm.func @_dfr_DFT_work_function__main0(%arg0: !llvm.ptr) { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(3 : i64) : i64 + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg0 : !llvm.ptr + llvm.return + } +)XXX", "main", true); + ASSERT_EXPECTED_VALUE(lambda(), 7); +} + +TEST(CompileAndRunDFR, 1in1out_task) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + llvm.func @_dfr_await_future(!llvm.ptr) -> !llvm.ptr> attributes {sym_visibility = "private"} + llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"} + llvm.func @malloc(i64) -> !llvm.ptr + llvm.func @_dfr_make_ready_future(...) -> !llvm.ptr attributes {sym_visibility = "private"} + llvm.func @_dfr_stop() + llvm.func @_dfr_start() + func @main(%arg0: i64) -> i64 { + %0 = llvm.mlir.addressof @_dfr_DFT_work_function__main0 : !llvm.ptr, ptr)>> + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(8 : i64) : i64 + llvm.call @_dfr_start() : () -> () + %3 = llvm.mlir.null : !llvm.ptr + %4 = llvm.mlir.constant(1 : index) : i64 + %5 = llvm.getelementptr %3[%4] : (!llvm.ptr, i64) -> !llvm.ptr + %6 = llvm.ptrtoint %5 : !llvm.ptr to i64 + %7 = llvm.call @malloc(%6) : (i64) -> !llvm.ptr + %8 = llvm.bitcast %7 : !llvm.ptr to !llvm.ptr + llvm.store %arg0, %8 : !llvm.ptr + %9 = llvm.call @_dfr_make_ready_future(%8) : (!llvm.ptr) -> !llvm.ptr + %10 = llvm.mlir.constant(1 : i64) : i64 + %11 = llvm.alloca %10 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%0, %1, %1, %9, %2, %11, %2) : (!llvm.ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %12 = llvm.load %11 : !llvm.ptr> + %13 = llvm.call @_dfr_await_future(%12) : (!llvm.ptr) -> !llvm.ptr> + %14 = llvm.bitcast %13 : !llvm.ptr> to !llvm.ptr + %15 = llvm.load %14 : !llvm.ptr + llvm.call @_dfr_stop() : () -> () + return %15 : i64 + } + llvm.func @_dfr_DFT_work_function__main0(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { + %0 = llvm.mlir.constant(2 : i64) : i64 + %1 = llvm.load %arg0 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %1, %0 : i64 + llvm.store %2, %arg1 : !llvm.ptr + llvm.return + } +)XXX", "main", true); + + ASSERT_EXPECTED_VALUE(lambda(5_u64), 7); +} + +TEST(CompileAndRunDFR, 2in1out_task) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + llvm.func @_dfr_await_future(!llvm.ptr) -> !llvm.ptr> attributes {sym_visibility = "private"} + llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"} + llvm.func @malloc(i64) -> !llvm.ptr + llvm.func @_dfr_make_ready_future(...) -> !llvm.ptr attributes {sym_visibility = "private"} + llvm.func @_dfr_stop() + llvm.func @_dfr_start() + func @main(%arg0: i64, %arg1: i64) -> i64 { + %0 = llvm.mlir.addressof @_dfr_DFT_work_function__main0 : !llvm.ptr, ptr, ptr)>> + %1 = llvm.mlir.constant(2 : i64) : i64 + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.mlir.constant(8 : i64) : i64 + llvm.call @_dfr_start() : () -> () + %4 = llvm.mlir.null : !llvm.ptr + %5 = llvm.mlir.constant(1 : index) : i64 + %6 = llvm.getelementptr %4[%5] : (!llvm.ptr, i64) -> !llvm.ptr + %7 = llvm.ptrtoint %6 : !llvm.ptr to i64 + %8 = llvm.call @malloc(%7) : (i64) -> !llvm.ptr + %9 = llvm.bitcast %8 : !llvm.ptr to !llvm.ptr + llvm.store %arg0, %9 : !llvm.ptr + %10 = llvm.call @_dfr_make_ready_future(%9) : (!llvm.ptr) -> !llvm.ptr + %11 = llvm.mlir.null : !llvm.ptr + %12 = llvm.mlir.constant(1 : index) : i64 + %13 = llvm.getelementptr %11[%12] : (!llvm.ptr, i64) -> !llvm.ptr + %14 = llvm.ptrtoint %13 : !llvm.ptr to i64 + %15 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr + %16 = llvm.bitcast %15 : !llvm.ptr to !llvm.ptr + llvm.store %arg1, %16 : !llvm.ptr + %17 = llvm.call @_dfr_make_ready_future(%16) : (!llvm.ptr) -> !llvm.ptr + %18 = llvm.mlir.constant(1 : i64) : i64 + %19 = llvm.alloca %18 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%0, %1, %2, %10, %3, %17, %3, %19, %3) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %20 = llvm.load %19 : !llvm.ptr> + %21 = llvm.call @_dfr_await_future(%20) : (!llvm.ptr) -> !llvm.ptr> + %22 = llvm.bitcast %21 : !llvm.ptr> to !llvm.ptr + %23 = llvm.load %22 : !llvm.ptr + llvm.call @_dfr_stop() : () -> () + return %23 : i64 + } + llvm.func @_dfr_DFT_work_function__main0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } +)XXX", "main", true); + + ASSERT_EXPECTED_VALUE(lambda(1_u64, 6_u64), 7); +} + + + +TEST(CompileAndRunDFR, taskgraph) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + llvm.func @_dfr_await_future(!llvm.ptr) -> !llvm.ptr> attributes {sym_visibility = "private"} + llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"} + llvm.func @malloc(i64) -> !llvm.ptr + llvm.func @_dfr_make_ready_future(...) -> !llvm.ptr attributes {sym_visibility = "private"} + llvm.func @_dfr_stop() + llvm.func @_dfr_start() + func @main(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 { + %0 = llvm.mlir.constant(7 : i64) : i64 + %1 = llvm.mlir.addressof @_dfr_DFT_work_function__main0 : !llvm.ptr, ptr, ptr)>> + %2 = llvm.mlir.constant(2 : i64) : i64 + %3 = llvm.mlir.constant(1 : i64) : i64 + %4 = llvm.mlir.constant(8 : i64) : i64 + %5 = llvm.mlir.addressof @_dfr_DFT_work_function__main1 : !llvm.ptr, ptr, ptr)>> + %6 = llvm.mlir.addressof @_dfr_DFT_work_function__main2 : !llvm.ptr, ptr, ptr)>> + %7 = llvm.mlir.addressof @_dfr_DFT_work_function__main3 : !llvm.ptr, ptr, ptr)>> + %8 = llvm.mlir.addressof @_dfr_DFT_work_function__main4 : !llvm.ptr, ptr, ptr)>> + %9 = llvm.mlir.addressof @_dfr_DFT_work_function__main5 : !llvm.ptr, ptr, ptr)>> + %10 = llvm.mlir.addressof @_dfr_DFT_work_function__main6 : !llvm.ptr, ptr, ptr)>> + %11 = llvm.mlir.addressof @_dfr_DFT_work_function__main7 : !llvm.ptr, ptr, ptr)>> + llvm.call @_dfr_start() : () -> () + %12 = llvm.mlir.null : !llvm.ptr + %13 = llvm.mlir.constant(1 : index) : i64 + %14 = llvm.getelementptr %12[%13] : (!llvm.ptr, i64) -> !llvm.ptr + %15 = llvm.ptrtoint %14 : !llvm.ptr to i64 + %16 = llvm.call @malloc(%15) : (i64) -> !llvm.ptr + %17 = llvm.bitcast %16 : !llvm.ptr to !llvm.ptr + llvm.store %arg0, %17 : !llvm.ptr + %18 = llvm.call @_dfr_make_ready_future(%17) : (!llvm.ptr) -> !llvm.ptr + %19 = llvm.mlir.null : !llvm.ptr + %20 = llvm.mlir.constant(1 : index) : i64 + %21 = llvm.getelementptr %19[%20] : (!llvm.ptr, i64) -> !llvm.ptr + %22 = llvm.ptrtoint %21 : !llvm.ptr to i64 + %23 = llvm.call @malloc(%22) : (i64) -> !llvm.ptr + %24 = llvm.bitcast %23 : !llvm.ptr to !llvm.ptr + llvm.store %arg1, %24 : !llvm.ptr + %25 = llvm.call @_dfr_make_ready_future(%24) : (!llvm.ptr) -> !llvm.ptr + %26 = llvm.mlir.constant(1 : i64) : i64 + %27 = llvm.alloca %26 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%1, %2, %3, %18, %4, %25, %4, %27, %4) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %28 = llvm.load %27 : !llvm.ptr> + %29 = llvm.mlir.null : !llvm.ptr + %30 = llvm.mlir.constant(1 : index) : i64 + %31 = llvm.getelementptr %29[%30] : (!llvm.ptr, i64) -> !llvm.ptr + %32 = llvm.ptrtoint %31 : !llvm.ptr to i64 + %33 = llvm.call @malloc(%32) : (i64) -> !llvm.ptr + %34 = llvm.bitcast %33 : !llvm.ptr to !llvm.ptr + llvm.store %arg2, %34 : !llvm.ptr + %35 = llvm.call @_dfr_make_ready_future(%34) : (!llvm.ptr) -> !llvm.ptr + %36 = llvm.mlir.constant(1 : i64) : i64 + %37 = llvm.alloca %36 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%5, %2, %3, %18, %4, %35, %4, %37, %4) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %38 = llvm.load %37 : !llvm.ptr> + %39 = llvm.mlir.constant(1 : i64) : i64 + %40 = llvm.alloca %39 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%6, %2, %3, %25, %4, %35, %4, %40, %4) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %41 = llvm.load %40 : !llvm.ptr> + %42 = llvm.mul %arg0, %0 : i64 + %43 = llvm.mul %arg1, %0 : i64 + %44 = llvm.mul %arg2, %0 : i64 + %45 = llvm.mlir.null : !llvm.ptr + %46 = llvm.mlir.constant(1 : index) : i64 + %47 = llvm.getelementptr %45[%46] : (!llvm.ptr, i64) -> !llvm.ptr + %48 = llvm.ptrtoint %47 : !llvm.ptr to i64 + %49 = llvm.call @malloc(%48) : (i64) -> !llvm.ptr + %50 = llvm.bitcast %49 : !llvm.ptr to !llvm.ptr + llvm.store %42, %50 : !llvm.ptr + %51 = llvm.call @_dfr_make_ready_future(%50) : (!llvm.ptr) -> !llvm.ptr + %52 = llvm.mlir.constant(1 : i64) : i64 + %53 = llvm.alloca %52 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%7, %2, %3, %28, %4, %51, %4, %53, %4) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %54 = llvm.load %53 : !llvm.ptr> + %55 = llvm.mlir.null : !llvm.ptr + %56 = llvm.mlir.constant(1 : index) : i64 + %57 = llvm.getelementptr %55[%56] : (!llvm.ptr, i64) -> !llvm.ptr + %58 = llvm.ptrtoint %57 : !llvm.ptr to i64 + %59 = llvm.call @malloc(%58) : (i64) -> !llvm.ptr + %60 = llvm.bitcast %59 : !llvm.ptr to !llvm.ptr + llvm.store %43, %60 : !llvm.ptr + %61 = llvm.call @_dfr_make_ready_future(%60) : (!llvm.ptr) -> !llvm.ptr + %62 = llvm.mlir.constant(1 : i64) : i64 + %63 = llvm.alloca %62 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%8, %2, %3, %38, %4, %61, %4, %63, %4) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %64 = llvm.load %63 : !llvm.ptr> + %65 = llvm.mlir.null : !llvm.ptr + %66 = llvm.mlir.constant(1 : index) : i64 + %67 = llvm.getelementptr %65[%66] : (!llvm.ptr, i64) -> !llvm.ptr + %68 = llvm.ptrtoint %67 : !llvm.ptr to i64 + %69 = llvm.call @malloc(%68) : (i64) -> !llvm.ptr + %70 = llvm.bitcast %69 : !llvm.ptr to !llvm.ptr + llvm.store %44, %70 : !llvm.ptr + %71 = llvm.call @_dfr_make_ready_future(%70) : (!llvm.ptr) -> !llvm.ptr + %72 = llvm.mlir.constant(1 : i64) : i64 + %73 = llvm.alloca %72 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%9, %2, %3, %41, %4, %71, %4, %73, %4) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %74 = llvm.load %73 : !llvm.ptr> + %75 = llvm.mlir.constant(1 : i64) : i64 + %76 = llvm.alloca %75 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%10, %2, %3, %54, %4, %64, %4, %76, %4) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %77 = llvm.load %76 : !llvm.ptr> + %78 = llvm.mlir.constant(1 : i64) : i64 + %79 = llvm.alloca %78 x !llvm.ptr : (i64) -> !llvm.ptr> + llvm.call @_dfr_create_async_task(%11, %2, %3, %77, %4, %74, %4, %79, %4) : (!llvm.ptr, ptr, ptr)>>, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr>, i64) -> () + %80 = llvm.load %79 : !llvm.ptr> + %81 = llvm.call @_dfr_await_future(%80) : (!llvm.ptr) -> !llvm.ptr> + %82 = llvm.bitcast %81 : !llvm.ptr> to !llvm.ptr + %83 = llvm.load %82 : !llvm.ptr + llvm.call @_dfr_stop() : () -> () + return %83 : i64 + } + llvm.func @_dfr_DFT_work_function__main0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {_dfr_work_function_attribute} { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } + llvm.func @_dfr_DFT_work_function__main1(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {_dfr_work_function_attribute} { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } + llvm.func @_dfr_DFT_work_function__main2(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {_dfr_work_function_attribute} { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } + llvm.func @_dfr_DFT_work_function__main3(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {_dfr_work_function_attribute} { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } + llvm.func @_dfr_DFT_work_function__main4(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {_dfr_work_function_attribute} { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } + llvm.func @_dfr_DFT_work_function__main5(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {_dfr_work_function_attribute} { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } + llvm.func @_dfr_DFT_work_function__main6(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {_dfr_work_function_attribute} { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } + llvm.func @_dfr_DFT_work_function__main7(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {_dfr_work_function_attribute} { + %0 = llvm.load %arg0 : !llvm.ptr + %1 = llvm.load %arg1 : !llvm.ptr + llvm.br ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.add %0, %1 : i64 + llvm.store %2, %arg2 : !llvm.ptr + llvm.return + } +)XXX", "main", true); + + ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64, 3_u64), 54); + ASSERT_EXPECTED_VALUE(lambda(2_u64, 5_u64, 1_u64), 72); + ASSERT_EXPECTED_VALUE(lambda(3_u64, 1_u64, 7_u64), 99); +}