feat(dfr): add the DFR (DataFlow Runtime).

This commit is contained in:
Antoniu Pop
2021-10-05 14:45:56 +01:00
committed by Antoniu Pop
parent 5773310215
commit bb44124999
18 changed files with 1444 additions and 16 deletions

View File

@@ -52,3 +52,31 @@ jobs:
echo "Debug: ccache statistics (after the build):"
ccache -s
chmod -R ugo+rwx /tmp/KeySetCache
BuildAndTestDF:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- name: Build and test compiler (dataflow)
uses: addnab/docker-run-action@v3
with:
registry: ghcr.io
image: ghcr.io/zama-ai/zamalang-df-compiler:latest
username: ${{ secrets.GHCR_LOGIN }}
password: ${{ secrets.GHCR_PASSWORD }}
options: -v ${{ github.workspace }}/compiler:/compiler
shell: bash
run: |
set -e
echo "Debug: ccache statistics (prior to the build):"
ccache -s
cd /compiler
pip install pytest
rm -rf /build
export PYTHONPATH=""
make PARALLEL_EXECUTION_ENABLED=ON CCACHE=ON BUILD_DIR=/build test test-dataflow
echo "Debug: ccache statistics (after the build):"
ccache -s
chmod -R ugo+rwx /tmp/KeySetCache

29
.github/workflows/docker-hpx.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: Docker image (HPX build)
on:
push:
paths:
- builders/Dockerfile.hpx-env
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
jobs:
build_publish:
name: Build & Publish the Docker image
runs-on: ubuntu-latest
env:
IMAGE: ghcr.io/zama-ai/hpx
steps:
- uses: actions/checkout@v2
- name: build
run: docker build -t $IMAGE -f builders/Dockerfile.hpx-env .
- name: login
run: echo "${{ secrets.GHCR_PASSWORD }}" | docker login -u ${{ secrets.GHCR_LOGIN }} --password-stdin ghcr.io
- name: tag and publish
run: |
docker push $IMAGE:latest

View File

@@ -0,0 +1,35 @@
name: Docker image (dataflow)
on:
push:
branches:
- master
release:
types: [created]
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
jobs:
build_publish:
name: Build & Publish the Docker image (dataflow)
runs-on: ubuntu-latest
env:
IMAGE: ghcr.io/zama-ai/zamalang-df-compiler
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- name: login
run: echo "${{ secrets.GHCR_PASSWORD }}" | docker login -u ${{ secrets.GHCR_LOGIN }} --password-stdin ghcr.io
- name: build
run: docker image build --no-cache --label "commit-sha=${{ github.sha }}" -t $IMAGE -f builders/Dockerfile.zamalang-df-env .
- name: tag and publish
run: |
docker image tag $IMAGE $IMAGE:${{ github.sha }}
docker image push $IMAGE:latest
docker image push $IMAGE:${{ github.sha }}

View File

@@ -0,0 +1,22 @@
FROM ubuntu:latest
RUN apt-get update --fix-missing
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y curl cmake g++ build-essential python3 python3-pip python3-setuptools ninja-build git libboost-filesystem-dev libhwloc-dev
RUN pip install numpy pybind11==2.6.2 PyYAML
RUN mkdir /cmake-build
ADD https://github.com/Kitware/CMake/releases/download/v3.22.0/cmake-3.22.0-linux-x86_64.tar.gz /cmake-build/cmake.tar.gz
RUN cd /cmake-build && tar xzf cmake.tar.gz
ENV PATH=/cmake-build/cmake-3.22.0-linux-x86_64/bin:${PATH}
RUN git clone https://github.com/STEllAR-GROUP/hpx.git
ENV HPX=$PWD/hpx
RUN cd ${HPX} && git checkout 1.7.1
RUN mkdir ${HPX}/build
RUN cd ${HPX}/build && cmake \
-DHPX_WITH_FETCH_ASIO=on \
-DHPX_FILESYSTEM_WITH_BOOST_FILESYSTEM_COMPATIBILITY=ON \
-DHPX_WITH_MALLOC=system ..
RUN cd ${HPX}/build && make -j2
FROM ubuntu:latest
COPY --from=0 /hpx/ /hpx/
COPY --from=0 /cmake-build/ /cmake/

View File

@@ -0,0 +1,43 @@
FROM ubuntu:latest
RUN apt-get update
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y curl cmake g++ \
build-essential python3 python3-pip \
python3-setuptools ninja-build git \
zlib1g-dev ccache libboost-filesystem-dev libhwloc-dev
# setup ccache with an unlimited amount of files and storage
RUN ccache -M 0
RUN ccache -F 0
RUN pip install numpy pybind11==2.6.2 PyYAML
# Setup Concrete
COPY --from=ghcr.io/zama-ai/concrete-api-env:latest /target/release /concrete/target/release
ENV CONCRETE_PROJECT=/concrete
# Setup HPX
COPY --from=ghcr.io/zama-ai/hpx:latest /hpx /hpx
ENV HPX_INSTALL_DIR=/hpx/build
# Setup Cmake
COPY --from=ghcr.io/zama-ai/hpx:latest /cmake /cmake
ENV PATH=/cmake/cmake-3.22.0-linux-x86_64/bin:${PATH}
# Setup LLVM
COPY /llvm-project /llvm-project
# Setup and build compiler
COPY /compiler /compiler
WORKDIR /compiler
RUN mkdir -p /build
RUN make PARALLEL_EXECUTION_ENABLED=ON BUILD_DIR=/build CCACHE=ON \
zamacompiler python-bindings && \
mv /build/tools/zamalang/python_packages/zamalang_core /zamalang_core && \
mv /build/bin/zamacompiler /bin && \
mv /build/lib/libZamalangRuntime.so /lib && \
mv /build/lib/libDFRuntime.so /lib && \
rm -rf /build && \
mkdir -p /build/tools/zamalang/python_packages/ && \
mkdir -p /build/bin && \
mkdir -p /build/lib && \
mv /zamalang_core /build/tools/zamalang/python_packages/ && \
mv /bin/zamacompiler /build/bin && \
mv /lib/libZamalangRuntime.so /build/lib && \
mv /lib/libDFRuntime.so /build/lib
ENV PYTHONPATH "$PYTHONPATH:/build/tools/zamalang/python_packages/zamalang_core"
ENV PATH "$PATH:/build/bin"
ENV RT_LIB "/build/lib/libZamalangRuntime.so"

View File

