fix(rnd): Disable unused prisma connection on pyro API Server process (#7641)

### Background

Pyro for API Server is not using Prisma, but still holding a Prisma connection.
The fast-API thread is also holding a Prisma connection, making Prisma connected in two different loop within a single process.

### Changes 🏗️

Disable a Prisma connection on Pyro thread for Server API process.
Fix test flakiness issue due to concurrency issue.
This commit is contained in:
Zamil Majdy
2024-07-30 19:33:28 +04:00
committed by GitHub
parent 29ba4c2c73
commit 122f544966
4 changed files with 41 additions and 36 deletions

View File

@@ -160,10 +160,12 @@ def _enqueue_next_nodes(
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
return enqueued_executions
# Upserting execution input includes reading the existing input pins in the node
# which then either updating the existing execution input or creating a new one.
# While reading, we should avoid any other process to add input to the same node.
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = wait(
upsert_execution_input(
node_id=next_node_id,
@@ -173,40 +175,41 @@ def _enqueue_next_nodes(
)
)
# Complete missing static input pins data using the last execution input.
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := wait(get_latest_execution(next_node_id, graph_exec_id))
):
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
# Complete missing static input pins data using the last execution input.
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := wait(
get_latest_execution(next_node_id, graph_exec_id)
)
):
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = (
f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
if not next_node_input:
logger.warning(f"{prefix} Skipped queueing {suffix}")
return enqueued_executions
# Incomplete input data, skip queueing the execution.
if not next_node_input:
logger.warning(f"{prefix} Skipped queueing {suffix}")
return enqueued_executions
# Input is complete, enqueue the execution.
logger.warning(f"{prefix} Enqueued {suffix}")
enqueued_executions.append(
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
)
# Input is complete, enqueue the execution.
logger.warning(f"{prefix} Enqueued {suffix}")
enqueued_executions.append(
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
)
if not node_link.is_static:
return enqueued_executions
# Next execution stops here if the link is not static.
if not node_link.is_static:
return enqueued_executions
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
# While reading, we should avoid any other process to re-enqueue the same node.
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in wait(get_incomplete_executions(next_node_id, graph_exec_id)):
idata = iexec.input_data
ineid = iexec.node_exec_id

View File

@@ -44,6 +44,7 @@ class AgentServer(AppService):
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
manager = ConnectionManager()
mutex = KeyedMutex()
use_db = False
async def event_broadcaster(self):
while True:
@@ -53,8 +54,8 @@ class AgentServer(AppService):
@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
self.run_and_wait(block.initialize_blocks())
self.run_and_wait(graph_db.import_packaged_templates())
await block.initialize_blocks()
await graph_db.import_packaged_templates()
asyncio.create_task(self.event_broadcaster())
yield
await db.disconnect()

View File

@@ -40,6 +40,7 @@ class PyroNameServer(AppProcess):
class AppService(AppProcess):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = True
@classmethod
@property
@@ -60,7 +61,8 @@ class AppService(AppProcess):
def run(self):
self.shared_event_loop = asyncio.get_event_loop()
self.shared_event_loop.run_until_complete(db.connect())
if self.use_db:
self.shared_event_loop.run_until_complete(db.connect())
# Initialize the async loop.
async_thread = threading.Thread(target=self.__start_async_loop)

View File

@@ -1,7 +1,6 @@
import pytest
from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock
from autogpt_server.blocks.if_block import ComparisonOperator, ConditionBlock
from autogpt_server.blocks.maths import MathsBlock, Operation
from autogpt_server.data import execution, graph
from autogpt_server.executor import ExecutionManager