mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(testing-dfr): fix startup and ternimation call parameters to allow sequential test execution on root node when the runtime is instantiated for distributed execution.
This commit is contained in:
@@ -26,9 +26,9 @@ void _dfr_deallocate_future(void *);
|
||||
void _dfr_deallocate_future_data(void *);
|
||||
|
||||
/* Initialisation & termination. */
|
||||
void _dfr_start_c(void *);
|
||||
void _dfr_start(int);
|
||||
void _dfr_stop(int);
|
||||
void _dfr_start_c(int64_t, void *);
|
||||
void _dfr_start(int64_t);
|
||||
void _dfr_stop(int64_t);
|
||||
|
||||
void _dfr_terminate();
|
||||
}
|
||||
|
||||
@@ -468,14 +468,21 @@ struct LowerDataflowTasksPass
|
||||
(dfr::_dfr_is_root_node())
|
||||
? mlir::FunctionType::get(
|
||||
entryPoint->getContext(),
|
||||
{entryPoint.getArgument(ctxIndex).getType()}, {})
|
||||
: mlir::FunctionType::get(entryPoint->getContext(), {}, {});
|
||||
{useDFRVal.getType(),
|
||||
entryPoint.getArgument(ctxIndex).getType()},
|
||||
{})
|
||||
: mlir::FunctionType::get(entryPoint->getContext(),
|
||||
{useDFRVal.getType()}, {});
|
||||
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_start_c",
|
||||
startFunTy);
|
||||
builder.create<mlir::func::CallOp>(
|
||||
entryPoint.getLoc(), "_dfr_start_c", mlir::TypeRange(),
|
||||
(dfr::_dfr_is_root_node()) ? entryPoint.getArgument(ctxIndex)
|
||||
: mlir::ValueRange());
|
||||
(dfr::_dfr_is_root_node())
|
||||
? builder.create<mlir::func::CallOp>(
|
||||
entryPoint.getLoc(), "_dfr_start_c", mlir::TypeRange(),
|
||||
mlir::ValueRange(
|
||||
{useDFRVal, entryPoint.getArgument(ctxIndex)}))
|
||||
: builder.create<mlir::func::CallOp>(entryPoint.getLoc(),
|
||||
"_dfr_start_c",
|
||||
mlir::TypeRange(), useDFRVal);
|
||||
} else {
|
||||
auto startFunTy = mlir::FunctionType::get(entryPoint->getContext(),
|
||||
{useDFRVal.getType()}, {});
|
||||
|
||||
@@ -1131,7 +1131,7 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
/* Start/stop functions to be called from within user code (or during
|
||||
JIT invocation). These serve to pause/resume the runtime
|
||||
scheduler and to clean up used resources. */
|
||||
void _dfr_start(int use_dfr_p) {
|
||||
void _dfr_start(int64_t use_dfr_p) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::whole_timer);
|
||||
if (use_dfr_p) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::init_timer);
|
||||
@@ -1167,30 +1167,32 @@ void _dfr_start(int use_dfr_p) {
|
||||
}
|
||||
|
||||
// Startup entry point when a RuntimeContext is used
|
||||
void _dfr_start_c(void *ctx) {
|
||||
void _dfr_start_c(int64_t use_dfr_p, void *ctx) {
|
||||
_dfr_start(2);
|
||||
|
||||
if (mlir::concretelang::dfr::num_nodes > 1) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::broadcast_timer);
|
||||
new mlir::concretelang::dfr::RuntimeContextManager();
|
||||
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
|
||||
->setContext(ctx);
|
||||
if (use_dfr_p) {
|
||||
if (mlir::concretelang::dfr::num_nodes > 1) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::broadcast_timer);
|
||||
new mlir::concretelang::dfr::RuntimeContextManager();
|
||||
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
|
||||
->setContext(ctx);
|
||||
|
||||
// If this is not JIT, then the remote nodes never reach _dfr_stop,
|
||||
// so root should not instantiate this barrier.
|
||||
if (mlir::concretelang::dfr::_dfr_is_root_node() &&
|
||||
mlir::concretelang::dfr::_dfr_is_jit())
|
||||
mlir::concretelang::dfr::_dfr_startup_barrier->wait();
|
||||
END_TIME(&mlir::concretelang::dfr::broadcast_timer, "Key broadcasting");
|
||||
// If this is not JIT, then the remote nodes never reach _dfr_stop,
|
||||
// so root should not instantiate this barrier.
|
||||
if (mlir::concretelang::dfr::_dfr_is_root_node() &&
|
||||
mlir::concretelang::dfr::_dfr_is_jit())
|
||||
mlir::concretelang::dfr::_dfr_startup_barrier->wait();
|
||||
END_TIME(&mlir::concretelang::dfr::broadcast_timer, "Key broadcasting");
|
||||
}
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
|
||||
}
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
|
||||
}
|
||||
|
||||
// This function cannot be used to terminate the runtime as it is
|
||||
// non-decidable if another computation phase will follow. Instead the
|
||||
// _dfr_terminate function provides this facility and is normally
|
||||
// called on exit from "main" when not using the main wrapper library.
|
||||
void _dfr_stop(int use_dfr_p) {
|
||||
void _dfr_stop(int64_t use_dfr_p) {
|
||||
if (use_dfr_p) {
|
||||
if (mlir::concretelang::dfr::num_nodes > 1) {
|
||||
// Non-root nodes synchronize here with the root to mark the point
|
||||
@@ -1310,11 +1312,11 @@ bool _dfr_use_omp() { return use_omp_p; }
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
void _dfr_start(int use_dfr_p) {
|
||||
void _dfr_start(int64_t use_dfr_p) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
|
||||
}
|
||||
void _dfr_start_c(void *ctx) { _dfr_start(2); }
|
||||
void _dfr_stop(int use_dfr_p) {
|
||||
void _dfr_start_c(int64_t use_dfr_p, void *ctx) { _dfr_start(2); }
|
||||
void _dfr_stop(int64_t use_dfr_p) {
|
||||
END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user