@@ -84,6 +84,23 @@ else()
message(STATUS "ZamaLang Python bindings are disabled.")
endif()
#-------------------------------------------------------------------------------
# DFR - parallel execution configuration
#-------------------------------------------------------------------------------
option(ZAMALANG_PARALLEL_EXECUTION_ENABLED "Enables parallel execution for ZamaLang." ON)
if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
message(STATUS "ZamaLang parallel execution enabled.")
find_package(HPX REQUIRED CONFIG)
include_directories(SYSTEM ${HPX_INCLUDE_DIRS})
list(APPEND CMAKE_MODULE_PATH "${HPX_CMAKE_DIR}")
else()
message(STATUS "ZamaLang parallel execution disabled.")
endif()
#-------------------------------------------------------------------------------
# Unit tests
#-------------------------------------------------------------------------------

View File

@@ -1,6 +1,7 @@
BUILD_DIR=./build
Python3_EXECUTABLE=
BINDINGS_PYTHON_ENABLED=ON
PARALLEL_EXECUTION_ENABLED=OFF
ifeq ($(shell which ccache),)
CCACHE=OFF
@@ -24,7 +25,9 @@ $(BUILD_DIR)/configured.stamp:
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_ENABLE_BINDINGS_PYTHON=$(BINDINGS_PYTHON_ENABLED) \
-DZAMALANG_BINDINGS_PYTHON_ENABLED=$(BINDINGS_PYTHON_ENABLED) \
-DZAMALANG_PARALLEL_EXECUTION_ENABLED=$(PARALLEL_EXECUTION_ENABLED) \
-DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release \
-DHPX_DIR=${HPX_INSTALL_DIR}/lib/cmake/HPX \
-DLLVM_EXTERNAL_PROJECTS=zamalang \
-DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=. \
-DPython3_EXECUTABLE=${Python3_EXECUTABLE}
@@ -47,6 +50,8 @@ test-python: python-bindings zamacompiler
test: test-check test-end-to-end-jit test-python
test-dataflow: test-end-to-end-jit-dfr
# Unittests
build-end-to-end-jit-test: build-initialized
@@ -64,8 +69,12 @@ build-end-to-end-jit-hlfhelinalg: build-initialized
build-end-to-end-jit-lambda: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda
build-end-to-end-jit-dfr: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_dfr
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg
test-end-to-end-jit-test: build-end-to-end-jit-test
$(BUILD_DIR)/bin/end_to_end_jit_test
@@ -81,7 +90,8 @@ test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg
test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda
$(BUILD_DIR)/bin/end_to_end_jit_lambda
test-end-to-end-jit-dfr: build-end-to-end-jit-dfr
$(BUILD_DIR)/bin/end_to_end_jit_dfr
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg

View File

@@ -0,0 +1,116 @@
#ifndef ZAMALANG_DFR_DFRUNTIME_HPP
#define ZAMALANG_DFR_DFRUNTIME_HPP
#include <utility>
#include <memory>
#include <dlfcn.h>
#include "zamalang/Runtime/runtime_api.h"
/* Debug interface. */
#include "zamalang/Runtime/dfr_debug_interface.h"
extern void *dl_handle;
struct WorkFunctionRegistry;
extern WorkFunctionRegistry *node_level_work_function_registry;
// Recover the name of the work function
static inline const char *
_dfr_get_function_name_from_address(void *fn)
{
Dl_info info;
if (!dladdr(fn, &info) || info.dli_sname == nullptr)
HPX_THROW_EXCEPTION(hpx::no_success,
"_dfr_get_function_name_from_address",
"Error recovering work function name from address.");
return info.dli_sname;
}
static inline wfnptr
_dfr_get_function_pointer_from_name(const char *fn_name)
{
auto ptr = dlsym(dl_handle, fn_name);
if (ptr == nullptr)
HPX_THROW_EXCEPTION(hpx::no_success,
"_dfr_get_function_pointer_from_name",
"Error recovering work function pointer from name.");
return (wfnptr) ptr;
}
// Determine where new task should run. For now just round-robin
// distribution - TODO: optimise.
static inline size_t
_dfr_find_next_execution_locality()
{
static size_t num_nodes = hpx::get_num_localities().get();
static std::atomic<std::size_t> next_locality{0};
size_t next_loc = ++next_locality;
return next_loc % num_nodes;
}
static inline bool
_dfr_is_root_node()
{
return hpx::find_here() == hpx::find_root_locality();
}
struct WorkFunctionRegistry
{
WorkFunctionRegistry()
{
node_level_work_function_registry = this;
}
wfnptr getWorkFunctionPointer(const std::string &name)
{
std::lock_guard<std::mutex> guard(registry_guard);
auto fnptrit = name_to_ptr_registry.find(name);
if (fnptrit != name_to_ptr_registry.end())
return (wfnptr) fnptrit->second;
auto ptr = dlsym(dl_handle, name.c_str());
if (ptr == nullptr)
HPX_THROW_EXCEPTION(hpx::no_success,
"WorkFunctionRegistry::getWorkFunctionPointer",
"Error recovering work function pointer from name.");
ptr_to_name_registry.insert(std::pair<const void *, std::string>(ptr, name));
name_to_ptr_registry.insert(std::pair<std::string, const void *>(name, ptr));
return (wfnptr) ptr;
}
std::string getWorkFunctionName(const void *fn)
{
std::lock_guard<std::mutex> guard(registry_guard);
auto fnnameit = ptr_to_name_registry.find(fn);
if (fnnameit != ptr_to_name_registry.end())
return fnnameit->second;
Dl_info info;
std::string ret;
// Assume that if we can't find the name, there is no dynamic
// library to find it in. TODO: fix this to distinguish JIT/binary
// and in case of distributed exec.
if (!dladdr(fn, &info) || info.dli_sname == nullptr)
{
static std::atomic<unsigned int> fnid{0};
ret = "_dfr_jit_wfnname_" + std::to_string(fnid++);
} else {
ret = info.dli_sname;
}
ptr_to_name_registry.insert(std::pair<const void *, std::string>(fn, ret));
name_to_ptr_registry.insert(std::pair<std::string, const void *>(ret, fn));
return ret;
}
private:
std::mutex registry_guard;
std::map<const void *, std::string> ptr_to_name_registry;
std::map<std::string, const void *> name_to_ptr_registry;
};
#endif

View File

