fix(server): Fix type checking and propagation issues (#7941)

- fix type propagation by `AppService.run_and_wait(..)`
- fix type propagation by `@expose` and add note
- fix type propagation by `wait(..)` in `.executor.manager.execute_node(..)`
- fix type propagation by `wait(..)` in `.executor.manager._enqueue_next_nodes(..)`
- remove unnecessary null checks for `.data.graph.get_node(..)`
- fix type issue in `ExecutionScheduler`
- reduce use of `# type: ignore` in `.data.execution`
- reduce usage of `# type: ignore` in `.executor.manager`
- reduce usage of `# type: ignore` in `.server`
- reduce usage of `# type: ignore` in cli.py
- update `pyright` to v1.1.378
This commit is contained in:
Reinier van der Leer
2024-09-02 16:13:56 +02:00
committed by GitHub
parent 956165adf3
commit 71de1a6a5e
8 changed files with 76 additions and 61 deletions

View File

@@ -42,7 +42,7 @@ def write_pid(pid: int):
class MainApp(AppProcess):
def run(self):
app.main(silent=True) # type: ignore
app.main(silent=True)
@click.group()
@@ -66,12 +66,12 @@ def start():
os.remove(get_pid_path())
print("Starting server")
pid = MainApp().start(background=True, silent=True) # type: ignore
pid = MainApp().start(background=True, silent=True)
print(f"Server running in process: {pid}")
write_pid(pid)
print("done")
os._exit(status=0) # type: ignore
os._exit(status=0)
@main.command()

View File

@@ -9,7 +9,11 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentGraphExecutionWhereInput
from prisma.types import (
AgentGraphExecutionInclude,
AgentGraphExecutionWhereInput,
AgentNodeExecutionInclude,
)
from pydantic import BaseModel
from autogpt_server.data.block import BlockData, BlockInput, CompletedBlockOutput
@@ -108,13 +112,24 @@ class ExecutionResult(BaseModel):
# --------------------- Model functions --------------------- #
EXECUTION_RESULT_INCLUDE = {
EXECUTION_RESULT_INCLUDE: AgentNodeExecutionInclude = {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
GRAPH_EXECUTION_INCLUDE: AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"include": {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
}
}
async def create_graph_execution(
graph_id: str,
@@ -148,9 +163,7 @@ async def create_graph_execution(
},
"userId": user_id,
},
include={
"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore
},
include=GRAPH_EXECUTION_INCLUDE,
)
return result.id, [
@@ -263,7 +276,7 @@ async def update_execution_status(
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=data, # type: ignore
include=EXECUTION_RESULT_INCLUDE, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
)
if not res:
raise ValueError(f"Execution {node_exec_id} not found.")
@@ -282,7 +295,7 @@ async def list_executions(graph_id: str, graph_version: int | None = None) -> li
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={"agentGraphExecutionId": graph_exec_id},
include=EXECUTION_RESULT_INCLUDE, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
order=[
{"queuedTime": "asc"},
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
@@ -372,14 +385,14 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult | None:
execution = await AgentNodeExecution.prisma().find_first(
where={ # type: ignore
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE},
"executionData": {"not": None},
"executionData": {"not": None}, # type: ignore
},
order={"queuedTime": "desc"},
include=EXECUTION_RESULT_INCLUDE, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:
return None
@@ -390,11 +403,11 @@ async def get_incomplete_executions(
node_id: str, graph_eid: str
) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={ # type: ignore
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
)
return [ExecutionResult.from_db(execution) for execution in executions]

View File

@@ -276,12 +276,12 @@ AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
# --------------------- Model functions --------------------- #
async def get_node(node_id: str) -> Node | None:
async def get_node(node_id: str) -> Node:
node = await AgentNode.prisma().find_unique_or_raise(
where={"id": node_id},
include=AGENT_NODE_INCLUDE,
)
return Node.from_db(node) if node else None
return Node.from_db(node)
async def get_graphs_meta(

View File

@@ -61,7 +61,7 @@ def execute_node(
asyncio.set_event_loop(loop)
def wait(f: Coroutine[T, Any, T]) -> T:
def wait(f: Coroutine[Any, Any, T]) -> T:
return loop.run_until_complete(f)
def update_execution(status: ExecutionStatus):
@@ -69,11 +69,8 @@ def execute_node(
api_client.send_execution_update(exec_update.model_dump())
node = wait(get_node(node_id))
if not node:
logger.error(f"Node {node_id} not found.")
return
node_block = get_block(node.block_id) # type: ignore
node_block = get_block(node.block_id)
if not node_block:
logger.error(f"Block {node.block_id} not found.")
return
@@ -133,7 +130,7 @@ def _enqueue_next_nodes(
graph_exec_id: str,
prefix: str,
) -> list[NodeExecution]:
def wait(f: Coroutine[T, Any, T]) -> T:
def wait(f: Coroutine[Any, Any, T]) -> T:
return loop.run_until_complete(f)
def add_enqueued_execution(
@@ -161,9 +158,6 @@ def _enqueue_next_nodes(
return enqueued_executions
next_node = wait(get_node(next_node_id))
if not next_node:
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
return enqueued_executions
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
@@ -264,7 +258,7 @@ def validate_exec(
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id) # type: ignore
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
@@ -291,7 +285,7 @@ def validate_exec(
data[name] = convert(value, data_type)
# Last validation: Validate the input values against the schema.
if error := node_block.input_schema.validate_data(data): # type: ignore
if error := node_block.input_schema.validate_data(data):
error_message = f"Input data doesn't match {node_block.name}: {error}"
logger.error(error_message)
return None, error_message

View File

@@ -18,7 +18,6 @@ def log(msg, **kwargs):
class ExecutionScheduler(AppService):
def __init__(self, refresh_interval=10):
self.last_check = datetime.min
self.refresh_interval = refresh_interval
@@ -38,7 +37,8 @@ class ExecutionScheduler(AppService):
def __refresh_jobs_from_db(self, scheduler: BackgroundScheduler):
schedules = self.run_and_wait(model.get_active_schedules(self.last_check))
for schedule in schedules:
self.last_check = max(self.last_check, schedule.last_updated)
if schedule.last_updated:
self.last_check = max(self.last_check, schedule.last_updated)
if not schedule.is_enabled:
log(f"Removing recurring job {schedule.id}: {schedule.schedule}")

View File

@@ -6,7 +6,7 @@ from typing import Annotated, Any, Dict
import uvicorn
from autogpt_libs.auth.middleware import auth_middleware
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
@@ -86,12 +86,12 @@ class AgentServer(AppService):
router.add_api_route(
path="/blocks",
endpoint=self.get_graph_blocks, # type: ignore
endpoint=self.get_graph_blocks,
methods=["GET"],
)
router.add_api_route(
path="/blocks/{block_id}/execute",
endpoint=self.execute_graph_block, # type: ignore
endpoint=self.execute_graph_block,
methods=["POST"],
)
router.add_api_route(
@@ -156,7 +156,7 @@ class AgentServer(AppService):
)
router.add_api_route(
path="/graphs/{graph_id}/execute",
endpoint=self.execute_graph, # type: ignore
endpoint=self.execute_graph,
methods=["POST"],
)
router.add_api_route(
@@ -171,7 +171,7 @@ class AgentServer(AppService):
)
router.add_api_route(
path="/graphs/{graph_id}/schedules",
endpoint=self.create_schedule, # type: ignore
endpoint=self.create_schedule,
methods=["POST"],
)
router.add_api_route(
@@ -181,7 +181,7 @@ class AgentServer(AppService):
)
router.add_api_route(
path="/graphs/schedules/{schedule_id}",
endpoint=self.update_schedule, # type: ignore
endpoint=self.update_schedule,
methods=["PUT"],
)
@@ -191,7 +191,7 @@ class AgentServer(AppService):
methods=["POST"],
)
app.add_exception_handler(500, self.handle_internal_error) # type: ignore
app.add_exception_handler(500, self.handle_internal_http_error)
app.include_router(router)
@@ -235,11 +235,11 @@ class AgentServer(AppService):
return get_service_client(ExecutionScheduler)
@classmethod
def handle_internal_error(cls, request, exc): # type: ignore
def handle_internal_http_error(cls, request: Request, exc: Exception):
return JSONResponse(
content={
"message": f"{request.url.path} call failure", # type: ignore
"error": str(exc), # type: ignore
"message": f"{request.method} {request.url.path} failed",
"error": str(exc),
},
status_code=500,
)
@@ -251,13 +251,13 @@ class AgentServer(AppService):
@classmethod
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
return [v.to_dict() for v in block.get_blocks().values()] # type: ignore
return [v.to_dict() for v in block.get_blocks().values()]
@classmethod
def execute_graph_block(
cls, block_id: str, data: BlockInput
) -> CompletedBlockOutput:
obj = block.get_block(block_id) # type: ignore
obj = block.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
@@ -478,14 +478,14 @@ class AgentServer(AppService):
) -> dict[Any, Any]:
execution_scheduler = self.execution_scheduler_client
is_enabled = input_data.get("is_enabled", False)
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id) # type: ignore
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id)
return {"id": schedule_id}
def get_execution_schedules(
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, str]:
execution_scheduler = self.execution_scheduler_client
return execution_scheduler.get_execution_schedules(graph_id, user_id) # type: ignore
return execution_scheduler.get_execution_schedules(graph_id, user_id)
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
@@ -521,12 +521,12 @@ class AgentServer(AppService):
try:
updated_fields: dict[Any, Any] = {"config": [], "secrets": []}
for key, value in updated_settings.get("config", {}).items():
if hasattr(settings.config, key): # type: ignore
setattr(settings.config, key, value) # type: ignore
if hasattr(settings.config, key):
setattr(settings.config, key, value)
updated_fields["config"].append(key)
for key, value in updated_settings.get("secrets", {}).items():
if hasattr(settings.secrets, key): # type: ignore
setattr(settings.secrets, key, value) # type: ignore
if hasattr(settings.secrets, key):
setattr(settings.secrets, key, value)
updated_fields["secrets"].append(key)
settings.save()
return {

View File

@@ -19,11 +19,21 @@ conn_retry = retry(
stop=stop_after_attempt(30), wait=wait_exponential(multiplier=1, min=1, max=30)
)
T = TypeVar("T")
C = TypeVar("C", bound=Callable)
pyro_host = Config().pyro_host
def expose(func: Callable) -> Callable:
def expose(func: C) -> C:
"""
Decorator to mark a method or class to be exposed for remote calls.
## ⚠️ Gotcha
The types on the exposed function signature are respected **as long as they are
fully picklable**. This is not the case for Pydantic models, so if you really need
to pass a model, try dumping the model and passing the resulting dict instead.
"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
@@ -32,7 +42,7 @@ def expose(func: Callable) -> Callable:
logger.exception(msg)
raise Exception(msg, e)
return pyro.expose(wrapper)
return pyro.expose(wrapper) # type: ignore
class PyroNameServer(AppProcess):
@@ -60,10 +70,10 @@ class AppService(AppProcess):
while True:
time.sleep(10)
def __run_async(self, coro: Coroutine[T, Any, T]):
def __run_async(self, coro: Coroutine[Any, Any, T]):
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
def run_and_wait(self, coro: Coroutine[T, Any, T]) -> T:
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
future = self.__run_async(coro)
return future.result()
@@ -107,7 +117,6 @@ def get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient:
@conn_retry
def __init__(self):
ns = pyro.locate_ns()

View File

@@ -25,7 +25,7 @@ requests = "*"
sentry-sdk = "^1.40.4"
[package.extras]
benchmark = ["agbenchmark @ file:///Users/aarushi/autogpt/AutoGPT/benchmark"]
benchmark = ["agbenchmark"]
[package.source]
type = "directory"
@@ -386,7 +386,7 @@ watchdog = "4.0.0"
webdriver-manager = "^4.0.1"
[package.extras]
benchmark = ["agbenchmark @ file:///Users/aarushi/autogpt/AutoGPT/benchmark"]
benchmark = ["agbenchmark"]
[package.source]
type = "directory"
@@ -2527,9 +2527,6 @@ files = [
{file = "lief-0.14.1-cp312-cp312-manylinux_2_28_x86_64.manylinux_2_27_x86_64.whl", hash = "sha256:497b88f9c9aaae999766ba188744ee35c5f38b4b64016f7dbb7037e9bf325382"},
{file = "lief-0.14.1-cp312-cp312-win32.whl", hash = "sha256:08bad88083f696915f8dcda4042a3bfc514e17462924ec8984085838b2261921"},
{file = "lief-0.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:e131d6158a085f8a72124136816fefc29405c725cd3695ce22a904e471f0f815"},
{file = "lief-0.14.1-cp313-cp313-manylinux_2_28_x86_64.manylinux_2_27_x86_64.whl", hash = "sha256:f9ff9a6959fb6d0e553cca41cd1027b609d27c5073e98d9fad8b774fbb5746c2"},
{file = "lief-0.14.1-cp313-cp313-win32.whl", hash = "sha256:95f295a7cc68f4e14ce7ea4ff8082a04f5313c2e5e63cc2bbe9d059190b7e4d5"},
{file = "lief-0.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:cdc1123c2e27970f8c8353505fd578e634ab33193c8d1dff36dc159e25599a40"},
{file = "lief-0.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:df650fa05ca131e4dfeb42c77985e1eb239730af9944bc0aadb1dfac8576e0e8"},
{file = "lief-0.14.1-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:b4e76eeb48ca2925c6ca6034d408582615f2faa855f9bb11482e7acbdecc4803"},
{file = "lief-0.14.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:016e4fac91303466024154dd3c4b599e8b7c52882f72038b62a2be386d98c8f9"},
@@ -3586,6 +3583,8 @@ files = [
{file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"},
{file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"},
{file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"},
{file = "orjson-3.10.6-cp313-none-win32.whl", hash = "sha256:efdf2c5cde290ae6b83095f03119bdc00303d7a03b42b16c54517baa3c4ca3d0"},
{file = "orjson-3.10.6-cp313-none-win_amd64.whl", hash = "sha256:8e190fe7888e2e4392f52cafb9626113ba135ef53aacc65cd13109eb9746c43e"},
{file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"},
{file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"},
{file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"},
@@ -4462,13 +4461,13 @@ dev = ["pytest (>=8.1.1)"]
[[package]]
name = "pyright"
version = "1.1.373"
version = "1.1.378"
description = "Command line wrapper for pyright"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyright-1.1.373-py3-none-any.whl", hash = "sha256:b805413227f2c209f27b14b55da27fe5e9fb84129c9f1eb27708a5d12f6f000e"},
{file = "pyright-1.1.373.tar.gz", hash = "sha256:f41bcfc8b9d1802b09921a394d6ae1ce19694957b628bc657629688daf8a83ff"},
{file = "pyright-1.1.378-py3-none-any.whl", hash = "sha256:8853776138b01bc284da07ac481235be7cc89d3176b073d2dba73636cb95be79"},
{file = "pyright-1.1.378.tar.gz", hash = "sha256:78a043be2876d12d0af101d667e92c7734f3ebb9db71dccc2c220e7e7eb89ca2"},
]
[package.dependencies]