@@ -0,0 +1,13 @@
#ifndef ZAMALANG_DRF_DEBUG_INTERFACE_H
#define ZAMALANG_DRF_DEBUG_INTERFACE_H
#include <stdint.h>
#include <unistd.h>
extern "C" {
size_t _dfr_debug_get_node_id();
size_t _dfr_debug_get_worker_id();
void _dfr_debug_print_task(const char *name, int inputs, int outputs);
void _dfr_print_debug(size_t val);
}
#endif

View File

@@ -0,0 +1,273 @@
#ifndef ZAMALANG_DFR_DISTRIBUTED_GENERIC_TASK_SERVER_HPP
#define ZAMALANG_DFR_DISTRIBUTED_GENERIC_TASK_SERVER_HPP
#include <cstdarg>
#include <string>
#include <cstdlib>
#include <hpx/include/actions.hpp>
#include <hpx/include/lcos.hpp>
#include <hpx/include/parallel_algorithm.hpp>
#include <hpx/include/parallel_numeric.hpp>
#include <hpx/include/util.hpp>
#include <hpx/iostream.hpp>
#include <hpx/serialization/detail/serialize_collection.hpp>
#include <hpx/serialization/serialization_fwd.hpp>
#include <hpx/serialization/serialize.hpp>
#include <hpx/async_colocated/get_colocation_id.hpp>
#include <hpx/async_colocated/get_colocation_id.hpp>
#include <hpx/include/client.hpp>
#include <hpx/include/runtime.hpp>
#include <hpx/modules/collectives.hpp>
#include "zamalang/Runtime/key_manager.hpp"
#include "zamalang/Runtime/DFRuntime.hpp"
extern WorkFunctionRegistry *node_level_work_function_registry;
using namespace hpx::naming;
using namespace hpx::components;
using namespace hpx::collectives;
struct OpaqueInputData
{
OpaqueInputData() = default;
OpaqueInputData(std::string wfn_name,
std::vector<void *> params,
std::vector<size_t> param_sizes,
std::vector<size_t> output_sizes,
bool alloc_p = false) :
wfn_name(wfn_name),
params(std::move(params)),
param_sizes(std::move(param_sizes)),
output_sizes(std::move(output_sizes)),
alloc_p(alloc_p)
{}
OpaqueInputData(const OpaqueInputData &oid) :
wfn_name(std::move(oid.wfn_name)),
params(std::move(oid.params)),
param_sizes(std::move(oid.param_sizes)),
output_sizes(std::move(oid.output_sizes)),
alloc_p(oid.alloc_p)
{}
friend class hpx::serialization::access;
template <class Archive>
void load(Archive &ar, const unsigned int version)
{
ar & wfn_name;
ar & param_sizes;
ar & output_sizes;
for (auto p : param_sizes)
{
char *param = new char[p];
// TODO: Optimise these serialisation operations
for (size_t i = 0; i < p; ++i)
ar & param[i];
params.push_back((void *)param);
}
alloc_p = true;
}
template <class Archive>
void save(Archive &ar, const unsigned int version) const
{
ar & wfn_name;
ar & param_sizes;
ar & output_sizes;
for (size_t p = 0; p < params.size(); ++p)
for (size_t i = 0; i < param_sizes[p]; ++i)
ar & static_cast<char *>(params[p])[i];
}
HPX_SERIALIZATION_SPLIT_MEMBER()
std::string wfn_name;
std::vector<void *> params;
std::vector<size_t> param_sizes;
std::vector<size_t> output_sizes;
bool alloc_p = false;
};
struct OpaqueOutputData
{
OpaqueOutputData() = default;
OpaqueOutputData(std::vector<void *> outputs,
std::vector<size_t> output_sizes,
bool alloc_p = false) :
outputs(std::move(outputs)),
output_sizes(std::move(output_sizes)),
alloc_p(alloc_p)
{}
OpaqueOutputData(const OpaqueOutputData &ood) :
outputs(std::move(ood.outputs)),
output_sizes(std::move(ood.output_sizes)),
alloc_p(ood.alloc_p)
{}
friend class hpx::serialization::access;
template <class Archive>
void load(Archive &ar, const unsigned int version)
{
ar & output_sizes;
for (auto p : output_sizes)
{
char *output = new char[p];
for (size_t i = 0; i < p; ++i)
ar & output[i];
outputs.push_back((void *)output);
}
alloc_p = true;
}
template <class Archive>
void save(Archive &ar, const unsigned int version) const
{
ar & output_sizes;
for (size_t p = 0; p < outputs.size(); ++p)
{
for (size_t i = 0; i < output_sizes[p]; ++i)
ar & static_cast<char *>(outputs[p])[i];
// TODO: investigate if HPX is automatically deallocating
//these. Here it could be safely assumed that these would no
//longer be live.
//delete ((char*)outputs[p]);
}
}
HPX_SERIALIZATION_SPLIT_MEMBER()
std::vector<void *> outputs;
std::vector<size_t> output_sizes;
bool alloc_p = false;
};
struct GenericComputeServer : component_base<GenericComputeServer>
{
GenericComputeServer () = default;
// Component actions exposed
OpaqueOutputData execute_task (const OpaqueInputData &inputs)
{
auto wfn = node_level_work_function_registry->getWorkFunctionPointer(inputs.wfn_name);
std::vector<void *> outputs;
switch (inputs.output_sizes.size()) {
case 1:
{
void *output = (void *)(new char[inputs.output_sizes[0]]);
switch (inputs.params.size()) {
case 0:
wfn(output);
break;
case 1:
wfn(inputs.params[0], output);
break;
case 2:
wfn(inputs.params[0], inputs.params[1], output);
break;
case 3:
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output};
break;
}
case 2:
{
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
switch (inputs.params.size()) {
case 0:
wfn(output1, output2);
break;
case 1:
wfn(inputs.params[0], output1, output2);
break;
case 2:
wfn(inputs.params[0], inputs.params[1], output1, output2);
break;
case 3:
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, output2);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output1, output2};
break;
}
case 3:
{
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
void *output3 = (void *)(new char[inputs.output_sizes[2]]);
switch (inputs.params.size()) {
case 0:
wfn(output1, output2, output3);
break;
case 1:
wfn(inputs.params[0], output1, output2, output3);
break;
case 2:
wfn(inputs.params[0], inputs.params[1], output1, output2, output3);
break;
case 3:
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, output2, output3);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output1, output2, output3};
break;
}
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task outputs not supported.");
}
if (inputs.alloc_p)
for (auto p : inputs.params)
delete((char*)p);
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes), inputs.alloc_p);
}
HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task);
};
HPX_REGISTER_ACTION_DECLARATION(GenericComputeServer::execute_task_action,
GenericComputeServer_execute_task_action)
HPX_REGISTER_COMPONENT_MODULE()
HPX_REGISTER_COMPONENT(hpx::components::component<GenericComputeServer>,
GenericComputeServer)
HPX_REGISTER_ACTION(GenericComputeServer::execute_task_action,
GenericComputeServer_execute_task_action)
struct GenericComputeClient : client_base<GenericComputeClient, GenericComputeServer>
{
typedef client_base<GenericComputeClient, GenericComputeServer> base_type;
GenericComputeClient() = default;
GenericComputeClient(id_type id) : base_type(std::move(id)) {}
hpx::future<OpaqueOutputData>
execute_task(const OpaqueInputData &inputs)
{
typedef GenericComputeServer::execute_task_action action_type;
return hpx::async<action_type>(this->get_id(), inputs);
}
};
#endif

View File

@@ -0,0 +1,153 @@
#ifndef ZAMALANG_DFR_KEY_MANAGER_HPP
#define ZAMALANG_DFR_KEY_MANAGER_HPP
#include <utility>
#include <memory>
#include <hpx/include/runtime.hpp>
#include <hpx/modules/collectives.hpp>
#include "zamalang/Runtime/DFRuntime.hpp"
struct PbsKeyManager;
extern PbsKeyManager *node_level_key_manager;
struct PbsKeyWrapper
{
std::shared_ptr<void *> key;
size_t key_id;
size_t size;
PbsKeyWrapper() {}
PbsKeyWrapper(void *key, size_t key_id, size_t size) :
key(std::make_shared<void *>(key)), key_id(key_id), size(size) {}
PbsKeyWrapper(std::shared_ptr<void *> key, size_t key_id, size_t size) :
key(key), key_id(key_id), size(size) {}
PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept :
key(moved.key), key_id(moved.key_id), size(moved.size) {}
PbsKeyWrapper(const PbsKeyWrapper &pbsk) :
key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {}
friend class hpx::serialization::access;
template <class Archive>
void save(Archive &ar, const unsigned int version) const
{
char *_key_ = static_cast<char *>(*key);
ar & key_id & size;
for (size_t i = 0; i < size; ++i)
ar & _key_[i];
}
template <class Archive>
void load(Archive &ar, const unsigned int version)
{
ar & key_id & size;
char *_key_ = (char *) malloc(size);
for (size_t i = 0; i < size; ++i)
ar & _key_[i];
key = std::make_shared<void *>(_key_);
}
HPX_SERIALIZATION_SPLIT_MEMBER()
};
inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs)
{
return lhs.key_id == rhs.key_id;
}
PbsKeyWrapper _dfr_fetch_key(size_t);
HPX_PLAIN_ACTION(_dfr_fetch_key, _dfr_fetch_key_action)
struct PbsKeyManager
{
// The initial keys registered on the root node and whether to push
// them is TBD.
PbsKeyManager()
{
node_level_key_manager = this;
}
PbsKeyWrapper get_key(const size_t key_id)
{
keystore_guard.lock();
auto keyit = keystore.find(key_id);
keystore_guard.unlock();
if (keyit == keystore.end())
{
_dfr_fetch_key_action fet;
PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id);
if (pkw.size == 0)
{
// Maybe retry or try other nodes... but for now it's an error.
HPX_THROW_EXCEPTION(hpx::no_success,
"_dfr_get_key",
"Error: key not found on remote node.");
}
else
{
std::lock_guard<std::mutex> guard(keystore_guard);
keyit = keystore.insert(std::pair<size_t, PbsKeyWrapper>(key_id, pkw)).first;
}
}
return keyit->second;
}
// To be used only for remote requests
PbsKeyWrapper fetch_key(const size_t key_id)
{
std::lock_guard<std::mutex> guard(keystore_guard);
auto keyit = keystore.find(key_id);
if (keyit != keystore.end())
return keyit->second;
// If this node does not contain this key, return an empty wrapper
return PbsKeyWrapper(nullptr, 0, 0);
}
void register_key(void *key, size_t key_id, size_t size)
{
std::lock_guard<std::mutex> guard(keystore_guard);
auto keyit =
keystore.insert(
std::pair<size_t, PbsKeyWrapper>(key_id,
PbsKeyWrapper(key, key_id, size))).first;
if (keyit == keystore.end())
{
HPX_THROW_EXCEPTION(hpx::no_success,
"_dfr_register_key",
"Error: could not register new key.");
}
}
void broadcast_keys()
{
std::lock_guard<std::mutex> guard(keystore_guard);
if (_dfr_is_root_node())
hpx::collectives::broadcast_to("keystore", this->keystore).get();
else
keystore = std::move(
hpx::collectives::broadcast_from<std::map<size_t, PbsKeyWrapper>>
("keystore").get());
}
private:
std::mutex keystore_guard;
std::map<size_t, PbsKeyWrapper> keystore;
};
PbsKeyWrapper
_dfr_fetch_key(size_t key_id)
{
return node_level_key_manager->fetch_key(key_id);
}
#endif

View File

@@ -0,0 +1,34 @@
/**
Define the API exposed to the compiler for code generation.
*/
#ifndef ZAMALANG_DFR_RUNTIME_API_H
#define ZAMALANG_DFR_RUNTIME_API_H
#include <cstddef>
#include <cstdlib>
extern "C" {
typedef void (*wfnptr)(...);
void *_dfr_make_ready_future(void *);
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
void *_dfr_await_future(void *);
/* Keys can have node-local copies which can be retrieved. This
should only be called on the node where the key is required. */
void _dfr_register_key(void *, size_t, size_t);
void _dfr_broadcast_keys();
void *_dfr_get_key(size_t);
/* Memory management:
_dfr_make_ready_future allocates the future, not the underlying storage.
_dfr_create_async_task allocates both future and storage for outputs. */
void _dfr_deallocate_future(void *);
void _dfr_deallocate_future_data(void *);
/* Initialisation & termination. */
void _dfr_start();
void _dfr_stop();
}
#endif

View File

@@ -6,4 +6,12 @@ add_library(ZamalangRuntime SHARED
target_link_libraries(ZamalangRuntime Concrete pthread m dl)
install(TARGETS ZamalangRuntime EXPORT ZamalangRuntime)
install(EXPORT ZamalangRuntime DESTINATION "./")
install(EXPORT ZamalangRuntime DESTINATION "./")
if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
add_library(DFRuntime SHARED DFRuntime.cpp)
target_link_libraries(DFRuntime PUBLIC pthread m dl HPX::hpx HPX::iostreams_component -rdynamic)
install(TARGETS DFRuntime EXPORT DFRuntime)
install(EXPORT DFRuntime DESTINATION "./")
endif()

View File

@@ -0,0 +1,253 @@
/**
This file implements the dataflow runtime. It encapsulates all of
the underlying communication, parallelism, etc. and only exposes a
simplified interface for code generation in runtime_api.h
This hides the details of implementation, including of the HPX
framework currently used, from the code generation side.
*/
#include <hpx/future.hpp>
#include <hpx/hpx_start.hpp>
#include <hpx/hpx_suspend.hpp>
#include "zamalang/Runtime/DFRuntime.hpp"
#include "zamalang/Runtime/distributed_generic_task_server.hpp"
#include "zamalang/Runtime/runtime_api.h"
std::vector<GenericComputeClient> gcc;
void *dl_handle;
PbsKeyManager *node_level_key_manager;
WorkFunctionRegistry *node_level_work_function_registry;
using namespace hpx;
void *_dfr_make_ready_future(void *in) {
return static_cast<void *>(
new hpx::shared_future<void *>(hpx::make_ready_future(in)));
}
void *_dfr_await_future(void *in) {
return static_cast<hpx::shared_future<void *> *>(in)->get();
}
void _dfr_deallocate_future_data(void *in) {
delete (static_cast<hpx::shared_future<void *> *>(in)->get());
}
void _dfr_deallocate_future(void *in) {
delete (static_cast<hpx::shared_future<void *> *>(in));
}
// Runtime generic async_task. Each first NUM_PARAMS pairs of
// arguments in the variadic list corresponds to a void* pointer on a
// hpx::future<void*> and the size of data within the future. After
// that come NUM_OUTPUTS pairs of hpx::future<void*>* and size_t for
// the returns.
void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
...) {
std::vector<void *> params;
std::vector<void *> outputs;
std::vector<size_t> param_sizes;
std::vector<size_t> output_sizes;
va_list args;
va_start(args, num_outputs);
for (size_t i = 0; i < num_params; ++i) {
params.push_back(va_arg(args, void *));
param_sizes.push_back(va_arg(args, size_t));
}
for (size_t i = 0; i < num_outputs; ++i) {
outputs.push_back(va_arg(args, void *));
output_sizes.push_back(va_arg(args, size_t));
}
va_end(args);
// We pass functions by name - which is not strictly necessary in
// shared memory as pointers suffice, but is needed in the
// distributed case where the functions need to be located/loaded on
// the node.
auto wfnname =
node_level_work_function_registry->getWorkFunctionName((void *)wfn);
hpx::future<hpx::future<OpaqueOutputData>> oodf;
// In order to allow complete dataflow semantics for
// communication/synchronization, we split tasks in two parts: an
// execution body that is scheduled once all input dependences are
// satisfied, which generates a future on a tuple of outputs, which
// is then further split into a tuple of futures and provide
// individual synchronization for each return independently.
switch (num_params) {
case 0:
oodf = std::move(
hpx::dataflow([wfnname, param_sizes,
output_sizes]() -> hpx::future<OpaqueOutputData> {
std::vector<void *> params = {};
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
}));
break;
case 1:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0)
-> hpx::future<OpaqueOutputData> {
std::vector<void *> params = {param0.get()};
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0]));
break;
case 2:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1)
-> hpx::future<OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get()};
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1]));
break;
case 3:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2)
-> hpx::future<OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get()};
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2]));
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_create_async_task",
"Error: number of task parameters not supported.");
}
switch (num_outputs) {
case 1:
*((void **)outputs[0]) = new hpx::shared_future<void *>(hpx::dataflow(
[](hpx::future<OpaqueOutputData> oodf_in) -> void * {
return oodf_in.get().outputs[0];
},
oodf));
break;
case 2: {
hpx::future<hpx::tuple<void *, void *>> &&ft = hpx::dataflow(
[](hpx::future<OpaqueOutputData> oodf_in)
-> hpx::tuple<void *, void *> {
std::vector<void *> outputs = std::move(oodf_in.get().outputs);
return hpx::make_tuple<>(outputs[0], outputs[1]);
},
oodf);
hpx::tuple<hpx::future<void *>, hpx::future<void *>> &&tf =
hpx::split_future(std::move(ft));
*((void **)outputs[0]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<0>(tf)));
*((void **)outputs[1]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<1>(tf)));
break;
}
case 3: {
hpx::future<hpx::tuple<void *, void *, void *>> &&ft = hpx::dataflow(
[](hpx::future<OpaqueOutputData> oodf_in)
-> hpx::tuple<void *, void *, void *> {
std::vector<void *> outputs = std::move(oodf_in.get().outputs);
return hpx::make_tuple<>(outputs[0], outputs[1], outputs[2]);
},
oodf);
hpx::tuple<hpx::future<void *>, hpx::future<void *>, hpx::future<void *>>
&&tf = hpx::split_future(std::move(ft));
*((void **)outputs[0]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<0>(tf)));
*((void **)outputs[1]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<1>(tf)));
*((void **)outputs[2]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<2>(tf)));
break;
}
default:
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_create_async_task",
"Error: number of task outputs not supported.");
}
}
/* Distributed key management. */
void _dfr_register_key(void *key, size_t key_id, size_t size) {
node_level_key_manager->register_key(key, key_id, size);
}
void _dfr_broadcast_keys() { node_level_key_manager->broadcast_keys(); }
void *_dfr_get_key(size_t key_id) {
return *node_level_key_manager->get_key(key_id).key.get();
}
/* Runtime initialization and finalization. */
static inline void _dfr_stop_impl() {
hpx::apply([]() { hpx::finalize(); });
hpx::stop();
}
static inline void _dfr_start_impl(int argc, char *argv[]) {
dl_handle = dlopen(nullptr, RTLD_NOW);
if (argc == 0) {
char *_argv[1] = {const_cast<char *>("__dummy_dfr_HPX_program_name__")};
int _argc = 1;
hpx::start(nullptr, _argc, _argv);
} else {
hpx::start(nullptr, argc, argv);
}
new PbsKeyManager();
new WorkFunctionRegistry();
if (!_dfr_is_root_node()) {
_dfr_stop_impl();
exit(EXIT_SUCCESS);
}
// Create compute server components on each node and the
// corresponding compute client.
auto num_nodes = hpx::get_num_localities().get();
gcc = hpx::new_<GenericComputeClient[]>(
hpx::default_layout(hpx::find_all_localities()), num_nodes)
.get();
}
// TODO: we need a better way to wrap main. For now loader --wrap and
// main's constructor/destructor are not functional, but that should
// replace the current, inefficient calls to _dfr_start/stop generated
// in each compiled function.
void _dfr_start() { _dfr_start_impl(0, nullptr); }
void _dfr_stop() { _dfr_stop_impl(); }
/* Debug interface. */
size_t _dfr_debug_get_node_id() { return hpx::get_locality_id(); }
size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); }
void _dfr_debug_print_task(const char *name, int inputs, int outputs) {
// clang-format off
hpx::cout << "Task \"" << name << "\""
<< " [" << inputs << " inputs, " << outputs << " outputs]"
<< " Executing on Node/Worker: " << _dfr_debug_get_node_id()
<< " / " << _dfr_debug_get_worker_id() << "\n" << std::flush;
// clang-format on
}
// Generic utility function for printing debug info
void _dfr_print_debug(size_t val) {
hpx::cout << "_dfr_print_debug : " << val << "\n" << std::flush;
}

View File

@@ -5,21 +5,47 @@ llvm_update_compile_flags(zamacompiler)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(zamacompiler
PRIVATE
${dialect_libs}
${conversion_libs}
MLIRTransforms
LowLFHEDialect
MidLFHEDialect
HLFHEDialect
if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
target_link_libraries(zamacompiler
PRIVATE
${dialect_libs}
${conversion_libs}
MLIRIR
MLIRLLVMIR
MLIRLLVMToLLVMIRTranslation
MLIRTransforms
LowLFHEDialect
MidLFHEDialect
HLFHEDialect
MLIRIR
MLIRLLVMIR
MLIRLLVMToLLVMIRTranslation
ZamalangSupport
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
-Wl,-rpath,${HPX_DIR}/../../
-Wl,--no-as-needed
DFRuntime
)
else()
target_link_libraries(zamacompiler
PRIVATE
${dialect_libs}
${conversion_libs}
MLIRTransforms
LowLFHEDialect
MidLFHEDialect
HLFHEDialect
MLIRIR
MLIRLLVMIR
MLIRLLVMToLLVMIRTranslation
ZamalangSupport
)
endif()
ZamalangSupport
)
mlir_check_all_link_libraries(zamacompiler)

View File

@@ -1,3 +1,3 @@
if (ZAMALANG_UNIT_TESTS)
add_subdirectory(unittest)
endif()
endif()

View File

@@ -79,3 +79,23 @@ gtest_discover_tests(end_to_end_jit_encrypted_tensor)
gtest_discover_tests(end_to_end_jit_hlfhelinalg)
gtest_discover_tests(end_to_end_jit_lambda)
if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
add_executable(
end_to_end_jit_dfr
end_to_end_jit_dfr.cc
)
set_source_files_properties(
end_to_end_jit_dfr.cc
PROPERTIES COMPILE_FLAGS "-fno-rtti"
)
target_link_libraries(
end_to_end_jit_dfr
gtest_main
ZamalangSupport
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
-Wl,-rpath,${HPX_DIR}/../../
-Wl,--no-as-needed
DFRuntime
)
gtest_discover_tests(end_to_end_jit_dfr)
endif()

View File

@@ -0,0 +1,348 @@
#include <cstdint>
#include <gtest/gtest.h>
#include <type_traits>
#include "end_to_end_jit_test.h"
const mlir::zamalang::V0FHEConstraint defaultV0Constraints{10, 7};
TEST(CompileAndRunDFR, start_stop) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func private @_dfr_stop()
func private @_dfr_start()
func @main() -> i64{
call @_dfr_start() : () -> ()
%1 = arith.constant 7 : i64
call @_dfr_stop() : () -> ()
return %1 : i64
}
)XXX", "main", true);
ASSERT_EXPECTED_VALUE(lambda(), 7);
}
TEST(CompileAndRunDFR, 0in1out_task) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
llvm.func @_dfr_await_future(!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>> attributes {sym_visibility = "private"}
llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"}
llvm.func @_dfr_stop()
llvm.func @_dfr_start()
func @main() -> i64 {
%0 = llvm.mlir.addressof @_dfr_DFT_work_function__main0 : !llvm.ptr<func<void (ptr<i64>)>>
%1 = llvm.mlir.constant(0 : i64) : i64
%2 = llvm.mlir.constant(1 : i64) : i64
%3 = llvm.mlir.constant(8 : i64) : i64
llvm.call @_dfr_start() : () -> ()
%4 = llvm.mlir.constant(1 : i64) : i64
%5 = llvm.alloca %4 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%0, %1, %2, %5, %3) : (!llvm.ptr<func<void (ptr<i64>)>>, i64, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%6 = llvm.load %5 : !llvm.ptr<ptr<i64>>
%7 = llvm.call @_dfr_await_future(%6) : (!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>>
%8 = llvm.bitcast %7 : !llvm.ptr<ptr<i64>> to !llvm.ptr<i64>
%9 = llvm.load %8 : !llvm.ptr<i64>
llvm.call @_dfr_stop() : () -> ()
return %9 : i64
}
llvm.func @_dfr_DFT_work_function__main0(%arg0: !llvm.ptr<i64>) {
%0 = llvm.mlir.constant(4 : i64) : i64
%1 = llvm.mlir.constant(3 : i64) : i64
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg0 : !llvm.ptr<i64>
llvm.return
}
)XXX", "main", true);
ASSERT_EXPECTED_VALUE(lambda(), 7);
}
TEST(CompileAndRunDFR, 1in1out_task) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
llvm.func @_dfr_await_future(!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>> attributes {sym_visibility = "private"}
llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"}
llvm.func @malloc(i64) -> !llvm.ptr<i8>
llvm.func @_dfr_make_ready_future(...) -> !llvm.ptr<i64> attributes {sym_visibility = "private"}
llvm.func @_dfr_stop()
llvm.func @_dfr_start()
func @main(%arg0: i64) -> i64 {
%0 = llvm.mlir.addressof @_dfr_DFT_work_function__main0 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>)>>
%1 = llvm.mlir.constant(1 : i64) : i64
%2 = llvm.mlir.constant(8 : i64) : i64
llvm.call @_dfr_start() : () -> ()
%3 = llvm.mlir.null : !llvm.ptr<i64>
%4 = llvm.mlir.constant(1 : index) : i64
%5 = llvm.getelementptr %3[%4] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%6 = llvm.ptrtoint %5 : !llvm.ptr<i64> to i64
%7 = llvm.call @malloc(%6) : (i64) -> !llvm.ptr<i8>
%8 = llvm.bitcast %7 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %arg0, %8 : !llvm.ptr<i64>
%9 = llvm.call @_dfr_make_ready_future(%8) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%10 = llvm.mlir.constant(1 : i64) : i64
%11 = llvm.alloca %10 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%0, %1, %1, %9, %2, %11, %2) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%12 = llvm.load %11 : !llvm.ptr<ptr<i64>>
%13 = llvm.call @_dfr_await_future(%12) : (!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>>
%14 = llvm.bitcast %13 : !llvm.ptr<ptr<i64>> to !llvm.ptr<i64>
%15 = llvm.load %14 : !llvm.ptr<i64>
llvm.call @_dfr_stop() : () -> ()
return %15 : i64
}
llvm.func @_dfr_DFT_work_function__main0(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>) {
%0 = llvm.mlir.constant(2 : i64) : i64
%1 = llvm.load %arg0 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %1, %0 : i64
llvm.store %2, %arg1 : !llvm.ptr<i64>
llvm.return
}
)XXX", "main", true);
ASSERT_EXPECTED_VALUE(lambda(5_u64), 7);
}
TEST(CompileAndRunDFR, 2in1out_task) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
llvm.func @_dfr_await_future(!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>> attributes {sym_visibility = "private"}
llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"}
llvm.func @malloc(i64) -> !llvm.ptr<i8>
llvm.func @_dfr_make_ready_future(...) -> !llvm.ptr<i64> attributes {sym_visibility = "private"}
llvm.func @_dfr_stop()
llvm.func @_dfr_start()
func @main(%arg0: i64, %arg1: i64) -> i64 {
%0 = llvm.mlir.addressof @_dfr_DFT_work_function__main0 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
%1 = llvm.mlir.constant(2 : i64) : i64
%2 = llvm.mlir.constant(1 : i64) : i64
%3 = llvm.mlir.constant(8 : i64) : i64
llvm.call @_dfr_start() : () -> ()
%4 = llvm.mlir.null : !llvm.ptr<i64>
%5 = llvm.mlir.constant(1 : index) : i64
%6 = llvm.getelementptr %4[%5] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%7 = llvm.ptrtoint %6 : !llvm.ptr<i64> to i64
%8 = llvm.call @malloc(%7) : (i64) -> !llvm.ptr<i8>
%9 = llvm.bitcast %8 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %arg0, %9 : !llvm.ptr<i64>
%10 = llvm.call @_dfr_make_ready_future(%9) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%11 = llvm.mlir.null : !llvm.ptr<i64>
%12 = llvm.mlir.constant(1 : index) : i64
%13 = llvm.getelementptr %11[%12] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%14 = llvm.ptrtoint %13 : !llvm.ptr<i64> to i64
%15 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr<i8>
%16 = llvm.bitcast %15 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %arg1, %16 : !llvm.ptr<i64>
%17 = llvm.call @_dfr_make_ready_future(%16) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%18 = llvm.mlir.constant(1 : i64) : i64
%19 = llvm.alloca %18 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%0, %1, %2, %10, %3, %17, %3, %19, %3) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%20 = llvm.load %19 : !llvm.ptr<ptr<i64>>
%21 = llvm.call @_dfr_await_future(%20) : (!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>>
%22 = llvm.bitcast %21 : !llvm.ptr<ptr<i64>> to !llvm.ptr<i64>
%23 = llvm.load %22 : !llvm.ptr<i64>
llvm.call @_dfr_stop() : () -> ()
return %23 : i64
}
llvm.func @_dfr_DFT_work_function__main0(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
)XXX", "main", true);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 6_u64), 7);
}
TEST(CompileAndRunDFR, taskgraph) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
llvm.func @_dfr_await_future(!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>> attributes {sym_visibility = "private"}
llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"}
llvm.func @malloc(i64) -> !llvm.ptr<i8>
llvm.func @_dfr_make_ready_future(...) -> !llvm.ptr<i64> attributes {sym_visibility = "private"}
llvm.func @_dfr_stop()
llvm.func @_dfr_start()
func @main(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
%0 = llvm.mlir.constant(7 : i64) : i64
%1 = llvm.mlir.addressof @_dfr_DFT_work_function__main0 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
%2 = llvm.mlir.constant(2 : i64) : i64
%3 = llvm.mlir.constant(1 : i64) : i64
%4 = llvm.mlir.constant(8 : i64) : i64
%5 = llvm.mlir.addressof @_dfr_DFT_work_function__main1 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
%6 = llvm.mlir.addressof @_dfr_DFT_work_function__main2 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
%7 = llvm.mlir.addressof @_dfr_DFT_work_function__main3 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
%8 = llvm.mlir.addressof @_dfr_DFT_work_function__main4 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
%9 = llvm.mlir.addressof @_dfr_DFT_work_function__main5 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
%10 = llvm.mlir.addressof @_dfr_DFT_work_function__main6 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
%11 = llvm.mlir.addressof @_dfr_DFT_work_function__main7 : !llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>
llvm.call @_dfr_start() : () -> ()
%12 = llvm.mlir.null : !llvm.ptr<i64>
%13 = llvm.mlir.constant(1 : index) : i64
%14 = llvm.getelementptr %12[%13] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%15 = llvm.ptrtoint %14 : !llvm.ptr<i64> to i64
%16 = llvm.call @malloc(%15) : (i64) -> !llvm.ptr<i8>
%17 = llvm.bitcast %16 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %arg0, %17 : !llvm.ptr<i64>
%18 = llvm.call @_dfr_make_ready_future(%17) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%19 = llvm.mlir.null : !llvm.ptr<i64>
%20 = llvm.mlir.constant(1 : index) : i64
%21 = llvm.getelementptr %19[%20] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%22 = llvm.ptrtoint %21 : !llvm.ptr<i64> to i64
%23 = llvm.call @malloc(%22) : (i64) -> !llvm.ptr<i8>
%24 = llvm.bitcast %23 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %arg1, %24 : !llvm.ptr<i64>
%25 = llvm.call @_dfr_make_ready_future(%24) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%26 = llvm.mlir.constant(1 : i64) : i64
%27 = llvm.alloca %26 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%1, %2, %3, %18, %4, %25, %4, %27, %4) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%28 = llvm.load %27 : !llvm.ptr<ptr<i64>>
%29 = llvm.mlir.null : !llvm.ptr<i64>
%30 = llvm.mlir.constant(1 : index) : i64
%31 = llvm.getelementptr %29[%30] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%32 = llvm.ptrtoint %31 : !llvm.ptr<i64> to i64
%33 = llvm.call @malloc(%32) : (i64) -> !llvm.ptr<i8>
%34 = llvm.bitcast %33 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %arg2, %34 : !llvm.ptr<i64>
%35 = llvm.call @_dfr_make_ready_future(%34) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%36 = llvm.mlir.constant(1 : i64) : i64
%37 = llvm.alloca %36 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%5, %2, %3, %18, %4, %35, %4, %37, %4) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%38 = llvm.load %37 : !llvm.ptr<ptr<i64>>
%39 = llvm.mlir.constant(1 : i64) : i64
%40 = llvm.alloca %39 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%6, %2, %3, %25, %4, %35, %4, %40, %4) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%41 = llvm.load %40 : !llvm.ptr<ptr<i64>>
%42 = llvm.mul %arg0, %0 : i64
%43 = llvm.mul %arg1, %0 : i64
%44 = llvm.mul %arg2, %0 : i64
%45 = llvm.mlir.null : !llvm.ptr<i64>
%46 = llvm.mlir.constant(1 : index) : i64
%47 = llvm.getelementptr %45[%46] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%48 = llvm.ptrtoint %47 : !llvm.ptr<i64> to i64
%49 = llvm.call @malloc(%48) : (i64) -> !llvm.ptr<i8>
%50 = llvm.bitcast %49 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %42, %50 : !llvm.ptr<i64>
%51 = llvm.call @_dfr_make_ready_future(%50) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%52 = llvm.mlir.constant(1 : i64) : i64
%53 = llvm.alloca %52 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%7, %2, %3, %28, %4, %51, %4, %53, %4) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%54 = llvm.load %53 : !llvm.ptr<ptr<i64>>
%55 = llvm.mlir.null : !llvm.ptr<i64>
%56 = llvm.mlir.constant(1 : index) : i64
%57 = llvm.getelementptr %55[%56] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%58 = llvm.ptrtoint %57 : !llvm.ptr<i64> to i64
%59 = llvm.call @malloc(%58) : (i64) -> !llvm.ptr<i8>
%60 = llvm.bitcast %59 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %43, %60 : !llvm.ptr<i64>
%61 = llvm.call @_dfr_make_ready_future(%60) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%62 = llvm.mlir.constant(1 : i64) : i64
%63 = llvm.alloca %62 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%8, %2, %3, %38, %4, %61, %4, %63, %4) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%64 = llvm.load %63 : !llvm.ptr<ptr<i64>>
%65 = llvm.mlir.null : !llvm.ptr<i64>
%66 = llvm.mlir.constant(1 : index) : i64
%67 = llvm.getelementptr %65[%66] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
%68 = llvm.ptrtoint %67 : !llvm.ptr<i64> to i64
%69 = llvm.call @malloc(%68) : (i64) -> !llvm.ptr<i8>
%70 = llvm.bitcast %69 : !llvm.ptr<i8> to !llvm.ptr<i64>
llvm.store %44, %70 : !llvm.ptr<i64>
%71 = llvm.call @_dfr_make_ready_future(%70) : (!llvm.ptr<i64>) -> !llvm.ptr<i64>
%72 = llvm.mlir.constant(1 : i64) : i64
%73 = llvm.alloca %72 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%9, %2, %3, %41, %4, %71, %4, %73, %4) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%74 = llvm.load %73 : !llvm.ptr<ptr<i64>>
%75 = llvm.mlir.constant(1 : i64) : i64
%76 = llvm.alloca %75 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%10, %2, %3, %54, %4, %64, %4, %76, %4) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%77 = llvm.load %76 : !llvm.ptr<ptr<i64>>
%78 = llvm.mlir.constant(1 : i64) : i64
%79 = llvm.alloca %78 x !llvm.ptr<i64> : (i64) -> !llvm.ptr<ptr<i64>>
llvm.call @_dfr_create_async_task(%11, %2, %3, %77, %4, %74, %4, %79, %4) : (!llvm.ptr<func<void (ptr<i64>, ptr<i64>, ptr<i64>)>>, i64, i64, !llvm.ptr<i64>, i64, !llvm.ptr<i64>, i64, !llvm.ptr<ptr<i64>>, i64) -> ()
%80 = llvm.load %79 : !llvm.ptr<ptr<i64>>
%81 = llvm.call @_dfr_await_future(%80) : (!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>>
%82 = llvm.bitcast %81 : !llvm.ptr<ptr<i64>> to !llvm.ptr<i64>
%83 = llvm.load %82 : !llvm.ptr<i64>
llvm.call @_dfr_stop() : () -> ()
return %83 : i64
}
llvm.func @_dfr_DFT_work_function__main0(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) attributes {_dfr_work_function_attribute} {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
llvm.func @_dfr_DFT_work_function__main1(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) attributes {_dfr_work_function_attribute} {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
llvm.func @_dfr_DFT_work_function__main2(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) attributes {_dfr_work_function_attribute} {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
llvm.func @_dfr_DFT_work_function__main3(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) attributes {_dfr_work_function_attribute} {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
llvm.func @_dfr_DFT_work_function__main4(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) attributes {_dfr_work_function_attribute} {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
llvm.func @_dfr_DFT_work_function__main5(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) attributes {_dfr_work_function_attribute} {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
llvm.func @_dfr_DFT_work_function__main6(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) attributes {_dfr_work_function_attribute} {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
llvm.func @_dfr_DFT_work_function__main7(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: !llvm.ptr<i64>) attributes {_dfr_work_function_attribute} {
%0 = llvm.load %arg0 : !llvm.ptr<i64>
%1 = llvm.load %arg1 : !llvm.ptr<i64>
llvm.br ^bb1
^bb1: // pred: ^bb0
%2 = llvm.add %0, %1 : i64
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
)XXX", "main", true);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64, 3_u64), 54);
ASSERT_EXPECTED_VALUE(lambda(2_u64, 5_u64, 1_u64), 72);
ASSERT_EXPECTED_VALUE(lambda(3_u64, 1_u64, 7_u64), 99);
}